shithub: dav1d

Download patch

ref: 3c8110a947e28d7a82087aadc520cd25ac720fe4
parent: 35ab85bbbc6f8270b5ba973cf8452cb5aa2b32eb
author: Lynne <dev@lynne.ee>
date: Mon Jan 13 06:06:05 EST 2020

x86/msac: add an avx2 version for msac_decode_symbol_adapt16

msac_decode_symbol_adapt16_c: 55.1
msac_decode_symbol_adapt16_sse2: 30.3
msac_decode_symbol_adapt16_avx2: 28.0

Most code written by Henrik Gramner.

--- a/src/x86/msac.asm
+++ b/src/x86/msac.asm
@@ -67,7 +67,7 @@
     .update_cdf: resd 1
 endstruc
 
-%define m(x) mangle(private_prefix %+ _ %+ x %+ SUFFIX)
+%define m(x, y) mangle(private_prefix %+ _ %+ x %+ y)
 
 SECTION .text
 
@@ -240,7 +240,7 @@
     pcmpeqw        m1, m2
     pmovmskb      eax, m1
     test          t3d, t3d
-    jz m(msac_decode_symbol_adapt4).renorm
+    jz m(msac_decode_symbol_adapt4, SUFFIX).renorm
     movzx         t3d, word [t1+t4*2]
     pcmpeqw        m2, m2
     mov           t2d, t3d
@@ -257,7 +257,7 @@
     paddw          m0, m2
     mova         [t1], m0
     mov     [t1+t4*2], t2w
-    jmp m(msac_decode_symbol_adapt4).renorm
+    jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm
 
 cglobal msac_decode_symbol_adapt16, 0, 6, 6
     DECODE_SYMBOL_ADAPT_INIT
@@ -330,7 +330,7 @@
 %if WIN64
     add           rsp, 48
 %endif
-    jmp m(msac_decode_symbol_adapt4).renorm2
+    jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm2
 
 cglobal msac_decode_bool_adapt, 0, 6, 0
     movifnidn      t1, r1mp
@@ -366,7 +366,7 @@
 %endif
     not            t4
     test          t3d, t3d
-    jz m(msac_decode_symbol_adapt4).renorm3
+    jz m(msac_decode_symbol_adapt4, SUFFIX).renorm3
 %if UNIX64 == 0
     push           t6
 %endif
@@ -390,13 +390,13 @@
 %if WIN64
     mov           t1d, [t7+msac.cnt]
     pop            t6
-    jmp m(msac_decode_symbol_adapt4).renorm4
+    jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm4
 %else
 %if ARCH_X86_64 == 0
     pop            t5
     pop            t6
 %endif
-    jmp m(msac_decode_symbol_adapt4).renorm3
+    jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm3
 %endif
 
 cglobal msac_decode_bool_equi, 0, 6, 0
@@ -418,7 +418,7 @@
 %if ARCH_X86_64 == 0
     movzx         eax, al
 %endif
-    jmp m(msac_decode_symbol_adapt4).renorm3
+    jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm3
 
 cglobal msac_decode_bool, 0, 6, 0
     movifnidn      t0, r0mp
@@ -442,7 +442,7 @@
 %if ARCH_X86_64 == 0
     movzx         eax, al
 %endif
-    jmp m(msac_decode_symbol_adapt4).renorm3
+    jmp m(msac_decode_symbol_adapt4, SUFFIX).renorm3
 
 %macro HI_TOK 1 ; update_cdf
 %if ARCH_X86_64 == 0
@@ -598,3 +598,71 @@
     HI_TOK          1
 .no_update_cdf:
     HI_TOK          0
+
+%if ARCH_X86_64
+INIT_YMM avx2
+cglobal msac_decode_symbol_adapt16, 3, 6, 6
+    lea           rax, [pw_0xff00]
+    vpbroadcastw   m2, [t0+msac.rng]
+    mova           m0, [t1]
+    vpbroadcastw   m3, [t0+msac.dif+6]
+    vbroadcasti128 m4, [rax]
+    mov           t3d, [t0+msac.update_cdf]
+    mov           t4d, t2d
+    not            t2
+%if STACK_ALIGNMENT < 32
+    mov            r5, rsp
+%if WIN64
+    and           rsp, ~31
+    sub           rsp, 40
+%else
+    and            r5, ~31
+    %define buf r5-32
+%endif
+%elif WIN64
+    sub           rsp, 64
+%else
+    %define buf rsp-56
+%endif
+    psrlw          m1, m0, 6
+    movd      [buf-4], xm2
+    pand           m2, m4
+    psllw          m1, 7
+    pmulhuw        m1, m2
+    paddw          m1, [rax+t2*2]
+    mova        [buf], m1
+    pmaxuw         m1, m3
+    pcmpeqw        m1, m3
+    pmovmskb      eax, m1
+    test          t3d, t3d
+    jz .renorm
+    movzx         t3d, word [t1+t4*2]
+    pcmpeqw        m2, m2
+    lea           t2d, [t3+80]
+    shr           t2d, 4
+    cmp           t3d, 32
+    adc           t3d, 0
+    movd          xm3, t2d
+    pavgw          m2, m1
+    psubw          m2, m0
+    psubw          m0, m1
+    psraw          m2, xm3
+    paddw          m0, m2
+    mova         [t1], m0
+    mov     [t1+t4*2], t3w
+.renorm:
+    tzcnt         eax, eax
+    mov            t4, [t0+msac.dif]
+    movzx         t1d, word [buf+rax-0]
+    movzx         t2d, word [buf+rax-2]
+    shr           eax, 1
+%if WIN64
+%if STACK_ALIGNMENT < 32
+    mov           rsp, r5
+%else
+    add           rsp, 64
+%endif
+%endif
+    vzeroupper
+    jmp m(msac_decode_symbol_adapt4, _sse2).renorm2
+%endif
--- a/src/x86/msac.h
+++ b/src/x86/msac.h
@@ -39,6 +39,10 @@
 unsigned dav1d_msac_decode_bool_sse2(MsacContext *s, unsigned f);
 unsigned dav1d_msac_decode_hi_tok_sse2(MsacContext *s, uint16_t *cdf);
 
+/* Needed for checkasm */
+unsigned dav1d_msac_decode_symbol_adapt16_avx2(MsacContext *s, uint16_t *cdf,
+                                               size_t n_symbols);
+
 #if ARCH_X86_64 || defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
 #define dav1d_msac_decode_symbol_adapt4  dav1d_msac_decode_symbol_adapt4_sse2
 #define dav1d_msac_decode_symbol_adapt8  dav1d_msac_decode_symbol_adapt8_sse2
@@ -49,7 +53,9 @@
 #define dav1d_msac_decode_bool_equi      dav1d_msac_decode_bool_equi_sse2
 #define dav1d_msac_decode_bool           dav1d_msac_decode_bool_sse2
 
-#if defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
+#if ARCH_X86_64
+#define dav1d_msac_decode_symbol_adapt16(ctx, cdf, symb) ((ctx)->symbol_adapt16(ctx, cdf, symb))
+#elif defined(__SSE2__) || (defined(_M_IX86_FP) && _M_IX86_FP >= 2)
 #define dav1d_msac_decode_symbol_adapt16 dav1d_msac_decode_symbol_adapt16_sse2
 #endif
 
--- a/src/x86/msac_init.c
+++ b/src/x86/msac_init.c
@@ -28,14 +28,15 @@
 #include "src/msac.h"
 #include "src/x86/msac.h"
 
-unsigned dav1d_msac_decode_symbol_adapt16_avx2(MsacContext *s, uint16_t *cdf,
-                                               size_t n_symbols);
-
 void dav1d_msac_init_x86(MsacContext *const s) {
     const unsigned flags = dav1d_get_cpu_flags();
 
     if (flags & DAV1D_X86_CPU_FLAG_SSE2) {
         s->symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_sse2;
+    }
+
+    if (flags & DAV1D_X86_CPU_FLAG_AVX2) {
+        s->symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2;
     }
 }
 
--- a/tests/checkasm/msac.c
+++ b/tests/checkasm/msac.c
@@ -258,6 +258,12 @@
         c.bool           = dav1d_msac_decode_bool_sse2;
         c.hi_tok         = dav1d_msac_decode_hi_tok_sse2;
     }
+
+#if ARCH_X86_64
+    if (dav1d_get_cpu_flags() & DAV1D_X86_CPU_FLAG_AVX2) {
+        c.symbol_adapt16 = dav1d_msac_decode_symbol_adapt16_avx2;
+    }
+#endif
 #endif
 
     uint8_t buf[BUF_SIZE];