shithub: dav1d

Download patch

ref: c89eb564366dfe11da016a738e4fc609f17710c1
parent: c2a2e6ee187f854a03973318cd00edbf0269725d
author: Martin Storsjö <martin@martin.st>
date: Wed Feb 5 05:43:19 EST 2020

arm64: looprestoration: NEON implementation of wiener filter for 16 bpc

Checkasm benchmarks:     Cortex A53       A72       A73
wiener_chroma_16bpc_neon:  190288.4  129369.5  127284.1
wiener_luma_16bpc_neon:    195108.4  129387.8  127042.7

The corresponding numbers for 8 bpc for comparison:
wiener_chroma_8bpc_neon:   150586.9  101647.1   97709.9
wiener_luma_8bpc_neon:     146297.4  101593.2   97670.5

--- /dev/null
+++ b/src/arm/64/looprestoration16.S
@@ -1,0 +1,680 @@
+/*
+ * Copyright © 2018, VideoLAN and dav1d authors
+ * Copyright © 2020, Martin Storsjo
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ *    list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ *    this list of conditions and the following disclaimer in the documentation
+ *    and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "src/arm/asm.S"
+#include "util.S"
+
+// void dav1d_wiener_filter_h_16bpc_neon(int16_t *dst, const pixel (*left)[4],
+//                                       const pixel *src, ptrdiff_t stride,
+//                                       const int16_t fh[7], const intptr_t w,
+//                                       int h, enum LrEdgeFlags edges,
+//                                       const int bitdepth_max);
+function wiener_filter_h_16bpc_neon, export=1
+        ldr             w8,  [sp]      // bitdepth_max
+        ld1             {v0.8h},  [x4]
+        clz             w8,  w8
+        movi            v30.4s,  #1
+        sub             w9,  w8,  #38  // -(bitdepth + 6)
+        sub             w8,  w8,  #25  // -round_bits_h
+        neg             w9,  w9        // bitdepth + 6
+        dup             v1.4s,   w9
+        dup             v29.4s,  w8    // -round_bits_h
+        movi            v31.8h,  #0x20, lsl #8  // 1 << 13 = 8192
+        ushl            v30.4s,  v30.4s,  v1.4s // 1 << (bitdepth + 6)
+        mov             w8,  w5
+        // Calculate mid_stride
+        add             w10, w5,  #7
+        bic             w10, w10, #7
+        lsl             w10, w10, #1
+
+        // Clear the last unused element of v0, to allow filtering a single
+        // pixel with one plain mul+addv.
+        ins             v0.h[7], wzr
+
+        // Set up pointers for reading/writing alternate rows
+        add             x12, x0,  x10
+        lsl             w10, w10, #1
+        add             x13, x2,  x3
+        lsl             x3,  x3,  #1
+
+        // Subtract the width from mid_stride
+        sub             x10, x10, w5, uxtw #1
+
+        // For w >= 8, we read (w+5)&~7+8 pixels, for w < 8 we read 16 pixels.
+        cmp             w5,  #8
+        add             w11, w5,  #13
+        bic             w11, w11, #7
+        b.ge            1f
+        mov             w11, #16
+1:
+        sub             x3,  x3,  w11, uxtw #1
+
+        // Set up the src pointers to include the left edge, for LR_HAVE_LEFT, left == NULL
+        tst             w7,  #1 // LR_HAVE_LEFT
+        b.eq            2f
+        // LR_HAVE_LEFT
+        cbnz            x1,  0f
+        // left == NULL
+        sub             x2,  x2,  #6
+        sub             x13, x13, #6
+        b               1f
+0:      // LR_HAVE_LEFT, left != NULL
+2:      // !LR_HAVE_LEFT, increase the stride.
+        // For this case we don't read the left 3 pixels from the src pointer,
+        // but shift it as if we had done that.
+        add             x3,  x3,  #6
+
+
+1:      // Loop vertically
+        ld1             {v2.8h, v3.8h},  [x2],  #32
+        ld1             {v4.8h, v5.8h},  [x13], #32
+
+        tst             w7,  #1 // LR_HAVE_LEFT
+        b.eq            0f
+        cbz             x1,  2f
+        // LR_HAVE_LEFT, left != NULL
+        ld1             {v1.d}[1],  [x1], #8
+        // Move x2/x13 back to account for the last 3 pixels we loaded earlier,
+        // which we'll shift out.
+        sub             x2,  x2,  #6
+        sub             x13, x13, #6
+        ld1             {v6.d}[1],  [x1], #8
+        ext             v3.16b,  v2.16b,  v3.16b,  #10
+        ext             v2.16b,  v1.16b,  v2.16b,  #10
+        ext             v5.16b,  v4.16b,  v5.16b,  #10
+        ext             v4.16b,  v6.16b,  v4.16b,  #10
+        b               2f
+0:
+        // !LR_HAVE_LEFT, fill v1 with the leftmost pixel
+        // and shift v2/v3 to have 3x the first pixel at the front.
+        dup             v1.8h,   v2.h[0]
+        dup             v6.8h,   v4.h[0]
+        // Move x2 back to account for the last 3 pixels we loaded before,
+        // which we shifted out.
+        sub             x2,  x2,  #6
+        sub             x13, x13, #6
+        ext             v3.16b,  v2.16b,  v3.16b,  #10
+        ext             v2.16b,  v1.16b,  v2.16b,  #10
+        ext             v5.16b,  v4.16b,  v5.16b,  #10
+        ext             v4.16b,  v6.16b,  v4.16b,  #10
+
+2:
+
+        tst             w7,  #2 // LR_HAVE_RIGHT
+        b.ne            4f
+        // If we'll need to pad the right edge, load that byte to pad with
+        // here since we can find it pretty easily from here.
+        sub             w9,  w5,  #14
+        ldr             h27, [x2,  w9, sxtw #1]
+        ldr             h28, [x13, w9, sxtw #1]
+        // Fill v27/v28 with the right padding pixel
+        dup             v27.8h,  v27.h[0]
+        dup             v28.8h,  v28.h[0]
+3:      // !LR_HAVE_RIGHT
+        // If we'll have to pad the right edge we need to quit early here.
+        cmp             w5,  #11
+        b.ge            4f   // If w >= 11, all used input pixels are valid
+        cmp             w5,  #7
+        b.ge            5f   // If w >= 7, we can filter 4 pixels
+        b               6f
+
+4:      // Loop horizontally
+.macro ushll_sz d0, d1, src, shift, wd
+        ushll           \d0\().4s,  \src\().4h,  \shift
+.ifc \wd, .8h
+        ushll2          \d1\().4s,  \src\().8h,  \shift
+.endif
+.endm
+.macro add_sz d0, d1, s0, s1, c, wd
+        add             \d0\().4s,  \s0\().4s,   \c\().4s
+.ifc \wd, .8h
+        add             \d1\().4s,  \s1\().4s,   \c\().4s
+.endif
+.endm
+.macro srshl_sz d0, d1, s0, s1, c, wd
+        srshl           \d0\().4s,  \s0\().4s,   \c\().4s
+.ifc \wd, .8h
+        srshl           \d1\().4s,  \s1\().4s,   \c\().4s
+.endif
+.endm
+.macro sqxtun_sz dst, s0, s1, wd
+        sqxtun          \dst\().4h, \s0\().4s
+.ifc \wd, .8h
+        sqxtun2         \dst\().8h, \s1\().4s
+.endif
+.endm
+
+.macro filter wd
+        // Interleaving the mul/mla chains actually hurts performance
+        // significantly on Cortex A53, thus keeping mul/mla tightly
+        // chained like this.
+        ext             v16.16b, v2.16b,  v3.16b, #2
+        ext             v17.16b, v2.16b,  v3.16b, #4
+        ext             v18.16b, v2.16b,  v3.16b, #6
+        ext             v19.16b, v2.16b,  v3.16b, #8
+        ext             v20.16b, v2.16b,  v3.16b, #10
+        ext             v21.16b, v2.16b,  v3.16b, #12
+        ushll_sz        v6,  v7,  v18, #7, \wd
+        smlal           v6.4s,   v2.4h,   v0.h[0]
+        smlal           v6.4s,   v16.4h,  v0.h[1]
+        smlal           v6.4s,   v17.4h,  v0.h[2]
+        smlal           v6.4s,   v18.4h,  v0.h[3]
+        smlal           v6.4s,   v19.4h,  v0.h[4]
+        smlal           v6.4s,   v20.4h,  v0.h[5]
+        smlal           v6.4s,   v21.4h,  v0.h[6]
+.ifc \wd, .8h
+        smlal2          v7.4s,   v2.8h,   v0.h[0]
+        smlal2          v7.4s,   v16.8h,  v0.h[1]
+        smlal2          v7.4s,   v17.8h,  v0.h[2]
+        smlal2          v7.4s,   v18.8h,  v0.h[3]
+        smlal2          v7.4s,   v19.8h,  v0.h[4]
+        smlal2          v7.4s,   v20.8h,  v0.h[5]
+        smlal2          v7.4s,   v21.8h,  v0.h[6]
+.endif
+        ext             v19.16b, v4.16b,  v5.16b, #2
+        ext             v20.16b, v4.16b,  v5.16b, #4
+        ext             v21.16b, v4.16b,  v5.16b, #6
+        ext             v22.16b, v4.16b,  v5.16b, #8
+        ext             v23.16b, v4.16b,  v5.16b, #10
+        ext             v24.16b, v4.16b,  v5.16b, #12
+        ushll_sz        v16, v17, v21, #7, \wd
+        smlal           v16.4s,  v4.4h,   v0.h[0]
+        smlal           v16.4s,  v19.4h,  v0.h[1]
+        smlal           v16.4s,  v20.4h,  v0.h[2]
+        smlal           v16.4s,  v21.4h,  v0.h[3]
+        smlal           v16.4s,  v22.4h,  v0.h[4]
+        smlal           v16.4s,  v23.4h,  v0.h[5]
+        smlal           v16.4s,  v24.4h,  v0.h[6]
+.ifc \wd, .8h
+        smlal2          v17.4s,  v4.8h,   v0.h[0]
+        smlal2          v17.4s,  v19.8h,  v0.h[1]
+        smlal2          v17.4s,  v20.8h,  v0.h[2]
+        smlal2          v17.4s,  v21.8h,  v0.h[3]
+        smlal2          v17.4s,  v22.8h,  v0.h[4]
+        smlal2          v17.4s,  v23.8h,  v0.h[5]
+        smlal2          v17.4s,  v24.8h,  v0.h[6]
+.endif
+        mvni            v24\wd,  #0x80, lsl #8 // 0x7fff = (1 << 15) - 1
+        add_sz          v6,  v7,  v6,  v7,  v30, \wd
+        add_sz          v16, v17, v16, v17, v30, \wd
+        srshl_sz        v6,  v7,  v6,  v7,  v29, \wd
+        srshl_sz        v16, v17, v16, v17, v29, \wd
+        sqxtun_sz       v6,  v6,  v7,  \wd
+        sqxtun_sz       v7,  v16, v17, \wd
+        umin            v6\wd,   v6\wd,   v24\wd
+        umin            v7\wd,   v7\wd,   v24\wd
+        sub             v6\wd,   v6\wd,   v31\wd
+        sub             v7\wd,   v7\wd,   v31\wd
+.endm
+        filter          .8h
+        st1             {v6.8h},  [x0],  #16
+        st1             {v7.8h},  [x12], #16
+
+        subs            w5,  w5,  #8
+        b.le            9f
+        tst             w7,  #2 // LR_HAVE_RIGHT
+        mov             v2.16b,  v3.16b
+        mov             v4.16b,  v5.16b
+        ld1             {v3.8h},  [x2],  #16
+        ld1             {v5.8h},  [x13], #16
+        b.ne            4b // If we don't need to pad, just keep filtering.
+        b               3b // If we need to pad, check how many pixels we have left.
+
+5:      // Filter 4 pixels, 7 <= w < 11
+        filter          .4h
+        st1             {v6.4h},  [x0],  #8
+        st1             {v7.4h},  [x12], #8
+
+        subs            w5,  w5,  #4 // 3 <= w < 7
+        ext             v2.16b,  v2.16b,  v3.16b, #8
+        ext             v3.16b,  v3.16b,  v3.16b, #8
+        ext             v4.16b,  v4.16b,  v5.16b, #8
+        ext             v5.16b,  v5.16b,  v5.16b, #8
+
+6:      // Pad the right edge and filter the last few pixels.
+        // w < 7, w+3 pixels valid in v2-v3
+        cmp             w5,  #5
+        b.lt            7f
+        b.gt            8f
+        // w == 5, 8 pixels valid in v2, v3 invalid
+        mov             v3.16b,  v27.16b
+        mov             v5.16b,  v28.16b
+        b               88f
+
+7:      // 1 <= w < 5, 4-7 pixels valid in v2
+        sub             w9,  w5,  #1
+        // w9 = (pixels valid - 4)
+        adr             x11, L(variable_shift_tbl)
+        ldrh            w9,  [x11, w9, uxtw #1]
+        sub             x11, x11, w9, uxth
+        mov             v3.16b,  v27.16b
+        mov             v5.16b,  v28.16b
+        br              x11
+44:     // 4 pixels valid in v2/v4, fill the high half with padding.
+        ins             v2.d[1], v3.d[0]
+        ins             v4.d[1], v5.d[0]
+        b               88f
+        // Shift v2 right, shifting out invalid pixels,
+        // shift v2 left to the original offset, shifting in padding pixels.
+55:     // 5 pixels valid
+        ext             v2.16b,  v2.16b,  v2.16b,  #10
+        ext             v2.16b,  v2.16b,  v3.16b,  #6
+        ext             v4.16b,  v4.16b,  v4.16b,  #10
+        ext             v4.16b,  v4.16b,  v5.16b,  #6
+        b               88f
+66:     // 6 pixels valid, fill the upper 2 pixels with padding.
+        ins             v2.s[3], v3.s[0]
+        ins             v4.s[3], v5.s[0]
+        b               88f
+77:     // 7 pixels valid, fill the last pixel with padding.
+        ins             v2.h[7], v3.h[0]
+        ins             v4.h[7], v5.h[0]
+        b               88f
+
+L(variable_shift_tbl):
+        .hword L(variable_shift_tbl) - 44b
+        .hword L(variable_shift_tbl) - 55b
+        .hword L(variable_shift_tbl) - 66b
+        .hword L(variable_shift_tbl) - 77b
+
+8:      // w > 5, w == 6, 9 pixels valid in v2-v3, 1 pixel valid in v3
+        ins             v27.h[0],  v3.h[0]
+        ins             v28.h[0],  v5.h[0]
+        mov             v3.16b,  v27.16b
+        mov             v5.16b,  v28.16b
+
+88:
+        // w < 7, v2-v3 padded properly
+        cmp             w5,  #4
+        b.lt            888f
+
+        // w >= 4, filter 4 pixels
+        filter          .4h
+        st1             {v6.4h},  [x0],  #8
+        st1             {v7.4h},  [x12], #8
+        subs            w5,  w5,  #4 // 0 <= w < 4
+        ext             v2.16b,  v2.16b,  v3.16b, #8
+        ext             v4.16b,  v4.16b,  v5.16b, #8
+        b.eq            9f
+888:    // 1 <= w < 4, filter 1 pixel at a time
+        smull           v6.4s,   v2.4h,   v0.4h
+        smull2          v7.4s,   v2.8h,   v0.8h
+        smull           v16.4s,  v4.4h,   v0.4h
+        smull2          v17.4s,  v4.8h,   v0.8h
+        add             v6.4s,   v6.4s,   v7.4s
+        add             v16.4s,  v16.4s,  v17.4s
+        addv            s6,      v6.4s
+        addv            s7,      v16.4s
+        dup             v16.4h,  v2.h[3]
+        ins             v16.h[1], v4.h[3]
+        ins             v6.s[1], v7.s[0]
+        mvni            v24.4h,  #0x80, lsl #8 // 0x7fff = (1 << 15) - 1
+        ushll           v16.4s,  v16.4h,  #7
+        add             v6.4s,   v6.4s,   v30.4s
+        add             v6.4s,   v6.4s,   v16.4s
+        srshl           v6.4s,   v6.4s,   v29.4s
+        sqxtun          v6.4h,   v6.4s
+        umin            v6.4h,   v6.4h,   v24.4h
+        sub             v6.4h,   v6.4h,   v31.4h
+        st1             {v6.h}[0], [x0],  #2
+        st1             {v6.h}[1], [x12], #2
+        subs            w5,  w5,  #1
+        ext             v2.16b,  v2.16b,  v3.16b,  #2
+        ext             v4.16b,  v4.16b,  v5.16b,  #2
+        b.gt            888b
+
+9:
+        subs            w6,  w6,  #2
+        b.le            0f
+        // Jump to the next row and loop horizontally
+        add             x0,  x0,  x10
+        add             x12, x12, x10
+        add             x2,  x2,  x3
+        add             x13, x13, x3
+        mov             w5,  w8
+        b               1b
+0:
+        ret
+.purgem filter
+endfunc
+
+// void dav1d_wiener_filter_v_16bpc_neon(pixel *dst, ptrdiff_t stride,
+//                                       const int16_t *mid, int w, int h,
+//                                       const int16_t fv[7], enum LrEdgeFlags edges,
+//                                       ptrdiff_t mid_stride, const int bitdepth_max);
+function wiener_filter_v_16bpc_neon, export=1
+        ldr             w8,  [sp]       // bitdepth_max
+        ld1             {v0.8h},  [x5]
+        dup             v31.8h,  w8
+        clz             w8,  w8
+        movi            v1.8h,   #128
+        sub             w8,  w8,  #11   // round_bits_v
+        add             v1.8h,   v1.8h,   v0.8h
+        dup             v30.4s,  w8
+        mov             w8,  w4
+        neg             v30.4s,  v30.4s // -round_bits_v
+
+        // Calculate the number of rows to move back when looping vertically
+        mov             w11, w4
+        tst             w6,  #4 // LR_HAVE_TOP
+        b.eq            0f
+        sub             x2,  x2,  x7,  lsl #1
+        add             w11, w11, #2
+0:
+        tst             w6,  #8 // LR_HAVE_BOTTOM
+        b.eq            1f
+        add             w11, w11, #2
+
+1:      // Start of horizontal loop; start one vertical filter slice.
+        // Load rows into v16-v19 and pad properly.
+        tst             w6,  #4 // LR_HAVE_TOP
+        ld1             {v16.8h}, [x2], x7
+        b.eq            2f
+        // LR_HAVE_TOP
+        ld1             {v18.8h}, [x2], x7
+        mov             v17.16b, v16.16b
+        ld1             {v19.8h}, [x2], x7
+        b               3f
+2:      // !LR_HAVE_TOP
+        mov             v17.16b, v16.16b
+        mov             v18.16b, v16.16b
+        mov             v19.16b, v16.16b
+
+3:
+        cmp             w4,  #4
+        b.lt            5f
+        // Start filtering normally; fill in v20-v22 with unique rows.
+        ld1             {v20.8h}, [x2], x7
+        ld1             {v21.8h}, [x2], x7
+        ld1             {v22.8h}, [x2], x7
+
+4:
+.macro filter compare
+        subs            w4,  w4,  #1
+        // Interleaving the mul/mla chains actually hurts performance
+        // significantly on Cortex A53, thus keeping mul/mla tightly
+        // chained like this.
+        smull           v2.4s,  v16.4h,  v0.h[0]
+        smlal           v2.4s,  v17.4h,  v0.h[1]
+        smlal           v2.4s,  v18.4h,  v0.h[2]
+        smlal           v2.4s,  v19.4h,  v1.h[3]
+        smlal           v2.4s,  v20.4h,  v0.h[4]
+        smlal           v2.4s,  v21.4h,  v0.h[5]
+        smlal           v2.4s,  v22.4h,  v0.h[6]
+        smull2          v3.4s,  v16.8h,  v0.h[0]
+        smlal2          v3.4s,  v17.8h,  v0.h[1]
+        smlal2          v3.4s,  v18.8h,  v0.h[2]
+        smlal2          v3.4s,  v19.8h,  v1.h[3]
+        smlal2          v3.4s,  v20.8h,  v0.h[4]
+        smlal2          v3.4s,  v21.8h,  v0.h[5]
+        smlal2          v3.4s,  v22.8h,  v0.h[6]
+        srshl           v2.4s,  v2.4s,   v30.4s // round_bits_v
+        srshl           v3.4s,  v3.4s,   v30.4s
+        sqxtun          v2.4h,  v2.4s
+        sqxtun2         v2.8h,  v3.4s
+        umin            v2.8h,  v2.8h,   v31.8h // bitdepth_max
+        st1             {v2.8h}, [x0], x1
+.if \compare
+        cmp             w4,  #4
+.else
+        b.le            9f
+.endif
+        mov             v16.16b,  v17.16b
+        mov             v17.16b,  v18.16b
+        mov             v18.16b,  v19.16b
+        mov             v19.16b,  v20.16b
+        mov             v20.16b,  v21.16b
+        mov             v21.16b,  v22.16b
+.endm
+        filter          1
+        b.lt            7f
+        ld1             {v22.8h}, [x2], x7
+        b               4b
+
+5:      // Less than 4 rows in total; not all of v20-v21 are filled yet.
+        tst             w6,  #8 // LR_HAVE_BOTTOM
+        b.eq            6f
+        // LR_HAVE_BOTTOM
+        cmp             w4,  #2
+        // We load at least 2 rows in all cases.
+        ld1             {v20.8h}, [x2], x7
+        ld1             {v21.8h}, [x2], x7
+        b.gt            53f // 3 rows in total
+        b.eq            52f // 2 rows in total
+51:     // 1 row in total, v19 already loaded, load edge into v20-v22.
+        mov             v22.16b,  v21.16b
+        b               8f
+52:     // 2 rows in total, v19 already loaded, load v20 with content data
+        // and 2 rows of edge.
+        ld1             {v22.8h}, [x2], x7
+        mov             v23.16b,  v22.16b
+        b               8f
+53:
+        // 3 rows in total, v19 already loaded, load v20 and v21 with content
+        // and 2 rows of edge.
+        ld1             {v22.8h}, [x2], x7
+        ld1             {v23.8h}, [x2], x7
+        mov             v24.16b,  v23.16b
+        b               8f
+
+6:
+        // !LR_HAVE_BOTTOM
+        cmp             w4,  #2
+        b.gt            63f // 3 rows in total
+        b.eq            62f // 2 rows in total
+61:     // 1 row in total, v19 already loaded, pad that into v20-v22.
+        mov             v20.16b,  v19.16b
+        mov             v21.16b,  v19.16b
+        mov             v22.16b,  v19.16b
+        b               8f
+62:     // 2 rows in total, v19 already loaded, load v20 and pad that into v21-v23.
+        ld1             {v20.8h}, [x2], x7
+        mov             v21.16b,  v20.16b
+        mov             v22.16b,  v20.16b
+        mov             v23.16b,  v20.16b
+        b               8f
+63:
+        // 3 rows in total, v19 already loaded, load v20 and v21 and pad v21 into v22-v24.
+        ld1             {v20.8h}, [x2], x7
+        ld1             {v21.8h}, [x2], x7
+        mov             v22.16b,  v21.16b
+        mov             v23.16b,  v21.16b
+        mov             v24.16b,  v21.16b
+        b               8f
+
+7:
+        // All registers up to v21 are filled already, 3 valid rows left.
+        // < 4 valid rows left; fill in padding and filter the last
+        // few rows.
+        tst             w6,  #8 // LR_HAVE_BOTTOM
+        b.eq            71f
+        // LR_HAVE_BOTTOM; load 2 rows of edge.
+        ld1             {v22.8h}, [x2], x7
+        ld1             {v23.8h}, [x2], x7
+        mov             v24.16b,  v23.16b
+        b               8f
+71:
+        // !LR_HAVE_BOTTOM, pad 3 rows
+        mov             v22.16b,  v21.16b
+        mov             v23.16b,  v21.16b
+        mov             v24.16b,  v21.16b
+
+8:      // At this point, all registers up to v22-v24 are loaded with
+        // edge/padding (depending on how many rows are left).
+        filter          0 // This branches to 9f when done
+        mov             v22.16b,  v23.16b
+        mov             v23.16b,  v24.16b
+        b               8b
+
+9:      // End of one vertical slice.
+        subs            w3,  w3,  #8
+        b.le            0f
+        // Move pointers back up to the top and loop horizontally.
+        msub            x0,  x1,  x8,  x0
+        msub            x2,  x7,  x11, x2
+        add             x0,  x0,  #16
+        add             x2,  x2,  #16
+        mov             w4,  w8
+        b               1b
+
+0:
+        ret
+.purgem filter
+endfunc
+
+// void dav1d_copy_narrow_16bpc_neon(pixel *dst, ptrdiff_t stride,
+//                                   const pixel *src, int w, int h);
+function copy_narrow_16bpc_neon, export=1
+        adr             x5,  L(copy_narrow_tbl)
+        ldrh            w6,  [x5, w3, uxtw #1]
+        sub             x5,  x5,  w6, uxth
+        br              x5
+10:
+        add             x7,  x0,  x1
+        lsl             x1,  x1,  #1
+18:
+        subs            w4,  w4,  #8
+        b.lt            110f
+        ld1             {v0.8h}, [x2], #16
+        st1             {v0.h}[0], [x0], x1
+        st1             {v0.h}[1], [x7], x1
+        st1             {v0.h}[2], [x0], x1
+        st1             {v0.h}[3], [x7], x1
+        st1             {v0.h}[4], [x0], x1
+        st1             {v0.h}[5], [x7], x1
+        st1             {v0.h}[6], [x0], x1
+        st1             {v0.h}[7], [x7], x1
+        b.le            0f
+        b               18b
+110:
+        add             w4,  w4,  #8
+        asr             x1,  x1,  #1
+11:
+        subs            w4,  w4,  #1
+        ld1             {v0.h}[0], [x2], #2
+        st1             {v0.h}[0], [x0], x1
+        b.gt            11b
+0:
+        ret
+
+20:
+        add             x7,  x0,  x1
+        lsl             x1,  x1,  #1
+24:
+        subs            w4,  w4,  #4
+        b.lt            210f
+        ld1             {v0.4s}, [x2], #16
+        st1             {v0.s}[0], [x0], x1
+        st1             {v0.s}[1], [x7], x1
+        st1             {v0.s}[2], [x0], x1
+        st1             {v0.s}[3], [x7], x1
+        b.le            0f
+        b               24b
+210:
+        add             w4,  w4,  #4
+        asr             x1,  x1,  #1
+22:
+        subs            w4,  w4,  #1
+        ld1             {v0.s}[0], [x2], #4
+        st1             {v0.s}[0], [x0], x1
+        b.gt            22b
+0:
+        ret
+
+30:
+        ldr             w5,  [x2]
+        ldrh            w6,  [x2, #4]
+        add             x2,  x2,  #6
+        subs            w4,  w4,  #1
+        str             w5,  [x0]
+        strh            w6,  [x0, #4]
+        add             x0,  x0,  x1
+        b.gt            30b
+        ret
+
+40:
+        add             x7,  x0,  x1
+        lsl             x1,  x1,  #1
+42:
+        subs            w4,  w4,  #2
+        b.lt            41f
+        ld1             {v0.2d}, [x2], #16
+        st1             {v0.d}[0], [x0], x1
+        st1             {v0.d}[1], [x7], x1
+        b.le            0f
+        b               42b
+41:
+        ld1             {v0.4h}, [x2]
+        st1             {v0.4h}, [x0]
+0:
+        ret
+
+50:
+        ldr             x5,  [x2]
+        ldrh            w6,  [x2, #8]
+        add             x2,  x2,  #10
+        subs            w4,  w4,  #1
+        str             x5,  [x0]
+        strh            w6,  [x0, #8]
+        add             x0,  x0,  x1
+        b.gt            50b
+        ret
+
+60:
+        ldr             x5,  [x2]
+        ldr             w6,  [x2, #8]
+        add             x2,  x2,  #12
+        subs            w4,  w4,  #1
+        str             x5,  [x0]
+        str             w6,  [x0, #8]
+        add             x0,  x0,  x1
+        b.gt            60b
+        ret
+
+70:
+        ldr             x5,  [x2]
+        ldr             w6,  [x2, #8]
+        ldrh            w7,  [x2, #12]
+        add             x2,  x2,  #14
+        subs            w4,  w4,  #1
+        str             x5,  [x0]
+        str             w6,  [x0, #8]
+        strh            w7,  [x0, #12]
+        add             x0,  x0,  x1
+        b.gt            70b
+        ret
+
+L(copy_narrow_tbl):
+        .hword 0
+        .hword L(copy_narrow_tbl) - 10b
+        .hword L(copy_narrow_tbl) - 20b
+        .hword L(copy_narrow_tbl) - 30b
+        .hword L(copy_narrow_tbl) - 40b
+        .hword L(copy_narrow_tbl) - 50b
+        .hword L(copy_narrow_tbl) - 60b
+        .hword L(copy_narrow_tbl) - 70b
+endfunc
--- a/src/arm/looprestoration_init_tmpl.c
+++ b/src/arm/looprestoration_init_tmpl.c
@@ -29,9 +29,9 @@
 #include "src/looprestoration.h"
 #include "src/tables.h"
 
-#if BITDEPTH == 8
-// This calculates things slightly differently than the reference C version.
-// This version calculates roughly this:
+#if BITDEPTH == 8 || ARCH_AARCH64
+// The 8bpc version calculates things slightly differently than the reference
+// C version. That version calculates roughly this:
 // int16_t sum = 0;
 // for (int i = 0; i < 7; i++)
 //     sum += src[idx] * fh[i];
@@ -41,6 +41,9 @@
 // Compared to the reference C version, this is the output of the first pass
 // _subtracted_ by 1 << (bitdepth + 6 - round_bits_h) = 2048, i.e.
 // with round_offset precompensated.
+// The 16bpc version calculates things pretty much the same way as the
+// reference C version, but with the end result subtracted by
+// 1 << (bitdepth + 6 - round_bits_h).
 void BF(dav1d_wiener_filter_h, neon)(int16_t *dst, const pixel (*left)[4],
                                      const pixel *src, ptrdiff_t stride,
                                      const int16_t fh[7], const intptr_t w,
@@ -101,7 +104,9 @@
         BF(dav1d_copy_narrow, neon)(dst + (w & ~7), dst_stride, tmp, w & 7, h);
     }
 }
+#endif
 
+#if BITDEPTH == 8
 void dav1d_sgr_box3_h_neon(int32_t *sumsq, int16_t *sum,
                            const pixel (*left)[4],
                            const pixel *src, const ptrdiff_t stride,
@@ -270,8 +275,10 @@
 
     if (!(flags & DAV1D_ARM_CPU_FLAG_NEON)) return;
 
-#if BITDEPTH == 8
+#if BITDEPTH == 8 || ARCH_AARCH64
     c->wiener = wiener_filter_neon;
+#endif
+#if BITDEPTH == 8
     c->selfguided = sgr_filter_neon;
 #endif
 }
--- a/src/meson.build
+++ b/src/meson.build
@@ -118,6 +118,7 @@
 
             if dav1d_bitdepths.contains('16')
                 libdav1d_sources += files(
+                    'arm/64/looprestoration16.S',
                     'arm/64/mc16.S',
                 )
             endif