ref: 1d5c1a493742503838b12f108e041644d4838656
parent: 058ca08d081f0c59910daceb854d6a94b4e6f9f0
author: Martin Storsjö <martin@martin.st>
date: Sat Apr 27 20:08:02 EDT 2019
arm64: msac: Implement NEON msac_decode_symbol_adapt Cortex A53 A72 A73 msac_decode_symbol_adapt4_c: 107.6 57.1 67.8 msac_decode_symbol_adapt4_neon: 70.4 56.4 55.1 msac_decode_symbol_adapt8_c: 157.1 74.5 90.3 msac_decode_symbol_adapt8_neon: 75.6 57.2 56.9 msac_decode_symbol_adapt16_c: 257.4 106.6 135.9 msac_decode_symbol_adapt16_neon: 101.8 62.0 65.2
--- /dev/null
+++ b/src/arm/64/msac.S
@@ -1,0 +1,280 @@
+/*
+ * Copyright © 2019, VideoLAN and dav1d authors
+ * Copyright © 2019, Martin Storsjo
+ * All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * 1. Redistributions of source code must retain the above copyright notice, this
+ * list of conditions and the following disclaimer.
+ *
+ * 2. Redistributions in binary form must reproduce the above copyright notice,
+ * this list of conditions and the following disclaimer in the documentation
+ * and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "src/arm/asm.S"
+#include "util.S"
+
+#define BUF_POS 0
+#define BUF_END 8
+#define DIF 16
+#define RNG 24
+#define CNT 28
+#define ALLOW_UPDATE_CDF 32
+
+const coeffs
+ .short 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4, 0
+ .short 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
+endconst
+
+const bits
+ .short 0x1, 0x2, 0x4, 0x8, 0x10, 0x20, 0x40, 0x80
+ .short 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000
+endconst
+
+.macro ld1_n d0, d1, src, sz, n
+.if \n <= 8
+ ld1 {\d0\sz}, [\src]
+.else
+ ld1 {\d0\sz, \d1\sz}, [\src]
+.endif
+.endm
+
+.macro st1_n s0, s1, dst, sz, n
+.if \n <= 8
+ st1 {\s0\sz}, [\dst]
+.else
+ st1 {\s0\sz, \s1\sz}, [\dst]
+.endif
+.endm
+
+.macro ushr_n d0, d1, s0, s1, shift, sz, n
+ ushr \d0\sz, \s0\sz, \shift
+.if \n == 16
+ ushr \d1\sz, \s1\sz, \shift
+.endif
+.endm
+
+.macro add_n d0, d1, s0, s1, s2, s3, sz, n
+ add \d0\sz, \s0\sz, \s2\sz
+.if \n == 16
+ add \d1\sz, \s1\sz, \s3\sz
+.endif
+.endm
+
+.macro sub_n d0, d1, s0, s1, s2, s3, sz, n
+ sub \d0\sz, \s0\sz, \s2\sz
+.if \n == 16
+ sub \d1\sz, \s1\sz, \s3\sz
+.endif
+.endm
+
+.macro and_n d0, d1, s0, s1, s2, s3, sz, n
+ and \d0\sz, \s0\sz, \s2\sz
+.if \n == 16
+ and \d1\sz, \s1\sz, \s3\sz
+.endif
+.endm
+
+.macro cmhs_n d0, d1, s0, s1, s2, s3, sz, n
+ cmhs \d0\sz, \s0\sz, \s2\sz
+.if \n == 16
+ cmhs \d1\sz, \s1\sz, \s3\sz
+.endif
+.endm
+
+.macro urhadd_n d0, d1, s0, s1, s2, s3, sz, n
+ urhadd \d0\sz, \s0\sz, \s2\sz
+.if \n == 16
+ urhadd \d1\sz, \s1\sz, \s3\sz
+.endif
+.endm
+
+.macro sshl_n d0, d1, s0, s1, s2, s3, sz, n
+ sshl \d0\sz, \s0\sz, \s2\sz
+.if \n == 16
+ sshl \d1\sz, \s1\sz, \s3\sz
+.endif
+.endm
+
+.macro umull_n d0, d1, d2, d3, s0, s1, s2, s3, n
+ umull \d0\().4s, \s0\().4h, \s2\().4h
+.if \n >= 8
+ umull2 \d1\().4s, \s0\().8h, \s2\().8h
+.endif
+.if \n == 16
+ umull \d2\().4s, \s1\().4h, \s3\().4h
+ umull2 \d3\().4s, \s1\().8h, \s3\().8h
+.endif
+.endm
+
+.macro shrn_n d0, d1, s0, s1, s2, s3, shift, n
+ shrn \d0\().4h, \s0\().4s, \shift
+.if \n >= 8
+ shrn2 \d0\().8h, \s1\().4s, \shift
+.endif
+.if \n == 16
+ shrn \d1\().4h, \s2\().4s, \shift
+ shrn2 \d1\().8h, \s3\().4s, \shift
+.endif
+.endm
+
+.macro str_n idx0, idx1, dstreg, dstoff, n
+ str q\idx0, [\dstreg, \dstoff]
+.if \n == 16
+ str q\idx1, [\dstreg, \dstoff + 16]
+.endif
+.endm
+
+// unsigned dav1d_msac_decode_symbol_adapt4_neon(MsacContext *s, uint16_t *cdf,
+// size_t n_symbols);
+
+function msac_decode_symbol_adapt4_neon, export=1
+.macro decode_update sz, szb, n
+ sub sp, sp, #48
+ add x8, x0, #RNG
+ ld1_n v0, v1, x1, \sz, \n // cdf
+ ld1r {v4\sz}, [x8] // rng
+ movrel x9, coeffs, 32
+ sub x9, x9, x2, lsl #1
+ ushr_n v2, v3, v0, v1, #6, \sz, \n // cdf >> EC_PROB_SHIFT
+ str h4, [sp, #14] // store original u = s->rng
+ ushr v4\sz, v4\sz, #8 // r = rng >> 8
+
+ umull_n v16, v17, v18, v19, v4, v4, v2, v3, \n // r * (cdf >> EC_PROB_SHIFT)
+ ld1_n v4, v5, x9, \sz, \n // EC_MIN_PROB * (n_symbols - ret)
+ shrn_n v2, v3, v16, v17, v18, v19, #1, \n // v >>= 7 - EC_PROB_SHIFT
+ add x8, x0, #DIF + 6
+
+ add_n v4, v5, v2, v3, v4, v5, \sz, \n // v += EC_MIN_PROB * (n_symbols - ret)
+
+ ld1r {v6.8h}, [x8] // dif >> (EC_WIN_SIZE - 16)
+ movrel x8, bits
+ str_n 4, 5, sp, #16, \n // store v values to allow indexed access
+
+ ld1_n v16, v17, x8, .8h, \n
+
+ cmhs_n v2, v3, v6, v6, v4, v5, .8h, \n // c >= v
+
+ and_n v6, v7, v2, v3, v16, v17, .16b, \n // One bit per halfword set in the mask
+.if \n == 16
+ add v6.8h, v6.8h, v7.8h
+.endif
+ addv h6, v6.8h // Aggregate mask bits
+ ldr w4, [x0, #ALLOW_UPDATE_CDF]
+ umov w3, v6.h[0]
+ rbit w3, w3
+ clz w15, w3 // ret
+
+ cbz w4, L(renorm)
+ // update_cdf
+ ldrh w3, [x1, x2, lsl #1] // count = cdf[n_symbols]
+ movi v5\szb, #0xff
+ cmp x2, #4 // set C if n_symbols >= 4 (n_symbols > 3)
+ mov w14, #4
+ lsr w4, w3, #4 // count >> 4
+ urhadd_n v4, v5, v5, v5, v2, v3, \sz, \n // i >= val ? -1 : 32768
+ adc w4, w4, w14 // (count >> 4) + (n_symbols > 3) + 4
+ neg w4, w4 // -rate
+ sub_n v4, v5, v4, v5, v0, v1, \sz, \n // (32768 - cdf[i]) or (-1 - cdf[i])
+ dup v6.8h, w4 // -rate
+
+ sub w3, w3, w3, lsr #5 // count + (count >= 32)
+ sub_n v0, v1, v0, v1, v2, v3, \sz, \n // cdf + (i >= val ? 1 : 0)
+ sshl_n v4, v5, v4, v5, v6, v6, \sz, \n // ({32768,-1} - cdf[i]) >> rate
+ add w3, w3, #1 // count + (count < 32)
+ add_n v0, v1, v0, v1, v4, v5, \sz, \n // cdf + (32768 - cdf[i]) >> rate
+ st1_n v0, v1, x1, \sz, \n
+ strh w3, [x1, x2, lsl #1]
+.endm
+
+ decode_update .4h, .8b, 4
+
+L(renorm):
+ add x8, sp, #16
+ add x8, x8, w15, uxtw #1
+ ldrh w3, [x8] // v
+ ldurh w4, [x8, #-2] // u
+ ldr w6, [x0, #CNT]
+ ldr x7, [x0, #DIF]
+ sub w4, w4, w3 // rng = u - v
+ clz w5, w4 // clz(rng)
+ eor w5, w5, #16 // d = clz(rng) ^ 16
+ mvn x7, x7 // ~dif
+ add x7, x7, x3, lsl #48 // ~dif + (v << 48)
+ lsl w4, w4, w5 // rng << d
+ subs w6, w6, w5 // cnt -= d
+ lsl x7, x7, x5 // (~dif + (v << 48)) << d
+ str w4, [x0, #RNG]
+ mvn x7, x7 // ~dif
+ b.ge 9f
+
+ // refill
+ ldr x3, [x0, #BUF_POS]
+ ldr x4, [x0, #BUF_END]
+ add x5, x3, #8
+ cmp x5, x4
+ b.gt 2f
+
+ ldr x3, [x3] // next_bits
+ add w8, w6, #23 // shift_bits = cnt + 23
+ add w6, w6, #16 // cnt += 16
+ rev x3, x3 // next_bits = bswap(next_bits)
+ sub x5, x5, x8, lsr #3 // buf_pos -= shift_bits >> 3
+ and w8, w8, #24 // shift_bits &= 24
+ lsr x3, x3, x8 // next_bits >>= shift_bits
+ sub w8, w8, w6 // shift_bits -= 16 + cnt
+ str x5, [x0, #BUF_POS]
+ lsl x3, x3, x8 // next_bits <<= shift_bits
+ mov w4, #48
+ sub w6, w4, w8 // cnt = cnt + 64 - shift_bits
+ eor x7, x7, x3 // dif ^= next_bits
+ b 9f
+
+2: // refill_eob
+ mov w14, #40
+ sub w5, w14, w6 // c = 40 - cnt
+3:
+ cmp x3, x4
+ b.ge 4f
+ ldrb w8, [x3], #1
+ lsl x8, x8, x5
+ eor x7, x7, x8
+ subs w5, w5, #8
+ b.ge 3b
+
+4: // refill_eob_end
+ str x3, [x0, #BUF_POS]
+ sub w6, w14, w5 // cnt = 40 - c
+
+9:
+ str w6, [x0, #CNT]
+ str x7, [x0, #DIF]
+
+ mov w0, w15
+ add sp, sp, #48
+ ret
+endfunc
+
+function msac_decode_symbol_adapt8_neon, export=1
+ decode_update .8h, .16b, 8
+ b L(renorm)
+endfunc
+
+function msac_decode_symbol_adapt16_neon, export=1
+ decode_update .8h, .16b, 16
+ b L(renorm)
+endfunc
--- a/src/meson.build
+++ b/src/meson.build
@@ -96,6 +96,7 @@
'arm/64/loopfilter.S',
'arm/64/looprestoration.S',
'arm/64/mc.S',
+ 'arm/64/msac.S',
)
elif host_machine.cpu_family().startswith('arm')
libdav1d_sources += files(
--- a/src/msac.h
+++ b/src/msac.h
@@ -55,7 +55,17 @@
int dav1d_msac_decode_uniform(MsacContext *s, unsigned n);
/* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
-#if ARCH_X86_64 && HAVE_ASM
+#if ARCH_AARCH64 && HAVE_ASM
+unsigned dav1d_msac_decode_symbol_adapt4_neon(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+unsigned dav1d_msac_decode_symbol_adapt8_neon(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+unsigned dav1d_msac_decode_symbol_adapt16_neon(MsacContext *s, uint16_t *cdf,
+ size_t n_symbols);
+#define dav1d_msac_decode_symbol_adapt4 dav1d_msac_decode_symbol_adapt4_neon
+#define dav1d_msac_decode_symbol_adapt8 dav1d_msac_decode_symbol_adapt8_neon
+#define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
+#elif ARCH_X86_64 && HAVE_ASM
unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
size_t n_symbols);
unsigned dav1d_msac_decode_symbol_adapt8_sse2(MsacContext *s, uint16_t *cdf,
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -103,7 +103,13 @@
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt_c;
c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
-#if ARCH_X86_64 && HAVE_ASM
+#if ARCH_AARCH64 && HAVE_ASM
+ if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
+ c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_neon;
+ c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_neon;
+ c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_neon;
+ }
+#elif ARCH_X86_64 && HAVE_ASM
if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_SSE2) {
c.symbol_adapt4 = dav1d_msac_decode_symbol_adapt4_sse2;
c.symbol_adapt8 = dav1d_msac_decode_symbol_adapt8_sse2;