shithub: opus

Download patch

ref: 981d06eefda5ecfaf44bce05d75cccd01a3a1e24
parent: a4f7c157cf879ad768f97c3fa31520ede267bea7
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Fri Sep 9 22:18:30 EDT 2022

Refactoring towards multiple offset decoding

--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -282,27 +282,22 @@
     nb_bits = decoder.nb_bits
     bunch = decoder.bunch
     bits_input = Input(shape=(None, nb_bits), name="split_bits")
-    uqbits_input = Input(shape=(None, nb_bits), name="split_uqbits")
     quant_embed_input = Input(shape=(None, 6*nb_bits), name="split_embed")
     gru_state_input = Input(shape=(None,nb_state_dim), name="split_state")
 
-    range_select = Lambda(lambda x: x[0][:,x[1]+bunch//2-1:x[2]:bunch//2,:])
+    range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
     elem_select = Lambda(lambda x: x[0][:,x[1],:])
     points = [0, 64, 128, 192, 256]
     outputs = []
-    uqbits = []
     for i in range(len(points)-1):
-        begin = points[i]//2
-        end = points[i+1]//2
+        begin = points[i]//bunch
+        end = points[i+1]//bunch
         state = elem_select([gru_state_input, end-1])
         bits = range_select([bits_input, begin, end])
-        uq = range_select([uqbits_input, begin, end])
-        uqbits.append(uq)
         embed = range_select([quant_embed_input, begin, end])
         outputs.append(decoder([bits, embed, state]))
     output = Concatenate(axis=1)(outputs)
-    uqbits = Concatenate(axis=1)(uqbits)
-    split = Model([bits_input, uqbits_input, quant_embed_input, gru_state_input], [output, uqbits], name="split")
+    split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
     return split
 
 
@@ -327,14 +322,20 @@
     hardquant = Lambda(hard_quantize)
     dzone = Lambda(apply_dead_zone)
     dze = dzone([ze,dead_zone])
+    
+    mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
     gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
-    combined_output, uqbits = split_decoder([hardquant(dze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec])
     ndze = noisequant(dze)
-    unquantized_output, uqbits = split_decoder([ndze, dze, quant_embed_dec, gru_state_dec])
-    unquantized_output_dec, uqbits = split_decoder([tf.stop_gradient(ndze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec])
+    for i in [1]:
+        dze_select = mod_select([dze, i])
+        ndze_select = mod_select([ndze, i])
+        state_select = mod_select([gru_state_dec, i])
+        combined_output = split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select])
+        unquantized_output = split_decoder([ndze_select, quant_embed_dec, state_select])
+        unquantized_output_dec = split_decoder([tf.stop_gradient(ndze_select), tf.stop_gradient(quant_embed_dec), state_select])
 
-    e2 = Concatenate(name="hard_bits")([uqbits, hard_distr_embed, lambda_bunched])
-    e = Concatenate(name="soft_bits")([uqbits, soft_distr_embed, lambda_bunched])
+    e2 = Concatenate(name="hard_bits")([dze_select, hard_distr_embed, lambda_bunched])
+    e = Concatenate(name="soft_bits")([dze_select, soft_distr_embed, lambda_bunched])
 
 
     model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")
--