shithub: opus

Download patch

ref: 79d1a916d0715f7bcd188819abb53631842eeb2d
parent: 0b018637325bca70b5f3a727dd1b1e83f87c2829
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sun Oct 2 08:34:42 EDT 2022

Weighting loss by 1/sqrt(lambda)

--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -94,14 +94,16 @@
     return log2_e*tf.math.log(eps+x)
 
 def feat_dist_loss(y_true,y_pred):
+    lambda_1 = 1./K.sqrt(y_pred[:,:,:,-1])
+    y_pred = y_pred[:,:,:,:-1]
     ceps = y_pred[:,:,:,:18] - y_true[:,:,:18]
     pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
     corr = y_pred[:,:,:,19:] - y_true[:,:,19:]
     pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5))
-    return K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr))
+    return K.mean(lambda_1*K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr), axis=-1))
 
 def sq1_rate_loss(y_true,y_pred):
-    lambda_val = y_pred[:,:,-1]
+    lambda_val = K.sqrt(y_pred[:,:,-1])
     y_pred = y_pred[:,:,:-1]
     log2_e = 1.4427
     n = y_pred.shape[-1]//3
@@ -120,7 +122,7 @@
     return K.mean(rate)
 
 def sq2_rate_loss(y_true,y_pred):
-    lambda_val = y_pred[:,:,-1]
+    lambda_val = K.sqrt(y_pred[:,:,-1])
     y_pred = y_pred[:,:,:-1]
     log2_e = 1.4427
     n = y_pred.shape[-1]//3
@@ -136,7 +138,6 @@
     return K.mean(rate)
 
 def sq_rate_metric(y_true,y_pred, reduce=True):
-    lambda_val = y_pred[:,:,-1]
     y_pred = y_pred[:,:,:-1]
     log2_e = 1.4427
     n = y_pred.shape[-1]//3
@@ -311,6 +312,7 @@
     quant_id = Input(shape=(None,), batch_size=batch_size)
     lambda_val = Input(shape=(None, 1), batch_size=batch_size)
     lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
+    lambda_up = Lambda(lambda x: K.repeat_elements(x, 2, axis=-2))(lambda_val)
 
     qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros')
     quant_embed_dec = qembedding(quant_id)
@@ -341,12 +343,19 @@
     gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
     combined_output = []
     unquantized_output = []
+    cat = Concatenate(name="out_cat")
     for i in range(bunch//2):
         dze_select = mod_select([dze_unquant, i])
         ndze_select = mod_select([ndze_unquant, i])
         state_select = mod_select([gru_state_dec, i])
-        combined_output.append(split_decoder([hardquant(dze_select), state_select]))
-        unquantized_output.append(split_decoder([ndze_select, state_select]))
+
+        tmp = split_decoder([hardquant(dze_select), state_select])
+        tmp = cat([tmp, lambda_up])
+        combined_output.append(tmp)
+
+        tmp = split_decoder([ndze_select, state_select])
+        tmp = cat([tmp, lambda_up])        
+        unquantized_output.append(tmp)
 
     concat = Lambda(tensor_concat, name="output")
     combined_output = concat(combined_output)
--