shithub: dav1d

Download patch

ref: 2737c05eac98c0f4c99572614714a033617f8f3f
parent: fdf1570e13b9360d7f3d224e1f77655e34980350
author: Henrik Gramner <gramner@twoorioles.com>
date: Wed Dec 2 09:10:52 EST 2020

Add miscellaneous minor wiener optimizations

Combine horizontal and vertical filter pointers into a single parameter
when calling the wiener DSP function.

Eliminate the +128 filter coefficient handling where possible.

--- a/src/arm/32/looprestoration.S
+++ b/src/arm/32/looprestoration.S
@@ -30,7 +30,7 @@
 
 // void dav1d_wiener_filter_h_8bpc_neon(int16_t *dst, const pixel (*left)[4],
 //                                      const pixel *src, ptrdiff_t stride,
-//                                      const int16_t fh[7], const intptr_t w,
+//                                      const int16_t fh[8], intptr_t w,
 //                                      int h, enum LrEdgeFlags edges);
 function wiener_filter_h_8bpc_neon, export=1
         push            {r4-r11,lr}
@@ -38,7 +38,7 @@
         ldrd            r4,  r5,  [sp, #52]
         ldrd            r6,  r7,  [sp, #60]
         mov             r8,  r5
-        vld1.16         {q0},  [r4]
+        vld1.16         {q0},  [r4, :128]
         movw            r9,  #(1 << 14) - (1 << 2)
         vdup.16         q14, r9
         vmov.s16        q15, #2048
@@ -358,7 +358,7 @@
 
 // void dav1d_wiener_filter_v_8bpc_neon(pixel *dst, ptrdiff_t stride,
 //                                      const int16_t *mid, int w, int h,
-//                                      const int16_t fv[7], enum LrEdgeFlags edges,
+//                                      const int16_t fv[8], enum LrEdgeFlags edges,
 //                                      ptrdiff_t mid_stride);
 function wiener_filter_v_8bpc_neon, export=1
         push            {r4-r7,lr}
@@ -365,11 +365,7 @@
         ldrd            r4,  r5,  [sp, #20]
         ldrd            r6,  r7,  [sp, #28]
         mov             lr,  r4
-        vmov.s16        q1,  #0
-        mov             r12, #128
-        vld1.16         {q0},  [r5]
-        vmov.s16        d2[3], r12
-        vadd.s16        q0,  q0,  q1
+        vld1.16         {q0},  [r5, :128]
 
         // Calculate the number of rows to move back when looping vertically
         mov             r12, r4
--- a/src/arm/32/looprestoration16.S
+++ b/src/arm/32/looprestoration16.S
@@ -39,7 +39,7 @@
         ldrd            r4,  r5,  [sp, #100]
         ldrd            r6,  r7,  [sp, #108]
         ldr             r8,       [sp, #116] // bitdepth_max
-        vld1.16         {q0}, [r4]
+        vld1.16         {q0}, [r4, :128]
         clz             r8,  r8
         vmov.i32        q14, #1
         sub             r9,  r8,  #38  // -(bitdepth + 6)
@@ -151,16 +151,14 @@
         b               6f
 
 4:      // Loop horizontally
-        vext.8          q10, q2,  q3,  #6
         vext.8          q8,  q2,  q3,  #2
         vext.8          q9,  q2,  q3,  #4
-        vshll.u16       q6,  d20, #7
-        vshll.u16       q7,  d21, #7
-        vmlal.s16       q6,  d4,  d0[0]
+        vext.8          q10, q2,  q3,  #6
+        vmull.s16       q6,  d4,  d0[0]
         vmlal.s16       q6,  d16, d0[1]
         vmlal.s16       q6,  d18, d0[2]
         vmlal.s16       q6,  d20, d0[3]
-        vmlal.s16       q7,  d5,  d0[0]
+        vmull.s16       q7,  d5,  d0[0]
         vmlal.s16       q7,  d17, d0[1]
         vmlal.s16       q7,  d19, d0[2]
         vmlal.s16       q7,  d21, d0[3]
@@ -173,14 +171,12 @@
         vmlal.s16       q7,  d17, d1[0]
         vmlal.s16       q7,  d19, d1[1]
         vmlal.s16       q7,  d21, d1[2]
-        vext.8          q10, q4,  q5,  #6
         vext.8          q2,  q4,  q5,  #2
-        vshll.u16       q8,  d20, #7
-        vshll.u16       q9,  d21, #7
-        vmlal.s16       q8,  d8,  d0[0]
+        vext.8          q10, q4,  q5,  #6
+        vmull.s16       q8,  d8,  d0[0]
         vmlal.s16       q8,  d4,  d0[1]
         vmlal.s16       q8,  d20, d0[3]
-        vmlal.s16       q9,  d9,  d0[0]
+        vmull.s16       q9,  d9,  d0[0]
         vmlal.s16       q9,  d5,  d0[1]
         vmlal.s16       q9,  d21, d0[3]
         vext.8          q2,  q4,  q5,  #4
@@ -233,8 +229,7 @@
         vext.8          d17, d4,  d5,  #4
         vext.8          d19, d5,  d6,  #2
         vext.8          d20, d5,  d6,  #4
-        vshll.u16       q6,  d18, #7
-        vmlal.s16       q6,  d4,  d0[0]
+        vmull.s16       q6,  d4,  d0[0]
         vmlal.s16       q6,  d16, d0[1]
         vmlal.s16       q6,  d17, d0[2]
         vmlal.s16       q6,  d18, d0[3]
@@ -247,8 +242,7 @@
         vext.8          d17, d8,  d9,  #4
         vext.8          d19, d9,  d10, #2
         vext.8          d20, d9,  d10, #4
-        vshll.u16       q7,  d18, #7
-        vmlal.s16       q7,  d8,  d0[0]
+        vmull.s16       q7,  d8,  d0[0]
         vmlal.s16       q7,  d16, d0[1]
         vmlal.s16       q7,  d17, d0[2]
         vmlal.s16       q7,  d18, d0[3]
@@ -356,14 +350,9 @@
         vadd.i32        q8,  q9
         vpadd.i32       d12, d12, d13
         vpadd.i32       d13, d16, d17
-        vdup.16         d14, d4[3]
-        vdup.16         d15, d8[3]
         vpadd.i32       d12, d12, d13
-        vtrn.16         d14, d15
         vadd.i32        d12, d12, d28
-        vshll.u16       q7,  d14, #7
         vmvn.i16        d20, #0x8000 // 0x7fff = (1 << 15) - 1
-        vadd.i32        d12, d12, d14
         vrshl.s32       d12, d12, d26
         vqmovun.s32     d12, q6
         vmin.u16        d12, d12, d20
@@ -401,14 +390,10 @@
         ldrd            r4,  r5,  [sp, #52]
         ldrd            r6,  r7,  [sp, #60]
         ldr             lr,       [sp, #68] // bitdepth_max
-        vmov.i16        q1,  #0
-        mov             r12, #128
-        vld1.16         {q0},  [r5]
+        vld1.16         {q0},  [r5, :128]
         vdup.16         q5,  lr
         clz             lr,  lr
-        vmov.i16        d2[3], r12
         sub             lr,  lr,  #11   // round_bits_v
-        vadd.i16        q0,  q0,  q1
         vdup.32         q4,  lr
         mov             lr,  r4
         vneg.s32        q4,  q4         // -round_bits_v
--- a/src/arm/64/looprestoration.S
+++ b/src/arm/64/looprestoration.S
@@ -30,7 +30,7 @@
 
 // void dav1d_wiener_filter_h_8bpc_neon(int16_t *dst, const pixel (*left)[4],
 //                                      const pixel *src, ptrdiff_t stride,
-//                                      const int16_t fh[7], const intptr_t w,
+//                                      const int16_t fh[8], intptr_t w,
 //                                      int h, enum LrEdgeFlags edges);
 function wiener_filter_h_8bpc_neon, export=1
         mov             w8,  w5
@@ -308,13 +308,11 @@
 
 // void dav1d_wiener_filter_v_8bpc_neon(pixel *dst, ptrdiff_t stride,
 //                                      const int16_t *mid, int w, int h,
-//                                      const int16_t fv[7], enum LrEdgeFlags edges,
+//                                      const int16_t fv[8], enum LrEdgeFlags edges,
 //                                      ptrdiff_t mid_stride);
 function wiener_filter_v_8bpc_neon, export=1
         mov             w8,  w4
         ld1             {v0.8h},  [x5]
-        movi            v1.8h, #128
-        add             v1.8h,  v1.8h,  v0.8h
 
         // Calculate the number of rows to move back when looping vertically
         mov             w11, w4
@@ -359,7 +357,7 @@
         smull           v2.4s,  v16.4h,  v0.h[0]
         smlal           v2.4s,  v17.4h,  v0.h[1]
         smlal           v2.4s,  v18.4h,  v0.h[2]
-        smlal           v2.4s,  v19.4h,  v1.h[3]
+        smlal           v2.4s,  v19.4h,  v0.h[3]
         smlal           v2.4s,  v20.4h,  v0.h[4]
         smlal           v2.4s,  v21.4h,  v0.h[5]
         smlal           v2.4s,  v22.4h,  v0.h[6]
@@ -366,7 +364,7 @@
         smull2          v3.4s,  v16.8h,  v0.h[0]
         smlal2          v3.4s,  v17.8h,  v0.h[1]
         smlal2          v3.4s,  v18.8h,  v0.h[2]
-        smlal2          v3.4s,  v19.8h,  v1.h[3]
+        smlal2          v3.4s,  v19.8h,  v0.h[3]
         smlal2          v3.4s,  v20.8h,  v0.h[4]
         smlal2          v3.4s,  v21.8h,  v0.h[5]
         smlal2          v3.4s,  v22.8h,  v0.h[6]
--- a/src/arm/64/looprestoration16.S
+++ b/src/arm/64/looprestoration16.S
@@ -143,12 +143,6 @@
         b               6f
 
 4:      // Loop horizontally
-.macro ushll_sz d0, d1, src, shift, wd
-        ushll           \d0\().4s,  \src\().4h,  \shift
-.ifc \wd, .8h
-        ushll2          \d1\().4s,  \src\().8h,  \shift
-.endif
-.endm
 .macro add_sz d0, d1, s0, s1, c, wd
         add             \d0\().4s,  \s0\().4s,   \c\().4s
 .ifc \wd, .8h
@@ -172,14 +166,13 @@
         // Interleaving the mul/mla chains actually hurts performance
         // significantly on Cortex A53, thus keeping mul/mla tightly
         // chained like this.
-        ext             v18.16b, v2.16b,  v3.16b, #6
         ext             v16.16b, v2.16b,  v3.16b, #2
         ext             v17.16b, v2.16b,  v3.16b, #4
+        ext             v18.16b, v2.16b,  v3.16b, #6
         ext             v19.16b, v2.16b,  v3.16b, #8
         ext             v20.16b, v2.16b,  v3.16b, #10
-        ushll_sz        v6,  v7,  v18, #7, \wd
         ext             v21.16b, v2.16b,  v3.16b, #12
-        smlal           v6.4s,   v2.4h,   v0.h[0]
+        smull           v6.4s,   v2.4h,   v0.h[0]
         smlal           v6.4s,   v16.4h,  v0.h[1]
         smlal           v6.4s,   v17.4h,  v0.h[2]
         smlal           v6.4s,   v18.4h,  v0.h[3]
@@ -187,7 +180,7 @@
         smlal           v6.4s,   v20.4h,  v0.h[5]
         smlal           v6.4s,   v21.4h,  v0.h[6]
 .ifc \wd, .8h
-        smlal2          v7.4s,   v2.8h,   v0.h[0]
+        smull2          v7.4s,   v2.8h,   v0.h[0]
         smlal2          v7.4s,   v16.8h,  v0.h[1]
         smlal2          v7.4s,   v17.8h,  v0.h[2]
         smlal2          v7.4s,   v18.8h,  v0.h[3]
@@ -195,14 +188,13 @@
         smlal2          v7.4s,   v20.8h,  v0.h[5]
         smlal2          v7.4s,   v21.8h,  v0.h[6]
 .endif
-        ext             v21.16b, v4.16b,  v5.16b, #6
         ext             v19.16b, v4.16b,  v5.16b, #2
         ext             v20.16b, v4.16b,  v5.16b, #4
+        ext             v21.16b, v4.16b,  v5.16b, #6
         ext             v22.16b, v4.16b,  v5.16b, #8
         ext             v23.16b, v4.16b,  v5.16b, #10
-        ushll_sz        v16, v17, v21, #7, \wd
         ext             v24.16b, v4.16b,  v5.16b, #12
-        smlal           v16.4s,  v4.4h,   v0.h[0]
+        smull           v16.4s,  v4.4h,   v0.h[0]
         smlal           v16.4s,  v19.4h,  v0.h[1]
         smlal           v16.4s,  v20.4h,  v0.h[2]
         smlal           v16.4s,  v21.4h,  v0.h[3]
@@ -210,7 +202,7 @@
         smlal           v16.4s,  v23.4h,  v0.h[5]
         smlal           v16.4s,  v24.4h,  v0.h[6]
 .ifc \wd, .8h
-        smlal2          v17.4s,  v4.8h,   v0.h[0]
+        smull2          v17.4s,  v4.8h,   v0.h[0]
         smlal2          v17.4s,  v19.8h,  v0.h[1]
         smlal2          v17.4s,  v20.8h,  v0.h[2]
         smlal2          v17.4s,  v21.8h,  v0.h[3]
@@ -329,13 +321,9 @@
         add             v16.4s,  v16.4s,  v17.4s
         addv            s6,      v6.4s
         addv            s7,      v16.4s
-        dup             v16.4h,  v2.h[3]
-        ins             v16.h[1], v4.h[3]
         ins             v6.s[1], v7.s[0]
         mvni            v24.4h,  #0x80, lsl #8 // 0x7fff = (1 << 15) - 1
-        ushll           v16.4s,  v16.4h,  #7
         add             v6.2s,   v6.2s,   v30.2s
-        add             v6.2s,   v6.2s,   v16.2s
         srshl           v6.2s,   v6.2s,   v29.2s
         sqxtun          v6.4h,   v6.4s
         umin            v6.4h,   v6.4h,   v24.4h
@@ -371,9 +359,7 @@
         ld1             {v0.8h},  [x5]
         dup             v31.8h,  w8
         clz             w8,  w8
-        movi            v1.8h,   #128
         sub             w8,  w8,  #11   // round_bits_v
-        add             v1.8h,   v1.8h,   v0.8h
         dup             v30.4s,  w8
         mov             w8,  w4
         neg             v30.4s,  v30.4s // -round_bits_v
@@ -421,7 +407,7 @@
         smull           v2.4s,  v16.4h,  v0.h[0]
         smlal           v2.4s,  v17.4h,  v0.h[1]
         smlal           v2.4s,  v18.4h,  v0.h[2]
-        smlal           v2.4s,  v19.4h,  v1.h[3]
+        smlal           v2.4s,  v19.4h,  v0.h[3]
         smlal           v2.4s,  v20.4h,  v0.h[4]
         smlal           v2.4s,  v21.4h,  v0.h[5]
         smlal           v2.4s,  v22.4h,  v0.h[6]
@@ -428,7 +414,7 @@
         smull2          v3.4s,  v16.8h,  v0.h[0]
         smlal2          v3.4s,  v17.8h,  v0.h[1]
         smlal2          v3.4s,  v18.8h,  v0.h[2]
-        smlal2          v3.4s,  v19.8h,  v1.h[3]
+        smlal2          v3.4s,  v19.8h,  v0.h[3]
         smlal2          v3.4s,  v20.8h,  v0.h[4]
         smlal2          v3.4s,  v21.8h,  v0.h[5]
         smlal2          v3.4s,  v22.8h,  v0.h[6]
--- a/src/arm/looprestoration_init_tmpl.c
+++ b/src/arm/looprestoration_init_tmpl.c
@@ -45,12 +45,11 @@
 // 1 << (bitdepth + 6 - round_bits_h).
 void BF(dav1d_wiener_filter_h, neon)(int16_t *dst, const pixel (*left)[4],
                                      const pixel *src, ptrdiff_t stride,
-                                     const int16_t fh[7], const intptr_t w,
+                                     const int16_t fh[8], intptr_t w,
                                      int h, enum LrEdgeFlags edges
                                      HIGHBD_DECL_SUFFIX);
 // This calculates things slightly differently than the reference C version.
 // This version calculates roughly this:
-// fv[3] += 128;
 // int32_t sum = 0;
 // for (int i = 0; i < 7; i++)
 //     sum += mid[idx] * fv[i];
@@ -58,7 +57,7 @@
 // This function assumes that the width is a multiple of 8.
 void BF(dav1d_wiener_filter_v, neon)(pixel *dst, ptrdiff_t stride,
                                      const int16_t *mid, int w, int h,
-                                     const int16_t fv[7], enum LrEdgeFlags edges,
+                                     const int16_t fv[8], enum LrEdgeFlags edges,
                                      ptrdiff_t mid_stride HIGHBD_DECL_SUFFIX);
 void BF(dav1d_copy_narrow, neon)(pixel *dst, ptrdiff_t stride,
                                  const pixel *src, int w, int h);
@@ -66,9 +65,9 @@
 static void wiener_filter_neon(pixel *const dst, const ptrdiff_t dst_stride,
                                const pixel (*const left)[4],
                                const pixel *lpf, const ptrdiff_t lpf_stride,
-                               const int w, const int h, const int16_t fh[7],
-                               const int16_t fv[7], const enum LrEdgeFlags edges
-                               HIGHBD_DECL_SUFFIX)
+                               const int w, const int h,
+                               const int16_t filter[2][8],
+                               const enum LrEdgeFlags edges HIGHBD_DECL_SUFFIX)
 {
     ALIGN_STK_16(int16_t, mid, 68 * 384,);
     int mid_stride = (w + 7) & ~7;
@@ -75,20 +74,21 @@
 
     // Horizontal filter
     BF(dav1d_wiener_filter_h, neon)(&mid[2 * mid_stride], left, dst, dst_stride,
-                                    fh, w, h, edges HIGHBD_TAIL_SUFFIX);
+                                    filter[0], w, h, edges HIGHBD_TAIL_SUFFIX);
     if (edges & LR_HAVE_TOP)
         BF(dav1d_wiener_filter_h, neon)(mid, NULL, lpf, lpf_stride,
-                                        fh, w, 2, edges HIGHBD_TAIL_SUFFIX);
+                                        filter[0], w, 2, edges
+                                        HIGHBD_TAIL_SUFFIX);
     if (edges & LR_HAVE_BOTTOM)
         BF(dav1d_wiener_filter_h, neon)(&mid[(2 + h) * mid_stride], NULL,
                                         lpf + 6 * PXSTRIDE(lpf_stride),
-                                        lpf_stride, fh, w, 2, edges
+                                        lpf_stride, filter[0], w, 2, edges
                                         HIGHBD_TAIL_SUFFIX);
 
     // Vertical filter
     if (w >= 8)
         BF(dav1d_wiener_filter_v, neon)(dst, dst_stride, &mid[2*mid_stride],
-                                        w & ~7, h, fv, edges,
+                                        w & ~7, h, filter[1], edges,
                                         mid_stride * sizeof(*mid)
                                         HIGHBD_TAIL_SUFFIX);
     if (w & 7) {
@@ -97,7 +97,7 @@
         ALIGN_STK_16(pixel, tmp, 64 * 8,);
         BF(dav1d_wiener_filter_v, neon)(tmp, (w & 7) * sizeof(pixel),
                                         &mid[2*mid_stride + (w & ~7)],
-                                        w & 7, h, fv, edges,
+                                        w & 7, h, filter[1], edges,
                                         mid_stride * sizeof(*mid)
                                         HIGHBD_TAIL_SUFFIX);
         BF(dav1d_copy_narrow, neon)(dst + (w & ~7), dst_stride, tmp, w & 7, h);
--- a/src/looprestoration.h
+++ b/src/looprestoration.h
@@ -54,9 +54,8 @@
 void (name)(pixel *dst, ptrdiff_t dst_stride, \
             const_left_pixel_row left, \
             const pixel *lpf, ptrdiff_t lpf_stride, \
-            int w, int h, const int16_t filterh[7], \
-            const int16_t filterv[7], enum LrEdgeFlags edges \
-            HIGHBD_DECL_SUFFIX)
+            int w, int h, const int16_t filter[2][8], \
+            enum LrEdgeFlags edges HIGHBD_DECL_SUFFIX)
 typedef decl_wiener_filter_fn(*wienerfilter_fn);
 
 #define decl_selfguided_filter_fn(name) \
--- a/src/looprestoration_tmpl.c
+++ b/src/looprestoration_tmpl.c
@@ -135,7 +135,7 @@
                      const pixel (*const left)[4],
                      const pixel *lpf, const ptrdiff_t lpf_stride,
                      const int w, const int h,
-                     const int16_t filterh[7], const int16_t filterv[7],
+                     const int16_t filter[2][8],
                      const enum LrEdgeFlags edges HIGHBD_DECL_SUFFIX)
 {
     // Wiener filtering is applied to a maximum stripe height of 64 + 3 pixels
@@ -156,10 +156,13 @@
     const int clip_limit = 1 << (bitdepth + 1 + 7 - round_bits_h);
     for (int j = 0; j < h + 6; j++) {
         for (int i = 0; i < w; i++) {
-            int sum = (tmp_ptr[i + 3] << 7) + (1 << (bitdepth + 6));
+            int sum = (1 << (bitdepth + 6));
+#if BITDEPTH == 8
+            sum += tmp_ptr[i + 3] * 128;
+#endif
 
             for (int k = 0; k < 7; k++) {
-                sum += tmp_ptr[i + k] * filterh[k];
+                sum += tmp_ptr[i + k] * filter[0][k];
             }
 
             hor_ptr[i] =
@@ -174,10 +177,10 @@
     const int round_offset = 1 << (bitdepth + (round_bits_v - 1));
     for (int j = 0; j < h; j++) {
         for (int i = 0; i < w; i++) {
-            int sum = (hor[(j + 3) * REST_UNIT_STRIDE + i] << 7) - round_offset;
+            int sum = -round_offset;
 
             for (int k = 0; k < 7; k++) {
-                sum += hor[(j + k) * REST_UNIT_STRIDE + i] * filterv[k];
+                sum += hor[(j + k) * REST_UNIT_STRIDE + i] * filter[1][k];
             }
 
             p[j * PXSTRIDE(p_stride) + i] =
--- a/src/lr_apply_tmpl.c
+++ b/src/lr_apply_tmpl.c
@@ -162,18 +162,24 @@
     // The first stripe of the frame is shorter by 8 luma pixel rows.
     int stripe_h = imin((64 - 8 * !y) >> ss_ver, row_h - y);
 
-    // FIXME [8] might be easier for SIMD
-    int16_t filterh[7], filterv[7];
+    ALIGN_STK_16(int16_t, filter, 2, [8]);
     if (lr->type == DAV1D_RESTORATION_WIENER) {
-        filterh[0] = filterh[6] = lr->filter_h[0];
-        filterh[1] = filterh[5] = lr->filter_h[1];
-        filterh[2] = filterh[4] = lr->filter_h[2];
-        filterh[3] = -((filterh[0] + filterh[1] + filterh[2]) * 2);
+        filter[0][0] = filter[0][6] = lr->filter_h[0];
+        filter[0][1] = filter[0][5] = lr->filter_h[1];
+        filter[0][2] = filter[0][4] = lr->filter_h[2];
+        filter[0][3] = -(filter[0][0] + filter[0][1] + filter[0][2]) * 2;
+#if BITDEPTH != 8
+        /* For 8-bit SIMD it's beneficial to handle the +128 separately
+         * in order to avoid overflows. */
+        filter[0][3] += 128;
+#endif
 
-        filterv[0] = filterv[6] = lr->filter_v[0];
-        filterv[1] = filterv[5] = lr->filter_v[1];
-        filterv[2] = filterv[4] = lr->filter_v[2];
-        filterv[3] = -((filterv[0] + filterv[1] + filterv[2]) * 2);
+        filter[1][0] = filter[1][6] = lr->filter_v[0];
+        filter[1][1] = filter[1][5] = lr->filter_v[1];
+        filter[1][2] = filter[1][4] = lr->filter_v[2];
+        filter[1][3] = 128 - (filter[1][0] + filter[1][1] + filter[1][2]) * 2;
+    } else {
+        assert(lr->type == DAV1D_RESTORATION_SGRPROJ);
     }
 
     while (y + stripe_h <= row_h) {
@@ -181,9 +187,8 @@
         edges ^= (-(y + stripe_h != row_h) ^ edges) & LR_HAVE_BOTTOM;
         if (lr->type == DAV1D_RESTORATION_WIENER) {
             dsp->lr.wiener(p, p_stride, left, lpf, lpf_stride, unit_w, stripe_h,
-                           filterh, filterv, edges HIGHBD_CALL_SUFFIX);
+                           filter, edges HIGHBD_CALL_SUFFIX);
         } else {
-            assert(lr->type == DAV1D_RESTORATION_SGRPROJ);
             dsp->lr.selfguided(p, p_stride, left, lpf, lpf_stride, unit_w, stripe_h,
                                lr->sgr_idx, lr->sgr_weights, edges HIGHBD_CALL_SUFFIX);
         }
--- a/src/ppc/looprestoration_init_tmpl.c
+++ b/src/ppc/looprestoration_init_tmpl.c
@@ -49,7 +49,7 @@
 
 static void wiener_filter_h_vsx(int32_t *hor_ptr,
                                 uint8_t *tmp_ptr,
-                                const int16_t filterh[7],
+                                const int16_t filterh[8],
                                 const int w, const int h)
 {
     static const i32x4 zerov = vec_splats(0);
@@ -149,14 +149,10 @@
 } while (0)
 
 #define LOAD_AND_APPLY_FILTER_V(sumpixelv, hor) do { \
-    i32x4 v_1 = (i32x4) vec_ld( 0, &hor[(j + 3) * REST_UNIT_STRIDE + i]); \
-    i32x4 v_2 = (i32x4) vec_ld(16, &hor[(j + 3) * REST_UNIT_STRIDE + i]); \
-    i32x4 v_3 = (i32x4) vec_ld(32, &hor[(j + 3) * REST_UNIT_STRIDE + i]); \
-    i32x4 v_4 = (i32x4) vec_ld(48, &hor[(j + 3) * REST_UNIT_STRIDE + i]); \
-    i32x4 sum1 = -round_offset_vec; \
-    i32x4 sum2 = -round_offset_vec; \
-    i32x4 sum3 = -round_offset_vec; \
-    i32x4 sum4 = -round_offset_vec; \
+    i32x4 sum1 = round_vec; \
+    i32x4 sum2 = round_vec; \
+    i32x4 sum3 = round_vec; \
+    i32x4 sum4 = round_vec; \
     APPLY_FILTER_V(0, filterv0); \
     APPLY_FILTER_V(1, filterv1); \
     APPLY_FILTER_V(2, filterv2); \
@@ -164,31 +160,25 @@
     APPLY_FILTER_V(4, filterv4); \
     APPLY_FILTER_V(5, filterv5); \
     APPLY_FILTER_V(6, filterv6); \
-    sum1 = (v_1 << seven_vec) + sum1 + rounding_off_vec; \
-    sum2 = (v_2 << seven_vec) + sum2 + rounding_off_vec; \
-    sum3 = (v_3 << seven_vec) + sum3 + rounding_off_vec; \
-    sum4 = (v_4 << seven_vec) + sum4 + rounding_off_vec; \
     sum1 = sum1 >> round_bits_vec; \
     sum2 = sum2 >> round_bits_vec; \
     sum3 = sum3 >> round_bits_vec; \
     sum4 = sum4 >> round_bits_vec; \
-    i16x8 sum_short_packed_1 = (i16x8) vec_pack( sum1, sum2 ); \
-    i16x8 sum_short_packed_2 = (i16x8) vec_pack( sum3, sum4 ); \
+    i16x8 sum_short_packed_1 = (i16x8) vec_pack(sum1, sum2); \
+    i16x8 sum_short_packed_2 = (i16x8) vec_pack(sum3, sum4); \
     sum_short_packed_1 = iclip_u8_vec(sum_short_packed_1); \
     sum_short_packed_2 = iclip_u8_vec(sum_short_packed_2); \
-    sum_pixel = (u8x16) vec_pack(sum_short_packed_1, sum_short_packed_2 ); \
+    sum_pixel = (u8x16) vec_pack(sum_short_packed_1, sum_short_packed_2); \
 } while (0)
 
 static inline void wiener_filter_v_vsx(uint8_t *p,
                                        const ptrdiff_t p_stride,
                                        const int32_t *hor,
-                                       const int16_t filterv[7],
+                                       const int16_t filterv[8],
                                        const int w, const int h)
 {
     static const i32x4 round_bits_vec = vec_splats(11);
-    static const i32x4 rounding_off_vec = vec_splats(1 << 10);
-    static const i32x4 round_offset_vec = vec_splats(1 << 18);
-    static const i32x4 seven_vec = vec_splats(7);
+    static const i32x4 round_vec = vec_splats((1 << 10) - (1 << 18));
 
     i32x4 filterv0 =  vec_splats((int32_t) filterv[0]);
     i32x4 filterv1 =  vec_splats((int32_t) filterv[1]);
@@ -319,8 +309,7 @@
                               const uint8_t *lpf,
                               const ptrdiff_t lpf_stride,
                               const int w, const int h,
-                              const int16_t filterh[7],
-                              const int16_t filterv[7],
+                              const int16_t filter[2][8],
                               const enum LrEdgeFlags edges HIGHBD_DECL_SUFFIX)
 {
     // Wiener filtering is applied to a maximum stripe height of 64 + 3 pixels
@@ -329,8 +318,8 @@
     padding(tmp, p, p_stride, left, lpf, lpf_stride, w, h, edges);
     ALIGN_STK_16(int32_t, hor, 70 /*(64 + 3 + 3)*/ * REST_UNIT_STRIDE + 64,);
 
-    wiener_filter_h_vsx(hor, tmp, filterh, w, h);
-    wiener_filter_v_vsx(p, p_stride, hor, filterv, w, h);
+    wiener_filter_h_vsx(hor, tmp, filter[0], w, h);
+    wiener_filter_v_vsx(p, p_stride, hor, filter[1], w, h);
 
 }
 #endif
--- a/src/x86/looprestoration.asm
+++ b/src/x86/looprestoration.asm
@@ -40,7 +40,6 @@
 pw_256: times 2 dw 256
 pw_2048: times 2 dw 2048
 pw_16380: times 2 dw 16380
-pw_0_128: dw 0, 128
 pw_5_6: dw 5, 6
 pd_6: dd 6
 pd_1024: dd 1024
@@ -52,14 +51,14 @@
 SECTION .text
 
 INIT_YMM avx2
-cglobal wiener_filter_h, 5, 12, 16, dst, left, src, stride, fh, w, h, edge
+cglobal wiener_filter_h, 5, 12, 16, dst, left, src, stride, flt, w, h, edge
     mov        edged, edgem
-    vpbroadcastb m15, [fhq+0]
+    vpbroadcastb m15, [fltq+0]
     movifnidn     wd, wm
-    vpbroadcastb m14, [fhq+2]
+    vpbroadcastb m14, [fltq+2]
     mov           hd, hm
-    vpbroadcastb m13, [fhq+4]
-    vpbroadcastw m12, [fhq+6]
+    vpbroadcastb m13, [fltq+4]
+    vpbroadcastw m12, [fltq+6]
     vpbroadcastd m11, [pw_2048]
     vpbroadcastd m10, [pw_16380]
     lea          r11, [pb_right_ext_mask]
@@ -207,18 +206,16 @@
     jg .loop
     RET
 
-cglobal wiener_filter_v, 4, 10, 13, dst, stride, mid, w, h, fv, edge
-    movifnidn    fvq, fvmp
+cglobal wiener_filter_v, 4, 10, 13, dst, stride, mid, w, h, flt, edge
+    movifnidn   fltq, fltmp
     mov        edged, edgem
     movifnidn     hd, hm
-    vpbroadcastd m10, [fvq]
-    vpbroadcastd m11, [fvq+4]
-    vpbroadcastd  m0, [pw_0_128]
+    vpbroadcastd m10, [fltq+16]
+    vpbroadcastd m11, [fltq+20]
     vpbroadcastd m12, [pd_1024]
 
     DEFINE_ARGS dst, stride, mid, w, h, ylim, edge, y, mptr, dstptr
     rorx       ylimd, edged, 2
-    paddw        m11, m0
     and        ylimd, 2 ; have_bottom
     sub        ylimd, 3
 
--- a/src/x86/looprestoration_init_tmpl.c
+++ b/src/x86/looprestoration_init_tmpl.c
@@ -47,32 +47,34 @@
 \
 void dav1d_wiener_filter_h_##ext(int16_t *dst, const pixel (*left)[4], \
                                  const pixel *src, ptrdiff_t stride, \
-                                 const int16_t fh[7], const intptr_t w, \
+                                 const int16_t filter[2][8], const intptr_t w, \
                                  int h, enum LrEdgeFlags edges); \
 void dav1d_wiener_filter_v_##ext(pixel *dst, ptrdiff_t stride, \
                                  const int16_t *mid, int w, int h, \
-                                 const int16_t fv[7], enum LrEdgeFlags edges); \
+                                 const int16_t filter[2][8], \
+                                 enum LrEdgeFlags edges); \
 \
 static void wiener_filter_##ext(pixel *const dst, const ptrdiff_t dst_stride, \
                                 const pixel (*const left)[4], \
                                 const pixel *lpf, const ptrdiff_t lpf_stride, \
-                                const int w, const int h, const int16_t fh[7], \
-                                const int16_t fv[7], const enum LrEdgeFlags edges) \
+                                const int w, const int h, \
+                                const int16_t filter[2][8], \
+                                const enum LrEdgeFlags edges) \
 { \
     ALIGN_STK_32(int16_t, mid, 68 * 384,); \
 \
     /* horizontal filter */ \
     dav1d_wiener_filter_h_##ext(&mid[2 * 384], left, dst, dst_stride, \
-                               fh, w, h, edges); \
+                                filter, w, h, edges); \
     if (edges & LR_HAVE_TOP) \
         dav1d_wiener_filter_h_##ext(mid, NULL, lpf, lpf_stride, \
-                                   fh, w, 2, edges); \
+                                    filter, w, 2, edges); \
     if (edges & LR_HAVE_BOTTOM) \
         dav1d_wiener_filter_h_##ext(&mid[(2 + h) * 384], NULL, \
-                                   lpf + 6 * PXSTRIDE(lpf_stride), lpf_stride, \
-                                   fh, w, 2, edges); \
+                                    lpf + 6 * PXSTRIDE(lpf_stride), lpf_stride, \
+                                    filter, w, 2, edges); \
 \
-    dav1d_wiener_filter_v_##ext(dst, dst_stride, &mid[2*384], w, h, fv, edges); \
+    dav1d_wiener_filter_v_##ext(dst, dst_stride, &mid[2*384], w, h, filter, edges); \
 }
 
 #define SGR_FILTER(ext) \
--- a/src/x86/looprestoration_ssse3.asm
+++ b/src/x86/looprestoration_ssse3.asm
@@ -52,7 +52,6 @@
 pw_2048: times 8 dw 2048
 pw_16380: times 8 dw 16380
 pw_5_6: times 4 dw 5, 6
-pw_0_128: times 4 dw 0, 128
 pd_1024: times 4 dd 1024
 %if ARCH_X86_32
 pd_256: times 4 dd 256
@@ -129,12 +128,12 @@
 
 %macro WIENER_H 0
 %if ARCH_X86_64
-cglobal wiener_filter_h, 5, 15, 16, dst, left, src, stride, fh, w, h, edge
+cglobal wiener_filter_h, 5, 15, 16, dst, left, src, stride, flt, w, h, edge
     mov        edged, edgem
     movifnidn     wd, wm
     mov           hd, hm
 %else
-cglobal wiener_filter_h, 5, 7, 8, -84, dst, left, src, stride, fh, w, h, edge
+cglobal wiener_filter_h, 5, 7, 8, -84, dst, left, src, stride, flt, w, h, edge
     mov           r5, edgem
     mov     [esp+12], r5
     mov           wd, wm
@@ -146,7 +145,7 @@
  %define m12 m3
 %endif
 
-    movq         m15, [fhq]
+    movq         m15, [fltq]
 %if cpuflag(ssse3)
     pshufb       m12, m15, [PIC_sym(pb_6_7)]
     pshufb       m13, m15, [PIC_sym(pb_4)]
@@ -438,14 +437,13 @@
 
 %macro WIENER_V 0
 %if ARCH_X86_64
-cglobal wiener_filter_v, 4, 10, 16, dst, stride, mid, w, h, fv, edge
+cglobal wiener_filter_v, 4, 10, 16, dst, stride, mid, w, h, flt, edge
     mov        edged, edgem
-    movifnidn    fvq, fvmp
+    movifnidn   fltq, fltmp
     movifnidn     hd, hm
-    movq         m15, [fvq]
+    movq         m15, [fltq+16]
     pshufd       m14, m15, q1111
     pshufd       m15, m15, q0000
-    paddw        m14, [pw_0_128]
     mova         m12, [pd_1024]
 
     DEFINE_ARGS dst, stride, mid, w, h, y, edge, ylim, mptr, dstptr
@@ -455,7 +453,7 @@
     shr        ylimd, 2
     sub        ylimd, 3
 %else
-cglobal wiener_filter_v, 5, 7, 8, -96, dst, stride, mid, w, h, fv, edge
+cglobal wiener_filter_v, 5, 7, 8, -96, dst, stride, mid, w, h, flt, edge
  %define ylimd [esp+12]
 
     mov          r5d, edgem
@@ -463,15 +461,14 @@
     shr          r5d, 2
     sub          r5d, 3
     mov        ylimd, r5d
-    mov          fvq, fvmp
+    mov         fltq, fltmp
     mov        edged, edgem
 
     SETUP_PIC edged
 
-    movq          m0, [fvq]
+    movq          m0, [fltq+16]
     pshufd        m1, m0, q1111
     pshufd        m0, m0, q0000
-    paddw         m1, [PIC_sym(pw_0_128)]
     mova  [esp+0x50], m0
     mova  [esp+0x40], m1
 
--- a/tests/checkasm/looprestoration.c
+++ b/tests/checkasm/looprestoration.c
@@ -47,38 +47,32 @@
     ALIGN_STK_64(pixel, c_dst, 448 * 64,);
     ALIGN_STK_64(pixel, a_dst, 448 * 64,);
     ALIGN_STK_64(pixel, h_edge, 448 * 8,);
+    ALIGN_STK_16(int16_t, filter, 2, [8]);
     pixel left[64][4];
 
     declare_func(void, pixel *dst, ptrdiff_t dst_stride,
                  const pixel (*const left)[4],
                  const pixel *lpf, ptrdiff_t lpf_stride,
-                 int w, int h, const int16_t filterh[7],
-                 const int16_t filterv[7], enum LrEdgeFlags edges
-                 HIGHBD_DECL_SUFFIX);
+                 int w, int h, const int16_t filter[2][8],
+                 enum LrEdgeFlags edges HIGHBD_DECL_SUFFIX);
 
     for (int pl = 0; pl < 2; pl++) {
         if (check_func(c->wiener, "wiener_%s_%dbpc",
                        pl ? "chroma" : "luma", bpc))
         {
-            int16_t filter[2][3], filter_v[7], filter_h[7];
+            filter[0][0] = filter[0][6] = pl ? 0 : (rnd() & 15) - 5;
+            filter[0][1] = filter[0][5] = (rnd() & 31) - 23;
+            filter[0][2] = filter[0][4] = (rnd() & 63) - 17;
+            filter[0][3] = -(filter[0][0] + filter[0][1] + filter[0][2]) * 2;
+#if BITDEPTH != 8
+            filter[0][3] += 128;
+#endif
 
-            filter[0][0] = pl ? 0 : (rnd() & 15) - 5;
-            filter[0][1] = (rnd() & 31) - 23;
-            filter[0][2] = (rnd() & 63) - 17;
-            filter[1][0] = pl ? 0 : (rnd() & 15) - 5;
-            filter[1][1] = (rnd() & 31) - 23;
-            filter[1][2] = (rnd() & 63) - 17;
+            filter[1][0] = filter[1][6] = pl ? 0 : (rnd() & 15) - 5;
+            filter[1][1] = filter[1][5] = (rnd() & 31) - 23;
+            filter[1][2] = filter[1][4] = (rnd() & 63) - 17;
+            filter[1][3] = 128 - (filter[1][0] + filter[1][1] + filter[1][2]) * 2;
 
-            filter_h[0] = filter_h[6] = filter[0][0];
-            filter_h[1] = filter_h[5] = filter[0][1];
-            filter_h[2] = filter_h[4] = filter[0][2];
-            filter_h[3] = -((filter_h[0] + filter_h[1] + filter_h[2]) * 2);
-
-            filter_v[0] = filter_v[6] = filter[1][0];
-            filter_v[1] = filter_v[5] = filter[1][1];
-            filter_v[2] = filter_v[4] = filter[1][2];
-            filter_v[3] = -((filter_v[0] + filter_v[1] + filter_v[2]) * 2);
-
             const int base_w = 1 + (rnd() % 384);
             const int base_h = 1 + (rnd() & 63);
             const int bitdepth_max = (1 << bpc) - 1;
@@ -95,10 +89,10 @@
 
                 call_ref(c_dst + 32, 448 * sizeof(pixel), left,
                          h_edge + 32, 448 * sizeof(pixel),
-                         w, h, filter_h, filter_v, edges HIGHBD_TAIL_SUFFIX);
+                         w, h, filter, edges HIGHBD_TAIL_SUFFIX);
                 call_new(a_dst + 32, 448 * sizeof(pixel), left,
                          h_edge + 32, 448 * sizeof(pixel),
-                         w, h, filter_h, filter_v, edges HIGHBD_TAIL_SUFFIX);
+                         w, h, filter, edges HIGHBD_TAIL_SUFFIX);
                 checkasm_check_pixel(c_dst + 32, 448 * sizeof(pixel),
                                      a_dst + 32, 448 * sizeof(pixel),
                                      w, h, "dst");
@@ -105,7 +99,7 @@
             }
             bench_new(a_dst + 32, 448 * sizeof(pixel), left,
                       h_edge + 32, 448 * sizeof(pixel),
-                      256, 64, filter_h, filter_v, 0xf HIGHBD_TAIL_SUFFIX);
+                      256, 64, filter, 0xf HIGHBD_TAIL_SUFFIX);
         }
     }
 }