ref: 0a2d6dfcb656108eaace165310245e7f74a5b360
parent: 38dda0f950e126badf83ecabd8449d27796a9040
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Sep 28 11:34:02 EDT 2022
Use the encoder state as decoder initial state Helps reduce the error on the most recent frames
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -213,7 +213,7 @@
enc_dense3 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense3')
enc_dense4 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense4')
enc_dense5 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense5')
- enc_dense6 = gru(cond_size, return_sequences=True, return_state=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
+ enc_dense6 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense6')
enc_dense7 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense7')
enc_dense8 = Dense(cond_size, activation='tanh', kernel_constraint=constraint, name='enc_dense8')
@@ -228,15 +228,16 @@
d3 = enc_dense3(d2)
d4 = enc_dense4(d3)
d5 = enc_dense5(d4)
- d6, gru_state = enc_dense6(d5)
+ d6 = enc_dense6(d5)
d7 = enc_dense7(d6)
d8 = enc_dense8(d7)
- enc_out = bits_dense(Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8]))
+ pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])
+ enc_out = bits_dense(pre_out)
#enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out)
bits = Multiply()([enc_out, quant_scale])
global_dense1 = Dense(128, activation='tanh', name='gdense1')
global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
- global_bits = global_dense2(global_dense1(d6))
+ global_bits = global_dense2(global_dense1(pre_out))
encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
return encoder
@@ -265,15 +266,18 @@
quant_scale_dec = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed_dec')(quant_embed_input))#gru_state_rep = RepeatVector(64//bunch)(gru_state_input)
- gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
+ #gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
+ gru_state1 = Dense(cond_size, name="state1", activation='tanh')(gru_state_input)
+ gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input)
+ gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
- dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input), gru_state_rep])
+ dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input)])
dec1 = dec_dense1(time_reverse(dec_inputs))
dec2 = dec_dense2(dec1)
dec3 = dec_dense3(dec2)
- dec4 = dec_dense4(dec3)
- dec5 = dec_dense5(dec4)
- dec6 = dec_dense6(dec5)
+ dec4 = dec_dense4(dec3, initial_state=gru_state1)
+ dec5 = dec_dense5(dec4, initial_state=gru_state2)
+ dec6 = dec_dense6(dec5, initial_state=gru_state3)
dec7 = dec_dense7(dec6)
dec8 = dec_dense8(dec7)
output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))
--
⑨