shithub: opus

Download patch

ref: 222662dac8bfbc2d764142d178b91f9d928f56cc
parent: 4e104555e98c8227464f02ee388d983d387612b6
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Tue Nov 7 12:46:38 EST 2023

DRED: quantize scale and dead zone to 8 bits

--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@
 srcdir=`dirname $0`
 test -n "$srcdir" && cd "$srcdir"
 
-dnn/download_model.sh 2386a60
+dnn/download_model.sh b6095cf
 
 echo "Updating build configuration files, please wait...."
 
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -59,33 +59,35 @@
     p0              = torch.sigmoid(w[:, 4 , :]).numpy()
     p0              = 1 - r ** (0.5 + 0.5 * p0)
 
+    scales_norm = 255./256./(1e-15+np.max(quant_scales,axis=0))
+    quant_scales = quant_scales*scales_norm
     quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
-    dead_zone_q10   = np.round(dead_zone * 2**10).astype(np.uint16)
+    dead_zone_q8   = np.clip(np.round(dead_zone * 2**8), 0, 255).astype(np.uint16)
     r_q8           = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
     p0_q8          = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
 
     mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
     quant_scales_q8 = quant_scales_q8[:, mask]
-    dead_zone_q10 = dead_zone_q10[:, mask]
+    dead_zone_q8 = dead_zone_q8[:, mask]
     r_q8 = r_q8[:, mask]
     p0_q8 = p0_q8[:, mask]
     N = r_q8.shape[-1]
 
-    print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False)
-    print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False)
+    print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint8', static=False)
+    print_vector(writer.source, dead_zone_q8, f'dred_{name}_dead_zone_q8', dtype='opus_uint8', static=False)
     print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
     print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
 
     writer.header.write(
 f"""
-extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}];
-extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}];
+extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}];
+extern const opus_uint8 dred_{name}_dead_zone_q8[{levels * N}];
 extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
 extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
 
 """
     )
-    return N, mask
+    return N, mask, torch.tensor(scales_norm[mask])
 
 
 def c_export(args, model):
@@ -128,14 +130,16 @@
     levels = qembedding.shape[0]
     qembedding = torch.reshape(qembedding, (levels, 6, -1))
 
-    latent_dim, latent_mask = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
-    state_dim, state_mask = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
+    latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
+    state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
 
     padded_latent_dim = (latent_dim+7)//8*8
     latent_pad = padded_latent_dim - latent_dim;
     w = latent_out.weight[latent_mask,:]
+    w = w/latent_scale[:, None]
     w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
     b = latent_out.bias[latent_mask]
+    b = b/latent_scale
     b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
     latent_out.weight = torch.nn.Parameter(w)
     latent_out.bias = torch.nn.Parameter(b)
@@ -143,8 +147,10 @@
     padded_state_dim = (state_dim+7)//8*8
     state_pad = padded_state_dim - state_dim;
     w = state_out.weight[state_mask,:]
+    w = w/state_scale[:, None]
     w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
     b = state_out.bias[state_mask]
+    b = b/state_scale
     b = torch.cat([b, torch.zeros(state_pad)], dim=0)
     state_out.weight = torch.nn.Parameter(w)
     state_out.bias = torch.nn.Parameter(b)
@@ -151,8 +157,8 @@
 
     latent_in = model.get_submodule('core_decoder.module.dense_1')
     state_in = model.get_submodule('core_decoder.module.hidden_init')
-    latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask])
-    state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask])
+    latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]*latent_scale)
+    state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]*state_scale)
 
     # encoder
     encoder_dense_layers = [
--- a/silk/dred_decoder.c
+++ b/silk/dred_decoder.c
@@ -45,7 +45,7 @@
   return (x ^ m) - m;
 }
 
-static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint16 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
+static void dred_decode_latents(ec_dec *dec, float *x, const opus_uint8 *scale, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
     int i;
     for (i=0;i<dim;i++) {
         int q;
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -223,7 +223,7 @@
     }
 }
 
-static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint16 *scale, const opus_uint16 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
+static void dred_encode_latents(ec_enc *enc, const float *x, const opus_uint8 *scale, const opus_uint8 *dzone, const opus_uint8 *r, const opus_uint8 *p0, int dim) {
     int i;
     int q[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
     float xq[IMAX(DRED_LATENT_DIM,DRED_STATE_DIM)];
@@ -233,7 +233,7 @@
     /* This is split into multiple loops (with temporary arrays) so that the compiler
        can vectorize all of it, and so we can call the vector tanh(). */
     for (i=0;i<dim;i++) {
-        delta[i] = dzone[i]*(1.f/1024.f);
+        delta[i] = dzone[i]*(1.f/256.f);
         xq[i] = x[i]*scale[i]*(1.f/256.f);
         deadzone[i] = xq[i]/(delta[i]+eps);
     }
@@ -272,7 +272,7 @@
         &ec_encoder,
         enc->initial_state,
         dred_state_quant_scales_q8 + state_qoffset,
-        dred_state_dead_zone_q10 + state_qoffset,
+        dred_state_dead_zone_q8 + state_qoffset,
         dred_state_r_q8 + state_qoffset,
         dred_state_p0_q8 + state_qoffset,
         DRED_STATE_DIM);
@@ -291,7 +291,7 @@
             &ec_encoder,
             enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM,
             dred_latent_quant_scales_q8 + offset,
-            dred_latent_dead_zone_q10 + offset,
+            dred_latent_dead_zone_q8 + offset,
             dred_latent_r_q8 + offset,
             dred_latent_p0_q8 + offset,
             DRED_LATENT_DIM
--