shithub: opus

Download patch

ref: 0a2d6dfcb656108eaace165310245e7f74a5b360
parent: 38dda0f950e126badf83ecabd8449d27796a9040
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Sep 28 11:34:02 EDT 2022

Use the encoder state as decoder initial state

Helps reduce the error on the most recent frames

--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -213,7 +213,7 @@
     enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
     enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
     enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
-    enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
+    enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
     enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
     enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
 
@@ -228,15 +228,16 @@
     d3 = enc_dense3(d2)
     d4 = enc_dense4(d3)
     d5 = enc_dense5(d4)
-    d6, gru_state = enc_dense6(d5)
+    d6 = enc_dense6(d5)
     d7 = enc_dense7(d6)
     d8 = enc_dense8(d7)
-    enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]))
+    pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])
+    enc_out = bits_dense(pre_out)
     #enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out)
     bits = Multiply()([enc_out, quant_scale])
     global_dense1 = Dense(128, activation='tanh', name='gdense1')
     global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
-    global_bits = global_dense2(global_dense1(d6))
+    global_bits = global_dense2(global_dense1(pre_out))
 
     encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
     return encoder
@@ -265,15 +266,18 @@
     quant_scale_dec = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed_dec')(quant_embed_input))
     #gru_state_rep = RepeatVector(64//bunch)(gru_state_input)
 
-    gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
+    #gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
+    gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input)
+    gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input)
+    gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
 
-    dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input), gru_state_rep])
+    dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input)])
     dec1 = dec_dense1(time_reverse(dec_inputs))
     dec2 = dec_dense2(dec1)
     dec3 = dec_dense3(dec2)
-    dec4 = dec_dense4(dec3)
-    dec5 = dec_dense5(dec4)
-    dec6 = dec_dense6(dec5)
+    dec4 = dec_dense4(dec3, initial_state=gru_state1)
+    dec5 = dec_dense5(dec4, initial_state=gru_state2)
+    dec6 = dec_dense6(dec5, initial_state=gru_state3)
     dec7 = dec_dense7(dec6)
     dec8 = dec_dense8(dec7)
     output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))
--