shithub: dav1d

Download patch

ref: e16e2726e8c1019bfdd62b73caa8cba255b7ee1c
parent: e25ed5550ef7304e8324fa2de981522e7cb14ec4
author: Henrik Gramner <gramner@twoorioles.com>
date: Tue May 14 16:13:06 EDT 2019

x86-64: Add msac_decode_bool and msac_decode_bool_adapt asm

--- a/src/msac.c
+++ b/src/msac.c
@@ -85,7 +85,7 @@
 /* Decode a single binary value.
  * f: The probability that the bit is one
  * Return: The value decoded (0 or 1). */
-unsigned dav1d_msac_decode_bool(MsacContext *const s, const unsigned f) {
+unsigned dav1d_msac_decode_bool_c(MsacContext *const s, const unsigned f) {
     ec_win vw, dif = s->dif;
     unsigned ret, v, r = s->rng;
     assert((dif >> (EC_WIN_SIZE - 16)) < r);
@@ -155,8 +155,8 @@
     return val;
 }
 
-unsigned dav1d_msac_decode_bool_adapt(MsacContext *const s,
-                                      uint16_t *const cdf)
+unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *const s,
+                                        uint16_t *const cdf)
 {
     const unsigned bit = dav1d_msac_decode_bool(s, *cdf);
 
@@ -164,11 +164,10 @@
         // update_cdf() specialized for boolean CDFs
         const unsigned count = cdf[1];
         const int rate = (count >> 4) | 4;
-        if (bit) {
+        if (bit)
             cdf[0] += (32768 - cdf[0]) >> rate;
-        } else {
+        else
             cdf[0] -= cdf[0] >> rate;
-        }
         cdf[1] = count + (count < 32);
     }
 
--- a/src/msac.h
+++ b/src/msac.h
@@ -48,9 +48,9 @@
                      int disable_cdf_update_flag);
 unsigned dav1d_msac_decode_symbol_adapt_c(MsacContext *s, uint16_t *cdf,
                                           size_t n_symbols);
+unsigned dav1d_msac_decode_bool_adapt_c(MsacContext *s, uint16_t *cdf);
 unsigned dav1d_msac_decode_bool_equi_c(MsacContext *s);
-unsigned dav1d_msac_decode_bool(MsacContext *s, unsigned f);
-unsigned dav1d_msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
+unsigned dav1d_msac_decode_bool_c(MsacContext *s, unsigned f);
 int dav1d_msac_decode_subexp(MsacContext *s, int ref, int n, unsigned k);
 
 /* Supported n_symbols ranges: adapt4: 1-5, adapt8: 1-8, adapt16: 4-16 */
@@ -64,7 +64,9 @@
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt4_neon
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt8_neon
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_neon
+#define dav1d_msac_decode_bool_adapt     dav1d_msac_decode_bool_adapt_c
 #define dav1d_msac_decode_bool_equi      dav1d_msac_decode_bool_equi_c
+#define dav1d_msac_decode_bool           dav1d_msac_decode_bool_c
 #elif ARCH_X86_64 && HAVE_ASM
 unsigned dav1d_msac_decode_symbol_adapt4_sse2(MsacContext *s, uint16_t *cdf,
                                               size_t n_symbols);
@@ -72,16 +74,22 @@
                                               size_t n_symbols);
 unsigned dav1d_msac_decode_symbol_adapt16_sse2(MsacContext *s, uint16_t *cdf,
                                                size_t n_symbols);
+unsigned dav1d_msac_decode_bool_adapt_sse2(MsacContext *s, uint16_t *cdf);
 unsigned dav1d_msac_decode_bool_equi_sse2(MsacContext *s);
+unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt4_sse2
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt8_sse2
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
+#define dav1d_msac_decode_bool_adapt     dav1d_msac_decode_bool_adapt_sse2
 #define dav1d_msac_decode_bool_equi      dav1d_msac_decode_bool_equi_sse2
+#define dav1d_msac_decode_bool           dav1d_msac_decode_bool_sse2
 #else
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt_c
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt_c
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt_c
+#define dav1d_msac_decode_bool_adapt     dav1d_msac_decode_bool_adapt_c
 #define dav1d_msac_decode_bool_equi      dav1d_msac_decode_bool_equi_c
+#define dav1d_msac_decode_bool           dav1d_msac_decode_bool_c
 #endif
 
 static inline unsigned dav1d_msac_decode_bools(MsacContext *const s, unsigned n) {
--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -114,6 +114,7 @@
 .renorm3:
     mov           r1d, [sq+msac.cnt]
     movifnidn      t0, sq
+.renorm4:
     bsr           ecx, r2d
     xor           ecx, 15  ; d
     shl           r2d, cl
@@ -285,6 +286,58 @@
 %endif
     jmp m(msac_decode_symbol_adapt4).renorm2
 
+cglobal msac_decode_bool_adapt, 2, 7, 0, s, cdf
+    movzx         eax, word [cdfq]
+    movzx         r3d, byte [sq+msac.rng+1]
+    mov            r4, [sq+msac.dif]
+    mov           r2d, [sq+msac.rng]
+    mov           r5d, eax
+    and           eax, ~63
+    imul          eax, r3d
+%if UNIX64
+    mov            r7, r4
+%endif
+    shr           eax, 7
+    add           eax, 4   ; v
+    mov           r3d, eax
+    shl           rax, 48  ; vw
+    sub           r2d, r3d ; r - v
+    sub            r4, rax ; dif - vw
+    cmovb         r2d, r3d
+    mov           r3d, [sq+msac.update_cdf]
+%if UNIX64
+    cmovb          r4, r7
+%else
+    cmovb          r4, [sq+msac.dif]
+%endif
+    setb           al
+    not            r4
+    test          r3d, r3d
+    jz m(msac_decode_symbol_adapt4).renorm3
+%if WIN64
+    push           r7
+%endif
+    movzx         r7d, word [cdfq+2]
+    movifnidn      t0, sq
+    lea           ecx, [r7+64]
+    cmp           r7d, 32
+    adc           r7d, 0
+    mov      [cdfq+2], r7w
+    imul          r7d, eax, -32769
+    shr           ecx, 4   ; rate
+    add           r7d, r5d ; if (bit)
+    sub           r5d, eax ;     cdf[0] -= ((cdf[0] - 32769) >> rate) + 1;
+    sar           r7d, cl  ; else
+    sub           r5d, r7d ;     cdf[0] -= cdf[0] >> rate;
+    mov        [cdfq], r5w
+%if WIN64
+    mov           r1d, [t0+msac.cnt]
+    pop            r7
+    jmp m(msac_decode_symbol_adapt4).renorm4
+%else
+    jmp m(msac_decode_symbol_adapt4).renorm3
+%endif
+
 cglobal msac_decode_bool_equi, 1, 7, 0, s
     mov           r1d, [sq+msac.rng]
     mov            r4, [sq+msac.dif]
@@ -299,6 +352,25 @@
     cmovb         r2d, r1d
     cmovb          r4, r3
     setb           al ; the upper 32 bits contains garbage but that's OK
+    not            r4
+    jmp m(msac_decode_symbol_adapt4).renorm3
+
+cglobal msac_decode_bool, 2, 7, 0, s, f
+    movzx         eax, byte [sq+msac.rng+1] ; r >> 8
+    mov            r4, [sq+msac.dif]
+    mov           r2d, [sq+msac.rng]
+    and           r1d, ~63
+    imul          eax, r1d
+    mov            r3, r4
+    shr           eax, 7
+    add           eax, 4   ; v
+    mov           r1d, eax
+    shl           rax, 48  ; vw
+    sub           r2d, r1d ; r - v
+    sub            r4, rax ; dif - vw
+    cmovb         r2d, r1d
+    cmovb          r4, r3
+    setb           al
     not            r4
     jmp m(msac_decode_symbol_adapt4).renorm3
 
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -37,13 +37,17 @@
 /* The normal code doesn't use function pointers */
 typedef unsigned (*decode_symbol_adapt_fn)(MsacContext *s, uint16_t *cdf,
                                            size_t n_symbols);
+typedef unsigned (*decode_bool_adapt_fn)(MsacContext *s, uint16_t *cdf);
 typedef unsigned (*decode_bool_equi_fn)(MsacContext *s);
+typedef unsigned (*decode_bool_fn)(MsacContext *s, unsigned f);
 
 typedef struct {
     decode_symbol_adapt_fn symbol_adapt4;
     decode_symbol_adapt_fn symbol_adapt8;
     decode_symbol_adapt_fn symbol_adapt16;
+    decode_bool_adapt_fn   bool_adapt;
     decode_bool_equi_fn    bool_equi;
+    decode_bool_fn         bool;
 } MsacDSPContext;
 
 static void randomize_cdf(uint16_t *const cdf, int n) {
@@ -85,9 +89,7 @@
     }                                                                      \
 } while (0)
 
-static void check_decode_symbol_adapt(MsacDSPContext *const c,
-                                      uint8_t *const buf)
-{
+static void check_decode_symbol(MsacDSPContext *const c, uint8_t *const buf) {
     /* Use an aligned CDF buffer for more consistent benchmark
      * results, and a misaligned one for checking correctness. */
     ALIGN_STK_16(uint16_t, cdf, 2, [17]);
@@ -97,16 +99,36 @@
     CHECK_SYMBOL_ADAPT( 4, 1,  5);
     CHECK_SYMBOL_ADAPT( 8, 1,  8);
     CHECK_SYMBOL_ADAPT(16, 4, 16);
-    report("decode_symbol_adapt");
+    report("decode_symbol");
 }
 
-static void check_decode_bool_equi(MsacDSPContext *const c,
-                                   uint8_t *const buf)
-{
-    declare_func(unsigned, MsacContext *s);
+static void check_decode_bool(MsacDSPContext *const c, uint8_t *const buf) {
+    MsacContext s_c, s_a;
+
+    if (check_func(c->bool_adapt, "msac_decode_bool_adapt")) {
+        declare_func(unsigned, MsacContext *s, uint16_t *cdf);
+        uint16_t cdf[2][2];
+        for (int cdf_update = 0; cdf_update <= 1; cdf_update++) {
+            dav1d_msac_init(&s_c, buf, BUF_SIZE, !cdf_update);
+            s_a = s_c;
+            cdf[0][0] = cdf[1][0] = rnd() % 32767 + 1;
+            cdf[0][1] = cdf[1][1] = 0;
+            for (int i = 0; i < 64; i++) {
+                unsigned c_res = call_ref(&s_c, cdf[0]);
+                unsigned a_res = call_new(&s_a, cdf[1]);
+                if (c_res != a_res || msac_cmp(&s_c, &s_a) ||
+                    memcmp(cdf[0], cdf[1], sizeof(*cdf)))
+                {
+                    fail();
+                }
+            }
+            if (cdf_update)
+                bench_new(&s_a, cdf[0]);
+        }
+    }
 
     if (check_func(c->bool_equi, "msac_decode_bool_equi")) {
-        MsacContext s_c, s_a;
+        declare_func(unsigned, MsacContext *s);
         dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
         s_a = s_c;
         for (int i = 0; i < 64; i++) {
@@ -118,7 +140,21 @@
         bench_new(&s_a);
     }
 
-    report("decode_bool_equi");
+    if (check_func(c->bool, "msac_decode_bool")) {
+        declare_func(unsigned, MsacContext *s, unsigned f);
+        dav1d_msac_init(&s_c, buf, BUF_SIZE, 1);
+        s_a = s_c;
+        for (int i = 0; i < 64; i++) {
+            const unsigned f = rnd() & 0x7fff;
+            unsigned c_res = call_ref(&s_c, f);
+            unsigned a_res = call_new(&s_a, f);
+            if (c_res != a_res || msac_cmp(&s_c, &s_a))
+                fail();
+        }
+        bench_new(&s_a, 16384);
+    }
+
+    report("decode_bool");
 }
 
 void checkasm_check_msac(void) {
@@ -126,7 +162,9 @@
     c.symbol_adapt4  = dav1d_msac_decode_symbol_adapt_c;
     c.symbol_adapt8  = dav1d_msac_decode_symbol_adapt_c;
     c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt_c;
+    c.bool_adapt     = dav1d_msac_decode_bool_adapt_c;
     c.bool_equi      = dav1d_msac_decode_bool_equi_c;
+    c.bool           = dav1d_msac_decode_bool_c;
 
 #if ARCH_AARCH64 && HAVE_ASM
     if (dav1d_get_cpu_flags() & DAV1D_ARM_CPU_FLAG_NEON) {
@@ -139,7 +177,9 @@
         c.symbol_adapt4  = dav1d_msac_decode_symbol_adapt4_sse2;
         c.symbol_adapt8  = dav1d_msac_decode_symbol_adapt8_sse2;
         c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
+        c.bool_adapt     = dav1d_msac_decode_bool_adapt_sse2;
         c.bool_equi      = dav1d_msac_decode_bool_equi_sse2;
+        c.bool           = dav1d_msac_decode_bool_sse2;
     }
 #endif
 
@@ -147,6 +187,6 @@
     for (int i = 0; i < BUF_SIZE; i++)
         buf[i] = rnd();
 
-    check_decode_symbol_adapt(&c, buf);
-    check_decode_bool_equi(&c, buf);
+    check_decode_symbol(&c, buf);
+    check_decode_bool(&c, buf);
 }