ref: a8170986ecc6259b54d1ae30404b833a48acbc26
parent: eab9472d0d2199a5a7404c30db97a2b05e595293
author: Jan Buethe <jbuethe@amazon.de>
date: Mon Oct 31 11:21:12 EDT 2022
updated rdovae_exchange
--- a/dnn/training_tf2/rdovae_exchange.py
+++ b/dnn/training_tf2/rdovae_exchange.py
@@ -29,10 +29,8 @@
import argparse
-from ftplib import parse150
import os
import sys
-sys.path.append('/Users/jbuethe/Projects/DRED')os.environ['CUDA_VISIBLE_DEVICES'] = ""
@@ -46,13 +44,10 @@
args = parser.parse_args()
# now import the heavy stuff
-import tensorflow as tf
-import numpy as np
from rdovae import new_rdovae_model
-from exchange.tf import dump_tf_gru_weights, dump_tf_conv1d_weights, dump_tf_dense_weights, dump_tf_embedding_weights
+from wexchange.tf import dump_tf_weights, load_tf_weights
-
exchange_name = {'enc_dense1' : 'encoder_stack_layer1_dense',
'enc_dense3' : 'encoder_stack_layer3_dense',
@@ -109,21 +104,14 @@
'bits_dense'
]
- for name in encoder_dense_names:
- print(f"writing layer {exchange_name[name]}...")- dump_tf_dense_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name))
- for name in encoder_gru_names:
+ for name in encoder_dense_names + encoder_gru_names + encoder_conv1d_names:
print(f"writing layer {exchange_name[name]}...")- dump_tf_gru_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name))
+ dump_tf_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name))
- for name in encoder_conv1d_names:
- print(f"writing layer {exchange_name[name]}...")- dump_tf_conv1d_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name))
-
# qembedding
print(f"writing layer {exchange_name['qembedding']}...")- dump_tf_embedding_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding)
+ dump_tf_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding)
# decoder
decoder_dense_names = [
@@ -144,10 +132,6 @@
'dec_dense6'
]
- for name in decoder_dense_names:
+ for name in decoder_dense_names + decoder_gru_names:
print(f"writing layer {exchange_name[name]}...")- dump_tf_dense_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name))
-
- for name in decoder_gru_names:
- print(f"writing layer {exchange_name[name]}...")- dump_tf_gru_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name))
\ No newline at end of file
+ dump_tf_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name))
--
⑨