ref: 3fc183df5575ef4e50a570f324420635efbcd272
parent: 56820f0d10e984eae608e4ffff0fe6f1601366ba
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Mon Feb 18 19:14:36 EST 2019
adjusting quantization bitrate
--- a/dnn/ceps_vq_train.c
+++ b/dnn/ceps_vq_train.c
@@ -51,7 +51,7 @@
return nearest;
}
-int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist)
+int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist, int sign)
{int i, j;
float min_dist = 1e15;
@@ -70,6 +70,21 @@
nearest = i;
}
}
+ if (sign) {+ for (i=0;i<nb_entries;i++)
+ {+ int offset;
+ float dist=0;
+ offset = (i&MULTI_MASK)*ndim;
+ for (j=0;j<ndim;j++)
+ dist += (x[offset+j]+codebook[i*ndim+j])*(x[offset+j]+codebook[i*ndim+j]);
+ if (dist<min_dist)
+ {+ min_dist = dist;
+ nearest = i+nb_entries;
+ }
+ }
+ }
if (dist)
*dist = min_dist;
return nearest;
@@ -231,7 +246,7 @@
//fprintf(stderr, "%f / %d\n", 1./w2, nb_entries);
}
-void update_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim)
+void update_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim, int sign)
{int i,j;
int count[nb_entries];
@@ -244,7 +259,7 @@
for (i=0;i<nb_vectors;i++)
{float dist;
- nearest[i] = find_nearest_multi(codebook, nb_entries, data+MULTI*i*ndim, ndim, &dist);
+ nearest[i] = find_nearest_multi(codebook, nb_entries, data+MULTI*i*ndim, ndim, &dist, sign);
err += dist;
}
printf("RMS error = %f\n", sqrt(err/nb_vectors/ndim));@@ -253,10 +268,11 @@
for (i=0;i<nb_vectors;i++)
{- int n = nearest[i];
+ int n = nearest[i] % nb_entries;
+ float sign = nearest[i] < nb_entries ? 1 : -1;
count[n]++;
for (j=0;j<ndim;j++)
- codebook[n*ndim+j] += data[(MULTI*i + (n&MULTI_MASK))*ndim+j];
+ codebook[n*ndim+j] += sign*data[(MULTI*i + (n&MULTI_MASK))*ndim+j];
}
float w2=0;
@@ -334,11 +350,11 @@
for (j=0;j<4;j++)
update(data, nb_vectors, codebook, e, ndim);
}
- for (j=0;j<ndim*2;j++)
+ for (j=0;j<10;j++)
update(data, nb_vectors, codebook, e, ndim);
}
-void vq_train_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim)
+void vq_train_multi(float *data, int nb_vectors, float *codebook, int nb_entries, int ndim, int sign)
{int i, j, e;
for (e=0;e<MULTI;e++) {@@ -355,7 +371,7 @@
}
e = MULTI;
for (j=0;j<10;j++)
- update_multi(data, nb_vectors, codebook, e, ndim);
+ update_multi(data, nb_vectors, codebook, e, ndim, sign);
while (e < nb_entries)
{@@ -363,10 +379,10 @@
e<<=1;
fprintf(stderr, "%d\n", e);
for (j=0;j<4;j++)
- update_multi(data, nb_vectors, codebook, e, ndim);
+ update_multi(data, nb_vectors, codebook, e, ndim, sign);
}
- for (j=0;j<ndim*2;j++)
- update_multi(data, nb_vectors, codebook, e, ndim);
+ for (j=0;j<10;j++)
+ update_multi(data, nb_vectors, codebook, e, ndim, sign);
}
@@ -402,7 +418,7 @@
int main(int argc, char **argv)
{int i,j;
- int nb_vectors, nb_entries, nb_entries2, ndim, ndim0, total_dim;
+ int nb_vectors, nb_entries, nb_entries1, nb_entries2a, nb_entries2b, ndim, ndim0, total_dim;
float *data, *pred, *multi_data, *multi_data2, *qdata;
float *codebook, *codebook2, *codebook_diff2, *codebook_diff4;
float *delta;
@@ -414,7 +430,9 @@
total_dim = atoi(argv[2]);
nb_vectors = atoi(argv[3]);
nb_entries = 1<<atoi(argv[4]);
- nb_entries2 = 64;
+ nb_entries1 = 256;
+ nb_entries2a = 2048;
+ nb_entries2b = 256;
data = malloc((nb_vectors*ndim+total_dim)*sizeof(*data));
qdata = malloc((nb_vectors*ndim+total_dim)*sizeof(*qdata));
@@ -422,9 +440,9 @@
multi_data = malloc(MULTI*nb_vectors*ndim*sizeof(*multi_data));
multi_data2 = malloc(MULTI*nb_vectors*ndim*sizeof(*multi_data));
codebook = malloc(nb_entries*ndim0*sizeof(*codebook));
- codebook2 = malloc(nb_entries*ndim0*sizeof(*codebook2));
- codebook_diff4 = malloc(nb_entries*ndim*sizeof(*codebook_diff4));
- codebook_diff2 = malloc(nb_entries2*ndim*sizeof(*codebook_diff2));
+ codebook2 = malloc(nb_entries1*ndim0*sizeof(*codebook2));
+ codebook_diff4 = malloc(nb_entries2a*ndim*sizeof(*codebook_diff4));
+ codebook_diff2 = malloc(nb_entries2b*ndim*sizeof(*codebook_diff2));
for (i=0;i<nb_vectors;i++)
{@@ -465,13 +483,13 @@
}
fprintf(stderr, "Cepstrum RMS error: %f\n", sqrt(err/nb_vectors/ndim));
- vq_train(delta, nb_vectors, codebook2, nb_entries, ndim0);
+ vq_train(delta, nb_vectors, codebook2, nb_entries1, ndim0);
err=0;
for (i=0;i<nb_vectors;i++)
{int n1;
- n1 = find_nearest(codebook2, nb_entries, &delta[i*ndim0], ndim0, NULL);
+ n1 = find_nearest(codebook2, nb_entries1, &delta[i*ndim0], ndim0, NULL);
for (j=0;j<ndim0;j++)
{qdata[i*ndim+j+1] += codebook2[n1*ndim0+j];
@@ -506,10 +524,10 @@
multi_data2[(MULTI*i+3)*ndim+j] = data[(i+2)*ndim+j] - qdata[(i+4)*ndim+j];
}
- vq_train_multi(multi_data2, nb_vectors-4, codebook_diff4, nb_entries, ndim);
+ vq_train_multi(multi_data2, nb_vectors-4, codebook_diff4, nb_entries2a, ndim, 1);
printf("done\n");- vq_train_multi(multi_data, nb_vectors-4, codebook_diff2, 64, ndim);
+ vq_train_multi(multi_data, nb_vectors-4, codebook_diff2, nb_entries2b, ndim, 0);
fout = fopen("ceps_codebooks.c", "w");@@ -524,8 +542,8 @@
}
fprintf(fout, "};\n\n");
- fprintf(fout, "float ceps_codebook2[%d*%d] = {\n",nb_entries, ndim0);- for (i=0;i<nb_entries;i++)
+ fprintf(fout, "float ceps_codebook2[%d*%d] = {\n",nb_entries1, ndim0);+ for (i=0;i<nb_entries1;i++)
{for (j=0;j<ndim0;j++)
fprintf(fout, "%f, ", codebook2[i*ndim0+j]);
@@ -533,8 +551,8 @@
}
fprintf(fout, "};\n\n");
- fprintf(fout, "float ceps_codebook_diff4[%d*%d] = {\n",nb_entries, ndim);- for (i=0;i<nb_entries;i++)
+ fprintf(fout, "float ceps_codebook_diff4[%d*%d] = {\n",nb_entries2a, ndim);+ for (i=0;i<nb_entries2a;i++)
{for (j=0;j<ndim;j++)
fprintf(fout, "%f, ", codebook_diff4[i*ndim+j]);
@@ -542,8 +560,8 @@
}
fprintf(fout, "};\n\n");
- fprintf(fout, "float ceps_codebook_diff2[%d*%d] = {\n",nb_entries2, ndim);- for (i=0;i<nb_entries2;i++)
+ fprintf(fout, "float ceps_codebook_diff2[%d*%d] = {\n",nb_entries2b, ndim);+ for (i=0;i<nb_entries2b;i++)
{for (j=0;j<ndim;j++)
fprintf(fout, "%f, ", codebook_diff2[i*ndim+j]);
--- a/dnn/dump_data.c
+++ b/dnn/dump_data.c
@@ -91,7 +91,7 @@
for (i=0;i<NB_BANDS_1;i++) {x[i] -= ceps_codebook1[id*NB_BANDS_1 + i];
}
- id2 = vq_quantize(ceps_codebook2, 1024, x, NB_BANDS_1, NULL);
+ id2 = vq_quantize(ceps_codebook2, 256, x, NB_BANDS_1, NULL);
for (i=0;i<NB_BANDS_1;i++) {x[i] = ceps_codebook2[id2*NB_BANDS_1 + i];
}
@@ -109,7 +109,7 @@
return id;
}
-static int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist)
+static int find_nearest_multi(const float *codebook, int nb_entries, const float *x, int ndim, float *dist, int sign)
{int i, j;
float min_dist = 1e15;
@@ -128,12 +128,28 @@
nearest = i;
}
}
+ if (sign) {+ for (i=0;i<nb_entries;i++)
+ {+ int offset;
+ float dist=0;
+ offset = (i&MULTI_MASK)*ndim;
+ for (j=0;j<ndim;j++)
+ dist += (x[offset+j]+codebook[i*ndim+j])*(x[offset+j]+codebook[i*ndim+j]);
+ if (dist<min_dist)
+ {+ min_dist = dist;
+ nearest = i+nb_entries;
+ }
+ }
+ }
if (dist)
*dist = min_dist;
return nearest;
}
-int quantize_diff(float *x, float *left, float *right, float *codebook, int bits)
+
+int quantize_diff(float *x, float *left, float *right, float *codebook, int bits, int sign)
{int i;
int nb_entries;
@@ -141,6 +157,7 @@
float ref[NB_BANDS];
float pred[4*NB_BANDS];
float target[4*NB_BANDS];
+ float s = 1;
nb_entries = 1<<bits;
RNN_COPY(ref, x, NB_BANDS);
for (i=0;i<NB_BANDS;i++) pred[i] = pred[NB_BANDS+i] = .5*(left[i] + right[i]);
@@ -148,9 +165,13 @@
for (i=0;i<NB_BANDS;i++) pred[3*NB_BANDS+i] = right[i];
for (i=0;i<4*NB_BANDS;i++) target[i] = x[i%NB_BANDS] - pred[i];
- id = find_nearest_multi(codebook, nb_entries, target, NB_BANDS, NULL);
+ id = find_nearest_multi(codebook, nb_entries, target, NB_BANDS, NULL, sign);
+ if (id >= 1<<bits) {+ s = -1;
+ id -= (1<<bits);
+ }
for (i=0;i<NB_BANDS;i++) {- x[i] = pred[(id&MULTI_MASK)*NB_BANDS + i] + codebook[id*NB_BANDS + i];
+ x[i] = pred[(id&MULTI_MASK)*NB_BANDS + i] + s*codebook[id*NB_BANDS + i];
}
if (1) {float err = 0;
@@ -402,10 +423,10 @@
//printf("%f\n", st->features[3][0]);st->features[3][0] = floor(.5 + st->features[3][0]*5)/5;
quantize_2stage(&st->features[3][1]);
- quantize_diff(&st->features[1][0], vq_mem, &st->features[3][0], ceps_codebook_diff4, 10);
+ quantize_diff(&st->features[1][0], vq_mem, &st->features[3][0], ceps_codebook_diff4, 11, 1);
//quantize_2stage(&st->features[1][1]);
- quantize_diff(&st->features[0][0], vq_mem, &st->features[1][0], ceps_codebook_diff2, 6);
- quantize_diff(&st->features[2][0], &st->features[1][0], &st->features[3][0], ceps_codebook_diff2, 6);
+ quantize_diff(&st->features[0][0], vq_mem, &st->features[1][0], ceps_codebook_diff2, 8, 0);
+ quantize_diff(&st->features[2][0], &st->features[1][0], &st->features[3][0], ceps_codebook_diff2, 8, 0);
RNN_COPY(vq_mem, &st->features[3][0], NB_BANDS);
for (i=0;i<4;i++) {fwrite(st->features[i], sizeof(float), NB_FEATURES, ffeat);
--
⑨