shithub: dav1d

Download patch

ref: 45d4fde63754a685a6ba69c012b6f4db528664a0
parent: 250204038a3799d2a2377b9be5a87c5d4213e9ba
author: Nathan E. Egge <unlord@xiph.org>
date: Mon Nov 19 03:56:01 EST 2018

Add msac_decode_bool_equi() function

When decoding an equi-probable bit (e.g. prob = 1/2) we can simplify the
decode function.

--- a/src/decode.c
+++ b/src/decode.c
@@ -413,7 +413,7 @@
     // find reused cache entries
     int i = 0;
     for (int n = 0; n < n_cache && i < pal_sz; n++)
-        if (msac_decode_bool(&ts->msac, EC_BOOL_EPROB))
+        if (msac_decode_bool_equi(&ts->msac))
             used_cache[i++] = cache[n];
     const int n_used_cache = i;
 
@@ -477,13 +477,13 @@
     uint16_t *const pal = f->frame_thread.pass ?
         f->frame_thread.pal[((t->by >> 1) + (t->bx & 1)) * (f->b4_stride >> 1) +
                             ((t->bx >> 1) + (t->by & 1))][2] : t->pal[2];
-    if (msac_decode_bool(&ts->msac, EC_BOOL_EPROB)) {
+    if (msac_decode_bool_equi(&ts->msac)) {
         const int bits = f->cur.p.bpc - 4 + msac_decode_bools(&ts->msac, 2);
         int prev = pal[0] = msac_decode_bools(&ts->msac, f->cur.p.bpc);
         const int max = (1 << f->cur.p.bpc) - 1;
         for (int i = 1; i < b->pal_sz[1]; i++) {
             int delta = msac_decode_bools(&ts->msac, bits);
-            if (delta && msac_decode_bool(&ts->msac, EC_BOOL_EPROB)) delta = -delta;
+            if (delta && msac_decode_bool_equi(&ts->msac)) delta = -delta;
             prev = pal[i] = (prev + delta) & max;
         }
     } else {
@@ -927,7 +927,7 @@
                 delta_q = msac_decode_bools(&ts->msac, n_bits) + 1 + (1 << n_bits);
             }
             if (delta_q) {
-                if (msac_decode_bool(&ts->msac, EC_BOOL_EPROB)) delta_q = -delta_q;
+                if (msac_decode_bool_equi(&ts->msac)) delta_q = -delta_q;
                 delta_q *= 1 << f->frame_hdr->delta.q.res_log2;
             }
             ts->last_qidx = iclip(ts->last_qidx + delta_q, 1, 255);
@@ -949,7 +949,7 @@
                                    1 + (1 << n_bits);
                     }
                     if (delta_lf) {
-                        if (msac_decode_bool(&ts->msac, EC_BOOL_EPROB))
+                        if (msac_decode_bool_equi(&ts->msac))
                             delta_lf = -delta_lf;
                         delta_lf *= 1 << f->frame_hdr->delta.lf.res_log2;
                     }
@@ -1572,7 +1572,7 @@
                 } else {
                     b->comp_type = COMP_INTER_SEG;
                 }
-                b->mask_sign = msac_decode_bool(&ts->msac, EC_BOOL_EPROB);
+                b->mask_sign = msac_decode_bool_equi(&ts->msac);
                 if (DEBUG_BLOCK_INFO)
                     printf("Post-seg/wedge[%d,wedge_idx=%d,sign=%d]: r=%d\n",
                            b->comp_type == COMP_INTER_WEDGE,
--- a/src/msac.c
+++ b/src/msac.c
@@ -91,6 +91,22 @@
     return ret - 1;
 }
 
+unsigned msac_decode_bool_equi(MsacContext *const s) {
+    ec_win v, vw, dif = s->dif;
+    uint16_t r = s->rng;
+    unsigned ret;
+    assert((dif >> (EC_WIN_SIZE - 16)) < r);
+    // When the probability is 1/2, f = 16384 >> EC_PROB_SHIFT = 256 and we can
+    // replace the multiply with a simple shift.
+    v = ((r >> 8) << 7) + EC_MIN_PROB;
+    vw   = v << (EC_WIN_SIZE - 16);
+    ret  = dif >= vw;
+    dif -= ret*vw;
+    v   += ret*(r - 2*v);
+    ctx_norm(s, dif, v);
+    return !ret;
+}
+
 /* Decode a single binary value.
  * f: The probability that the bit is one
  * Return: The value decoded (0 or 1). */
@@ -111,7 +127,7 @@
 unsigned msac_decode_bools(MsacContext *const c, const unsigned l) {
     int v = 0;
     for (int n = (int) l - 1; n >= 0; n--)
-        v = (v << 1) | msac_decode_bool(c, EC_BOOL_EPROB);
+        v = (v << 1) | msac_decode_bool_equi(c);
     return v;
 }
 
@@ -122,7 +138,7 @@
     int a = 0;
     int b = k;
     while ((2 << b) < n) {
-        if (!msac_decode_bool(c, EC_BOOL_EPROB)) break;
+        if (!msac_decode_bool_equi(c)) break;
         b = k + i++;
         a = (1 << b);
     }
@@ -137,7 +153,7 @@
     assert(l > 1);
     const unsigned m = (1 << l) - n;
     const unsigned v = msac_decode_bools(c, l - 1);
-    return v < m ? v : (v << 1) - m + msac_decode_bool(c, EC_BOOL_EPROB);
+    return v < m ? v : (v << 1) - m + msac_decode_bool_equi(c);
 }
 
 static void update_cdf(uint16_t *const cdf, const unsigned val,
--- a/src/msac.h
+++ b/src/msac.h
@@ -44,7 +44,6 @@
 } MsacContext;
 
 #define EC_PROB_SHIFT 6
-#define EC_BOOL_EPROB 256
 
 void msac_init(MsacContext *c, const uint8_t *data, size_t sz, int disable_cdf_update_flag);
 unsigned msac_decode_symbol(MsacContext *s, const uint16_t *cdf,
@@ -51,6 +50,7 @@
                             const unsigned n_symbols);
 unsigned msac_decode_symbol_adapt(MsacContext *s, uint16_t *cdf,
                                   const unsigned n_symbols);
+unsigned msac_decode_bool_equi(MsacContext *const s);
 unsigned msac_decode_bool(MsacContext *s, unsigned f);
 unsigned msac_decode_bool_adapt(MsacContext *s, uint16_t *cdf);
 unsigned msac_decode_bools(MsacContext *c, unsigned l);
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -50,8 +50,8 @@
     int len = 0;
     unsigned val = 1;
 
-    while (!msac_decode_bool(msac, EC_BOOL_EPROB) && len < 32) len++;
-    while (len--) val = (val << 1) | msac_decode_bool(msac, EC_BOOL_EPROB);
+    while (!msac_decode_bool_equi(msac) && len < 32) len++;
+    while (len--) val = (val << 1) | msac_decode_bool_equi(msac);
 
     return val - 1;
 }
@@ -152,7 +152,7 @@
         unsigned mask = eob >> 1;
         if (eob_hi_bit) eob |= mask;
         for (mask >>= 1; mask; mask >>= 1) {
-            const int eob_bit = msac_decode_bool(&ts->msac, EC_BOOL_EPROB);
+            const int eob_bit = msac_decode_bool_equi(&ts->msac);
             if (eob_bit) eob |= mask;
         }
         if (dbg)
@@ -231,7 +231,7 @@
             dc_sign = sign ? 0 : 2;
             dq = (dq_tbl[0] * qm_tbl[0] + 16) >> 5;
         } else {
-            sign = msac_decode_bool(&ts->msac, EC_BOOL_EPROB);
+            sign = msac_decode_bool_equi(&ts->msac);
             if (dbg)
             printf("Post-sign[%d=%d=%d]: r=%d\n", i, rc, sign, ts->msac.rng);
             dq = (dq_tbl[1] * qm_tbl[rc] + 16) >> 5;