shithub: dav1d

Download patch

ref: d322d45170b16b2e9c89ff6063c3c0ea1cb7d9f9
parent: 4f14573cffd640ea54f11dfae8f77a905a48e985
author: Martin Storsjö <martin@martin.st>
date: Thu Sep 26 05:46:30 EDT 2019

arm64: ipred: NEON implementation of the filter function

Use a different layout of the filter_intra_taps depending on
architecture; the current one is optimized for the x86 SIMD
implementation.

Relative speedups over the C code:
                             Cortex A53    A72    A73
intra_pred_filter_w4_8bpc_neon:    6.38   2.81   4.43
intra_pred_filter_w8_8bpc_neon:    9.30   3.62   5.71
intra_pred_filter_w16_8bpc_neon:   9.85   3.98   6.42
intra_pred_filter_w32_8bpc_neon:  10.77   4.08   7.09

--- a/src/arm/64/ipred.S
+++ b/src/arm/64/ipred.S
@@ -1327,6 +1327,166 @@
         .hword L(ipred_smooth_h_tbl) -  40b
 endfunc
 
+// void ipred_filter_neon(pixel *dst, const ptrdiff_t stride,
+//                        const pixel *const topleft,
+//                        const int width, const int height, const int filt_idx,
+//                        const int max_width, const int max_height);
+function ipred_filter_neon, export=1
+        and             w5,  w5,  #511
+        movrel          x6,  X(filter_intra_taps)
+        lsl             w5,  w5,  #6
+        add             x6,  x6,  w5, uxtw
+        ld1             {v16.8b, v17.8b, v18.8b, v19.8b}, [x6], #32
+        clz             w9,  w3
+        adr             x5,  L(ipred_filter_tbl)
+        ld1             {v20.8b, v21.8b, v22.8b}, [x6]
+        sub             w9,  w9,  #26
+        ldrh            w9,  [x5, w9, uxtw #1]
+        sxtl            v16.8h,  v16.8b
+        sxtl            v17.8h,  v17.8b
+        sub             x5,  x5,  w9, uxtw
+        sxtl            v18.8h,  v18.8b
+        sxtl            v19.8h,  v19.8b
+        add             x6,  x0,  x1
+        lsl             x1,  x1,  #1
+        sxtl            v20.8h,  v20.8b
+        sxtl            v21.8h,  v21.8b
+        sxtl            v22.8h,  v22.8b
+        br              x5
+40:
+        ldur            s0,  [x2, #1]             // top (0-3)
+        sub             x2,  x2,  #2
+        mov             x7,  #-2
+        uxtl            v0.8h,   v0.8b            // top (0-3)
+4:
+        ld1             {v1.s}[0], [x2], x7       // left (0-1) + topleft (2)
+        mul             v2.8h,   v17.8h,  v0.h[0] // p1(top[0]) * filter(1)
+        mla             v2.8h,   v18.8h,  v0.h[1] // p2(top[1]) * filter(2)
+        mla             v2.8h,   v19.8h,  v0.h[2] // p3(top[2]) * filter(3)
+        uxtl            v1.8h,   v1.8b            // left (0-1) + topleft (2)
+        mla             v2.8h,   v20.8h,  v0.h[3] // p4(top[3]) * filter(4)
+        mla             v2.8h,   v16.8h,  v1.h[2] // p0(topleft) * filter(0)
+        mla             v2.8h,   v21.8h,  v1.h[1] // p5(left[0]) * filter(5)
+        mla             v2.8h,   v22.8h,  v1.h[0] // p6(left[1]) * filter(6)
+        sqrshrun        v2.8b,   v2.8h,   #4
+        subs            w4,  w4,  #2
+        st1             {v2.s}[0], [x0], x1
+        uxtl            v0.8h,   v2.8b
+        st1             {v2.s}[1], [x6], x1
+        ext             v0.16b,  v0.16b,  v0.16b, #8 // move top from [4-7] to [0-3]
+        b.gt            4b
+        ret
+80:
+        ldur            d0,  [x2, #1]             // top (0-7)
+        sub             x2,  x2,  #2
+        mov             x7,  #-2
+        uxtl            v0.8h,   v0.8b            // top (0-7)
+8:
+        ld1             {v1.s}[0], [x2], x7       // left (0-1) + topleft (2)
+        mul             v2.8h,   v17.8h,  v0.h[0] // p1(top[0]) * filter(1)
+        mla             v2.8h,   v18.8h,  v0.h[1] // p2(top[1]) * filter(2)
+        mla             v2.8h,   v19.8h,  v0.h[2] // p3(top[2]) * filter(3)
+        uxtl            v1.8h,   v1.8b            // left (0-1) + topleft (2)
+        mla             v2.8h,   v20.8h,  v0.h[3] // p4(top[3]) * filter(4)
+        mla             v2.8h,   v16.8h,  v1.h[2] // p0(topleft) * filter(0)
+        mla             v2.8h,   v21.8h,  v1.h[1] // p5(left[0]) * filter(5)
+        mla             v2.8h,   v22.8h,  v1.h[0] // p6(left[1]) * filter(6)
+        mul             v3.8h,   v17.8h,  v0.h[4] // p1(top[0]) * filter(1)
+        mla             v3.8h,   v18.8h,  v0.h[5] // p2(top[1]) * filter(2)
+        mla             v3.8h,   v19.8h,  v0.h[6] // p3(top[2]) * filter(3)
+        sqrshrun        v2.8b,   v2.8h,   #4
+        uxtl            v1.8h,   v2.8b            // first block, in 16 bit
+        mla             v3.8h,   v20.8h,  v0.h[7] // p4(top[3]) * filter(4)
+        mla             v3.8h,   v16.8h,  v0.h[3] // p0(topleft) * filter(0)
+        mla             v3.8h,   v21.8h,  v1.h[3] // p5(left[0]) * filter(5)
+        mla             v3.8h,   v22.8h,  v1.h[7] // p6(left[1]) * filter(6)
+        sqrshrun        v3.8b,   v3.8h,   #4
+        subs            w4,  w4,  #2
+        st2             {v2.s, v3.s}[0], [x0], x1
+        zip2            v0.2s,   v2.2s,   v3.2s
+        st2             {v2.s, v3.s}[1], [x6], x1
+        uxtl            v0.8h,   v0.8b
+        b.gt            8b
+        ret
+160:
+320:
+        add             x8,  x2,  #1
+        sub             x2,  x2,  #2
+        mov             x7,  #-2
+        sub             x1,  x1,  w3, uxtw
+        mov             w9,  w3
+
+1:
+        ld1             {v0.s}[0], [x2], x7       // left (0-1) + topleft (2)
+        uxtl            v0.8h,   v0.8b            // left (0-1) + topleft (2)
+2:
+        ld1             {v2.16b}, [x8],   #16     // top(0-15)
+        mul             v3.8h,   v16.8h,  v0.h[2] // p0(topleft) * filter(0)
+        mla             v3.8h,   v21.8h,  v0.h[1] // p5(left[0]) * filter(5)
+        uxtl            v1.8h,   v2.8b            // top(0-7)
+        uxtl2           v2.8h,   v2.16b           // top(8-15)
+        mla             v3.8h,   v22.8h,  v0.h[0] // p6(left[1]) * filter(6)
+        mla             v3.8h,   v17.8h,  v1.h[0] // p1(top[0]) * filter(1)
+        mla             v3.8h,   v18.8h,  v1.h[1] // p2(top[1]) * filter(2)
+        mla             v3.8h,   v19.8h,  v1.h[2] // p3(top[2]) * filter(3)
+        mla             v3.8h,   v20.8h,  v1.h[3] // p4(top[3]) * filter(4)
+
+        mul             v4.8h,   v17.8h,  v1.h[4] // p1(top[0]) * filter(1)
+        mla             v4.8h,   v18.8h,  v1.h[5] // p2(top[1]) * filter(2)
+        mla             v4.8h,   v19.8h,  v1.h[6] // p3(top[2]) * filter(3)
+        sqrshrun        v3.8b,   v3.8h,   #4
+        uxtl            v0.8h,   v3.8b            // first block, in 16 bit
+        mla             v4.8h,   v20.8h,  v1.h[7] // p4(top[3]) * filter(4)
+        mla             v4.8h,   v16.8h,  v1.h[3] // p0(topleft) * filter(0)
+        mla             v4.8h,   v21.8h,  v0.h[3] // p5(left[0]) * filter(5)
+        mla             v4.8h,   v22.8h,  v0.h[7] // p6(left[1]) * filter(6)
+
+        mul             v5.8h,   v17.8h,  v2.h[0] // p1(top[0]) * filter(1)
+        mla             v5.8h,   v18.8h,  v2.h[1] // p2(top[1]) * filter(2)
+        mla             v5.8h,   v19.8h,  v2.h[2] // p3(top[2]) * filter(3)
+        sqrshrun        v4.8b,   v4.8h,   #4
+        uxtl            v0.8h,   v4.8b            // second block, in 16 bit
+        mla             v5.8h,   v20.8h,  v2.h[3] // p4(top[3]) * filter(4)
+        mla             v5.8h,   v16.8h,  v1.h[7] // p0(topleft) * filter(0)
+        mla             v5.8h,   v21.8h,  v0.h[3] // p5(left[0]) * filter(5)
+        mla             v5.8h,   v22.8h,  v0.h[7] // p6(left[1]) * filter(6)
+
+        mul             v6.8h,   v17.8h,  v2.h[4] // p1(top[0]) * filter(1)
+        mla             v6.8h,   v18.8h,  v2.h[5] // p2(top[1]) * filter(2)
+        mla             v6.8h,   v19.8h,  v2.h[6] // p3(top[2]) * filter(3)
+        sqrshrun        v5.8b,   v5.8h,   #4
+        uxtl            v0.8h,   v5.8b            // third block, in 16 bit
+        mla             v6.8h,   v20.8h,  v2.h[7] // p4(top[3]) * filter(4)
+        mla             v6.8h,   v16.8h,  v2.h[3] // p0(topleft) * filter(0)
+        mla             v6.8h,   v21.8h,  v0.h[3] // p5(left[0]) * filter(5)
+        mla             v6.8h,   v22.8h,  v0.h[7] // p6(left[1]) * filter(6)
+
+        subs            w3,  w3,  #16
+        sqrshrun        v6.8b,   v6.8h,   #4
+
+        ins             v0.h[2], v2.h[7]
+        st4             {v3.s, v4.s, v5.s, v6.s}[0], [x0], #16
+        ins             v0.b[0], v6.b[7]
+        st4             {v3.s, v4.s, v5.s, v6.s}[1], [x6], #16
+        ins             v0.b[2], v6.b[3]
+        b.gt            2b
+        subs            w4,  w4,  #2
+        b.le            9f
+        sub             x8,  x6,  w9, uxtw
+        add             x0,  x0,  x1
+        add             x6,  x6,  x1
+        mov             w3,  w9
+        b               1b
+9:
+        ret
+
+L(ipred_filter_tbl):
+        .hword L(ipred_filter_tbl) - 320b
+        .hword L(ipred_filter_tbl) - 160b
+        .hword L(ipred_filter_tbl) -  80b
+        .hword L(ipred_filter_tbl) -  40b
+endfunc
+
 // void pal_pred_neon(pixel *dst, const ptrdiff_t stride,
 //                    const uint16_t *const pal, const uint8_t *idx,
 //                    const int w, const int h);
--- a/src/arm/ipred_init_tmpl.c
+++ b/src/arm/ipred_init_tmpl.c
@@ -37,6 +37,7 @@
 decl_angular_ipred_fn(dav1d_ipred_smooth_neon);
 decl_angular_ipred_fn(dav1d_ipred_smooth_v_neon);
 decl_angular_ipred_fn(dav1d_ipred_smooth_h_neon);
+decl_angular_ipred_fn(dav1d_ipred_filter_neon);
 
 decl_pal_pred_fn(dav1d_pal_pred_neon);
 
@@ -56,6 +57,7 @@
     c->intra_pred[SMOOTH_PRED]   = dav1d_ipred_smooth_neon;
     c->intra_pred[SMOOTH_V_PRED] = dav1d_ipred_smooth_v_neon;
     c->intra_pred[SMOOTH_H_PRED] = dav1d_ipred_smooth_h_neon;
+    c->intra_pred[FILTER_PRED]   = dav1d_ipred_filter_neon;
 
     c->pal_pred                  = dav1d_pal_pred_neon;
 #endif
--- a/src/ipred_tmpl.c
+++ b/src/ipred_tmpl.c
@@ -597,6 +597,22 @@
     }
 }
 
+#if ARCH_X86
+#define FILTER(flt_ptr, p0, p1, p2, p3, p4, p5, p6) \
+    flt_ptr[ 0] * p0 + flt_ptr[ 1] * p1 +           \
+    flt_ptr[16] * p2 + flt_ptr[17] * p3 +           \
+    flt_ptr[32] * p4 + flt_ptr[33] * p5 +           \
+    flt_ptr[48] * p6
+#define FLT_INCR 2
+#else
+#define FILTER(flt_ptr, p0, p1, p2, p3, p4, p5, p6) \
+    flt_ptr[ 0] * p0 + flt_ptr[ 8] * p1 +           \
+    flt_ptr[16] * p2 + flt_ptr[24] * p3 +           \
+    flt_ptr[32] * p4 + flt_ptr[40] * p5 +           \
+    flt_ptr[48] * p6
+#define FLT_INCR 1
+#endif
+
 /* Up to 32x32 only */
 static void ipred_filter_c(pixel *dst, const ptrdiff_t stride,
                            const pixel *const topleft_in,
@@ -625,11 +641,8 @@
             const int8_t *flt_ptr = filter;
 
             for (int yy = 0; yy < 2; yy++) {
-                for (int xx = 0; xx < 4; xx++, flt_ptr += 2) {
-                    int acc = flt_ptr[ 0] * p0 + flt_ptr[ 1] * p1 +
-                              flt_ptr[16] * p2 + flt_ptr[17] * p3 +
-                              flt_ptr[32] * p4 + flt_ptr[33] * p5 +
-                              flt_ptr[48] * p6;
+                for (int xx = 0; xx < 4; xx++, flt_ptr += FLT_INCR) {
+                    int acc = FILTER(flt_ptr, p0, p1, p2, p3, p4, p5, p6);
                     ptr[xx] = iclip_pixel((acc + 8) >> 4);
                 }
                 ptr += PXSTRIDE(stride);
--- a/src/tables.c
+++ b/src/tables.c
@@ -716,52 +716,65 @@
        3        // 87, 177, 267
 };
 
+#if ARCH_X86
+#define F(idx, f0, f1, f2, f3, f4, f5, f6) \
+    [2*idx+0]  = f0, [2*idx+1]  = f1,      \
+    [2*idx+16] = f2, [2*idx+17] = f3,      \
+    [2*idx+32] = f4, [2*idx+33] = f5,      \
+    [2*idx+48] = f6
+#else
+#define F(idx, f0, f1, f2, f3, f4, f5, f6) \
+    [1*idx+0]  = f0, [1*idx+8]  = f1,      \
+    [1*idx+16] = f2, [1*idx+24] = f3,      \
+    [1*idx+32] = f4, [1*idx+40] = f5,      \
+    [1*idx+48] = f6
+#endif
 const int8_t ALIGN(dav1d_filter_intra_taps[5][64], 16) = {
     {
-         -6,  10,  -5,   2,  -3,   1,  -3,   1,
-         -4,   6,  -3,   2,  -3,   2,  -3,   1,
-          0,   0,  10,   0,   1,  10,   1,   2,
-          0,   0,   6,   0,   2,   6,   2,   2,
-          0,  12,   0,   9,   0,   7,  10,   5,
-          0,   2,   0,   2,   0,   2,   6,   3,
-          0,   0,   0,   0,   0,   0,   0,   0,
-         12,   0,   9,   0,   7,   0,   5,   0
+        F( 0,  -6, 10,  0,  0,  0, 12,  0 ),
+        F( 1,  -5,  2, 10,  0,  0,  9,  0 ),
+        F( 2,  -3,  1,  1, 10,  0,  7,  0 ),
+        F( 3,  -3,  1,  1,  2, 10,  5,  0 ),
+        F( 4,  -4,  6,  0,  0,  0,  2, 12 ),
+        F( 5,  -3,  2,  6,  0,  0,  2,  9 ),
+        F( 6,  -3,  2,  2,  6,  0,  2,  7 ),
+        F( 7,  -3,  1,  2,  2,  6,  3,  5 ),
     }, {
-        -10,  16,  -6,   0,  -4,   0,  -2,   0,
-        -10,  16,  -6,   0,  -4,   0,  -2,   0,
-          0,   0,  16,   0,   0,  16,   0,   0,
-          0,   0,  16,   0,   0,  16,   0,   0,
-          0,  10,   0,   6,   0,   4,  16,   2,
-          0,   0,   0,   0,   0,   0,  16,   0,
-          0,   0,   0,   0,   0,   0,   0,   0,
-         10,   0,   6,   0,   4,   0,   2,   0
+        F( 0, -10, 16,  0,  0,  0, 10,  0 ),
+        F( 1,  -6,  0, 16,  0,  0,  6,  0 ),
+        F( 2,  -4,  0,  0, 16,  0,  4,  0 ),
+        F( 3,  -2,  0,  0,  0, 16,  2,  0 ),
+        F( 4, -10, 16,  0,  0,  0,  0, 10 ),
+        F( 5,  -6,  0, 16,  0,  0,  0,  6 ),
+        F( 6,  -4,  0,  0, 16,  0,  0,  4 ),
+        F( 7,  -2,  0,  0,  0, 16,  0,  2 ),
     }, {
-         -8,   8,  -8,   0,  -8,   0,  -8,   0,
-         -4,   4,  -4,   0,  -4,   0,  -4,   0,
-          0,   0,   8,   0,   0,   8,   0,   0,
-          0,   0,   4,   0,   0,   4,   0,   0,
-          0,  16,   0,  16,   0,  16,   8,  16,
-          0,   0,   0,   0,   0,   0,   4,   0,
-          0,   0,   0,   0,   0,   0,   0,   0,
-         16,   0,  16,   0,  16,   0,  16,   0
+        F( 0,  -8,  8,  0,  0,  0, 16,  0 ),
+        F( 1,  -8,  0,  8,  0,  0, 16,  0 ),
+        F( 2,  -8,  0,  0,  8,  0, 16,  0 ),
+        F( 3,  -8,  0,  0,  0,  8, 16,  0 ),
+        F( 4,  -4,  4,  0,  0,  0,  0, 16 ),
+        F( 5,  -4,  0,  4,  0,  0,  0, 16 ),
+        F( 6,  -4,  0,  0,  4,  0,  0, 16 ),
+        F( 7,  -4,  0,  0,  0,  4,  0, 16 ),
     }, {
-         -2,   8,  -1,   3,  -1,   2,   0,   1,
-         -1,   4,  -1,   3,  -1,   2,  -1,   2,
-          0,   0,   8,   0,   3,   8,   2,   3,
-          0,   0,   4,   0,   3,   4,   2,   3,
-          0,  10,   0,   6,   0,   4,   8,   2,
-          0,   3,   0,   4,   0,   4,   4,   3,
-          0,   0,   0,   0,   0,   0,   0,   0,
-         10,   0,   6,   0,   4,   0,   3,   0
+        F( 0,  -2,  8,  0,  0,  0, 10,  0 ),
+        F( 1,  -1,  3,  8,  0,  0,  6,  0 ),
+        F( 2,  -1,  2,  3,  8,  0,  4,  0 ),
+        F( 3,   0,  1,  2,  3,  8,  2,  0 ),
+        F( 4,  -1,  4,  0,  0,  0,  3, 10 ),
+        F( 5,  -1,  3,  4,  0,  0,  4,  6 ),
+        F( 6,  -1,  2,  3,  4,  0,  4,  4 ),
+        F( 7,  -1,  2,  2,  3,  4,  3,  3 ),
     }, {
-        -12,  14, -10,   0,  -9,   0,  -8,   0,
-        -10,  12,  -9,   1,  -8,   0,  -7,   0,
-          0,   0,  14,   0,   0,  14,   0,   0,
-          0,   0,  12,   0,   0,  12,   0,   1,
-          0,  14,   0,  12,   0,  11,  14,  10,
-          0,   0,   0,   0,   0,   1,  12,   1,
-          0,   0,   0,   0,   0,   0,   0,   0,
-         14,   0,  12,   0,  11,   0,   9,   0
+        F( 0, -12, 14,  0,  0,  0, 14,  0 ),
+        F( 1, -10,  0, 14,  0,  0, 12,  0 ),
+        F( 2,  -9,  0,  0, 14,  0, 11,  0 ),
+        F( 3,  -8,  0,  0,  0, 14, 10,  0 ),
+        F( 4, -10, 12,  0,  0,  0,  0, 14 ),
+        F( 5,  -9,  1, 12,  0,  0,  0, 12 ),
+        F( 6,  -8,  0,  0, 12,  0,  1, 11 ),
+        F( 7,  -7,  0,  0,  1, 12,  1,  9 ),
     }
 };