ref: eab9472d0d2199a5a7404c30db97a2b05e595293
parent: a223122b8946c7929081de33d8dada66646b35aa
author: Jan Buethe <jbuethe@amazon.de>
date: Mon Oct 31 08:49:20 EDT 2022
added script for exporting RDOVAE weights (external dependency not added yet)
--- /dev/null
+++ b/dnn/training_tf2/rdovae_exchange.py
@@ -1,0 +1,153 @@
+"""
+/* Copyright (c) 2022 Amazon
+ Written by Jan Buethe */
+/*
+ Redistribution and use in source and binary forms, with or without
+ modification, are permitted provided that the following conditions
+ are met:
+
+ - Redistributions of source code must retain the above copyright
+ notice, this list of conditions and the following disclaimer.
+
+ - Redistributions in binary form must reproduce the above copyright
+ notice, this list of conditions and the following disclaimer in the
+ documentation and/or other materials provided with the distribution.
+
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+ ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+ OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+ LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+import argparse
+from ftplib import parse150
+import os
+import sys
+sys.path.append('/Users/jbuethe/Projects/DRED')+
+os.environ['CUDA_VISIBLE_DEVICES'] = ""
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')+parser.add_argument('output', metavar="<output folder>", type=str, help='output exchange folder')+parser.add_argument('--cond-size', type=int, help="conditioning size (default: 256)", default=256)+parser.add_argument('--latent-dim', type=int, help="dimension of latent space (default: 80)", default=80)+
+args = parser.parse_args()
+
+# now import the heavy stuff
+import tensorflow as tf
+import numpy as np
+from rdovae import new_rdovae_model
+from exchange.tf import dump_tf_gru_weights, dump_tf_conv1d_weights, dump_tf_dense_weights, dump_tf_embedding_weights
+
+
+
+exchange_name = {+ 'enc_dense1' : 'encoder_stack_layer1_dense',
+ 'enc_dense3' : 'encoder_stack_layer3_dense',
+ 'enc_dense5' : 'encoder_stack_layer5_dense',
+ 'enc_dense7' : 'encoder_stack_layer7_dense',
+ 'enc_dense8' : 'encoder_stack_layer8_dense',
+ 'gdense1' : 'encoder_state_layer1_dense',
+ 'gdense2' : 'encoder_state_layer2_dense',
+ 'enc_dense2' : 'encoder_stack_layer2_gru',
+ 'enc_dense4' : 'encoder_stack_layer4_gru',
+ 'enc_dense6' : 'encoder_stack_layer6_gru',
+ 'bits_dense' : 'encoder_stack_layer9_conv',
+ 'qembedding' : 'statistical_model_embedding',
+ 'state1' : 'decoder_state1_dense',
+ 'state2' : 'decoder_state2_dense',
+ 'state3' : 'decoder_state3_dense',
+ 'dec_dense1' : 'decoder_stack_layer1_dense',
+ 'dec_dense3' : 'decoder_stack_layer3_dense',
+ 'dec_dense5' : 'decoder_stack_layer5_dense',
+ 'dec_dense7' : 'decoder_stack_layer7_dense',
+ 'dec_dense8' : 'decoder_stack_layer8_dense',
+ 'dec_final' : 'decoder_stack_layer9_dense',
+ 'dec_dense2' : 'decoder_stack_layer2_gru',
+ 'dec_dense4' : 'decoder_stack_layer4_gru',
+ 'dec_dense6' : 'decoder_stack_layer6_gru'
+}
+
+
+if __name__ == "__main__":
+
+ model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size)
+ model.load_weights(args.weights)
+
+ os.makedirs(args.output, exist_ok=True)
+
+ # encoder
+ encoder_dense_names = [
+ 'enc_dense1',
+ 'enc_dense3',
+ 'enc_dense5',
+ 'enc_dense7',
+ 'enc_dense8',
+ 'gdense1',
+ 'gdense2'
+ ]
+
+ encoder_gru_names = [
+ 'enc_dense2',
+ 'enc_dense4',
+ 'enc_dense6'
+ ]
+
+ encoder_conv1d_names = [
+ 'bits_dense'
+ ]
+
+ for name in encoder_dense_names:
+ print(f"writing layer {exchange_name[name]}...")+ dump_tf_dense_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name))
+
+ for name in encoder_gru_names:
+ print(f"writing layer {exchange_name[name]}...")+ dump_tf_gru_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name))
+
+ for name in encoder_conv1d_names:
+ print(f"writing layer {exchange_name[name]}...")+ dump_tf_conv1d_weights(os.path.join(args.output, exchange_name[name]), encoder.get_layer(name))
+
+ # qembedding
+ print(f"writing layer {exchange_name['qembedding']}...")+ dump_tf_embedding_weights(os.path.join(args.output, exchange_name['qembedding']), qembedding)
+
+ # decoder
+ decoder_dense_names = [
+ 'state1',
+ 'state2',
+ 'state3',
+ 'dec_dense1',
+ 'dec_dense3',
+ 'dec_dense5',
+ 'dec_dense7',
+ 'dec_dense8',
+ 'dec_final'
+ ]
+
+ decoder_gru_names = [
+ 'dec_dense2',
+ 'dec_dense4',
+ 'dec_dense6'
+ ]
+
+ for name in decoder_dense_names:
+ print(f"writing layer {exchange_name[name]}...")+ dump_tf_dense_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name))
+
+ for name in decoder_gru_names:
+ print(f"writing layer {exchange_name[name]}...")+ dump_tf_gru_weights(os.path.join(args.output, exchange_name[name]), decoder.get_layer(name))
\ No newline at end of file
--
⑨