ref: 405aa7cf6962d164125e3da1ed0724241da1281a
parent: 981d06eefda5ecfaf44bce05d75cccd01a3a1e24
	author: Jean-Marc Valin <jmvalin@amazon.com>
	date: Sun Sep 11 00:13:24 EDT 2022
	
WIP: training with different alignment
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -94,9 +94,9 @@
return log2_e*tf.math.log(eps+x)
def feat_dist_loss(y_true,y_pred):
- ceps = y_pred[:,:,:18] - y_true[:,:,:18]
- pitch = 2*(y_pred[:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
- corr = y_pred[:,:,19:] - y_true[:,:,19:]
+ ceps = y_pred[:,:,:,:18] - y_true[:,:,:18]
+ pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
+ corr = y_pred[:,:,:,19:] - y_true[:,:,19:]
pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5))
return K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr))
@@ -300,7 +300,19 @@
split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
return split
+def tensor_concat(x):
+ #n = x[1]//2
+ #x = x[0]
+ n=2
+ y = []
+ for i in range(n-1):
+ offset = n-1-i
+ tmp = K.concatenate([x[i][:, offset:, :], x[-1][:, -offset:, :]], axis=-2)
+ y.append(tf.expand_dims(tmp, axis=0))
+ y.append(tf.expand_dims(x[-1], axis=0))
+ return Concatenate(axis=0)(y)
+
def new_rdovae_model(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256):
feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
@@ -315,8 +327,8 @@
split_decoder = new_split_decoder(decoder)
     dead_zone = Activation('softplus')(Lambda(lambda x: x[:,:,nb_bits:2*nb_bits], name='dead_zone_embed')(quant_embed_dec))-    soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))-    hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,::2,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))+    soft_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,2*nb_bits:4*nb_bits], name='soft_distr_embed')(quant_embed_dec))+    hard_distr_embed = Activation('sigmoid')(Lambda(lambda x: x[:,:,4*nb_bits:], name='hard_distr_embed')(quant_embed_dec))noisequant = UniformNoise()
hardquant = Lambda(hard_quantize)
@@ -326,19 +338,24 @@
mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
ndze = noisequant(dze)
- for i in [1]:
+ combined_output = []
+ unquantized_output = []
+ for i in range(bunch//2):
dze_select = mod_select([dze, i])
ndze_select = mod_select([ndze, i])
state_select = mod_select([gru_state_dec, i])
- combined_output = split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select])
- unquantized_output = split_decoder([ndze_select, quant_embed_dec, state_select])
- unquantized_output_dec = split_decoder([tf.stop_gradient(ndze_select), tf.stop_gradient(quant_embed_dec), state_select])
+ combined_output.append(split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select]))
+ unquantized_output.append(split_decoder([ndze_select, quant_embed_dec, state_select]))
- e2 = Concatenate(name="hard_bits")([dze_select, hard_distr_embed, lambda_bunched])
- e = Concatenate(name="soft_bits")([dze_select, soft_distr_embed, lambda_bunched])
+ concat = Lambda(tensor_concat, name="output")
+ combined_output = concat(combined_output)
+ unquantized_output = concat(unquantized_output)
+
+ e2 = Concatenate(name="hard_bits")([dze, hard_distr_embed, lambda_val])
+ e = Concatenate(name="soft_bits")([dze, soft_distr_embed, lambda_val])
- model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")
+ model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, e, e2], name="end2end")
model.nb_used_features = nb_used_features
return model, encoder, decoder
--- a/dnn/training_tf2/train_rdovae.py
+++ b/dnn/training_tf2/train_rdovae.py
@@ -100,7 +100,7 @@
with strategy.scope():
model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
-    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 0., 1., .1], metrics={'split':'mse', 'hard_bits':rdovae.sq_rate_metric})+    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 1., .1], metrics={'hard_bits':rdovae.sq_rate_metric})model.summary()
lpc_order = 16
@@ -147,4 +147,4 @@
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
callbacks.append(tensorboard_callback)
-model.fit([features, quant_id, lambda_val], [features, features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
+model.fit([features, quant_id, lambda_val], [features, features, features, features], batch_size=batch_size, epochs=nb_epochs, validation_split=0.0, callbacks=callbacks)
--
⑨