ref: 11736ca9e3654397bfe472e51f2ec96433764efe
parent: 1657bae0240e6c0043070580dac93b2893e90ad3
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Wed Dec 23 22:47:42 EST 2020
WIP: 8-bit mul
--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -34,6 +34,12 @@
#define ACTIVATION_RELU 3
#define ACTIVATION_SOFTMAX 4
+#ifdef DOT_PROD
+typedef signed char qweight;
+#else
+typedef float qweight;
+#endif
+
typedef struct {const float *bias;
const float *input_weights;
@@ -65,7 +71,7 @@
typedef struct {const float *bias;
const float *diag_weights;
- const float *recurrent_weights;
+ const qweight *recurrent_weights;
const int *idx;
int nb_neurons;
int activation;
--- a/dnn/training_tf2/dump_lpcnet.py
+++ b/dnn/training_tf2/dump_lpcnet.py
@@ -81,9 +81,10 @@
W = np.concatenate([W, vblock])
idx[pos] = nb_nonzero
f.write('#ifdef DOT_PROD\n')- printVector(f, W, name)
+ W = np.minimum(127, np.maximum(-128, np.round(W*128)))
+ printVector(f, W.astype('int'), name, dtype='qweight') f.write('#else /*DOT_PROD*/\n')- printVector(f, W0, name)
+ printVector(f, W0, name, dtype='nnet_weight')
f.write('#endif /*DOT_PROD*/\n')#idx = np.tile(np.concatenate([np.array([N]), np.arange(N)]), 3*N//16)
printVector(f, idx, name + '_idx', dtype='int')
@@ -232,7 +233,8 @@
hf = open(hfile, 'w')
-f.write('/*This file is automatically generated from a Keras model*/\n\n')+f.write('/*This file is automatically generated from a Keras model*/\n')+f.write('/*based on model {}*/\n\n'.format(sys.argv[1])) f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "nnet.h"\n#include "{}"\n\n'.format(hfile)) hf.write('/*This file is automatically generated from a Keras model*/\n\n')--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -25,6 +25,9 @@
NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
+
+#include "nnet.h"
+
/* No AVX2/FMA support */
#ifndef LPCNET_TEST
static float celt_exp2(float x)
@@ -164,8 +167,8 @@
}
#ifdef DOT_PROD
-
-static void sparse_sgemv_accum8x4(float *out, const float *w, int rows, const int *idx, const float *x)
+#define SCALE_1 (1.f/128.f/127.f)
+static void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, const int *idx, const float *x)
{int i, j;
for (i=0;i<rows;i+=8)
@@ -176,48 +179,21 @@
{int pos;
float * restrict y;
- float xj0, xj1, xj2, xj3;
+ int xj0, xj1, xj2, xj3;
pos = 4 * (*idx++);
- xj0 = x[pos+0];
- xj1 = x[pos+1];
- xj2 = x[pos+2];
- xj3 = x[pos+3];
+ xj0 = floor(.5+127*x[pos+0]);
+ xj1 = floor(.5+127*x[pos+1]);
+ xj2 = floor(.5+127*x[pos+2]);
+ xj3 = floor(.5+127*x[pos+3]);
y = &out[i];
- y[0] += w[0]*xj0;
- y[1] += w[4]*xj0;
- y[2] += w[8]*xj0;
- y[3] += w[12]*xj0;
- y[4] += w[16]*xj0;
- y[5] += w[20]*xj0;
- y[6] += w[24]*xj0;
- y[7] += w[28]*xj0;
-
- y[0] += w[1]*xj1;
- y[1] += w[5]*xj1;
- y[2] += w[9]*xj1;
- y[3] += w[13]*xj1;
- y[4] += w[17]*xj1;
- y[5] += w[21]*xj1;
- y[6] += w[25]*xj1;
- y[7] += w[29]*xj1;
-
- y[0] += w[2]*xj2;
- y[1] += w[6]*xj2;
- y[2] += w[10]*xj2;
- y[3] += w[14]*xj2;
- y[4] += w[18]*xj2;
- y[5] += w[22]*xj2;
- y[6] += w[26]*xj2;
- y[7] += w[30]*xj2;
-
- y[0] += w[3]*xj3;
- y[1] += w[7]*xj3;
- y[2] += w[11]*xj3;
- y[3] += w[15]*xj3;
- y[4] += w[19]*xj3;
- y[5] += w[23]*xj3;
- y[6] += w[27]*xj3;
- y[7] += w[31]*xj3;
+ y[0] += SCALE_1*(w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+ y[1] += SCALE_1*(w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+ y[2] += SCALE_1*(w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+ y[3] += SCALE_1*(w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+ y[4] += SCALE_1*(w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+ y[5] += SCALE_1*(w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+ y[6] += SCALE_1*(w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+ y[7] += SCALE_1*(w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
w += 32;
}
}
@@ -224,7 +200,7 @@
}
#else
-static void sparse_sgemv_accum8x4(float *out, const float *w, int rows, const int *idx, const float *x)
+static void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, const int *idx, const float *x)
{int i, j;
for (i=0;i<rows;i+=8)
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -219,8 +219,34 @@
}
#ifdef DOT_PROD
+static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
+{+ int i, j;
+ for (i=0;i<rows;i+=8)
+ {+ float * restrict y;
+ int cols;
+ __m256 vy0;
+ y = &out[i];
+ vy0 = _mm256_loadu_ps(&y[0]);
+ cols = *idx++;
+ for (j=0;j<cols;j++)
+ {+ int id;
+ __m256 vxj;
+ __m256 vw;
+ id = *idx++;
+
+ //kernel goes here
+
+ weights += 32;
+ }
+ _mm256_storeu_ps (&y[0], vy0);
+ }
+}
+
#else
-static void sparse_sgemv_accum8x4(float *out, const float *weights, int rows, const int *idx, const float *x)
+static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
{int i, j;
for (i=0;i<rows;i+=8)
--
⑨