shithub: opus

Download patch

ref: 3fc183df5575ef4e50a570f324420635efbcd272
parent: 56820f0d10e984eae608e4ffff0fe6f1601366ba
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Feb 18 19:14:36 EST 2019

adjusting quantization bitrate

--- a/dnn/ceps_vq_train.c
+++ b/dnn/ceps_vq_train.c
@@ -51,7 +51,7 @@
   return nearest;
 }
 
-int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist)
+int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist, int sign)
 {
   int i, j;
   float min_dist = 1e15;
@@ -70,6 +70,21 @@
       nearest = i;
     }
   }
+  if (sign) {
+    for (i=0;i<nb_entries;i++)
+    {
+      int offset;
+      float dist=0;
+      offset = (i&MULTI_MASK)*ndim;
+      for (j=0;j<ndim;j++)
+        dist += (x[offset+j]+codebook[i*ndim+j])*(x[offset+j]+codebook[i*ndim+j]);
+      if (dist<min_dist)
+      {
+        min_dist = dist;
+        nearest = i+nb_entries;
+      }
+    }
+  }
   if (dist)
     *dist = min_dist;
   return nearest;
@@ -231,7 +246,7 @@
   //fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
 }
 
-void update_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim)
+void update_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim, int sign)
 {
   int i,j;
   int count[nb_entries];
@@ -244,7 +259,7 @@
   for (i=0;i<nb_vectors;i++)
   {
     float dist;
-    nearest[i] = find_nearest_multi(codebook, nb_entries, data+MULTI*i*ndim, ndim, &dist);
+    nearest[i] = find_nearest_multi(codebook, nb_entries, data+MULTI*i*ndim, ndim, &dist, sign);
     err += dist;
   }
   printf("RMS error = %f\n", sqrt(err/nb_vectors/ndim));
@@ -253,10 +268,11 @@
 
   for (i=0;i<nb_vectors;i++)
   {
-    int n = nearest[i];
+    int n = nearest[i] % nb_entries;
+    float sign = nearest[i] < nb_entries ? 1 : -1;
     count[n]++;
     for (j=0;j<ndim;j++)
-      codebook[n*ndim+j] += data[(MULTI*i + (n&MULTI_MASK))*ndim+j];
+      codebook[n*ndim+j] += sign*data[(MULTI*i + (n&MULTI_MASK))*ndim+j];
   }
 
   float w2=0;
@@ -334,11 +350,11 @@
     for (j=0;j<4;j++)
       update(data, nb_vectors, codebook, e, ndim);
   }
-  for (j=0;j<ndim*2;j++)
+  for (j=0;j<10;j++)
     update(data, nb_vectors, codebook, e, ndim);
 }
 
-void vq_train_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim)
+void vq_train_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim, int sign)
 {
   int i, j, e;
   for (e=0;e<MULTI;e++) {
@@ -355,7 +371,7 @@
   }
   e = MULTI;
   for (j=0;j<10;j++)
-    update_multi(data, nb_vectors, codebook, e, ndim);
+    update_multi(data, nb_vectors, codebook, e, ndim, sign);
 
   while (e < nb_entries)
   {
@@ -363,10 +379,10 @@
     e<<=1;
     fprintf(stderr, "%d\n", e);
     for (j=0;j<4;j++)
-      update_multi(data, nb_vectors, codebook, e, ndim);
+      update_multi(data, nb_vectors, codebook, e, ndim, sign);
   }
-  for (j=0;j<ndim*2;j++)
-    update_multi(data, nb_vectors, codebook, e, ndim);
+  for (j=0;j<10;j++)
+    update_multi(data, nb_vectors, codebook, e, ndim, sign);
 }
 
 
@@ -402,7 +418,7 @@
 int main(int argc, char **argv)
 {
   int i,j;
-  int nb_vectors, nb_entries, nb_entries2, ndim, ndim0, total_dim;
+  int nb_vectors, nb_entries, nb_entries1, nb_entries2a, nb_entries2b, ndim, ndim0, total_dim;
   float *data, *pred, *multi_data, *multi_data2, *qdata;
   float *codebook, *codebook2, *codebook_diff2, *codebook_diff4;
   float *delta;
@@ -414,7 +430,9 @@
   total_dim = atoi(argv[2]);
   nb_vectors = atoi(argv[3]);
   nb_entries = 1<<atoi(argv[4]);
-  nb_entries2 = 64;
+  nb_entries1 = 256;
+  nb_entries2a = 2048;
+  nb_entries2b = 256;
   
   data = malloc((nb_vectors*ndim+total_dim)*sizeof(*data));
   qdata = malloc((nb_vectors*ndim+total_dim)*sizeof(*qdata));
@@ -422,9 +440,9 @@
   multi_data = malloc(MULTI*nb_vectors*ndim*sizeof(*multi_data));
   multi_data2 = malloc(MULTI*nb_vectors*ndim*sizeof(*multi_data));
   codebook = malloc(nb_entries*ndim0*sizeof(*codebook));
-  codebook2 = malloc(nb_entries*ndim0*sizeof(*codebook2));
-  codebook_diff4 = malloc(nb_entries*ndim*sizeof(*codebook_diff4));
-  codebook_diff2 = malloc(nb_entries2*ndim*sizeof(*codebook_diff2));
+  codebook2 = malloc(nb_entries1*ndim0*sizeof(*codebook2));
+  codebook_diff4 = malloc(nb_entries2a*ndim*sizeof(*codebook_diff4));
+  codebook_diff2 = malloc(nb_entries2b*ndim*sizeof(*codebook_diff2));
   
   for (i=0;i<nb_vectors;i++)
   {
@@ -465,13 +483,13 @@
   }
   fprintf(stderr, "Cepstrum RMS error: %f\n", sqrt(err/nb_vectors/ndim));
 
-  vq_train(delta, nb_vectors, codebook2, nb_entries, ndim0);
+  vq_train(delta, nb_vectors, codebook2, nb_entries1, ndim0);
   
   err=0;
   for (i=0;i<nb_vectors;i++)
   {
     int n1;
-    n1 = find_nearest(codebook2, nb_entries, &delta[i*ndim0], ndim0, NULL);
+    n1 = find_nearest(codebook2, nb_entries1, &delta[i*ndim0], ndim0, NULL);
     for (j=0;j<ndim0;j++)
     {
       qdata[i*ndim+j+1] += codebook2[n1*ndim0+j];
@@ -506,10 +524,10 @@
       multi_data2[(MULTI*i+3)*ndim+j] = data[(i+2)*ndim+j] - qdata[(i+4)*ndim+j];
   }
 
-  vq_train_multi(multi_data2, nb_vectors-4, codebook_diff4, nb_entries, ndim);
+  vq_train_multi(multi_data2, nb_vectors-4, codebook_diff4, nb_entries2a, ndim, 1);
 
   printf("done\n");
-  vq_train_multi(multi_data, nb_vectors-4, codebook_diff2, 64, ndim);
+  vq_train_multi(multi_data, nb_vectors-4, codebook_diff2, nb_entries2b, ndim, 0);
 
 
   fout = fopen("ceps_codebooks.c", "w");
@@ -524,8 +542,8 @@
   }
   fprintf(fout, "};\n\n");
 
-  fprintf(fout, "float ceps_codebook2[%d*%d] = {\n",nb_entries, ndim0);
-  for (i=0;i<nb_entries;i++)
+  fprintf(fout, "float ceps_codebook2[%d*%d] = {\n",nb_entries1, ndim0);
+  for (i=0;i<nb_entries1;i++)
   {
     for (j=0;j<ndim0;j++)
       fprintf(fout, "%f, ", codebook2[i*ndim0+j]);
@@ -533,8 +551,8 @@
   }
   fprintf(fout, "};\n\n");
 
-  fprintf(fout, "float ceps_codebook_diff4[%d*%d] = {\n",nb_entries, ndim);
-  for (i=0;i<nb_entries;i++)
+  fprintf(fout, "float ceps_codebook_diff4[%d*%d] = {\n",nb_entries2a, ndim);
+  for (i=0;i<nb_entries2a;i++)
   {
     for (j=0;j<ndim;j++)
       fprintf(fout, "%f, ", codebook_diff4[i*ndim+j]);
@@ -542,8 +560,8 @@
   }
   fprintf(fout, "};\n\n");
 
-  fprintf(fout, "float ceps_codebook_diff2[%d*%d] = {\n",nb_entries2, ndim);
-  for (i=0;i<nb_entries2;i++)
+  fprintf(fout, "float ceps_codebook_diff2[%d*%d] = {\n",nb_entries2b, ndim);
+  for (i=0;i<nb_entries2b;i++)
   {
     for (j=0;j<ndim;j++)
       fprintf(fout, "%f, ", codebook_diff2[i*ndim+j]);
--- a/dnn/dump_data.c
+++ b/dnn/dump_data.c
@@ -91,7 +91,7 @@
     for (i=0;i<NB_BANDS_1;i++) {
         x[i] -= ceps_codebook1[id*NB_BANDS_1 + i];
     }
-    id2 = vq_quantize(ceps_codebook2, 1024, x, NB_BANDS_1, NULL);
+    id2 = vq_quantize(ceps_codebook2, 256, x, NB_BANDS_1, NULL);
     for (i=0;i<NB_BANDS_1;i++) {
         x[i] = ceps_codebook2[id2*NB_BANDS_1 + i];
     }
@@ -109,7 +109,7 @@
     return id;
 }
 
-static int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist)
+static int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist, int sign)
 {
   int i, j;
   float min_dist = 1e15;
@@ -128,12 +128,28 @@
       nearest = i;
     }
   }
+  if (sign) {
+    for (i=0;i<nb_entries;i++)
+    {
+      int offset;
+      float dist=0;
+      offset = (i&MULTI_MASK)*ndim;
+      for (j=0;j<ndim;j++)
+        dist += (x[offset+j]+codebook[i*ndim+j])*(x[offset+j]+codebook[i*ndim+j]);
+      if (dist<min_dist)
+      {
+        min_dist = dist;
+        nearest = i+nb_entries;
+      }
+    }
+  }
   if (dist)
     *dist = min_dist;
   return nearest;
 }
 
-int quantize_diff(float *x, float *left, float *right, float *codebook, int bits)
+
+int quantize_diff(float *x, float *left, float *right, float *codebook, int bits, int sign)
 {
     int i;
     int nb_entries;
@@ -141,6 +157,7 @@
     float ref[NB_BANDS];
     float pred[4*NB_BANDS];
     float target[4*NB_BANDS];
+    float s = 1;
     nb_entries = 1<<bits;
     RNN_COPY(ref, x, NB_BANDS);
     for (i=0;i<NB_BANDS;i++) pred[i] = pred[NB_BANDS+i] = .5*(left[i] + right[i]);
@@ -148,9 +165,13 @@
     for (i=0;i<NB_BANDS;i++) pred[3*NB_BANDS+i] = right[i];
     for (i=0;i<4*NB_BANDS;i++) target[i] = x[i%NB_BANDS] - pred[i];
 
-    id = find_nearest_multi(codebook, nb_entries, target, NB_BANDS, NULL);
+    id = find_nearest_multi(codebook, nb_entries, target, NB_BANDS, NULL, sign);
+    if (id >= 1<<bits) {
+      s = -1;
+      id -= (1<<bits);
+    }
     for (i=0;i<NB_BANDS;i++) {
-      x[i] = pred[(id&MULTI_MASK)*NB_BANDS + i] + codebook[id*NB_BANDS + i];
+      x[i] = pred[(id&MULTI_MASK)*NB_BANDS + i] + s*codebook[id*NB_BANDS + i];
     }
     if (1) {
         float err = 0;
@@ -402,10 +423,10 @@
   //printf("%f\n", st->features[3][0]);
   st->features[3][0] = floor(.5 + st->features[3][0]*5)/5;
   quantize_2stage(&st->features[3][1]);
-  quantize_diff(&st->features[1][0], vq_mem, &st->features[3][0], ceps_codebook_diff4, 10);
+  quantize_diff(&st->features[1][0], vq_mem, &st->features[3][0], ceps_codebook_diff4, 11, 1);
   //quantize_2stage(&st->features[1][1]);
-  quantize_diff(&st->features[0][0], vq_mem, &st->features[1][0], ceps_codebook_diff2, 6);
-  quantize_diff(&st->features[2][0], &st->features[1][0], &st->features[3][0], ceps_codebook_diff2, 6);
+  quantize_diff(&st->features[0][0], vq_mem, &st->features[1][0], ceps_codebook_diff2, 8, 0);
+  quantize_diff(&st->features[2][0], &st->features[1][0], &st->features[3][0], ceps_codebook_diff2, 8, 0);
   RNN_COPY(vq_mem, &st->features[3][0], NB_BANDS);
   for (i=0;i<4;i++) {
     fwrite(st->features[i], sizeof(float), NB_FEATURES, ffeat);
--