shithub: opus

Download patch

ref: 4c82d3b4196e1e691f6cac8a60a43c1fcb733b86
parent: cd0993fd8c9b835353e2c0425df7a1ab02a744bf
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Thu Sep 29 23:53:40 EDT 2022

Completely move quantization out of encoder and decoder

--- a/dnn/training_tf2/decode_rdovae.py
+++ b/dnn/training_tf2/decode_rdovae.py
@@ -57,15 +57,18 @@
 
 import tensorflow as tf
 from rdovae import pvq_quantize
+from rdovae import apply_dead_zone
 
 # Try reducing batch_size if you run out of memory on your GPU
 batch_size = args.batch_size
 
-model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
+model, encoder, decoder, qembedding = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
 model.load_weights(args.weights)
 
 lpc_order = 16
+nbits=80
 
+
 bits_file = args.bits
 sequence_size = args.seq_length
 
@@ -72,30 +75,37 @@
 # u for unquantised, load 16 bit PCM samples and convert to mu-law
 
 
-bits = np.memmap(bits_file + "-bits.s16", dtype='int16', mode='r')
+bits = np.memmap(bits_file + "-syms.f32", dtype='float32', mode='r')
 nb_sequences = len(bits)//(40*sequence_size)//batch_size*batch_size
 bits = bits[:nb_sequences*sequence_size*40]
 
 bits = np.reshape(bits, (nb_sequences, sequence_size//2, 20*4))
-bits = bits[:,1::2,:]
 print(bits.shape)
 
-quant = np.memmap(bits_file + "-quant.f32", dtype='float32', mode='r')
-state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
+lambda_val = 0.0007 * np.ones((nb_sequences, sequence_size//2, 1))
+quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16')
+quant_id = quant_id[:,:,0]
+quant_embed = qembedding(quant_id)
+quant_scale = tf.math.softplus(quant_embed[:,:,:nbits])
+dead_zone = tf.math.softplus(quant_embed[:, :, nbits : 2 * nbits])
 
-quant = np.reshape(quant, (nb_sequences, sequence_size//2, 6*20*4))
-quant = quant[:,1::2,:]
+bits = bits*quant_scale
+bits = np.round(apply_dead_zone([bits, dead_zone]).numpy())
+bits = bits/quant_scale
 
+
+state = np.memmap(bits_file + "-state.f32", dtype='float32', mode='r')
+
 state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
 state = state[:,-1,:]
-state = pvq_quantize(state, 30)
-#state = state/(1e-15+tf.norm(state, axis=-1,keepdims=True))
+#state = pvq_quantize(state, 30)
+state = state/(1e-15+tf.norm(state, axis=-1,keepdims=True))
 
 print("shapes are:")
 print(bits.shape)
-print(quant.shape)
 print(state.shape)
 
-features = decoder.predict([bits, quant, state], batch_size=batch_size)
+bits = bits[:,1::2,:]
+features = decoder.predict([bits, state], batch_size=batch_size)
 
 features.astype('float32').tofile(args.output)
--- a/dnn/training_tf2/encode_rdovae.py
+++ b/dnn/training_tf2/encode_rdovae.py
@@ -58,11 +58,12 @@
 import h5py
 
 import tensorflow as tf
+from rdovae import pvq_quantize
 
 # Try reducing batch_size if you run out of memory on your GPU
 batch_size = args.batch_size
 
-model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
+model, encoder, decoder, qembedding = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
 model.load_weights(args.weights)
 
 lpc_order = 16
@@ -84,17 +85,11 @@
 features = features[:, :, :nb_used_features]
 #features = np.random.randn(73600, 1000, 17)
 
-lambda_val = 0.001 * np.ones((nb_sequences, sequence_size//2, 1))
-quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16')
-quant_id = quant_id[:,:,0]
 
-
-bits, quant_embed_dec, gru_state_dec = encoder.predict([features, quant_id, lambda_val], batch_size=batch_size)
+bits, gru_state_dec = encoder.predict([features], batch_size=batch_size)
 (gru_state_dec).astype('float32').tofile(args.output + "-state.f32")
 
 
-#quant_out, _, _, model_bits, _ = model.predict([features, quant_id, lambda_val], batch_size=batch_size)
-
 #dist = rdovae.feat_dist_loss(features, quant_out)
 #rate = rdovae.sq1_rate_loss(features, model_bits)
 #rate2 = rdovae.sq_rate_metric(features, model_bits)
@@ -102,20 +97,29 @@
 
 print("shapes are:")
 print(bits.shape)
-print(quant_embed_dec.shape)
 print(gru_state_dec.shape)
 
 features.astype('float32').tofile(args.output + "-input.f32")
 #quant_out.astype('float32').tofile(args.output + "-enc_dec.f32")
 nbits=80
-dead_zone = tf.math.softplus(quant_embed_dec[:, :, nbits : 2 * nbits])
-symbols = apply_dead_zone([bits, dead_zone]).numpy()
-np.round(bits).astype('int16').tofile(args.output + "-bits.s16")
-quant_embed_dec.astype('float32').tofile(args.output + "-quant.f32")
+bits.astype('float32').tofile(args.output + "-syms.f32")
 
+lambda_val = 0.0007 * np.ones((nb_sequences, sequence_size//2, 1))
+quant_id = np.round(10*np.log(lambda_val/.0007)).astype('int16')
+quant_id = quant_id[:,:,0]
+quant_embed = qembedding(quant_id)
+quant_scale = tf.math.softplus(quant_embed[:,:,:nbits])
+dead_zone = tf.math.softplus(quant_embed[:, :, nbits : 2 * nbits])
+
+bits = bits*quant_scale
+bits = np.round(apply_dead_zone([bits, dead_zone]).numpy())
+bits = bits/quant_scale
+
+gru_state_dec = pvq_quantize(gru_state_dec, 30)
+#gru_state_dec = gru_state_dec/(1e-15+tf.norm(gru_state_dec, axis=-1,keepdims=True))
 gru_state_dec = gru_state_dec[:,-1,:]
-dec_out = decoder([bits[:,1::2,:], quant_embed_dec[:,1::2,:], gru_state_dec])
+dec_out = decoder([bits[:,1::2,:], gru_state_dec])
 
 print(dec_out.shape)
 
-dec_out.numpy().astype('float32').tofile(args.output + "-dec_out.f32")
+dec_out.numpy().astype('float32').tofile(args.output + "-unquant_out.f32")
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -200,13 +200,6 @@
 def new_rdovae_encoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
     feat = Input(shape=(None, nb_used_features), batch_size=batch_size)
 
-    quant_id = Input(shape=(None,), batch_size=batch_size)
-    lambda_val = Input(shape=(None, 1), batch_size=batch_size)
-    qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros')
-    quant_embed = qembedding(quant_id)
-
-    quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed))
-
     gru = CuDNNGRU if training else GRU
     enc_dense1 = Dense(cond_size2, activation='tanh', kernel_constraint=constraint, name='enc_dense1')
     enc_dense2 = gru(cond_size, return_sequences=True, kernel_constraint=constraint, recurrent_constraint=constraint, name='enc_dense2')
@@ -221,8 +214,7 @@
     bits_dense = Conv1D(nb_bits, 4, padding='causal', activation='linear', name='bits_dense')
 
     zero_out = Lambda(lambda x: 0*x)
-    inputs = Concatenate()([Reshape((-1, 2*nb_used_features))(feat), tf.stop_gradient(quant_embed), lambda_val])
-    #inputs = Concatenate()([feat, tf.stop_gradient(quant_embed), lambda_val])
+    inputs = Reshape((-1, 2*nb_used_features))(feat)
     d1 = enc_dense1(inputs)
     d2 = enc_dense2(d1)
     d3 = enc_dense3(d2)
@@ -233,18 +225,15 @@
     d8 = enc_dense8(d7)
     pre_out = Concatenate()([d1, d2, d3, d4, d5, d6, d7, d8])
     enc_out = bits_dense(pre_out)
-    #enc_out = Lambda(lambda x: x[:, bunch//2-1::bunch//2])(enc_out)
-    bits = Multiply()([enc_out, quant_scale])
     global_dense1 = Dense(128, activation='tanh', name='gdense1')
     global_dense2 = Dense(nb_state_dim, activation='tanh', name='gdense2')
     global_bits = global_dense2(global_dense1(pre_out))
 
-    encoder = Model([feat, quant_id, lambda_val], [bits, quant_embed, global_bits], name='encoder')
+    encoder = Model([feat], [enc_out, global_bits], name='encoder')
     return encoder
 
 def new_rdovae_decoder(nb_used_features=20, nb_bits=17, bunch=4, nb_quant=40, batch_size=128, cond_size=128, cond_size2=256, training=False):
     bits_input = Input(shape=(None, nb_bits), batch_size=batch_size, name="dec_bits")
-    quant_embed_input = Input(shape=(None, 6*nb_bits), batch_size=batch_size, name="dec_embed")
     gru_state_input = Input(shape=(nb_state_dim,), batch_size=batch_size, name="dec_state")
 
     
@@ -260,10 +249,8 @@
 
     dec_final = Dense(bunch*nb_used_features, activation='linear', name='dec_final')
 
-    div = Lambda(lambda x: x[0]/x[1])
     time_reverse = Lambda(lambda x: K.reverse(x, 1))
     #time_reverse = Lambda(lambda x: x)
-    quant_scale_dec = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed_dec')(quant_embed_input))
     #gru_state_rep = RepeatVector(64//bunch)(gru_state_input)
 
     #gru_state_rep = Lambda(var_repeat, output_shape=(None, nb_state_dim)) ([gru_state_input, bits_input])
@@ -271,8 +258,7 @@
     gru_state2 = Dense(cond_size, name="state2", activation='tanh')(gru_state_input)
     gru_state3 = Dense(cond_size, name="state3", activation='tanh')(gru_state_input)
 
-    dec_inputs = Concatenate()([div([bits_input,quant_scale_dec]), tf.stop_gradient(quant_embed_input)])
-    dec1 = dec_dense1(time_reverse(dec_inputs))
+    dec1 = dec_dense1(time_reverse(bits_input))
     dec2 = dec_dense2(dec1)
     dec3 = dec_dense3(dec2)
     dec4 = dec_dense4(dec3, initial_state=gru_state1)
@@ -281,7 +267,7 @@
     dec7 = dec_dense7(dec6)
     dec8 = dec_dense8(dec7)
     output = Reshape((-1, nb_used_features))(dec_final(Concatenate()([dec1, dec2, dec3, dec4, dec5, dec6, dec7, dec8])))
-    decoder = Model([bits_input, quant_embed_input, gru_state_input], time_reverse(output), name='decoder')
+    decoder = Model([bits_input, gru_state_input], time_reverse(output), name='decoder')
     decoder.nb_bits = nb_bits
     decoder.bunch = bunch
     return decoder
@@ -290,7 +276,6 @@
     nb_bits = decoder.nb_bits
     bunch = decoder.bunch
     bits_input = Input(shape=(None, nb_bits), name="split_bits")
-    quant_embed_input = Input(shape=(None, 6*nb_bits), name="split_embed")
     gru_state_input = Input(shape=(None,nb_state_dim), name="split_state")
 
     range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
@@ -302,10 +287,9 @@
         end = points[i+1]//bunch
         state = elem_select([gru_state_input, end-1])
         bits = range_select([bits_input, begin, end])
-        embed = range_select([quant_embed_input, begin, end])
-        outputs.append(decoder([bits, embed, state]))
+        outputs.append(decoder([bits, state]))
     output = Concatenate(axis=1)(outputs)
-    split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
+    split = Model([bits_input, gru_state_input], output, name="split")
     return split
 
 def tensor_concat(x):
@@ -328,8 +312,13 @@
     lambda_val = Input(shape=(None, 1), batch_size=batch_size)
     lambda_bunched = AveragePooling1D(pool_size=bunch//2, strides=bunch//2, padding="valid")(lambda_val)
 
+    qembedding = Embedding(nb_quant, 6*nb_bits, name='quant_embed', embeddings_initializer='zeros')
+    quant_embed_dec = qembedding(quant_id)
+    quant_scale = Activation('softplus')(Lambda(lambda x: x[:,:,:nb_bits], name='quant_scale_embed')(quant_embed_dec))
+
     encoder = new_rdovae_encoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
-    ze, quant_embed_dec, gru_state_dec = encoder([feat, quant_id, lambda_val])
+    ze, gru_state_dec = encoder([feat])
+    ze = Multiply()([ze, quant_scale])
 
     decoder = new_rdovae_decoder(nb_used_features, nb_bits, bunch, nb_quant, batch_size, cond_size, cond_size2, training=training)
     split_decoder = new_split_decoder(decoder)
@@ -342,18 +331,22 @@
     hardquant = Lambda(hard_quantize)
     dzone = Lambda(apply_dead_zone)
     dze = dzone([ze,dead_zone])
+    ndze = noisequant(dze)
     
+    div = Lambda(lambda x: x[0]/x[1])
+    dze_unquant = div([dze,quant_scale])
+    ndze_unquant = div([ndze,quant_scale])
+
     mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
     gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
-    ndze = noisequant(dze)
     combined_output = []
     unquantized_output = []
     for i in range(bunch//2):
-        dze_select = mod_select([dze, i])
-        ndze_select = mod_select([ndze, i])
+        dze_select = mod_select([dze_unquant, i])
+        ndze_select = mod_select([ndze_unquant, i])
         state_select = mod_select([gru_state_dec, i])
-        combined_output.append(split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select]))
-        unquantized_output.append(split_decoder([ndze_select, quant_embed_dec, state_select]))
+        combined_output.append(split_decoder([hardquant(dze_select), state_select]))
+        unquantized_output.append(split_decoder([ndze_select, state_select]))
 
     concat = Lambda(tensor_concat, name="output")
     combined_output = concat(combined_output)
@@ -366,5 +359,5 @@
     model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, e, e2], name="end2end")
     model.nb_used_features = nb_used_features
 
-    return model, encoder, decoder
+    return model, encoder, decoder, qembedding
 
--- a/dnn/training_tf2/train_rdovae.py
+++ b/dnn/training_tf2/train_rdovae.py
@@ -99,8 +99,8 @@
 opt = Adam(lr, decay=decay, beta_2=0.99)
 
 with strategy.scope():
-    model, encoder, decoder = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
-    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[0.5, 0.5, 1., .1], metrics={'hard_bits':rdovae.sq_rate_metric})
+    model, encoder, decoder, _ = rdovae.new_rdovae_model(nb_used_features=20, nb_bits=80, batch_size=batch_size, cond_size=args.cond_size)
+    model.compile(optimizer=opt, loss=[rdovae.feat_dist_loss, rdovae.feat_dist_loss, rdovae.sq1_rate_loss, rdovae.sq2_rate_loss], loss_weights=[.1, .9, 1., .1], metrics={'hard_bits':rdovae.sq_rate_metric})
     model.summary()
 
 lpc_order = 16
--