shithub: opus

Download patch

ref: 9629ea6a7015ea64f3c8d4d78cd2453bc8379795
parent: 0f7fe64d5a438db9c4cf6b6640a43c98d8565191
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Thu Oct 13 21:01:39 EDT 2022

Fine-tuning the scripts

--- a/dnn/training_tf2/fec_encoder.py
+++ b/dnn/training_tf2/fec_encoder.py
@@ -108,8 +108,8 @@
 features = features[:, :num_subframes, :]
 
 #variable quantizer depending on the delay
-q0 = 2
-q1 = 10
+q0 = 3
+q1 = 15
 quant_id = np.round(q1 + (q0-q1)*np.arange(args.num_redundancy_frames//2)/args.num_redundancy_frames).astype('int16')
 #print(quant_id)
 
@@ -154,7 +154,10 @@
 fake_lambda = np.ones((sym_batch.shape[0], sym_batch.shape[1], 1), dtype='float32')
 rate_input = np.concatenate((sym_batch, hard_distr_embed, fake_lambda), axis=-1)
 rates = sq_rate_metric(None, rate_input, reduce=False).numpy()
-print("rate = ", np.mean(rates))
+print(rates.shape)
+print("average rate = ", np.mean(rates[args.num_redundancy_frames:,:]))
+
+#sym_batch.tofile('qsyms.f32')
 
 sym_batch = sym_batch / quant_scale
 print(sym_batch.shape, quant_state.shape)
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -167,8 +167,8 @@
         abs_kx = tf.abs(kx)
         kk=tf.reduce_sum(abs_y, axis=-1)
         #print("sums = ", kk)
-        plus = 1.0001*tf.reduce_min((abs_y+.5)/(abs_kx+1e-15), axis=-1)
-        minus = .9999*tf.reduce_max((abs_y-.5)/(abs_kx+1e-15), axis=-1)
+        plus = 1.000001*tf.reduce_min((abs_y+.5)/(abs_kx+1e-15), axis=-1)
+        minus = .999999*tf.reduce_max((abs_y-.5)/(abs_kx+1e-15), axis=-1)
         #print("plus = ", plus)
         #print("minus = ", minus)
         factor = tf.where(kk>k, minus, plus)
@@ -183,7 +183,7 @@
         y = tf.round(kx)
 
     #print(y)
-    
+    #print(K.mean(K.sum(K.abs(y), axis=-1)))
     return y
 
 def pvq_quantize(x, k):
@@ -281,7 +281,7 @@
 
     range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
     elem_select = Lambda(lambda x: x[0][:,x[1],:])
-    points = [0, 64, 128, 192, 256]
+    points = [0, 100, 200, 300, 400]
     outputs = []
     for i in range(len(points)-1):
         begin = points[i]//bunch
--