shithub: opus

Download patch

ref: 861f6739a491b4cf58d604e11b2500230159e826
parent: ecb5cbcf30978d7a35ff3fcd19e9df05276e8228
author: jbuethe <jbuethe@amazon.de>
date: Wed Nov 9 06:41:28 EST 2022

added import script for exchange format

--- a/dnn/training_tf2/fec_encoder.py
+++ b/dnn/training_tf2/fec_encoder.py
@@ -64,6 +64,7 @@
 
     parser.add_argument('--dump-data', type=str, default='./dump_data', help='path to dump data executable (default ./dump_data)')
     parser.add_argument('--cond-size', metavar='<units>', default=1024, type=int, help='number of units in conditioning network (default 1024)')
+    parser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 40)", default=40)
     parser.add_argument('--num-redundancy-frames', default=64, type=int, help='number of redundancy frames (20ms) per packet (default 64)')
     parser.add_argument('--extra-delay', default=0, type=int, help="last features in packet are calculated with the decoder aligned samples, use this option to add extra delay (in samples at 16kHz)")
     parser.add_argument('--lossfile', type=str, help='file containing loss trace (0 for frame received, 1 for lost)')
--- /dev/null
+++ b/dnn/training_tf2/rdovae_import.py
@@ -1,0 +1,123 @@
+"""
+/* Copyright (c) 2022 Amazon
+   Written by Jan Buethe */
+/*
+   Redistribution and use in source and binary forms, with or without
+   modification, are permitted provided that the following conditions
+   are met:
+
+   - Redistributions of source code must retain the above copyright
+   notice, this list of conditions and the following disclaimer.
+
+   - Redistributions in binary form must reproduce the above copyright
+   notice, this list of conditions and the following disclaimer in the
+   documentation and/or other materials provided with the distribution.
+
+   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
+   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
+   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
+   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+*/
+"""
+
+
+import argparse
+import os
+import sys
+
+os.environ['CUDA_VISIBLE_DEVICES'] = ""
+
+parser = argparse.ArgumentParser()
+
+parser.add_argument('input', metavar="<input folder>", type=str, help='input exchange folder')
+parser.add_argument('weights', metavar="<weight file>", type=str, help='model weight file in hdf5 format')
+parser.add_argument('--cond-size', type=int, help="conditioning size (default: 256)", default=256)
+parser.add_argument('--latent-dim', type=int, help="dimension of latent space (default: 80)", default=80)
+parser.add_argument('--quant-levels', type=int, help="number of quantization steps (default: 16)", default=16)
+
+args = parser.parse_args()
+
+# now import the heavy stuff
+from rdovae import new_rdovae_model
+from wexchange.tf import load_tf_weights
+
+
+exchange_name = {
+    'enc_dense1'    : 'encoder_stack_layer1_dense',
+    'enc_dense3'    : 'encoder_stack_layer3_dense',
+    'enc_dense5'    : 'encoder_stack_layer5_dense',
+    'enc_dense7'    : 'encoder_stack_layer7_dense',
+    'enc_dense8'    : 'encoder_stack_layer8_dense',
+    'gdense1'       : 'encoder_state_layer1_dense',
+    'gdense2'       : 'encoder_state_layer2_dense',
+    'enc_dense2'    : 'encoder_stack_layer2_gru',
+    'enc_dense4'    : 'encoder_stack_layer4_gru',
+    'enc_dense6'    : 'encoder_stack_layer6_gru',
+    'bits_dense'    : 'encoder_stack_layer9_conv',
+    'qembedding'    : 'statistical_model_embedding',
+    'state1'        : 'decoder_state1_dense',
+    'state2'        : 'decoder_state2_dense',
+    'state3'        : 'decoder_state3_dense',
+    'dec_dense1'    : 'decoder_stack_layer1_dense',
+    'dec_dense3'    : 'decoder_stack_layer3_dense',
+    'dec_dense5'    : 'decoder_stack_layer5_dense',
+    'dec_dense7'    : 'decoder_stack_layer7_dense',
+    'dec_dense8'    : 'decoder_stack_layer8_dense',
+    'dec_final'     : 'decoder_stack_layer9_dense',
+    'dec_dense2'    : 'decoder_stack_layer2_gru',
+    'dec_dense4'    : 'decoder_stack_layer4_gru',
+    'dec_dense6'    : 'decoder_stack_layer6_gru'
+}
+
+if __name__ == "__main__":
+
+    model, encoder, decoder, qembedding = new_rdovae_model(20, args.latent_dim, cond_size=args.cond_size, nb_quant=args.quant_levels)
+    
+    encoder_layers = [
+        'enc_dense1',
+        'enc_dense3',
+        'enc_dense5',
+        'enc_dense7',
+        'enc_dense8',
+        'gdense1',
+        'gdense2',
+        'enc_dense2',
+        'enc_dense4',
+        'enc_dense6',
+        'bits_dense'
+    ]
+    
+    decoder_layers = [
+        'state1',
+        'state2',
+        'state3',
+        'dec_dense1',
+        'dec_dense3',
+        'dec_dense5',
+        'dec_dense7',
+        'dec_dense8',
+        'dec_final',
+        'dec_dense2',
+        'dec_dense4',
+        'dec_dense6'
+    ]
+    
+    for name in encoder_layers:
+        print(f"loading weight for layer {name}...")
+        load_tf_weights(os.path.join(args.input, exchange_name[name]), encoder.get_layer(name))
+    
+    print(f"loading weight for layer qembedding...")
+    load_tf_weights(os.path.join(args.input, exchange_name['qembedding']), qembedding)
+    
+    for name in decoder_layers:
+        print(f"loading weight for layer {name}...")
+        load_tf_weights(os.path.join(args.input, exchange_name[name]), decoder.get_layer(name))
+        
+    model.save(args.weights)
\ No newline at end of file
--