shithub: opus

Download patch

ref: 62b546436fc07035802eb998f61702ee2716db60
parent: 61fb3b16894c8fff523efb4255247d151ed5bad5
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Sun Oct 29 20:08:53 EDT 2023

Speed up general case for float matrix multiply

--- a/dnn/vec_avx.h
+++ b/dnn/vec_avx.h
@@ -666,67 +666,54 @@
 #error "No optimizations in vec_avx.h. This should never happen. "
 #endif
 
-static inline void sgemv16x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
+static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
 {
-   int i, j;
-   for (i=0;i<rows;i+=16)
-   {
-      float *y;
-      __m256 vy0, vy8;
-      y = &out[i];
-      vy0 = _mm256_setzero_ps();
-      vy8 = _mm256_setzero_ps();
-      for (j=0;j<cols;j++)
-      {
-         __m256 vxj;
-         __m256 vw;
-         vxj = _mm256_broadcast_ss(&x[j]);
+  int i, j;
+  i=0;
+  for (;i<rows-15;i+=16)
+  {
+     float *y;
+     __m256 vy0, vy8;
+     y = &out[i];
+     vy0 = _mm256_setzero_ps();
+     vy8 = _mm256_setzero_ps();
+     for (j=0;j<cols;j++)
+     {
+        __m256 vxj;
+        __m256 vw;
+        vxj = _mm256_broadcast_ss(&x[j]);
 
-         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+        vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+        vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
 
-         vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
-         vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
-      }
-      _mm256_storeu_ps (&y[0], vy0);
-      _mm256_storeu_ps (&y[8], vy8);
-   }
-}
+        vw = _mm256_loadu_ps(&weights[j*col_stride + i + 8]);
+        vy8 = _mm256_fmadd_ps(vw, vxj, vy8);
+     }
+     _mm256_storeu_ps (&y[0], vy0);
+     _mm256_storeu_ps (&y[8], vy8);
+  }
+  for (;i<rows-7;i+=8)
+  {
+     float *y;
+     __m256 vy0;
+     y = &out[i];
+     vy0 = _mm256_setzero_ps();
+     for (j=0;j<cols;j++)
+     {
+        __m256 vxj;
+        __m256 vw;
+        vxj = _mm256_broadcast_ss(&x[j]);
 
-static inline void sgemv8x1(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
-   int i, j;
-   for (i=0;i<rows;i+=8)
-   {
-      float *y;
-      __m256 vy0;
-      y = &out[i];
-      vy0 = _mm256_setzero_ps();
-      for (j=0;j<cols;j++)
-      {
-         __m256 vxj;
-         __m256 vw;
-         vxj = _mm256_broadcast_ss(&x[j]);
-
-         vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
-         vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
-      }
-      _mm256_storeu_ps (&y[0], vy0);
-   }
-}
-
-static inline void sgemv(float *out, const float *weights, int rows, int cols, int col_stride, const float *x)
-{
-   if ((rows&0xf) == 0) sgemv16x1(out, weights, rows, cols, col_stride, x);
-   else if ((rows&0x7) == 0) sgemv8x1(out, weights, rows, cols, col_stride, x);
-   else {
-      int i, j;
-      for (i=0;i<rows;i++)
-      {
-         out[i] = 0;
-         for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
-      }
-   }
+        vw = _mm256_loadu_ps(&weights[j*col_stride + i]);
+        vy0 = _mm256_fmadd_ps(vw, vxj, vy0);
+     }
+     _mm256_storeu_ps (&y[0], vy0);
+  }
+  for (;i<rows;i++)
+  {
+    out[i] = 0;
+    for (j=0;j<cols;j++) out[i] += weights[j*col_stride + i]*x[j];
+  }
 }
 
 static inline void sparse_sgemv8x4(float *out, const float *weights, const int *idx, int rows, const float *x)
--