shithub: opus

Download patch

ref: aa474553b53254edc017e4921d661aa2a6876995
parent: a8673d0e253c9946b86c27ac95070f929501c775
author: jbuethe <jbuethe@amazon.de>
date: Fri Jan 13 06:48:04 EST 2023

updated torch framework to include quantization

--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -85,7 +85,7 @@
     
     enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message)
     dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message)
-    stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats"), message=message)
+    stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message)
     constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True)
     
     # some custom includes
@@ -99,7 +99,7 @@
 #include "nnet.h"
 """
         )
-
+        
     # encoder
     encoder_dense_layers = [
         ('core_encoder.module.dense_1'       , 'enc_dense1',   'TANH'), 
@@ -122,7 +122,8 @@
         ('core_encoder.module.gru_3'         , 'enc_dense6',   'TANH')
     ]
  
-    enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_gru_layers])
+    enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True, input_sparse=True, dotp=True)
+                             for name, export_name, activation in encoder_gru_layers])
  
     
     encoder_conv_layers = [   
@@ -158,7 +159,8 @@
         ('core_decoder.module.gru_3'         , 'dec_dense6',    'TANH')
     ]
     
-    dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in decoder_gru_layers])
+    dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True, input_sparse=True, dotp=True)
+                             for name, export_name, activation in decoder_gru_layers])
         
     del dec_writer
     
binary files /dev/null b/dnn/torch/rdovae/libs/wexchange-1.2-py3-none-any.whl differ
--- a/dnn/torch/rdovae/requirements.txt
+++ b/dnn/torch/rdovae/requirements.txt
@@ -2,4 +2,4 @@
 scipy
 torch
 tqdm
-libs/wexchange-1.0-py3-none-any.whl
\ No newline at end of file
+libs/wexchange-1.2-py3-none-any.whl
\ No newline at end of file
--