shithub: dav1d

Download patch

ref: b33f46e8d98b49dc0d6f6d205027d15c0c8e05c1
parent: aff9a2105583b33cef55935e62faba3814d12013
author: Martin Storsjö <martin@martin.st>
date: Wed Feb 12 19:23:02 EST 2020

arm: cdef: Do an 8 bit implementation for cases with all edges present

This increases the code size by around 3 KB on arm64.

Before:
ARM32:                    Cortex A7      A8      A9     A53     A72     A73
cdef_filter_4x4_8bpc_neon:    807.1   517.0   617.7   506.6   429.9   357.8
cdef_filter_4x8_8bpc_neon:   1407.9   899.3  1054.6   862.3   726.5   628.1
cdef_filter_8x8_8bpc_neon:   2394.9  1456.8  1676.8  1461.2  1084.4  1101.2
ARM64:
cdef_filter_4x4_8bpc_neon:                            460.7   301.8   308.0
cdef_filter_4x8_8bpc_neon:                            831.6   547.0   555.2
cdef_filter_8x8_8bpc_neon:                           1454.6   935.6   960.4

After:
ARM32:
cdef_filter_4x4_8bpc_neon:    669.3   541.3   524.4   424.9   322.7   298.1
cdef_filter_4x8_8bpc_neon:   1159.1   922.9   881.1   709.2   538.3   514.1
cdef_filter_8x8_8bpc_neon:   1888.8  1285.4  1358.5  1152.9   839.3   871.2
ARM64:
cdef_filter_4x4_8bpc_neon:                            383.6   262.1   259.9
cdef_filter_4x8_8bpc_neon:                            684.9   472.2   464.7
cdef_filter_8x8_8bpc_neon:                           1160.0   756.8   788.0

(The checkasm benchmark averages three different cases; the fully
edged case is one of those three, while it's the most common case
in actual video. The difference is much bigger if only benchmarking
that particular case.)

This actually apparently makes the code a little bit slower for the w=4
cases on Cortex A8, while it's a significant speedup on all other cores.

--- a/src/arm/32/cdef.S
+++ b/src/arm/32/cdef.S
@@ -162,6 +162,8 @@
         push            {r4-r7,lr}
         ldrd            r4,  r5,  [sp, #20]
         ldr             r6,  [sp, #28]
+        cmp             r6,  #0xf // fully edged
+        beq             cdef_padding\w\()_edged_8bpc_neon
         vmov.i16        q3,  #0x8000
         tst             r6,  #4 // CDEF_HAVE_TOP
         bne             1f
@@ -266,6 +268,65 @@
 padding_func 8, 16, d0, q0, d2, q1, 128
 padding_func 4, 8,  s0, d0, s4, d2, 64
 
+// void cdef_paddingX_edged_8bpc_neon(uint16_t *tmp, const pixel *src,
+//                                    ptrdiff_t src_stride, const pixel (*left)[2],
+//                                    const pixel *const top, int h,
+//                                    enum CdefEdgeFlags edges);
+
+.macro padding_func_edged w, stride, reg, align
+function cdef_padding\w\()_edged_8bpc_neon
+        sub             r0,  r0,  #(2*\stride)
+
+        ldrh            r12, [r4, #-2]
+        vldr            \reg, [r4]
+        add             r7,  r4,  r2
+        strh            r12, [r0, #-2]
+        ldrh            r12, [r4, #\w]
+        vstr            \reg, [r0]
+        strh            r12, [r0, #\w]
+
+        ldrh            r12, [r7, #-2]
+        vldr            \reg, [r7]
+        strh            r12, [r0, #\stride-2]
+        ldrh            r12, [r7, #\w]
+        vstr            \reg, [r0, #\stride]
+        strh            r12, [r0, #\stride+\w]
+        add             r0,  r0,  #2*\stride
+
+0:
+        ldrh            r12, [r3], #2
+        vldr            \reg, [r1]
+        str             r12, [r0, #-2]
+        ldrh            r12, [r1, #\w]
+        add             r1,  r1,  r2
+        subs            r5,  r5,  #1
+        vstr            \reg, [r0]
+        str             r12, [r0, #\w]
+        add             r0,  r0,  #\stride
+        bgt             0b
+
+        ldrh            r12, [r1, #-2]
+        vldr            \reg, [r1]
+        add             r7,  r1,  r2
+        strh            r12, [r0, #-2]
+        ldrh            r12, [r1, #\w]
+        vstr            \reg, [r0]
+        strh            r12, [r0, #\w]
+
+        ldrh            r12, [r7, #-2]
+        vldr            \reg, [r7]
+        strh            r12, [r0, #\stride-2]
+        ldrh            r12, [r7, #\w]
+        vstr            \reg, [r0, #\stride]
+        strh            r12, [r0, #\stride+\w]
+
+        pop             {r4-r7,pc}
+endfunc
+.endm
+
+padding_func_edged 8, 16, d0, 64
+padding_func_edged 4, 8,  s0, 32
+
 .macro dir_table w, stride
 const directions\w
         .byte           -1 * \stride + 1, -2 * \stride + 2
@@ -339,9 +400,11 @@
 // void dav1d_cdef_filterX_8bpc_neon(pixel *dst, ptrdiff_t dst_stride,
 //                                   const uint16_t *tmp, int pri_strength,
 //                                   int sec_strength, int dir, int damping,
-//                                   int h);
+//                                   int h, size_t edges);
 .macro filter_func w, pri, sec, min, suffix
 function cdef_filter\w\suffix\()_neon
+        cmp             r8,  #0xf
+        beq             cdef_filter\w\suffix\()_edged_neon
 .if \pri
         movrel_local    r8,  pri_taps
         and             r9,  r3,  #1
@@ -473,6 +536,7 @@
         vpush           {q4-q7}
         ldrd            r4,  r5,  [sp, #92]
         ldrd            r6,  r7,  [sp, #100]
+        ldr             r8,  [sp, #108]
         cmp             r3,  #0 // pri_strength
         bne             1f
         b               cdef_filter\w\()_sec_neon // only sec
@@ -487,6 +551,211 @@
 
 filter 8
 filter 4
+
+.macro load_px_8 d11, d12, d21, d22, w
+.if \w == 8
+        add             r6,  r2,  r9         // x + off
+        sub             r9,  r2,  r9         // x - off
+        vld1.8          {\d11}, [r6]         // p0
+        add             r6,  r6,  #16        // += stride
+        vld1.8          {\d21}, [r9]         // p1
+        add             r9,  r9,  #16        // += stride
+        vld1.8          {\d12}, [r6]         // p0
+        vld1.8          {\d22}, [r9]         // p1
+.else
+        add             r6,  r2,  r9         // x + off
+        sub             r9,  r2,  r9         // x - off
+        vld1.32         {\d11[0]}, [r6]      // p0
+        add             r6,  r6,  #8         // += stride
+        vld1.32         {\d21[0]}, [r9]      // p1
+        add             r9,  r9,  #8         // += stride
+        vld1.32         {\d11[1]}, [r6]      // p0
+        add             r6,  r6,  #8         // += stride
+        vld1.32         {\d21[1]}, [r9]      // p1
+        add             r9,  r9,  #8         // += stride
+        vld1.32         {\d12[0]}, [r6]      // p0
+        add             r6,  r6,  #8         // += stride
+        vld1.32         {\d22[0]}, [r9]      // p1
+        add             r9,  r9,  #8         // += stride
+        vld1.32         {\d12[1]}, [r6]      // p0
+        vld1.32         {\d22[1]}, [r9]      // p1
+.endif
+.endm
+.macro handle_pixel_8 s1, s2, thresh_vec, shift, tap, min
+.if \min
+        vmin.u8         q3,  q3,  \s1
+        vmax.u8         q4,  q4,  \s1
+        vmin.u8         q3,  q3,  \s2
+        vmax.u8         q4,  q4,  \s2
+.endif
+        vabd.u8         q8,  q0,  \s1        // abs(diff)
+        vabd.u8         q11, q0,  \s2        // abs(diff)
+        vshl.u8         q9,  q8,  \shift     // abs(diff) >> shift
+        vshl.u8         q12, q11, \shift     // abs(diff) >> shift
+        vqsub.u8        q9,  \thresh_vec, q9 // clip = imax(0, threshold - (abs(diff) >> shift))
+        vqsub.u8        q12, \thresh_vec, q12// clip = imax(0, threshold - (abs(diff) >> shift))
+        vcgt.u8         q10, q0,  \s1        // px > p0
+        vcgt.u8         q13, q0,  \s2        // px > p1
+        vmin.u8         q9,  q9,  q8         // imin(abs(diff), clip)
+        vmin.u8         q12, q12, q11        // imin(abs(diff), clip)
+        vneg.s8         q8,  q9              // -imin()
+        vneg.s8         q11, q12             // -imin()
+        vbsl            q10, q8,  q9         // constrain() = imax(imin(diff, clip), -clip)
+        vdup.8          d18, \tap            // taps[k]
+        vbsl            q13, q11, q12        // constrain() = imax(imin(diff, clip), -clip)
+        vmlal.s8        q1,  d20, d18        // sum += taps[k] * constrain()
+        vmlal.s8        q1,  d26, d18        // sum += taps[k] * constrain()
+        vmlal.s8        q2,  d21, d18        // sum += taps[k] * constrain()
+        vmlal.s8        q2,  d27, d18        // sum += taps[k] * constrain()
+.endm
+
+// void cdef_filterX_edged_neon(pixel *dst, ptrdiff_t dst_stride,
+//                              const uint16_t *tmp, int pri_strength,
+//                              int sec_strength, int dir, int damping,
+//                              int h, size_t edges);
+.macro filter_func_8 w, pri, sec, min, suffix
+function cdef_filter\w\suffix\()_edged_neon
+.if \pri
+        movrel_local    r8,  pri_taps
+        and             r9,  r3,  #1
+        add             r8,  r8,  r9, lsl #1
+.endif
+        movrel_local    r9,  directions\w
+        add             r5,  r9,  r5, lsl #1
+        vmov.u8         d17, #7
+        vdup.8          d16, r6              // damping
+
+        vmov.8          d8[0], r3
+        vmov.8          d8[1], r4
+        vclz.i8         d8,  d8              // clz(threshold)
+        vsub.i8         d8,  d17, d8         // ulog2(threshold)
+        vqsub.u8        d8,  d16, d8         // shift = imax(0, damping - ulog2(threshold))
+        vneg.s8         d8,  d8              // -shift
+.if \sec
+        vdup.8          q6,  d8[1]
+.endif
+.if \pri
+        vdup.8          q5,  d8[0]
+.endif
+
+1:
+.if \w == 8
+        add             r12, r2,  #16
+        vld1.8          {d0},  [r2,  :64]    // px
+        vld1.8          {d1},  [r12, :64]    // px
+.else
+        add             r12, r2,  #8
+        vld1.32         {d0[0]},  [r2,  :32] // px
+        add             r9,  r2,  #2*8
+        vld1.32         {d0[1]},  [r12, :32] // px
+        add             r12, r12, #2*8
+        vld1.32         {d1[0]},  [r9,  :32] // px
+        vld1.32         {d1[1]},  [r12, :32] // px
+.endif
+
+        vmov.u8         q1,  #0              // sum
+        vmov.u8         q2,  #0              // sum
+.if \min
+        vmov.u16        q3,  q0              // min
+        vmov.u16        q4,  q0              // max
+.endif
+
+        // Instead of loading sec_taps 2, 1 from memory, just set it
+        // to 2 initially and decrease for the second round.
+        // This is also used as loop counter.
+        mov             lr,  #2              // sec_taps[0]
+
+2:
+.if \pri
+        ldrsb           r9,  [r5]            // off1
+
+        load_px_8       d28, d29, d30, d31, \w
+.endif
+
+.if \sec
+        add             r5,  r5,  #4         // +2*2
+        ldrsb           r9,  [r5]            // off2
+.endif
+
+.if \pri
+        ldrb            r12, [r8]            // *pri_taps
+        vdup.8          q7,  r3              // threshold
+
+        handle_pixel_8  q14, q15, q7,  q5,  r12, \min
+.endif
+
+.if \sec
+        load_px_8       d28, d29, d30, d31, \w
+
+        add             r5,  r5,  #8         // +2*4
+        ldrsb           r9,  [r5]            // off3
+
+        vdup.8          q7,  r4              // threshold
+
+        handle_pixel_8  q14, q15, q7,  q6,  lr, \min
+
+        load_px_8       d28, d29, d30, d31, \w
+
+        handle_pixel_8  q14, q15, q7,  q6,  lr, \min
+
+        sub             r5,  r5,  #11        // r5 -= 2*(2+4); r5 += 1;
+.else
+        add             r5,  r5,  #1         // r5 += 1
+.endif
+        subs            lr,  lr,  #1         // sec_tap-- (value)
+.if \pri
+        add             r8,  r8,  #1         // pri_taps++ (pointer)
+.endif
+        bne             2b
+
+        vshr.s16        q14, q1,  #15        // -(sum < 0)
+        vshr.s16        q15, q2,  #15        // -(sum < 0)
+        vadd.i16        q1,  q1,  q14        // sum - (sum < 0)
+        vadd.i16        q2,  q2,  q15        // sum - (sum < 0)
+        vrshr.s16       q1,  q1,  #4         // (8 + sum - (sum < 0)) >> 4
+        vrshr.s16       q2,  q2,  #4         // (8 + sum - (sum < 0)) >> 4
+        vaddw.u8        q1,  q1,  d0         // px + (8 + sum ...) >> 4
+        vaddw.u8        q2,  q2,  d1         // px + (8 + sum ...) >> 4
+        vqmovun.s16     d0,  q1
+        vqmovun.s16     d1,  q2
+.if \min
+        vmin.u8         q0,  q0,  q4
+        vmax.u8         q0,  q0,  q3         // iclip(px + .., min, max)
+.endif
+.if \w == 8
+        vst1.8          {d0}, [r0, :64], r1
+        add             r2,  r2,  #2*16      // tmp += 2*tmp_stride
+        subs            r7,  r7,  #2         // h -= 2
+        vst1.8          {d1}, [r0, :64], r1
+.else
+        vst1.32         {d0[0]}, [r0, :32], r1
+        add             r2,  r2,  #4*8       // tmp += 4*tmp_stride
+        vst1.32         {d0[1]}, [r0, :32], r1
+        subs            r7,  r7,  #4         // h -= 4
+        vst1.32         {d1[0]}, [r0, :32], r1
+        vst1.32         {d1[1]}, [r0, :32], r1
+.endif
+
+        // Reset pri_taps and directions back to the original point
+        sub             r5,  r5,  #2
+.if \pri
+        sub             r8,  r8,  #2
+.endif
+
+        bgt             1b
+        vpop            {q4-q7}
+        pop             {r4-r9,pc}
+endfunc
+.endm
+
+.macro filter_8 w
+filter_func_8 \w, pri=1, sec=0, min=0, suffix=_pri
+filter_func_8 \w, pri=0, sec=1, min=0, suffix=_sec
+filter_func_8 \w, pri=1, sec=1, min=1, suffix=_pri_sec
+.endm
+
+filter_8 8
+filter_8 4
 
 const div_table, align=4
         .short         840, 420, 280, 210, 168, 140, 120, 105
--- a/src/arm/64/cdef.S
+++ b/src/arm/64/cdef.S
@@ -145,6 +145,8 @@
 
 .macro padding_func w, stride, rn, rw
 function cdef_padding\w\()_8bpc_neon, export=1
+        cmp             w6,  #0xf // fully edged
+        b.eq            cdef_padding\w\()_edged_8bpc_neon
         movi            v30.8h,  #0x80, lsl #8
         mov             v31.16b, v30.16b
         sub             x0,  x0,  #2*(2*\stride+2)
@@ -242,6 +244,67 @@
 padding_func 8, 16, d, q
 padding_func 4, 8,  s, d
 
+// void cdef_paddingX_edged_8bpc_neon(uint8_t *tmp, const pixel *src,
+//                                    ptrdiff_t src_stride, const pixel (*left)[2],
+//                                    const pixel *const top, int h,
+//                                    enum CdefEdgeFlags edges);
+
+.macro padding_func_edged w, stride, reg
+function cdef_padding\w\()_edged_8bpc_neon, export=1
+        sub             x4,  x4,  #2
+        sub             x0,  x0,  #(2*\stride+2)
+
+.if \w == 4
+        ldr             d0, [x4]
+        ldr             d1, [x4, x2]
+        st1             {v0.8b, v1.8b}, [x0], #16
+.else
+        add             x9,  x4,  x2
+        ldr             d0, [x4]
+        ldr             s1, [x4, #8]
+        ldr             d2, [x9]
+        ldr             s3, [x9, #8]
+        str             d0, [x0]
+        str             s1, [x0, #8]
+        str             d2, [x0, #\stride]
+        str             s3, [x0, #\stride+8]
+        add             x0,  x0,  #2*\stride
+.endif
+
+0:
+        ld1             {v0.h}[0], [x3], #2
+        ldr             h2,      [x1, #\w]
+        load_n_incr     v1,  x1,  x2,  \w
+        subs            w5,  w5,  #1
+        str             h0,      [x0]
+        stur            \reg\()1, [x0, #2]
+        str             h2,      [x0, #2+\w]
+        add             x0,  x0,  #\stride
+        b.gt            0b
+
+        sub             x1,  x1,  #2
+.if \w == 4
+        ldr             d0, [x1]
+        ldr             d1, [x1, x2]
+        st1             {v0.8b, v1.8b}, [x0], #16
+.else
+        add             x9,  x1,  x2
+        ldr             d0, [x1]
+        ldr             s1, [x1, #8]
+        ldr             d2, [x9]
+        ldr             s3, [x9, #8]
+        str             d0, [x0]
+        str             s1, [x0, #8]
+        str             d2, [x0, #\stride]
+        str             s3, [x0, #\stride+8]
+.endif
+        ret
+endfunc
+.endm
+
+padding_func_edged 8, 16, d
+padding_func_edged 4, 8,  s
+
 tables
 
 filter 8, 8
@@ -248,3 +311,207 @@
 filter 4, 8
 
 find_dir 8
+
+.macro load_px_8 d1, d2, w
+.if \w == 8
+        add             x6,  x2,  w9, sxtb          // x + off
+        sub             x9,  x2,  w9, sxtb          // x - off
+        ld1             {\d1\().d}[0], [x6]         // p0
+        add             x6,  x6,  #16               // += stride
+        ld1             {\d2\().d}[0], [x9]         // p1
+        add             x9,  x9,  #16               // += stride
+        ld1             {\d1\().d}[1], [x6]         // p0
+        ld1             {\d2\().d}[1], [x9]         // p0
+.else
+        add             x6,  x2,  w9, sxtb          // x + off
+        sub             x9,  x2,  w9, sxtb          // x - off
+        ld1             {\d1\().s}[0], [x6]         // p0
+        add             x6,  x6,  #8                // += stride
+        ld1             {\d2\().s}[0], [x9]         // p1
+        add             x9,  x9,  #8                // += stride
+        ld1             {\d1\().s}[1], [x6]         // p0
+        add             x6,  x6,  #8                // += stride
+        ld1             {\d2\().s}[1], [x9]         // p1
+        add             x9,  x9,  #8                // += stride
+        ld1             {\d1\().s}[2], [x6]         // p0
+        add             x6,  x6,  #8                // += stride
+        ld1             {\d2\().s}[2], [x9]         // p1
+        add             x9,  x9,  #8                // += stride
+        ld1             {\d1\().s}[3], [x6]         // p0
+        ld1             {\d2\().s}[3], [x9]         // p1
+.endif
+.endm
+.macro handle_pixel_8 s1, s2, thresh_vec, shift, tap, min
+.if \min
+        umin            v3.16b,  v3.16b,  \s1\().16b
+        umax            v4.16b,  v4.16b,  \s1\().16b
+        umin            v3.16b,  v3.16b,  \s2\().16b
+        umax            v4.16b,  v4.16b,  \s2\().16b
+.endif
+        uabd            v16.16b, v0.16b,  \s1\().16b  // abs(diff)
+        uabd            v20.16b, v0.16b,  \s2\().16b  // abs(diff)
+        ushl            v17.16b, v16.16b, \shift      // abs(diff) >> shift
+        ushl            v21.16b, v20.16b, \shift      // abs(diff) >> shift
+        uqsub           v17.16b, \thresh_vec, v17.16b // clip = imax(0, threshold - (abs(diff) >> shift))
+        uqsub           v21.16b, \thresh_vec, v21.16b // clip = imax(0, threshold - (abs(diff) >> shift))
+        cmhi            v18.16b, v0.16b,  \s1\().16b  // px > p0
+        cmhi            v22.16b, v0.16b,  \s2\().16b  // px > p1
+        umin            v17.16b, v17.16b, v16.16b     // imin(abs(diff), clip)
+        umin            v21.16b, v21.16b, v20.16b     // imin(abs(diff), clip)
+        dup             v19.16b, \tap                 // taps[k]
+        neg             v16.16b, v17.16b              // -imin()
+        neg             v20.16b, v21.16b              // -imin()
+        bsl             v18.16b, v16.16b, v17.16b     // constrain() = apply_sign()
+        bsl             v22.16b, v20.16b, v21.16b     // constrain() = apply_sign()
+        smlal           v1.8h,   v18.8b,  v19.8b      // sum += taps[k] * constrain()
+        smlal           v1.8h,   v22.8b,  v19.8b      // sum += taps[k] * constrain()
+        smlal2          v2.8h,   v18.16b, v19.16b     // sum += taps[k] * constrain()
+        smlal2          v2.8h,   v22.16b, v19.16b     // sum += taps[k] * constrain()
+.endm
+
+// void cdef_filterX_edged_8bpc_neon(pixel *dst, ptrdiff_t dst_stride,
+//                                   const uint8_t *tmp, int pri_strength,
+//                                   int sec_strength, int dir, int damping,
+//                                   int h);
+.macro filter_func_8 w, pri, sec, min, suffix
+function cdef_filter\w\suffix\()_edged_8bpc_neon
+.if \pri
+        movrel          x8,  pri_taps
+        and             w9,  w3,  #1
+        add             x8,  x8,  w9, uxtw #1
+.endif
+        movrel          x9,  directions\w
+        add             x5,  x9,  w5, uxtw #1
+        movi            v30.8b,  #7
+        dup             v28.8b,  w6                 // damping
+
+.if \pri
+        dup             v25.16b, w3                 // threshold
+.endif
+.if \sec
+        dup             v27.16b, w4                 // threshold
+.endif
+        trn1            v24.8b,  v25.8b, v27.8b
+        clz             v24.8b,  v24.8b             // clz(threshold)
+        sub             v24.8b,  v30.8b, v24.8b     // ulog2(threshold)
+        uqsub           v24.8b,  v28.8b, v24.8b     // shift = imax(0, damping - ulog2(threshold))
+        neg             v24.8b,  v24.8b             // -shift
+.if \sec
+        dup             v26.16b, v24.b[1]
+.endif
+.if \pri
+        dup             v24.16b, v24.b[0]
+.endif
+
+1:
+.if \w == 8
+        add             x12, x2,  #16
+        ld1             {v0.d}[0], [x2]             // px
+        ld1             {v0.d}[1], [x12]            // px
+.else
+        add             x12, x2,  #1*8
+        add             x13, x2,  #2*8
+        add             x14, x2,  #3*8
+        ld1             {v0.s}[0], [x2]             // px
+        ld1             {v0.s}[1], [x12]            // px
+        ld1             {v0.s}[2], [x13]            // px
+        ld1             {v0.s}[3], [x14]            // px
+.endif
+
+        movi            v1.8h,  #0                  // sum
+        movi            v2.8h,  #0                  // sum
+.if \min
+        mov             v3.16b, v0.16b              // min
+        mov             v4.16b, v0.16b              // max
+.endif
+
+        // Instead of loading sec_taps 2, 1 from memory, just set it
+        // to 2 initially and decrease for the second round.
+        // This is also used as loop counter.
+        mov             w11, #2                     // sec_taps[0]
+
+2:
+.if \pri
+        ldrb            w9,  [x5]                   // off1
+
+        load_px_8       v5,  v6, \w
+.endif
+
+.if \sec
+        add             x5,  x5,  #4                // +2*2
+        ldrb            w9,  [x5]                   // off2
+        load_px_8       v28, v29, \w
+.endif
+
+.if \pri
+        ldrb            w10, [x8]                   // *pri_taps
+
+        handle_pixel_8  v5,  v6,  v25.16b, v24.16b, w10, \min
+.endif
+
+.if \sec
+        add             x5,  x5,  #8                // +2*4
+        ldrb            w9,  [x5]                   // off3
+        load_px_8       v5,  v6,  \w
+
+        handle_pixel_8  v28, v29, v27.16b, v26.16b, w11, \min
+
+        handle_pixel_8  v5,  v6,  v27.16b, v26.16b, w11, \min
+
+        sub             x5,  x5,  #11               // x5 -= 2*(2+4); x5 += 1;
+.else
+        add             x5,  x5,  #1                // x5 += 1
+.endif
+        subs            w11, w11, #1                // sec_tap-- (value)
+.if \pri
+        add             x8,  x8,  #1                // pri_taps++ (pointer)
+.endif
+        b.ne            2b
+
+        sshr            v5.8h,   v1.8h,   #15       // -(sum < 0)
+        sshr            v6.8h,   v2.8h,   #15       // -(sum < 0)
+        add             v1.8h,   v1.8h,   v5.8h     // sum - (sum < 0)
+        add             v2.8h,   v2.8h,   v6.8h     // sum - (sum < 0)
+        srshr           v1.8h,   v1.8h,   #4        // (8 + sum - (sum < 0)) >> 4
+        srshr           v2.8h,   v2.8h,   #4        // (8 + sum - (sum < 0)) >> 4
+        uaddw           v1.8h,   v1.8h,   v0.8b     // px + (8 + sum ...) >> 4
+        uaddw2          v2.8h,   v2.8h,   v0.16b    // px + (8 + sum ...) >> 4
+        sqxtun          v0.8b,   v1.8h
+        sqxtun2         v0.16b,  v2.8h
+.if \min
+        umin            v0.16b,  v0.16b,  v4.16b
+        umax            v0.16b,  v0.16b,  v3.16b    // iclip(px + .., min, max)
+.endif
+.if \w == 8
+        st1             {v0.d}[0], [x0], x1
+        add             x2,  x2,  #2*16             // tmp += 2*tmp_stride
+        subs            w7,  w7,  #2                // h -= 2
+        st1             {v0.d}[1], [x0], x1
+.else
+        st1             {v0.s}[0], [x0], x1
+        add             x2,  x2,  #4*8              // tmp += 4*tmp_stride
+        st1             {v0.s}[1], [x0], x1
+        subs            w7,  w7,  #4                // h -= 4
+        st1             {v0.s}[2], [x0], x1
+        st1             {v0.s}[3], [x0], x1
+.endif
+
+        // Reset pri_taps and directions back to the original point
+        sub             x5,  x5,  #2
+.if \pri
+        sub             x8,  x8,  #2
+.endif
+
+        b.gt            1b
+        ret
+endfunc
+.endm
+
+.macro filter_8 w
+filter_func_8 \w, pri=1, sec=0, min=0, suffix=_pri
+filter_func_8 \w, pri=0, sec=1, min=0, suffix=_sec
+filter_func_8 \w, pri=1, sec=1, min=1, suffix=_pri_sec
+.endm
+
+filter_8 8
+filter_8 4
--- a/src/arm/64/cdef_tmpl.S
+++ b/src/arm/64/cdef_tmpl.S
@@ -103,13 +103,18 @@
 // void dav1d_cdef_filterX_Ybpc_neon(pixel *dst, ptrdiff_t dst_stride,
 //                                   const uint16_t *tmp, int pri_strength,
 //                                   int sec_strength, int dir, int damping,
-//                                   int h);
+//                                   int h, size_t edges);
 .macro filter_func w, bpc, pri, sec, min, suffix
 function cdef_filter\w\suffix\()_\bpc\()bpc_neon
+.if \bpc == 8
+        ldr             w8,  [sp]                   // bitdepth_max
+        cmp             w8,  #0xf
+        b.eq            cdef_filter\w\suffix\()_edged_8bpc_neon
+.endif
 .if \pri
 .if \bpc == 16
-        ldr             w8,  [sp]                   // bitdepth_max
-        clz             w9,  w8
+        ldr             w9,  [sp, #8]               // bitdepth_max
+        clz             w9,  w9
         sub             w9,  w9,  #24               // -bitdepth_min_8
         neg             w9,  w9                     // bitdepth_min_8
 .endif
--- a/src/arm/cdef_init_tmpl.c
+++ b/src/arm/cdef_init_tmpl.c
@@ -39,14 +39,17 @@
                                    const pixel *const top, int h,
                                    enum CdefEdgeFlags edges);
 
+// Passing edges to this function, to allow it to switch to a more
+// optimized version for fully edged cases. Using size_t for edges,
+// to avoid ABI differences for passing more than one argument on the stack.
 void BF(dav1d_cdef_filter4, neon)(pixel *dst, ptrdiff_t dst_stride,
                                   const uint16_t *tmp, int pri_strength,
-                                  int sec_strength, int dir, int damping, int h
-                                  HIGHBD_DECL_SUFFIX);
+                                  int sec_strength, int dir, int damping, int h,
+                                  size_t edges HIGHBD_DECL_SUFFIX);
 void BF(dav1d_cdef_filter8, neon)(pixel *dst, ptrdiff_t dst_stride,
                                   const uint16_t *tmp, int pri_strength,
-                                  int sec_strength, int dir, int damping, int h
-                                  HIGHBD_DECL_SUFFIX);
+                                  int sec_strength, int dir, int damping, int h,
+                                  size_t edges HIGHBD_DECL_SUFFIX);
 
 #define DEFINE_FILTER(w, h, tmp_stride)                                      \
 static void                                                                  \
@@ -62,7 +65,7 @@
     uint16_t *tmp = tmp_buf + 2 * tmp_stride + 8;                            \
     BF(dav1d_cdef_padding##w, neon)(tmp, dst, stride, left, top, h, edges);  \
     BF(dav1d_cdef_filter##w, neon)(dst, stride, tmp, pri_strength,           \
-                                   sec_strength, dir, damping, h             \
+                                   sec_strength, dir, damping, h, edges      \
                                    HIGHBD_TAIL_SUFFIX);                      \
 }