ref: 981d06eefda5ecfaf44bce05d75cccd01a3a1e24
parent: a4f7c157cf879ad768f97c3fa31520ede267bea7
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Fri Sep 9 22:18:30 EDT 2022
Refactoring towards multiple offset decoding
--- a/dnn/training_tf2/rdovae.py
+++ b/dnn/training_tf2/rdovae.py
@@ -282,27 +282,22 @@
nb_bits = decoder.nb_bits
bunch = decoder.bunch
bits_input = Input(shape=(None, nb_bits), name="split_bits")
- uqbits_input = Input(shape=(None, nb_bits), name="split_uqbits")
quant_embed_input = Input(shape=(None, 6*nb_bits), name="split_embed")
gru_state_input = Input(shape=(None,nb_state_dim), name="split_state")
- range_select = Lambda(lambda x: x[0][:,x[1]+bunch//2-1:x[2]:bunch//2,:])
+ range_select = Lambda(lambda x: x[0][:,x[1]:x[2],:])
elem_select = Lambda(lambda x: x[0][:,x[1],:])
points = [0, 64, 128, 192, 256]
outputs = []
- uqbits = []
for i in range(len(points)-1):
- begin = points[i]//2
- end = points[i+1]//2
+ begin = points[i]//bunch
+ end = points[i+1]//bunch
state = elem_select([gru_state_input, end-1])
bits = range_select([bits_input, begin, end])
- uq = range_select([uqbits_input, begin, end])
- uqbits.append(uq)
embed = range_select([quant_embed_input, begin, end])
outputs.append(decoder([bits, embed, state]))
output = Concatenate(axis=1)(outputs)
- uqbits = Concatenate(axis=1)(uqbits)
- split = Model([bits_input, uqbits_input, quant_embed_input, gru_state_input], [output, uqbits], name="split")
+ split = Model([bits_input, quant_embed_input, gru_state_input], output, name="split")
return split
@@ -327,14 +322,20 @@
hardquant = Lambda(hard_quantize)
dzone = Lambda(apply_dead_zone)
dze = dzone([ze,dead_zone])
+
+ mod_select = Lambda(lambda x: x[0][:,x[1]::bunch//2,:])
gru_state_dec = Lambda(lambda x: pvq_quantize(x, 30))(gru_state_dec)
- combined_output, uqbits = split_decoder([hardquant(dze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec])
ndze = noisequant(dze)
- unquantized_output, uqbits = split_decoder([ndze, dze, quant_embed_dec, gru_state_dec])
- unquantized_output_dec, uqbits = split_decoder([tf.stop_gradient(ndze), dze, tf.stop_gradient(quant_embed_dec), gru_state_dec])
+ for i in [1]:
+ dze_select = mod_select([dze, i])
+ ndze_select = mod_select([ndze, i])
+ state_select = mod_select([gru_state_dec, i])
+ combined_output = split_decoder([hardquant(dze_select), tf.stop_gradient(quant_embed_dec), state_select])
+ unquantized_output = split_decoder([ndze_select, quant_embed_dec, state_select])
+ unquantized_output_dec = split_decoder([tf.stop_gradient(ndze_select), tf.stop_gradient(quant_embed_dec), state_select])
- e2 = Concatenate(name="hard_bits")([uqbits, hard_distr_embed, lambda_bunched])
- e = Concatenate(name="soft_bits")([uqbits, soft_distr_embed, lambda_bunched])
+ e2 = Concatenate(name="hard_bits")([dze_select, hard_distr_embed, lambda_bunched])
+ e = Concatenate(name="soft_bits")([dze_select, soft_distr_embed, lambda_bunched])
model = Model([feat, quant_id, lambda_val], [combined_output, unquantized_output, unquantized_output_dec, e, e2], name="end2end")
--
⑨