ref: 38dda0f950e126badf83ecabd8449d27796a9040
parent: b43f077ba84a4d1e843377a4208aae30b3d43011
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Sep 28 11:33:20 EDT 2022
Oops, forgot to run PVQ quantization for the state
--- a/dnn/training_tf2/decode_rdovae.py
+++ b/dnn/training_tf2/decode_rdovae.py
@@ -56,6 +56,7 @@
import h5py
import tensorflow as tf
+from rdovae import pvq_quantize
# Try reducing batch_size if you run out of memory on your GPU
batch_size = args.batch_size
@@ -87,6 +88,8 @@
state = np.reshape(state, (nb_sequences, sequence_size//2, 24))
state = state[:,-1,:]
+state = pvq_quantize(state, 30)
+#state = state/(1e-15+tf.norm(state, axis=-1,keepdims=True))
print("shapes are:")print(bits.shape)
--
⑨