shithub: opus

Download patch

ref: 11736ca9e3654397bfe472e51f2ec96433764efe
parent: 1657bae0240e6c0043070580dac93b2893e90ad3
author: Jean-Marc Valin <jmvalin@jmvalin.ca>
date: Wed Dec 23 22:47:42 EST 2020

WIP: 8-bit mul

--- a/dnn/nnet.h
+++ b/dnn/nnet.h
@@ -34,6 +34,12 @@
 #define ACTIVATION_RELU    3
 #define ACTIVATION_SOFTMAX 4
 
+#ifdef DOT_PROD
+typedef signed char qweight;
+#else
+typedef float qweight;
+#endif
+
 typedef struct {
   const float *bias;
   const float *input_weights;
@@ -65,7 +71,7 @@
 typedef struct {
   const float *bias;
   const float *diag_weights;
-  const float *recurrent_weights;
+  const qweight *recurrent_weights;
   const int *idx;
   int nb_neurons;
   int activation;
--- a/dnn/training_tf2/dump_lpcnet.py
+++ b/dnn/training_tf2/dump_lpcnet.py
@@ -81,9 +81,10 @@
                 W = np.concatenate([W, vblock])
         idx[pos] = nb_nonzero
     f.write('#ifdef DOT_PROD\n')
-    printVector(f, W, name)
+    W = np.minimum(127, np.maximum(-128, np.round(W*128)))
+    printVector(f, W.astype('int'), name, dtype='qweight')
     f.write('#else /*DOT_PROD*/\n')
-    printVector(f, W0, name)
+    printVector(f, W0, name, dtype='nnet_weight')
     f.write('#endif /*DOT_PROD*/\n')
     #idx = np.tile(np.concatenate([np.array([N]), np.arange(N)]), 3*N//16)
     printVector(f, idx, name + '_idx', dtype='int')
@@ -232,7 +233,8 @@
 hf = open(hfile, 'w')
 
 
-f.write('/*This file is automatically generated from a Keras model*/\n\n')
+f.write('/*This file is automatically generated from a Keras model*/\n')
+f.write('/*based on model {}*/\n\n'.format(sys.argv[1]))
 f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "nnet.h"\n#include "{}"\n\n'.format(hfile))
 
 hf.write('/*This file is automatically generated from a Keras model*/\n\n')
--- a/dnn/vec.h
+++ b/dnn/vec.h
@@ -25,6 +25,9 @@
    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */
+
+#include "nnet.h"
+
 /* No AVX2/FMA support */
 #ifndef LPCNET_TEST
 static float celt_exp2(float x)
@@ -164,8 +167,8 @@
 }
 
 #ifdef DOT_PROD
-
-static void sparse_sgemv_accum8x4(float *out, const float *w, int rows, const int *idx, const float *x)
+#define SCALE_1 (1.f/128.f/127.f)
+static void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, const int *idx, const float *x)
 {
    int i, j;
    for (i=0;i<rows;i+=8)
@@ -176,48 +179,21 @@
       {
          int pos;
          float * restrict y;
-         float xj0, xj1, xj2, xj3;
+         int xj0, xj1, xj2, xj3;
          pos = 4 * (*idx++);
-         xj0 = x[pos+0];
-         xj1 = x[pos+1];
-         xj2 = x[pos+2];
-         xj3 = x[pos+3];
+         xj0 = floor(.5+127*x[pos+0]);
+         xj1 = floor(.5+127*x[pos+1]);
+         xj2 = floor(.5+127*x[pos+2]);
+         xj3 = floor(.5+127*x[pos+3]);
          y = &out[i];
-         y[0] += w[0]*xj0;
-         y[1] += w[4]*xj0;
-         y[2] += w[8]*xj0;
-         y[3] += w[12]*xj0;
-         y[4] += w[16]*xj0;
-         y[5] += w[20]*xj0;
-         y[6] += w[24]*xj0;
-         y[7] += w[28]*xj0;
-
-         y[0] += w[1]*xj1;
-         y[1] += w[5]*xj1;
-         y[2] += w[9]*xj1;
-         y[3] += w[13]*xj1;
-         y[4] += w[17]*xj1;
-         y[5] += w[21]*xj1;
-         y[6] += w[25]*xj1;
-         y[7] += w[29]*xj1;
-
-         y[0] += w[2]*xj2;
-         y[1] += w[6]*xj2;
-         y[2] += w[10]*xj2;
-         y[3] += w[14]*xj2;
-         y[4] += w[18]*xj2;
-         y[5] += w[22]*xj2;
-         y[6] += w[26]*xj2;
-         y[7] += w[30]*xj2;
-
-         y[0] += w[3]*xj3;
-         y[1] += w[7]*xj3;
-         y[2] += w[11]*xj3;
-         y[3] += w[15]*xj3;
-         y[4] += w[19]*xj3;
-         y[5] += w[23]*xj3;
-         y[6] += w[27]*xj3;
-         y[7] += w[31]*xj3;
+         y[0] += SCALE_1*(w[0]*xj0+w[1]*xj1+w[2]*xj2+w[3]*xj3);
+         y[1] += SCALE_1*(w[4]*xj0+w[5]*xj1+w[6]*xj2+w[7]*xj3);
+         y[2] += SCALE_1*(w[8]*xj0+w[9]*xj1+w[10]*xj2+w[11]*xj3);
+         y[3] += SCALE_1*(w[12]*xj0+w[13]*xj1+w[14]*xj2+w[15]*xj3);
+         y[4] += SCALE_1*(w[16]*xj0+w[17]*xj1+w[18]*xj2+w[19]*xj3);
+         y[5] += SCALE_1*(w[20]*xj0+w[21]*xj1+w[22]*xj2+w[23]*xj3);
+         y[6] += SCALE_1*(w[24]*xj0+w[25]*xj1+w[26]*xj2+w[27]*xj3);
+         y[7] += SCALE_1*(w[28]*xj0+w[29]*xj1+w[30]*xj2+w[31]*xj3);
          w += 32;
       }
    }
@@ -224,7 +200,7 @@
 }
 
 #else
-static void sparse_sgemv_accum8x4(float *out, const float *w, int rows, const int *idx, const float *x)
+static void sparse_sgemv_accum8x4(float *out, const qweight *w, int rows, const int *idx, const float *x)
 {
    int i, j;
    for (i=0;i<rows;i+=8)
--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -219,8 +219,34 @@
 }
 
 #ifdef DOT_PROD
+static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
+{
+   int i, j;
+   for (i=0;i<rows;i+=8)
+   {
+      float * restrict y;
+      int cols;
+      __m256 vy0;
+      y = &out[i];
+      vy0 = _mm256_loadu_ps(&y[0]);
+      cols = *idx++;
+      for (j=0;j<cols;j++)
+      {
+         int id;
+         __m256 vxj;
+         __m256 vw;
+         id = *idx++;
+         
+         //kernel goes here
+
+         weights += 32;
+      }
+      _mm256_storeu_ps (&y[0], vy0);
+   }
+}
+
 #else
-static void sparse_sgemv_accum8x4(float *out, const float *weights, int rows, const int *idx, const float *x)
+static void sparse_sgemv_accum8x4(float *out, const qweight *weights, int rows, const int *idx, const float *x)
 {
    int i, j;
    for (i=0;i<rows;i+=8)
--