ref: aa474553b53254edc017e4921d661aa2a6876995
parent: a8673d0e253c9946b86c27ac95070f929501c775
author: jbuethe <jbuethe@amazon.de>
date: Fri Jan 13 06:48:04 EST 2023
updated torch framework to include quantization
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -85,7 +85,7 @@
enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message)
dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message)
- stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats"), message=message)
+ stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message)
constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True)
# some custom includes
@@ -99,7 +99,7 @@
#include "nnet.h"
"""
)
-
+
# encoder
encoder_dense_layers = [
('core_encoder.module.dense_1' , 'enc_dense1', 'TANH'), @@ -122,7 +122,8 @@
('core_encoder.module.gru_3' , 'enc_dense6', 'TANH')]
- enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in encoder_gru_layers])
+ enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, activation, verbose=True, input_sparse=True, dotp=True)
+ for name, export_name, activation in encoder_gru_layers])
encoder_conv_layers = [
@@ -158,7 +159,8 @@
('core_decoder.module.gru_3' , 'dec_dense6', 'TANH')]
- dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True) for name, export_name, activation in decoder_gru_layers])
+ dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, activation, verbose=True, input_sparse=True, dotp=True)
+ for name, export_name, activation in decoder_gru_layers])
del dec_writer
binary files /dev/null b/dnn/torch/rdovae/libs/wexchange-1.2-py3-none-any.whl differ
--- a/dnn/torch/rdovae/requirements.txt
+++ b/dnn/torch/rdovae/requirements.txt
@@ -2,4 +2,4 @@
scipy
torch
tqdm
-libs/wexchange-1.0-py3-none-any.whl
\ No newline at end of file
+libs/wexchange-1.2-py3-none-any.whl
\ No newline at end of file
--
⑨