shithub: opus

Download patch

ref: 5f830b4578c5f0e5dcb6ddc88db8d1d50b270416
parent: fc4f594e25ee608f6cf806809622b29bbf1eed2e
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Mar 11 08:04:36 EDT 2019

3-bit interpolation

--- a/dnn/dump_data.c
+++ b/dnn/dump_data.c
@@ -99,7 +99,7 @@
     for (i=0;i<NB_BANDS_1;i++) {
         x[i] = ceps_codebook1[id*NB_BANDS_1 + i] + ceps_codebook2[id2*NB_BANDS_1 + i] + ceps_codebook3[id3*NB_BANDS_1 + i];
     }
-    if (1) {
+    if (0) {
         float err = 0;
         for (i=0;i<NB_BANDS_1;i++) {
             err += (x[i]-ref[i])*(x[i]-ref[i]);
@@ -174,7 +174,8 @@
     for (i=0;i<NB_BANDS;i++) {
       x[i] = pred[(id&MULTI_MASK)*NB_BANDS + i] + s*codebook[id*NB_BANDS + i];
     }
-    if (1) {
+    //printf("%d %f ", id&MULTI_MASK, s);
+    if (0) {
         float err = 0;
         for (i=0;i<NB_BANDS;i++) {
             err += (x[i]-ref[i])*(x[i]-ref[i]);
@@ -185,6 +186,31 @@
     return id;
 }
 
+#define FORBIDDEN_INTERP 7
+
+int interp_search(const float *x, const float *left, const float *right, float *dist_out)
+{
+    int i, k;
+    float min_dist = 1e15;
+    int best_pred = 0;
+    float pred[4*NB_BANDS];
+    for (i=0;i<NB_BANDS;i++) pred[i] = pred[NB_BANDS+i] = .5*(left[i] + right[i]);
+    for (i=0;i<NB_BANDS;i++) pred[2*NB_BANDS+i] = left[i];
+    for (i=0;i<NB_BANDS;i++) pred[3*NB_BANDS+i] = right[i];
+
+    for (k=1;k<4;k++) {
+      float dist = 0;
+      for (i=0;i<NB_BANDS;i++) dist += (x[i] - pred[k*NB_BANDS+i])*(x[i] - pred[k*NB_BANDS+i]);
+      dist_out[k-1] = dist;
+      if (dist < min_dist) {
+        min_dist = dist;
+        best_pred = k;
+      }
+    }
+    return best_pred - 1;
+}
+
+
 void interp_diff(float *x, float *left, float *right, float *codebook, int bits, int sign)
 {
     int i, k;
@@ -208,10 +234,11 @@
         best_pred = k;
       }
     }
+    //printf("%d ", best_pred);
     for (i=0;i<NB_BANDS;i++) {
       x[i] = pred[best_pred*NB_BANDS + i];
     }
-    if (1) {
+    if (0) {
         float err = 0;
         for (i=0;i<NB_BANDS;i++) {
             err += (x[i]-ref[i])*(x[i]-ref[i]);
@@ -220,7 +247,62 @@
     }
 }
 
+int double_interp_search(const float features[4][NB_FEATURES], const float *mem) {
+    int i, j;
+    int id0, id1;
+    int best_id=0;
+    float min_dist = 1e15;
+    float dist[2][3];
+    id0 = interp_search(features[0], mem, features[1], dist[0]);
+    id1 = interp_search(features[2], features[1], features[3], dist[1]);
+    for (i=0;i<3;i++) {
+        for (j=0;j<3;j++) {
+            float d;
+            int id;
+            id = 3*i + j;
+            d = dist[0][i] + dist[1][j];
+            if (d < min_dist && id != FORBIDDEN_INTERP) {
+                min_dist = d;
+                best_id = id;
+            }
+        }
+    }
+    //printf("%d %d %f    %d %f\n", id0, id1, dist[0][id0] + dist[1][id1], best_id, min_dist);
+    return best_id - (best_id >= FORBIDDEN_INTERP);
+}
 
+static void single_interp(float *x, const float *left, const float *right, int id)
+{
+    int i;
+    float ref[NB_BANDS];
+    float pred[3*NB_BANDS];
+    RNN_COPY(ref, x, NB_BANDS);
+    for (i=0;i<NB_BANDS;i++) pred[i] = .5*(left[i] + right[i]);
+    for (i=0;i<NB_BANDS;i++) pred[NB_BANDS+i] = left[i];
+    for (i=0;i<NB_BANDS;i++) pred[2*NB_BANDS+i] = right[i];
+    for (i=0;i<NB_BANDS;i++) {
+      x[i] = pred[id*NB_BANDS + i];
+    }
+    if (0) {
+        float err = 0;
+        for (i=0;i<NB_BANDS;i++) {
+            err += (x[i]-ref[i])*(x[i]-ref[i]);
+        }
+        printf("%f\n", sqrt(err/NB_BANDS));
+    }
+}
+
+void perform_double_interp(float features[4][NB_FEATURES], const float *mem) {
+    int id0, id1;
+    int best_id;
+    best_id = double_interp_search(features, mem);
+    best_id += (best_id >= FORBIDDEN_INTERP);
+    id0 = best_id / 3;
+    id1 = best_id % 3;
+    single_interp(features[0], mem, features[1], id0);
+    single_interp(features[2], features[1], features[3], id1);
+}
+
 typedef struct {
   float analysis_mem[OVERLAP_SIZE];
   float cepstral_mem[CEPS_MEM][NB_BANDS];
@@ -462,9 +544,15 @@
   st->features[3][0] = floor(.5 + st->features[3][0]*5)/5;
   quantize_2stage(&st->features[3][1]);
   quantize_diff(&st->features[1][0], vq_mem, &st->features[3][0], ceps_codebook_diff4, 11, 1);
+  double_interp_search(st->features, vq_mem);
   //quantize_2stage(&st->features[1][1]);
+#if 0
   interp_diff(&st->features[0][0], vq_mem, &st->features[1][0], ceps_codebook_diff2, 6, 0);
   interp_diff(&st->features[2][0], &st->features[1][0], &st->features[3][0], ceps_codebook_diff2, 6, 0);
+#else
+  perform_double_interp(st->features, vq_mem);
+#endif
+  //printf("\n");
   RNN_COPY(vq_mem, &st->features[3][0], NB_BANDS);
   for (i=0;i<4;i++) {
     fwrite(st->features[i], sizeof(float), NB_FEATURES, ffeat);
--