shithub: dav1d

Download patch

ref: b3f0c9844be8610e23b0aa29e52f499de4eda083
parent: d1c56da1d1c65767924d6752e802380409a38d17
author: Martin Storsjö <martin@martin.st>
date: Fri Feb 8 09:19:55 EST 2019

arm64: cdef: NEON implementation of the dir function

Speedup vs C code:
                Cortex A53    A72    A73
cdef_dir_8bpc_neon:   4.43   3.51   4.39

--- a/src/arm/64/cdef.S
+++ b/src/arm/64/cdef.S
@@ -423,3 +423,193 @@
 
 filter 8
 filter 4
+
+const div_table
+        .short         840, 420, 280, 210, 168, 140, 120, 105
+endconst
+
+const alt_fact
+        .short         420, 210, 140, 105, 105, 105, 105, 105, 140, 210, 420, 0
+endconst
+
+// int dav1d_cdef_find_dir_neon(const pixel *img, const ptrdiff_t stride,
+//                              unsigned *const var)
+function cdef_find_dir_neon, export=1
+        sub             sp,  sp,  #32 // cost
+        mov             w3,  #8
+        movi            v31.16b, #128
+        movi            v30.16b, #0
+        movi            v1.8h,   #0 // v0-v1 sum_diag[0]
+        movi            v3.8h,   #0 // v2-v3 sum_diag[1]
+        movi            v5.8h,   #0 // v4-v5 sum_hv[0-1]
+        movi            v7.8h,   #0 // v6-v7 sum_alt[0]
+        movi            v17.8h,  #0 // v16-v17 sum_alt[1]
+        movi            v18.8h,  #0 // v18-v19 sum_alt[2]
+        movi            v19.8h,  #0
+        movi            v21.8h,  #0 // v20-v21 sum_alt[3]
+
+.irpc i, 01234567
+        ld1             {v26.8b}, [x0], x1
+        usubl           v26.8h,  v26.8b, v31.8b
+
+        addv            h25,     v26.8h               // [y]
+        rev64           v27.8h,  v26.8h
+        addp            v28.8h,  v26.8h,  v30.8h      // [(x >> 1)]
+        add             v5.8h,   v5.8h,   v26.8h      // sum_hv[1]
+        ext             v27.16b, v27.16b, v27.16b, #8 // [-x]
+        rev64           v29.4h,  v28.4h               // [-(x >> 1)]
+        ins             v4.h[\i], v25.h[0]            // sum_hv[0]
+
+.if \i == 0
+        mov             v0.16b,  v26.16b              // sum_diag[0]
+        mov             v2.16b,  v27.16b              // sum_diag[1]
+        mov             v6.16b,  v28.16b              // sum_alt[0]
+        mov             v16.16b, v29.16b              // sum_alt[1]
+.else
+        ext             v22.16b, v30.16b, v26.16b, #(16-2*\i)
+        ext             v23.16b, v26.16b, v30.16b, #(16-2*\i)
+        ext             v24.16b, v30.16b, v27.16b, #(16-2*\i)
+        ext             v25.16b, v27.16b, v30.16b, #(16-2*\i)
+        add             v0.8h,   v0.8h,   v22.8h      // sum_diag[0]
+        add             v1.8h,   v1.8h,   v23.8h      // sum_diag[0]
+        add             v2.8h,   v2.8h,   v24.8h      // sum_diag[1]
+        add             v3.8h,   v3.8h,   v25.8h      // sum_diag[1]
+        ext             v22.16b, v30.16b, v28.16b, #(16-2*\i)
+        ext             v23.16b, v28.16b, v30.16b, #(16-2*\i)
+        ext             v24.16b, v30.16b, v29.16b, #(16-2*\i)
+        ext             v25.16b, v29.16b, v30.16b, #(16-2*\i)
+        add             v6.8h,   v6.8h,   v22.8h      // sum_alt[0]
+        add             v7.8h,   v7.8h,   v23.8h      // sum_alt[0]
+        add             v16.8h,  v16.8h,  v24.8h      // sum_alt[1]
+        add             v17.8h,  v17.8h,  v25.8h      // sum_alt[1]
+.endif
+.if \i < 6
+        ext             v22.16b, v30.16b, v26.16b, #(16-2*(3-(\i/2)))
+        ext             v23.16b, v26.16b, v30.16b, #(16-2*(3-(\i/2)))
+        add             v18.8h,  v18.8h,  v22.8h      // sum_alt[2]
+        add             v19.8h,  v19.8h,  v23.8h      // sum_alt[2]
+.else
+        add             v18.8h,  v18.8h,  v26.8h      // sum_alt[2]
+.endif
+.if \i == 0
+        mov             v20.16b, v26.16b              // sum_alt[3]
+.elseif \i == 1
+        add             v20.8h,  v20.8h,  v26.8h      // sum_alt[3]
+.else
+        ext             v24.16b, v30.16b, v26.16b, #(16-2*(\i/2))
+        ext             v25.16b, v26.16b, v30.16b, #(16-2*(\i/2))
+        add             v20.8h,  v20.8h,  v24.8h      // sum_alt[3]
+        add             v21.8h,  v21.8h,  v25.8h      // sum_alt[3]
+.endif
+.endr
+
+        movi            v31.4s,  #105
+
+        smull           v26.4s,  v4.4h,   v4.4h       // sum_hv[0]*sum_hv[0]
+        smlal2          v26.4s,  v4.8h,   v4.8h
+        smull           v27.4s,  v5.4h,   v5.4h       // sum_hv[1]*sum_hv[1]
+        smlal2          v27.4s,  v5.8h,   v5.8h
+        mul             v26.4s,  v26.4s,  v31.4s      // cost[2] *= 105
+        mul             v27.4s,  v27.4s,  v31.4s      // cost[6] *= 105
+        addv            s4,  v26.4s                   // cost[2]
+        addv            s5,  v27.4s                   // cost[6]
+
+        rev64           v1.8h,   v1.8h
+        rev64           v3.8h,   v3.8h
+        ext             v1.16b,  v1.16b,  v1.16b, #8  // sum_diag[0][15-n]
+        ext             v3.16b,  v3.16b,  v3.16b, #8  // sum_diag[1][15-n]
+        ext             v1.16b,  v1.16b,  v1.16b, #2  // sum_diag[0][14-n]
+        ext             v3.16b,  v3.16b,  v3.16b, #2  // sum_diag[1][14-n]
+
+        str             s4,  [sp, #2*4]               // cost[2]
+        str             s5,  [sp, #6*4]               // cost[6]
+
+        movrel          x4,  div_table
+        ld1             {v31.8h}, [x4]
+
+        smull           v22.4s,  v0.4h,   v0.4h       // sum_diag[0]*sum_diag[0]
+        smull2          v23.4s,  v0.8h,   v0.8h
+        smlal           v22.4s,  v1.4h,   v1.4h
+        smlal2          v23.4s,  v1.8h,   v1.8h
+        smull           v24.4s,  v2.4h,   v2.4h       // sum_diag[1]*sum_diag[1]
+        smull2          v25.4s,  v2.8h,   v2.8h
+        smlal           v24.4s,  v3.4h,   v3.4h
+        smlal2          v25.4s,  v3.8h,   v3.8h
+        uxtl            v30.4s,  v31.4h               // div_table
+        uxtl2           v31.4s,  v31.8h
+        mul             v22.4s,  v22.4s,  v30.4s      // cost[0]
+        mla             v22.4s,  v23.4s,  v31.4s      // cost[0]
+        mul             v24.4s,  v24.4s,  v30.4s      // cost[4]
+        mla             v24.4s,  v25.4s,  v31.4s      // cost[4]
+        addv            s0,  v22.4s                   // cost[0]
+        addv            s2,  v24.4s                   // cost[4]
+
+        movrel          x5,  alt_fact
+        ld1             {v29.4h, v30.4h, v31.4h}, [x5]// div_table[2*m+1] + 105
+
+        str             s0,  [sp, #0*4]               // cost[0]
+        str             s2,  [sp, #4*4]               // cost[4]
+
+        uxtl            v29.4s,  v29.4h               // div_table[2*m+1] + 105
+        uxtl            v30.4s,  v30.4h
+        uxtl            v31.4s,  v31.4h
+
+.macro cost_alt d1, d2, s1, s2, s3, s4
+        smull           v22.4s,  \s1\().4h, \s1\().4h // sum_alt[n]*sum_alt[n]
+        smull2          v23.4s,  \s1\().8h, \s1\().8h
+        smull           v24.4s,  \s2\().4h, \s2\().4h
+        smull           v25.4s,  \s3\().4h, \s3\().4h // sum_alt[n]*sum_alt[n]
+        smull2          v26.4s,  \s3\().8h, \s3\().8h
+        smull           v27.4s,  \s4\().4h, \s4\().4h
+        mul             v22.4s,  v22.4s,  v29.4s      // sum_alt[n]^2*fact
+        mla             v22.4s,  v23.4s,  v30.4s
+        mla             v22.4s,  v24.4s,  v31.4s
+        mul             v25.4s,  v25.4s,  v29.4s      // sum_alt[n]^2*fact
+        mla             v25.4s,  v26.4s,  v30.4s
+        mla             v25.4s,  v27.4s,  v31.4s
+        addv            \d1, v22.4s                   // *cost_ptr
+        addv            \d2, v25.4s                   // *cost_ptr
+.endm
+        cost_alt        s6,  s16, v6,  v7,  v16, v17  // cost[1], cost[3]
+        str             s6,  [sp, #1*4]               // cost[1]
+        str             s16, [sp, #3*4]               // cost[3]
+        cost_alt        s18, s20, v18, v19, v20, v21  // cost[5], cost[7]
+        str             s18, [sp, #5*4]               // cost[5]
+        str             s20, [sp, #7*4]               // cost[7]
+
+        mov             w0,  #0                       // best_dir
+        mov             w1,  v0.s[0]                  // best_cost
+        mov             w3,  #1                       // n
+
+        mov             w4,  v6.s[0]
+
+.macro find_best s1, s2, s3
+.ifnb \s2
+        mov             w5,  \s2\().s[0]
+.endif
+        cmp             w4,  w1                       // cost[n] > best_cost
+        csel            w0,  w3,  w0,  gt             // best_dir = n
+        csel            w1,  w4,  w1,  gt             // best_cost = cost[n]
+.ifnb \s2
+        add             w3,  w3,  #1                  // n++
+        cmp             w5,  w1                       // cost[n] > best_cost
+        mov             w4,  \s3\().s[0]
+        csel            w0,  w3,  w0,  gt             // best_dir = n
+        csel            w1,  w5,  w1,  gt             // best_cost = cost[n]
+        add             w3,  w3,  #1                  // n++
+.endif
+.endm
+        find_best       v6,  v4, v16
+        find_best       v16, v2, v18
+        find_best       v18, v5, v20
+        find_best       v20
+
+        eor             w3,  w0,  #4                  // best_dir ^4
+        ldr             w4,  [sp, w3, uxtw #2]
+        sub             w1,  w1,  w4                  // best_cost - cost[best_dir ^ 4]
+        lsr             w1,  w1,  #10
+        str             w1,  [x2]                     // *var
+
+        add             sp,  sp,  #32
+        ret
+endfunc
--- a/src/arm/cdef_init_tmpl.c
+++ b/src/arm/cdef_init_tmpl.c
@@ -29,6 +29,8 @@
 #include "src/cdef.h"
 
 #if BITDEPTH == 8 && ARCH_AARCH64
+decl_cdef_dir_fn(dav1d_cdef_find_dir_neon);
+
 void dav1d_cdef_padding4_neon(uint16_t *tmp, const pixel *src,
                               ptrdiff_t src_stride, const pixel (*left)[2],
                               /*const*/ pixel *const top[2], int h,
@@ -76,6 +78,7 @@
     if (!(flags & DAV1D_ARM_CPU_FLAG_NEON)) return;
 
 #if BITDEPTH == 8 && ARCH_AARCH64
+    c->dir = dav1d_cdef_find_dir_neon;
     c->fb[0] = cdef_filter_8x8_neon;
     c->fb[1] = cdef_filter_4x8_neon;
     c->fb[2] = cdef_filter_4x4_neon;