ref: 79d1a916d0715f7bcd188819abb53631842eeb2d
parent: 0b018637325bca70b5f3a727dd1b1e83f87c2829
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sun Oct 2 08:34:42 EDT 2022
Weighting loss by 1/sqrt(lambda)
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -94,14 +94,16 @@
return log2_e*tf.math.log(eps+x)
def feat_dist_loss(y_true,y_pred):
+ lambda_1 = 1./K.sqrt(y_pred[:,:,:,-1])
+ y_pred = y_pred[:,:,:,:-1]
ceps = y_pred[:,:,:,:18] - y_true[:,:,:18]
pitch = 2*(y_pred[:,:,:,18:19] - y_true[:,:,18:19])/(y_true[:,:,18:19] + 2)
corr = y_pred[:,:,:,19:] - y_true[:,:,19:]
pitch_weight = K.square(K.maximum(0., y_true[:,:,19:]+.5))
- return K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr))
+ return K.mean(lambda_1*K.mean(K.square(ceps) + 10*(1/18.)*K.abs(pitch)*pitch_weight + (1/18.)*K.square(corr), axis=-1))
def sq1_rate_loss(y_true,y_pred):
- lambda_val = y_pred[:,:,-1]
+ lambda_val = K.sqrt(y_pred[:,:,-1])
y_pred = y_pred[:,:,:-1]
log2_e = 1.4427
n = y_pred.shape[-1]//3
@@ -120,7 +122,7 @@
return K.mean(rate)
def sq2_rate_loss(y_true,y_pred):
- lambda_val = y_pred[:,:,-1]
+ lambda_val = K.sqrt(y_pred[:,:,-1])
y_pred = y_pred[:,:,:-1]
log2_e = 1.4427
n = y_pred.shape[-1]//3
@@ -136,7 +138,6 @@
return K.mean(rate)
def sq_rate_metric(y_true,y_pred, reduce=True):
- lambda_val = y_pred[:,:,-1]
y_pred = y_pred[:,:,:-1]
log2_e = 1.4427
n = y_pred.shape[-1]//3
@@ -311,6 +312,7 @@
quant_id = Input(shape=(None,), batch_size=batch_size)
lambda_val = Input(shape=(None, 1), batch_size=batch_size)
lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
+ lambda_up = Lambda(lambda x: K.repeat_elements(x, 2, axis=-2))(lambda_val)
qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros')
quant_embed_dec = qembedding(quant_id)
@@ -341,12 +343,19 @@
gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
combined_output = []
unquantized_output = []
+ cat = Concatenate(name="out_cat")
for i in range(bunch//2):
dze_select = mod_select([dze_unquant, i])
ndze_select = mod_select([ndze_unquant, i])
state_select = mod_select([gru_state_dec, i])
- combined_output.append(split_decoder([hardquant(dze_select), state_select]))
- unquantized_output.append(split_decoder([ndze_select, state_select]))
+
+ tmp = split_decoder([hardquant(dze_select), state_select])
+ tmp = cat([tmp, lambda_up])
+ combined_output.append(tmp)
+
+ tmp = split_decoder([ndze_select, state_select])
+ tmp = cat([tmp, lambda_up])
+ unquantized_output.append(tmp)
concat = Lambda(tensor_concat, name="output")
combined_output = concat(combined_output)
--
⑨