shithub: opus

Download patch

ref: 0ab0640d4ad41d765ab2b9916f7146c67fe56a3c
parent: 2386a60ec644fadc437155cd6e5f6d4c561940d4
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Mon Nov 6 12:49:18 EST 2023

Split stats in two and remove useless dimensions

--- a/autogen.sh
+++ b/autogen.sh
@@ -9,7 +9,7 @@
 srcdir=`dirname $0`
 test -n "$srcdir" && cd "$srcdir"
 
-dnn/download_model.sh 98b8be0
+dnn/download_model.sh 2386a60
 
 echo "Updating build configuration files, please wait...."
 
--- a/dnn/dred_rdovae.c
+++ b/dnn/dred_rdovae.c
@@ -77,24 +77,3 @@
 {
     dred_rdovae_decode_qframe(h, model, qframe, z);
 }
-
-
-const opus_uint8 * DRED_rdovae_get_p0_pointer(void)
-{
-    return &dred_p0_q8[0];
-}
-
-const opus_uint16 * DRED_rdovae_get_dead_zone_pointer(void)
-{
-    return &dred_dead_zone_q10[0];
-}
-
-const opus_uint8 * DRED_rdovae_get_r_pointer(void)
-{
-    return &dred_r_q8[0];
-}
-
-const opus_uint16 * DRED_rdovae_get_quant_scales_pointer(void)
-{
-    return &dred_quant_scales_q8[0];
-}
--- a/dnn/dred_rdovae_enc.c
+++ b/dnn/dred_rdovae_enc.c
@@ -34,6 +34,7 @@
 
 #include "dred_rdovae_enc.h"
 #include "os_support.h"
+#include "dred_rdovae_constants.h"
 
 static void conv1_cond_init(float *mem, int len, int dilation, int *init)
 {
@@ -52,6 +53,8 @@
     const float *input              /* i: double feature frame (concatenated) */
     )
 {
+    float padded_latents[DRED_PADDED_LATENT_DIM];
+    float padded_state[DRED_PADDED_STATE_DIM];
     float buffer[ENC_DENSE1_OUT_SIZE + ENC_GRU1_OUT_SIZE + ENC_GRU2_OUT_SIZE + ENC_GRU3_OUT_SIZE + ENC_GRU4_OUT_SIZE + ENC_GRU5_OUT_SIZE
                + ENC_CONV1_OUT_SIZE + ENC_CONV2_OUT_SIZE + ENC_CONV3_OUT_SIZE + ENC_CONV4_OUT_SIZE + ENC_CONV5_OUT_SIZE];
     float state_hidden[GDENSE1_OUT_SIZE];
@@ -96,9 +99,11 @@
     compute_generic_conv1d_dilation(&model->enc_conv5, &buffer[output_index], enc_state->conv5_state, buffer, output_index, 2, ACTIVATION_TANH);
     output_index += ENC_CONV5_OUT_SIZE;
 
-    compute_generic_dense(&model->enc_zdense, latents, buffer, ACTIVATION_LINEAR);
+    compute_generic_dense(&model->enc_zdense, padded_latents, buffer, ACTIVATION_LINEAR);
+    OPUS_COPY(latents, padded_latents, DRED_LATENT_DIM);
 
     /* next, calculate initial state */
     compute_generic_dense(&model->gdense1, state_hidden, buffer, ACTIVATION_TANH);
-    compute_generic_dense(&model->gdense2, initial_state, state_hidden, ACTIVATION_LINEAR);
+    compute_generic_dense(&model->gdense2, padded_state, state_hidden, ACTIVATION_LINEAR);
+    OPUS_COPY(initial_state, padded_state, DRED_STATE_DIM);
 }
--- a/dnn/torch/rdovae/export_rdovae_weights.py
+++ b/dnn/torch/rdovae/export_rdovae_weights.py
@@ -49,37 +49,43 @@
 from wexchange.c_export import CWriter, print_vector
 
 
-def dump_statistical_model(writer, qembedding):
-    w = qembedding.weight.detach()
-    levels, dim = w.shape
-    N = dim // 6
+def dump_statistical_model(writer, w, name):
+    levels = w.shape[0]
 
     print("printing statistical model")
-    quant_scales    = torch.nn.functional.softplus(w[:, : N]).numpy()
-    dead_zone       = 0.05 * torch.nn.functional.softplus(w[:, N : 2 * N]).numpy()
-    r               = torch.sigmoid(w[:, 5 * N : 6 * N]).numpy()
-    p0              = torch.sigmoid(w[:, 4 * N : 5 * N]).numpy()
+    quant_scales    = torch.nn.functional.softplus(w[:, 0, :]).numpy()
+    dead_zone       = 0.05 * torch.nn.functional.softplus(w[:, 1, :]).numpy()
+    r               = torch.sigmoid(w[:, 5 , :]).numpy()
+    p0              = torch.sigmoid(w[:, 4 , :]).numpy()
     p0              = 1 - r ** (0.5 + 0.5 * p0)
 
     quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
     dead_zone_q10   = np.round(dead_zone * 2**10).astype(np.uint16)
-    r_q15           = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
-    p0_q15          = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
+    r_q8           = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
+    p0_q8          = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
 
-    print_vector(writer.source, quant_scales_q8, 'dred_quant_scales_q8', dtype='opus_uint16', static=False)
-    print_vector(writer.source, dead_zone_q10, 'dred_dead_zone_q10', dtype='opus_uint16', static=False)
-    print_vector(writer.source, r_q15, 'dred_r_q8', dtype='opus_uint8', static=False)
-    print_vector(writer.source, p0_q15, 'dred_p0_q8', dtype='opus_uint8', static=False)
+    mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
+    quant_scales_q8 = quant_scales_q8[:, mask]
+    dead_zone_q10 = dead_zone_q10[:, mask]
+    r_q8 = r_q8[:, mask]
+    p0_q8 = p0_q8[:, mask]
+    N = r_q8.shape[-1]
 
+    print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint16', static=False)
+    print_vector(writer.source, dead_zone_q10, f'dred_{name}_dead_zone_q10', dtype='opus_uint16', static=False)
+    print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
+    print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
+
     writer.header.write(
 f"""
-extern const opus_uint16 dred_quant_scales_q8[{levels * N}];
-extern const opus_uint16 dred_dead_zone_q10[{levels * N}];
-extern const opus_uint8 dred_r_q8[{levels * N}];
-extern const opus_uint8 dred_p0_q8[{levels * N}];
+extern const opus_uint16 dred_{name}_quant_scales_q8[{levels * N}];
+extern const opus_uint16 dred_{name}_dead_zone_q10[{levels * N}];
+extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
+extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
 
 """
     )
+    return N, mask
 
 
 def c_export(args, model):
@@ -113,6 +119,41 @@
 """
         )
 
+    latent_out = model.get_submodule('core_encoder.module.z_dense')
+    state_out = model.get_submodule('core_encoder.module.state_dense_2')
+    orig_latent_dim = latent_out.weight.shape[0]
+    orig_state_dim = state_out.weight.shape[0]
+    # statistical model
+    qembedding = model.statistical_model.quant_embedding.weight.detach()
+    levels = qembedding.shape[0]
+    qembedding = torch.reshape(qembedding, (levels, 6, -1))
+
+    latent_dim, latent_mask = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent')
+    state_dim, state_mask = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state')
+
+    padded_latent_dim = (latent_dim+7)//8*8
+    latent_pad = padded_latent_dim - latent_dim;
+    w = latent_out.weight[latent_mask,:]
+    w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
+    b = latent_out.bias[latent_mask]
+    b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
+    latent_out.weight = torch.nn.Parameter(w)
+    latent_out.bias = torch.nn.Parameter(b)
+
+    padded_state_dim = (state_dim+7)//8*8
+    state_pad = padded_state_dim - state_dim;
+    w = state_out.weight[state_mask,:]
+    w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
+    b = state_out.bias[state_mask]
+    b = torch.cat([b, torch.zeros(state_pad)], dim=0)
+    state_out.weight = torch.nn.Parameter(w)
+    state_out.bias = torch.nn.Parameter(b)
+
+    latent_in = model.get_submodule('core_decoder.module.dense_1')
+    state_in = model.get_submodule('core_decoder.module.hidden_init')
+    latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask])
+    state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask])
+
     # encoder
     encoder_dense_layers = [
         ('core_encoder.module.dense_1'       , 'enc_dense1',   'TANH', False,),
@@ -187,10 +228,6 @@
 
     del dec_writer
 
-    # statistical model
-    qembedding = model.statistical_model.quant_embedding
-    dump_statistical_model(stats_writer, qembedding)
-
     del stats_writer
 
     # constants
@@ -198,9 +235,13 @@
 f"""
 #define DRED_NUM_FEATURES {model.feature_dim}
 
-#define DRED_LATENT_DIM {model.latent_dim}
+#define DRED_LATENT_DIM {latent_dim}
 
-#define DRED_STATE_DIME {model.state_dim}
+#define DRED_STATE_DIM {state_dim}
+
+#define DRED_PADDED_LATENT_DIM {padded_latent_dim}
+
+#define DRED_PADDED_STATE_DIM {padded_state_dim}
 
 #define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
 
--- a/dnn/torch/weight-exchange/wexchange/c_export/common.py
+++ b/dnn/torch/weight-exchange/wexchange/c_export/common.py
@@ -124,6 +124,7 @@
     return diag, B
 
 def quantize_weight(weight, scale):
+    scale = scale + 1e-30
     Aq = np.round(weight / scale).astype('int')
     if Aq.max() > 127 or Aq.min() <= -128:
         raise ValueError("value out of bounds in quantize_weight")
@@ -227,7 +228,7 @@
 
     nb_inputs, nb_outputs = weight.shape
 
-    if scale is None:
+    if scale is None and quantize:
         scale = compute_scaling(weight)
 
 
@@ -359,4 +360,4 @@
     writer.header.write(f"\n#define {name.upper()}_OUT_SIZE {N}\n")
     writer.header.write(f"\n#define {name.upper()}_STATE_SIZE {N}\n")
 
-    return N
\ No newline at end of file
+    return N
--- a/silk/dred_config.h
+++ b/silk/dred_config.h
@@ -39,9 +39,6 @@
 #define DRED_MIN_BYTES 16
 
 /* these are inpart duplicates to the values defined in dred_rdovae_constants.h */
-#define DRED_NUM_FEATURES 20
-#define DRED_LATENT_DIM 80
-#define DRED_STATE_DIM 80
 #define DRED_SILK_ENCODER_DELAY (79+12-80)
 #define DRED_FRAME_SIZE 160
 #define DRED_DFRAME_SIZE (2 * (DRED_FRAME_SIZE))
--- a/silk/dred_decoder.c
+++ b/silk/dred_decoder.c
@@ -36,6 +36,8 @@
 #include "dred_coding.h"
 #include "celt/entdec.h"
 #include "celt/laplace.h"
+#include "dred_rdovae_stats_data.h"
+#include "dred_rdovae_constants.h"
 
 /* From http://graphics.stanford.edu/~seander/bithacks.html#FixedSignExtend */
 static int sign_extend(int x, int b) {
@@ -55,9 +57,6 @@
 
 int dred_ec_decode(OpusDRED *dec, const opus_uint8 *bytes, int num_bytes, int min_feature_frames)
 {
-  const opus_uint8 *p0              = DRED_rdovae_get_p0_pointer();
-  const opus_uint16 *quant_scales    = DRED_rdovae_get_quant_scales_pointer();
-  const opus_uint8 *r               = DRED_rdovae_get_r_pointer();
   ec_dec ec;
   int q_level;
   int i;
@@ -78,13 +77,13 @@
   /*printf("%d %d %d\n", dred_offset, q0, dQ);*/
 
   //dred_decode_state(&ec, dec->state);
-  state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_LATENT_DIM;
+  state_qoffset = q0*DRED_STATE_DIM;
   dred_decode_latents(
       &ec,
       dec->state,
-      quant_scales + state_qoffset,
-      r + state_qoffset,
-      p0 + state_qoffset,
+      dred_state_quant_scales_q8 + state_qoffset,
+      dred_state_r_q8 + state_qoffset,
+      dred_state_p0_q8 + state_qoffset,
       DRED_STATE_DIM);
 
   /* decode newest to oldest and store oldest to newest */
@@ -94,13 +93,13 @@
       if (8*num_bytes - ec_tell(&ec) <= 7)
          break;
       q_level = compute_quantizer(q0, dQ, i/2);
-      offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
+      offset = q_level*DRED_LATENT_DIM;
       dred_decode_latents(
           &ec,
           &dec->latents[(i/2)*DRED_LATENT_DIM],
-          quant_scales + offset,
-          r + offset,
-          p0 + offset,
+          dred_latent_quant_scales_q8 + offset,
+          dred_latent_r_q8 + offset,
+          dred_latent_p0_q8 + offset,
           DRED_LATENT_DIM
           );
 
--- a/silk/dred_decoder.h
+++ b/silk/dred_decoder.h
@@ -32,6 +32,7 @@
 #include "dred_config.h"
 #include "dred_rdovae.h"
 #include "entcode.h"
+#include "dred_rdovae_constants.h"
 
 struct OpusDRED {
     float        fec_features[2*DRED_NUM_REDUNDANCY_FRAMES*DRED_NUM_FEATURES];
--- a/silk/dred_encoder.c
+++ b/silk/dred_encoder.c
@@ -44,6 +44,7 @@
 #include "float_cast.h"
 #include "os_support.h"
 #include "celt/laplace.h"
+#include "dred_rdovae_stats_data.h"
 
 
 int dred_encoder_load_model(DREDEnc* enc, const unsigned char *data, int len)
@@ -244,10 +245,6 @@
 }
 
 int dred_encode_silk_frame(const DREDEnc *enc, unsigned char *buf, int max_chunks, int max_bytes) {
-    const opus_uint16 *dead_zone       = DRED_rdovae_get_dead_zone_pointer();
-    const opus_uint8 *p0              = DRED_rdovae_get_p0_pointer();
-    const opus_uint16 *quant_scales    = DRED_rdovae_get_quant_scales_pointer();
-    const opus_uint8 *r               = DRED_rdovae_get_r_pointer();
     ec_enc ec_encoder;
 
     int q_level;
@@ -265,14 +262,14 @@
     ec_enc_uint(&ec_encoder, enc->dred_offset, 32);
     ec_enc_uint(&ec_encoder, q0, 16);
     ec_enc_uint(&ec_encoder, dQ, 8);
-    state_qoffset = q0*(DRED_LATENT_DIM+DRED_STATE_DIM) + DRED_LATENT_DIM;
+    state_qoffset = q0*DRED_STATE_DIM;
     dred_encode_latents(
         &ec_encoder,
         enc->initial_state,
-        quant_scales + state_qoffset,
-        dead_zone + state_qoffset,
-        r + state_qoffset,
-        p0 + state_qoffset,
+        dred_state_quant_scales_q8 + state_qoffset,
+        dred_state_dead_zone_q10 + state_qoffset,
+        dred_state_r_q8 + state_qoffset,
+        dred_state_p0_q8 + state_qoffset,
         DRED_STATE_DIM);
     if (ec_tell(&ec_encoder) > 8*max_bytes) {
       return 0;
@@ -283,15 +280,15 @@
         ec_bak = ec_encoder;
 
         q_level = compute_quantizer(q0, dQ, i/2);
-        offset = q_level * (DRED_LATENT_DIM+DRED_STATE_DIM);
+        offset = q_level * DRED_LATENT_DIM;
 
         dred_encode_latents(
             &ec_encoder,
             enc->latents_buffer + (i+enc->latent_offset) * DRED_LATENT_DIM,
-            quant_scales + offset,
-            dead_zone + offset,
-            r + offset,
-            p0 + offset,
+            dred_latent_quant_scales_q8 + offset,
+            dred_latent_dead_zone_q10 + offset,
+            dred_latent_r_q8 + offset,
+            dred_latent_p0_q8 + offset,
             DRED_LATENT_DIM
         );
         if (ec_tell(&ec_encoder) > 8*max_bytes) {
--