shithub: dav1d

Download patch

ref: c76933864d5b66f628e8358d054e375fcad26f46
parent: d322d45170b16b2e9c89ff6063c3c0ea1cb7d9f9
author: Martin Storsjö <martin@martin.st>
date: Mon Sep 30 10:32:07 EDT 2019

arm64: ipred: NEON implementation of the cfl_pred functions

Relative speedup over the C code:
                             Cortex A53    A72    A73
cfl_pred_cfl_128_w4_8bpc_neon:    10.81   7.90   9.80
cfl_pred_cfl_128_w8_8bpc_neon:    18.38  11.15  13.24
cfl_pred_cfl_128_w16_8bpc_neon:   16.52  10.83  16.00
cfl_pred_cfl_128_w32_8bpc_neon:    3.27   3.60   3.70
cfl_pred_cfl_left_w4_8bpc_neon:    9.82   7.38   8.76
cfl_pred_cfl_left_w8_8bpc_neon:   17.22  10.63  11.97
cfl_pred_cfl_left_w16_8bpc_neon:  16.03  10.49  15.66
cfl_pred_cfl_left_w32_8bpc_neon:   3.28   3.61   3.72
cfl_pred_cfl_top_w4_8bpc_neon:     9.74   7.39   9.29
cfl_pred_cfl_top_w8_8bpc_neon:    17.48  10.89  12.58
cfl_pred_cfl_top_w16_8bpc_neon:   16.01  10.62  15.31
cfl_pred_cfl_top_w32_8bpc_neon:    3.25   3.62   3.75
cfl_pred_cfl_w4_8bpc_neon:         8.39   6.34   8.04
cfl_pred_cfl_w8_8bpc_neon:        15.99  10.12  12.42
cfl_pred_cfl_w16_8bpc_neon:       15.25  10.40  15.12
cfl_pred_cfl_w32_8bpc_neon:        3.23   3.58   3.71

The C code gets autovectorized for w >= 32, which is why the
relative speedup looks strange (but the performance of the NEON
functions is completely as expected).

--- a/src/arm/64/ipred.S
+++ b/src/arm/64/ipred.S
@@ -1577,3 +1577,371 @@
         .hword L(pal_pred_tbl) -  8b
         .hword L(pal_pred_tbl) -  4b
 endfunc
+
+// void ipred_cfl_128_neon(pixel *dst, const ptrdiff_t stride,
+//                         const pixel *const topleft,
+//                         const int width, const int height,
+//                         const int16_t *ac, const int alpha);
+function ipred_cfl_128_neon, export=1
+        clz             w9,  w3
+        adr             x7,  L(ipred_cfl_128_tbl)
+        sub             w9,  w9,  #26
+        ldrh            w9,  [x7, w9, uxtw #1]
+        movi            v0.8h,   #128 // dc
+        dup             v1.8h,   w6   // alpha
+        sub             x7,  x7,  w9, uxtw
+        add             x6,  x0,  x1
+        lsl             x1,  x1,  #1
+        br              x7
+L(ipred_cfl_splat_w4):
+        ld1             {v2.8h, v3.8h}, [x5], #32
+        mul             v2.8h,   v2.8h,   v1.8h  // diff = ac * alpha
+        mul             v3.8h,   v3.8h,   v1.8h
+        sshr            v4.8h,   v2.8h,   #15    // sign = diff >> 15
+        sshr            v5.8h,   v3.8h,   #15
+        add             v2.8h,   v2.8h,   v4.8h  // diff + sign
+        add             v3.8h,   v3.8h,   v5.8h
+        srshr           v2.8h,   v2.8h,   #6     // (diff + sign + 32) >> 6 = apply_sign()
+        srshr           v3.8h,   v3.8h,   #6
+        add             v2.8h,   v2.8h,   v0.8h  // dc + apply_sign()
+        add             v3.8h,   v3.8h,   v0.8h
+        sqxtun          v2.8b,   v2.8h           // iclip_pixel(dc + apply_sign())
+        sqxtun          v3.8b,   v3.8h
+        st1             {v2.s}[0],  [x0], x1
+        st1             {v2.s}[1],  [x6], x1
+        subs            w4,  w4,  #4
+        st1             {v3.s}[0],  [x0], x1
+        st1             {v3.s}[1],  [x6], x1
+        b.gt            L(ipred_cfl_splat_w4)
+        ret
+L(ipred_cfl_splat_w8):
+        ld1             {v2.8h, v3.8h, v4.8h, v5.8h}, [x5], #64
+        mul             v2.8h,   v2.8h,   v1.8h  // diff = ac * alpha
+        mul             v3.8h,   v3.8h,   v1.8h
+        mul             v4.8h,   v4.8h,   v1.8h
+        mul             v5.8h,   v5.8h,   v1.8h
+        sshr            v16.8h,  v2.8h,   #15    // sign = diff >> 15
+        sshr            v17.8h,  v3.8h,   #15
+        sshr            v18.8h,  v4.8h,   #15
+        sshr            v19.8h,  v5.8h,   #15
+        add             v2.8h,   v2.8h,   v16.8h // diff + sign
+        add             v3.8h,   v3.8h,   v17.8h
+        add             v4.8h,   v4.8h,   v18.8h
+        add             v5.8h,   v5.8h,   v19.8h
+        srshr           v2.8h,   v2.8h,   #6     // (diff + sign + 32) >> 6 = apply_sign()
+        srshr           v3.8h,   v3.8h,   #6
+        srshr           v4.8h,   v4.8h,   #6
+        srshr           v5.8h,   v5.8h,   #6
+        add             v2.8h,   v2.8h,   v0.8h  // dc + apply_sign()
+        add             v3.8h,   v3.8h,   v0.8h
+        add             v4.8h,   v4.8h,   v0.8h
+        add             v5.8h,   v5.8h,   v0.8h
+        sqxtun          v2.8b,   v2.8h           // iclip_pixel(dc + apply_sign())
+        sqxtun          v3.8b,   v3.8h
+        sqxtun          v4.8b,   v4.8h
+        sqxtun          v5.8b,   v5.8h
+        st1             {v2.8b},  [x0], x1
+        st1             {v3.8b},  [x6], x1
+        subs            w4,  w4,  #4
+        st1             {v4.8b},  [x0], x1
+        st1             {v5.8b},  [x6], x1
+        b.gt            L(ipred_cfl_splat_w8)
+        ret
+L(ipred_cfl_splat_w16):
+        add             x7,  x5,  w3, uxtw #1
+        sub             x1,  x1,  w3, uxtw
+        mov             w9,  w3
+1:
+        ld1             {v2.8h, v3.8h}, [x5], #32
+        ld1             {v4.8h, v5.8h}, [x7], #32
+        mul             v2.8h,   v2.8h,   v1.8h  // diff = ac * alpha
+        mul             v3.8h,   v3.8h,   v1.8h
+        mul             v4.8h,   v4.8h,   v1.8h
+        mul             v5.8h,   v5.8h,   v1.8h
+        sshr            v16.8h,  v2.8h,   #15    // sign = diff >> 15
+        sshr            v17.8h,  v3.8h,   #15
+        sshr            v18.8h,  v4.8h,   #15
+        sshr            v19.8h,  v5.8h,   #15
+        add             v2.8h,   v2.8h,   v16.8h // diff + sign
+        add             v3.8h,   v3.8h,   v17.8h
+        add             v4.8h,   v4.8h,   v18.8h
+        add             v5.8h,   v5.8h,   v19.8h
+        srshr           v2.8h,   v2.8h,   #6     // (diff + sign + 32) >> 6 = apply_sign()
+        srshr           v3.8h,   v3.8h,   #6
+        srshr           v4.8h,   v4.8h,   #6
+        srshr           v5.8h,   v5.8h,   #6
+        add             v2.8h,   v2.8h,   v0.8h  // dc + apply_sign()
+        add             v3.8h,   v3.8h,   v0.8h
+        add             v4.8h,   v4.8h,   v0.8h
+        add             v5.8h,   v5.8h,   v0.8h
+        sqxtun          v2.8b,   v2.8h           // iclip_pixel(dc + apply_sign())
+        sqxtun          v3.8b,   v3.8h
+        sqxtun          v4.8b,   v4.8h
+        sqxtun          v5.8b,   v5.8h
+        subs            w3,  w3,  #16
+        st1             {v2.8b, v3.8b},  [x0], #16
+        st1             {v4.8b, v5.8b},  [x6], #16
+        b.gt            1b
+        subs            w4,  w4,  #2
+        add             x5,  x5,  w9, uxtw #1
+        add             x7,  x7,  w9, uxtw #1
+        add             x0,  x0,  x1
+        add             x6,  x6,  x1
+        mov             w3,  w9
+        b.gt            1b
+        ret
+
+L(ipred_cfl_128_tbl):
+L(ipred_cfl_splat_tbl):
+        .hword L(ipred_cfl_128_tbl) - L(ipred_cfl_splat_w16)
+        .hword L(ipred_cfl_128_tbl) - L(ipred_cfl_splat_w16)
+        .hword L(ipred_cfl_128_tbl) - L(ipred_cfl_splat_w8)
+        .hword L(ipred_cfl_128_tbl) - L(ipred_cfl_splat_w4)
+endfunc
+
+// void ipred_cfl_top_neon(pixel *dst, const ptrdiff_t stride,
+//                         const pixel *const topleft,
+//                         const int width, const int height,
+//                         const int16_t *ac, const int alpha);
+function ipred_cfl_top_neon, export=1
+        clz             w9,  w3
+        adr             x7,  L(ipred_cfl_top_tbl)
+        sub             w9,  w9,  #26
+        ldrh            w9,  [x7, w9, uxtw #1]
+        dup             v1.8h,   w6   // alpha
+        add             x2,  x2,  #1
+        sub             x7,  x7,  w9, uxtw
+        add             x6,  x0,  x1
+        lsl             x1,  x1,  #1
+        br              x7
+4:
+        ld1r            {v0.2s},  [x2]
+        uaddlv          h0,      v0.8b
+        urshr           v0.8h,   v0.8h,   #3
+        dup             v0.8h,   v0.h[0]
+        b               L(ipred_cfl_splat_w4)
+8:
+        ld1             {v0.8b},  [x2]
+        uaddlv          h0,      v0.8b
+        urshr           v0.8h,   v0.8h,   #3
+        dup             v0.8h,   v0.h[0]
+        b               L(ipred_cfl_splat_w8)
+16:
+        ld1             {v0.16b}, [x2]
+        uaddlv          h0,      v0.16b
+        urshr           v0.8h,   v0.8h,   #4
+        dup             v0.8h,   v0.h[0]
+        b               L(ipred_cfl_splat_w16)
+32:
+        ld1             {v2.16b, v3.16b}, [x2]
+        uaddlv          h2,      v2.16b
+        uaddlv          h3,      v3.16b
+        add             v2.4h,   v2.4h,   v3.4h
+        urshr           v2.8h,   v2.8h,   #5
+        dup             v0.8h,   v2.h[0]
+        b               L(ipred_cfl_splat_w16)
+
+L(ipred_cfl_top_tbl):
+        .hword L(ipred_cfl_top_tbl) - 32b
+        .hword L(ipred_cfl_top_tbl) - 16b
+        .hword L(ipred_cfl_top_tbl) -  8b
+        .hword L(ipred_cfl_top_tbl) -  4b
+endfunc
+
+// void ipred_cfl_left_neon(pixel *dst, const ptrdiff_t stride,
+//                          const pixel *const topleft,
+//                          const int width, const int height,
+//                          const int16_t *ac, const int alpha);
+function ipred_cfl_left_neon, export=1
+        sub             x2,  x2,  w4, uxtw
+        clz             w9,  w3
+        clz             w8,  w4
+        adr             x10, L(ipred_cfl_splat_tbl)
+        adr             x7,  L(ipred_cfl_left_tbl)
+        sub             w9,  w9,  #26
+        sub             w8,  w8,  #26
+        ldrh            w9,  [x10, w9, uxtw #1]
+        ldrh            w8,  [x7,  w8, uxtw #1]
+        dup             v1.8h,   w6   // alpha
+        sub             x9,  x10, w9, uxtw
+        sub             x7,  x7,  w8, uxtw
+        add             x6,  x0,  x1
+        lsl             x1,  x1,  #1
+        br              x7
+
+L(ipred_cfl_left_h4):
+        ld1r            {v0.2s},  [x2]
+        uaddlv          h0,      v0.8b
+        urshr           v0.8h,   v0.8h,   #3
+        dup             v0.8h,   v0.h[0]
+        br              x9
+
+L(ipred_cfl_left_h8):
+        ld1             {v0.8b},  [x2]
+        uaddlv          h0,      v0.8b
+        urshr           v0.8h,   v0.8h,   #3
+        dup             v0.8h,   v0.h[0]
+        br              x9
+
+L(ipred_cfl_left_h16):
+        ld1             {v0.16b}, [x2]
+        uaddlv          h0,      v0.16b
+        urshr           v0.8h,   v0.8h,   #4
+        dup             v0.8h,   v0.h[0]
+        br              x9
+
+L(ipred_cfl_left_h32):
+        ld1             {v2.16b, v3.16b}, [x2]
+        uaddlv          h2,      v2.16b
+        uaddlv          h3,      v3.16b
+        add             v2.4h,   v2.4h,   v3.4h
+        urshr           v2.8h,   v2.8h,   #5
+        dup             v0.8h,   v2.h[0]
+        br              x9
+
+L(ipred_cfl_left_tbl):
+        .hword L(ipred_cfl_left_tbl) - L(ipred_cfl_left_h32)
+        .hword L(ipred_cfl_left_tbl) - L(ipred_cfl_left_h16)
+        .hword L(ipred_cfl_left_tbl) - L(ipred_cfl_left_h8)
+        .hword L(ipred_cfl_left_tbl) - L(ipred_cfl_left_h4)
+endfunc
+
+// void ipred_cfl_neon(pixel *dst, const ptrdiff_t stride,
+//                     const pixel *const topleft,
+//                     const int width, const int height,
+//                     const int16_t *ac, const int alpha);
+function ipred_cfl_neon, export=1
+        sub             x2,  x2,  w4, uxtw
+        add             w8,  w3,  w4             // width + height
+        dup             v1.8h,   w6              // alpha
+        clz             w9,  w3
+        clz             w6,  w4
+        dup             v16.8h, w8               // width + height
+        adr             x7,  L(ipred_cfl_tbl)
+        rbit            w8,  w8                  // rbit(width + height)
+        sub             w9,  w9,  #22            // 22 leading bits, minus table offset 4
+        sub             w6,  w6,  #26
+        clz             w8,  w8                  // ctz(width + height)
+        ldrh            w9,  [x7, w9, uxtw #1]
+        ldrh            w6,  [x7, w6, uxtw #1]
+        neg             w8,  w8                  // -ctz(width + height)
+        sub             x9,  x7,  w9, uxtw
+        sub             x7,  x7,  w6, uxtw
+        ushr            v16.8h,  v16.8h,  #1     // (width + height) >> 1
+        dup             v17.8h,  w8              // -ctz(width + height)
+        add             x6,  x0,  x1
+        lsl             x1,  x1,  #1
+        br              x7
+
+L(ipred_cfl_h4):
+        ld1             {v0.s}[0],  [x2], #4
+        ins             v0.s[1], wzr
+        uaddlv          h0,      v0.8b
+        br              x9
+L(ipred_cfl_w4):
+        add             x2,  x2,  #1
+        ld1             {v2.s}[0],  [x2]
+        ins             v2.s[1], wzr
+        add             v0.4h,   v0.4h,   v16.4h
+        uaddlv          h2,      v2.8b
+        cmp             w4,  #4
+        add             v0.4h,   v0.4h,   v2.4h
+        ushl            v0.4h,   v0.4h,   v17.4h
+        b.eq            1f
+        // h = 8/16
+        mov             w16, #(0x3334/2)
+        movk            w16, #(0x5556/2), lsl #16
+        add             w17, w4,  w4  // w17 = 2*h = 16 or 32
+        lsr             w16, w16, w17
+        dup             v16.4h,  w16
+        sqdmulh         v0.4h,   v0.4h,   v16.4h
+1:
+        dup             v0.8h,   v0.h[0]
+        b               L(ipred_cfl_splat_w4)
+
+L(ipred_cfl_h8):
+        ld1             {v0.8b},  [x2], #8
+        uaddlv          h0,      v0.8b
+        br              x9
+L(ipred_cfl_w8):
+        add             x2,  x2,  #1
+        ld1             {v2.8b},  [x2]
+        add             v0.4h,   v0.4h,   v16.4h
+        uaddlv          h2,      v2.8b
+        cmp             w4,  #8
+        add             v0.4h,   v0.4h,   v2.4h
+        ushl            v0.4h,   v0.4h,   v17.4h
+        b.eq            1f
+        // h = 4/16/32
+        cmp             w4,  #32
+        mov             w16, #(0x3334/2)
+        mov             w17, #(0x5556/2)
+        csel            w16, w16, w17, eq
+        dup             v16.4h,  w16
+        sqdmulh         v0.4h,   v0.4h,   v16.4h
+1:
+        dup             v0.8h,   v0.h[0]
+        b               L(ipred_cfl_splat_w8)
+
+L(ipred_cfl_h16):
+        ld1             {v0.16b}, [x2], #16
+        uaddlv          h0,      v0.16b
+        br              x9
+L(ipred_cfl_w16):
+        add             x2,  x2,  #1
+        ld1             {v2.16b}, [x2]
+        add             v0.4h,   v0.4h,   v16.4h
+        uaddlv          h2,      v2.16b
+        cmp             w4,  #16
+        add             v0.4h,   v0.4h,   v2.4h
+        ushl            v0.4h,   v0.4h,   v17.4h
+        b.eq            1f
+        // h = 4/8/32
+        cmp             w4,  #4
+        mov             w16, #(0x3334/2)
+        mov             w17, #(0x5556/2)
+        csel            w16, w16, w17, eq
+        dup             v16.4h,  w16
+        sqdmulh         v0.4h,   v0.4h,   v16.4h
+1:
+        dup             v0.8h,   v0.h[0]
+        b               L(ipred_cfl_splat_w16)
+
+L(ipred_cfl_h32):
+        ld1             {v2.16b, v3.16b}, [x2], #32
+        uaddlv          h2,      v2.16b
+        uaddlv          h3,      v3.16b
+        add             v0.4h,   v2.4h,   v3.4h
+        br              x9
+L(ipred_cfl_w32):
+        add             x2,  x2,  #1
+        ld1             {v2.16b, v3.16b}, [x2]
+        add             v0.4h,   v0.4h,   v16.4h
+        uaddlv          h2,      v2.16b
+        uaddlv          h3,      v3.16b
+        cmp             w4,  #32
+        add             v0.4h,   v0.4h,   v2.4h
+        add             v0.4h,   v0.4h,   v3.4h
+        ushl            v0.4h,   v0.4h,   v17.4h
+        b.eq            1f
+        // h = 8/16
+        mov             w16, #(0x5556/2)
+        movk            w16, #(0x3334/2), lsl #16
+        add             w17, w4,  w4  // w17 = 2*h = 16 or 32
+        lsr             w16, w16, w17
+        dup             v16.4h,  w16
+        sqdmulh         v0.4h,   v0.4h,   v16.4h
+1:
+        dup             v0.8h,   v0.h[0]
+        b               L(ipred_cfl_splat_w16)
+
+L(ipred_cfl_tbl):
+        .hword L(ipred_cfl_tbl) - L(ipred_cfl_h32)
+        .hword L(ipred_cfl_tbl) - L(ipred_cfl_h16)
+        .hword L(ipred_cfl_tbl) - L(ipred_cfl_h8)
+        .hword L(ipred_cfl_tbl) - L(ipred_cfl_h4)
+        .hword L(ipred_cfl_tbl) - L(ipred_cfl_w32)
+        .hword L(ipred_cfl_tbl) - L(ipred_cfl_w16)
+        .hword L(ipred_cfl_tbl) - L(ipred_cfl_w8)
+        .hword L(ipred_cfl_tbl) - L(ipred_cfl_w4)
+endfunc
--- a/src/arm/ipred_init_tmpl.c
+++ b/src/arm/ipred_init_tmpl.c
@@ -39,6 +39,11 @@
 decl_angular_ipred_fn(dav1d_ipred_smooth_h_neon);
 decl_angular_ipred_fn(dav1d_ipred_filter_neon);
 
+decl_cfl_pred_fn(dav1d_ipred_cfl_neon);
+decl_cfl_pred_fn(dav1d_ipred_cfl_128_neon);
+decl_cfl_pred_fn(dav1d_ipred_cfl_top_neon);
+decl_cfl_pred_fn(dav1d_ipred_cfl_left_neon);
+
 decl_pal_pred_fn(dav1d_pal_pred_neon);
 
 COLD void bitfn(dav1d_intra_pred_dsp_init_arm)(Dav1dIntraPredDSPContext *const c) {
@@ -58,6 +63,11 @@
     c->intra_pred[SMOOTH_V_PRED] = dav1d_ipred_smooth_v_neon;
     c->intra_pred[SMOOTH_H_PRED] = dav1d_ipred_smooth_h_neon;
     c->intra_pred[FILTER_PRED]   = dav1d_ipred_filter_neon;
+
+    c->cfl_pred[DC_PRED]         = dav1d_ipred_cfl_neon;
+    c->cfl_pred[DC_128_PRED]     = dav1d_ipred_cfl_128_neon;
+    c->cfl_pred[TOP_DC_PRED]     = dav1d_ipred_cfl_top_neon;
+    c->cfl_pred[LEFT_DC_PRED]    = dav1d_ipred_cfl_left_neon;
 
     c->pal_pred                  = dav1d_pal_pred_neon;
 #endif