ref: 159da408904ceb3f13756a1349e2f3f47a375222
parent: 818a0496d5171e09d498c8eb17fd3cd9ed12c516
author: Jan Buethe <jbuethe@amazon.de>
date: Tue Oct 25 08:59:17 EDT 2022
fixed calculation of p0
--- a/dnn/training_tf2/dump_rdovae.py
+++ b/dnn/training_tf2/dump_rdovae.py
@@ -58,9 +58,9 @@
print("dumping statistical model")quant_scales = tf.math.softplus(w[:, : N]).numpy()
dead_zone = 0.05 * tf.math.softplus(w[:, N : 2 * N]).numpy()
- theta = 0.5 + 0.5 * tf.math.sigmoid(w[:, 4 * N : 5 * N]).numpy()
r = tf.math.sigmoid(w[:, 5 * N : 6 * N]).numpy()
- p0 = 1 - r ** (0.5 + 0.5 * theta)
+ p0 = tf.math.sigmoid(w[:, 4 * N : 5 * N]).numpy()
+ p0 = 1 - r ** (0.5 + 0.5 * p0)
quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.int16)
dead_zone_q10 = np.round(dead_zone * 2**10).astype(np.int16)
--
⑨