shithub: opus

Download patch

ref: 60c48ade0a9d192b8535023bf7c7db40341ece1e
parent: e0ca05b1ec5ef4abbfed5f70623ed3e5ea77dd6b
author: Jean-Marc Valin <jmvalin@amazon.com>
date: Thu Jun 30 12:44:34 EDT 2022

Estimate the inner product accuracy to fix check-asm

Estimate the rounding error so that we can have a useful margin of
error when checking the asm against the C code even when the float
operations get reordered due to -ffast-math.

--- a/celt/arm/pitch_neon_intr.c
+++ b/celt/arm/pitch_neon_intr.c
@@ -137,9 +137,10 @@
 /* celt_inner_prod_neon_float_c_simulation() simulates the floating-point   */
 /* operations of celt_inner_prod_neon(), and both functions should have bit */
 /* exact output.                                                            */
-static opus_val32 celt_inner_prod_neon_float_c_simulation(const opus_val16 *x, const opus_val16 *y, int N)
+static opus_val32 celt_inner_prod_neon_float_c_simulation(const opus_val16 *x, const opus_val16 *y, float *err, int N)
 {
    int i;
+   *err = 0;
    opus_val32 xy, xy0 = 0, xy1 = 0, xy2 = 0, xy3 = 0;
    for (i = 0; i < N - 3; i += 4) {
       xy0 = MAC16_16(xy0, x[i + 0], y[i + 0]);
@@ -146,13 +147,17 @@
       xy1 = MAC16_16(xy1, x[i + 1], y[i + 1]);
       xy2 = MAC16_16(xy2, x[i + 2], y[i + 2]);
       xy3 = MAC16_16(xy3, x[i + 3], y[i + 3]);
+      *err += ABS32(xy0)+ABS32(xy1)+ABS32(xy2)+ABS32(xy3);
    }
    xy0 += xy2;
    xy1 += xy3;
    xy = xy0 + xy1;
+   *err += ABS32(xy1)+ABS32(xy0)+ABS32(xy);
    for (; i < N; i++) {
       xy = MAC16_16(xy, x[i], y[i]);
+      *err += ABS32(xy);
    }
+   *err = *err*2e-7 + N*1e-37;
    return xy;
 }
 
@@ -160,32 +165,10 @@
 /* operations of dual_inner_prod_neon(), and both functions should have bit */
 /* exact output.                                                            */
 static void dual_inner_prod_neon_float_c_simulation(const opus_val16 *x, const opus_val16 *y01, const opus_val16 *y02,
-      int N, opus_val32 *xy1, opus_val32 *xy2)
+      int N, opus_val32 *xy1, opus_val32 *xy2, float *err)
 {
-   int i;
-   opus_val32 xy01, xy02, xy01_0 = 0, xy01_1 = 0, xy01_2 = 0, xy01_3 = 0, xy02_0 = 0, xy02_1 = 0, xy02_2 = 0, xy02_3 = 0;
-   for (i = 0; i < N - 3; i += 4) {
-      xy01_0 = MAC16_16(xy01_0, x[i + 0], y01[i + 0]);
-      xy01_1 = MAC16_16(xy01_1, x[i + 1], y01[i + 1]);
-      xy01_2 = MAC16_16(xy01_2, x[i + 2], y01[i + 2]);
-      xy01_3 = MAC16_16(xy01_3, x[i + 3], y01[i + 3]);
-      xy02_0 = MAC16_16(xy02_0, x[i + 0], y02[i + 0]);
-      xy02_1 = MAC16_16(xy02_1, x[i + 1], y02[i + 1]);
-      xy02_2 = MAC16_16(xy02_2, x[i + 2], y02[i + 2]);
-      xy02_3 = MAC16_16(xy02_3, x[i + 3], y02[i + 3]);
-   }
-   xy01_0 += xy01_2;
-   xy02_0 += xy02_2;
-   xy01_1 += xy01_3;
-   xy02_1 += xy02_3;
-   xy01 = xy01_0 + xy01_1;
-   xy02 = xy02_0 + xy02_1;
-   for (; i < N; i++) {
-      xy01 = MAC16_16(xy01, x[i], y01[i]);
-      xy02 = MAC16_16(xy02, x[i], y02[i]);
-   }
-   *xy1 = xy01;
-   *xy2 = xy02;
+   *xy1 = celt_inner_prod_neon_float_c_simulation(x, y01, &err[0], N);
+   *xy2 = celt_inner_prod_neon_float_c_simulation(x, y02, &err[1], N);
 }
 
 #endif /* OPUS_CHECK_ASM */
@@ -225,7 +208,12 @@
     }
 
 #ifdef OPUS_CHECK_ASM
-    celt_assert(ABS32(celt_inner_prod_neon_float_c_simulation(x, y, N) - xy) <= VERY_SMALL);
+    {
+        float err, res;
+        res = celt_inner_prod_neon_float_c_simulation(x, y, &err, N);
+        /*if (ABS32(res - xy) > err) fprintf(stderr, "%g %g %g\n", res, xy, err);*/
+        celt_assert(ABS32(res - xy) <= err);
+    }
 #endif
 
     return xy;
@@ -280,9 +268,12 @@
 #ifdef OPUS_CHECK_ASM
     {
         opus_val32 xy1_c, xy2_c;
-        dual_inner_prod_neon_float_c_simulation(x, y01, y02, N, &xy1_c, &xy2_c);
-        celt_assert(ABS32(xy1_c - *xy1) <= VERY_SMALL);
-        celt_assert(ABS32(xy2_c - *xy2) <= VERY_SMALL);
+        float err[2];
+        dual_inner_prod_neon_float_c_simulation(x, y01, y02, N, &xy1_c, &xy2_c, err);
+        /*if (ABS32(xy1_c - *xy1) > err[0]) fprintf(stderr, "dual1 fail: %g %g %g\n", xy1_c, *xy1, err[0]);
+        if (ABS32(xy2_c - *xy2) > err[1]) fprintf(stderr, "dual2 fail: %g %g %g\n", xy2_c, *xy2, err[1]);*/
+        celt_assert(ABS32(xy1_c - *xy1) <= err[0]);
+        celt_assert(ABS32(xy2_c - *xy2) <= err[1]);
     }
 #endif
 }