shithub: opus

Download patch

ref: 77594bf158bae48c8267f1f548209caa118ae7d5
parent: 222662dac8bfbc2d764142d178b91f9d928f56cc
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Wed Nov 8 12:32:43 EST 2023

Dumping RDOVAE stats from XML

--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -48,8 +48,27 @@
 from wexchange.torch import dump_torch_weights
 from wexchange.c_export import CWriter, print_vector
 
-
-def dump_statistical_model(writer, w, name):
+def print_xml(xmlout, val, param, anchor, name):
+    xmlout.write(
+f"""
+            <table anchor="{anchor}_{name}">
+                <name>{param} values for {name}</name>
+                <thead>
+                    <tr><th>k</th><th>Q0</th><th>Q1</th><th>Q2</th><th>Q3</th><th>Q4</th><th>Q5</th><th>Q6</th><th>Q7</th><th>Q8</th><th>Q9</th><th>Q10</th><th>Q11</th><th>Q12</th><th>Q13</th><th>Q14</th><th>Q15</th></tr>
+                </thead>
+                <tbody>
+""")
+    for k in range(val.shape[1]):
+        xmlout.write(f"        <tr><th>{k}</th>")
+        for j in range(val.shape[0]):
+            xmlout.write(f"<th>{val[j][k]}</th>")
+        xmlout.write("</tr>\n")
+    xmlout.write(
+f"""
+                </tbody>
+            </table>
+""")
+def dump_statistical_model(writer, w, name, xmlout):
     levels = w.shape[0]
 
     print("printing statistical model")
@@ -78,6 +97,11 @@
     print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
     print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
 
+    print_xml(xmlout, quant_scales_q8, "Scale", "scale", name)
+    print_xml(xmlout, dead_zone_q8, "Dead zone", "deadzone", name)
+    print_xml(xmlout, r_q8, "Decay (r)", "decay", name)
+    print_xml(xmlout, p0_q8, "P(0)", "p0", name)
+
     writer.header.write(
 f"""
 extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}];
@@ -98,6 +122,7 @@
     dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message, model_struct_name='RDOVAEDec')
     stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message, enable_binary_blob=False)
     constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True, enable_binary_blob=False)
+    xmlout = open("stats.xml", "w")
 
     # some custom includes
     for writer in [enc_writer, dec_writer]:
@@ -130,8 +155,8 @@
     levels = qembedding.shape[0]
     qembedding = torch.reshape(qembedding, (levels, 6, -1))
 
-    latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
-    state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
+    latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent', xmlout)
+    state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state', xmlout)
 
     padded_latent_dim = (latent_dim+7)//8*8
     latent_pad = padded_latent_dim - latent_dim;
--