shithub: dav1d

Download patch

ref: a495179a8a6d2e6f49773c9a2566b23236fd55a2
parent: aca57bf3db00c29e90605656f1015561d1d67c2d
author: Ronald S. Bultje <rsbultje@gmail.com>
date: Sat Feb 1 10:04:08 EST 2020

Use union refmvs_pair { mv mv[2]; uint64_t n; } for MV pairs

This allows combining two MV comparisons at once for compound refs.

--- a/src/decode.c
+++ b/src/decode.c
@@ -300,8 +300,8 @@
 #define add_sample(dx, dy, sx, sy, rp) do { \
     pts[np][0][0] = 16 * (2 * dx + sx * bs(rp)[0]) - 8; \
     pts[np][0][1] = 16 * (2 * dy + sy * bs(rp)[1]) - 8; \
-    pts[np][1][0] = pts[np][0][0] + (rp)->mv[0].x; \
-    pts[np][1][1] = pts[np][0][1] + (rp)->mv[0].y; \
+    pts[np][1][0] = pts[np][0][0] + (rp)->mv.mv[0].x; \
+    pts[np][1][1] = pts[np][0][1] + (rp)->mv.mv[0].y; \
     np++; \
 } while (0)
 
@@ -796,13 +796,13 @@
                 refmvs_block *const r = &t->rt.r[(t->by & 31) + 5 + bh4 - 1][t->bx];
                 for (int x = 0; x < bw4; x++) {
                     r[x].ref.ref[0] = b->ref[0] + 1;
-                    r[x].mv[0] = b->mv[0];
+                    r[x].mv.mv[0] = b->mv[0];
                     r[x].bs = bs;
                 }
                 refmvs_block *const *rr = &t->rt.r[(t->by & 31) + 5];
                 for (int y = 0; y < bh4 - 1; y++) {
                     rr[y][t->bx + bw4 - 1].ref.ref[0] = b->ref[0] + 1;
-                    rr[y][t->bx + bw4 - 1].mv[0] = b->mv[0];
+                    rr[y][t->bx + bw4 - 1].mv.mv[0] = b->mv[0];
                     rr[y][t->bx + bw4 - 1].bs = bs;
                 }
             }
@@ -1305,10 +1305,10 @@
                           (union refmvs_refpair) { .ref = { 0, -1 }},
                           bs, intra_edge_flags, t->by, t->bx);
 
-        if (mvstack[0].mv[0].n)
-            b->mv[0] = mvstack[0].mv[0];
-        else if (mvstack[1].mv[0].n)
-            b->mv[0] = mvstack[1].mv[0];
+        if (mvstack[0].mv.mv[0].n)
+            b->mv[0] = mvstack[0].mv.mv[0];
+        else if (mvstack[1].mv.mv[0].n)
+            b->mv[0] = mvstack[1].mv.mv[0];
         else {
             if (t->by - (16 << f->seq_hdr->sb128) < ts->tiling.row_start) {
                 b->mv[0].y = 0;
@@ -1381,7 +1381,7 @@
         if (DEBUG_BLOCK_INFO)
             printf("Post-dmv[%d/%d,ref=%d/%d|%d/%d]: r=%d\n",
                    b->mv[0].y, b->mv[0].x, ref.y, ref.x,
-                   mvstack[0].mv[0].y, mvstack[0].mv[0].x, ts->msac.rng);
+                   mvstack[0].mv.mv[0].y, mvstack[0].mv.mv[0].x, ts->msac.rng);
         read_vartx_tree(t, b, bs, bx4, by4);
 
         // reconstruction
@@ -1448,8 +1448,8 @@
                                     b->ref[0] + 1, b->ref[1] + 1 }},
                               bs, intra_edge_flags, t->by, t->bx);
 
-            b->mv[0] = mvstack[0].mv[0];
-            b->mv[1] = mvstack[0].mv[1];
+            b->mv[0] = mvstack[0].mv.mv[0];
+            b->mv[1] = mvstack[0].mv.mv[1];
             fix_mv_precision(f->frame_hdr, &b->mv[0]);
             fix_mv_precision(f->frame_hdr, &b->mv[1]);
             if (DEBUG_BLOCK_INFO)
@@ -1571,7 +1571,7 @@
             switch (im[idx]) { \
             case NEARMV: \
             case NEARESTMV: \
-                b->mv[idx] = mvstack[b->drl_idx].mv[idx]; \
+                b->mv[idx] = mvstack[b->drl_idx].mv.mv[idx]; \
                 fix_mv_precision(f->frame_hdr, &b->mv[idx]); \
                 break; \
             case GLOBALMV: \
@@ -1582,7 +1582,7 @@
                 fix_mv_precision(f->frame_hdr, &b->mv[idx]); \
                 break; \
             case NEWMV: \
-                b->mv[idx] = mvstack[b->drl_idx].mv[idx]; \
+                b->mv[idx] = mvstack[b->drl_idx].mv.mv[idx]; \
                 read_mv_residual(t, &b->mv[idx], &ts->cdf.mv, \
                                  !f->frame_hdr->force_integer_mv); \
                 break; \
@@ -1739,7 +1739,7 @@
                         b->drl_idx = NEAREST_DRL;
                     }
                     assert(b->drl_idx >= NEAREST_DRL && b->drl_idx <= NEARISH_DRL);
-                    b->mv[0] = mvstack[b->drl_idx].mv[0];
+                    b->mv[0] = mvstack[b->drl_idx].mv.mv[0];
                     if (b->drl_idx < NEAR_DRL)
                         fix_mv_precision(f->frame_hdr, &b->mv[0]);
                 }
@@ -1764,10 +1764,10 @@
                 }
                 assert(b->drl_idx >= NEAREST_DRL && b->drl_idx <= NEARISH_DRL);
                 if (n_mvs > 1) {
-                    b->mv[0] = mvstack[b->drl_idx].mv[0];
+                    b->mv[0] = mvstack[b->drl_idx].mv.mv[0];
                 } else {
                     assert(!b->drl_idx);
-                    b->mv[0] = mvstack[0].mv[0];
+                    b->mv[0] = mvstack[0].mv.mv[0];
                     fix_mv_precision(f->frame_hdr, &b->mv[0]);
                 }
                 if (DEBUG_BLOCK_INFO)
@@ -1938,7 +1938,8 @@
         // context updates
         if (is_comp) {
             splat_tworef_mv(&t->rt, t->by, t->bx, bs, b->inter_mode,
-                            b->ref[0], b->ref[1], b->mv);
+                            (refmvs_refpair) { .ref = { b->ref[0], b->ref[1] }},
+                            (refmvs_mvpair) { .mv = { [0] = b->mv[0], [1] = b->mv[1] }});
         } else {
             splat_oneref_mv(&t->rt, t->by, t->bx, bs, b->inter_mode,
                             b->ref[0], b->mv[0], b->interintra_type);
--- a/src/recon_tmpl.c
+++ b/src/recon_tmpl.c
@@ -1018,7 +1018,7 @@
                 const int ow4 = iclip(a_b_dim[0], 2, b_dim[0]);
                 const int oh4 = imin(b_dim[1], 16) >> 1;
                 res = mc(t, lap, NULL, ow4 * h_mul * sizeof(pixel), ow4, (oh4 * 3 + 3) >> 2,
-                         t->bx + x, t->by, pl, a_r->mv[0],
+                         t->bx + x, t->by, pl, a_r->mv.mv[0],
                          &f->refp[a_r->ref.ref[0] - 1], a_r->ref.ref[0] - 1,
                          dav1d_filter_2d[t->a->filter[1][bx4 + x + 1]][t->a->filter[0][bx4 + x + 1]]);
                 if (res) return res;
@@ -1040,7 +1040,7 @@
                 const int ow4 = imin(b_dim[0], 16) >> 1;
                 const int oh4 = iclip(l_b_dim[1], 2, b_dim[1]);
                 res = mc(t, lap, NULL, h_mul * ow4 * sizeof(pixel), ow4, oh4,
-                         t->bx, t->by + y, pl, l_r->mv[0],
+                         t->bx, t->by + y, pl, l_r->mv.mv[0],
                          &f->refp[l_r->ref.ref[0] - 1], l_r->ref.ref[0] - 1,
                          dav1d_filter_2d[t->l.filter[1][by4 + y + 1]][t->l.filter[0][by4 + y + 1]]);
                 if (res) return res;
@@ -1631,7 +1631,7 @@
                     res = mc(t, ((pixel *) f->cur.data[1 + pl]) + uvdstoff,
                              NULL, f->cur.stride[1],
                              bw4, bh4, t->bx - 1, t->by - 1, 1 + pl,
-                             r[-1][t->bx - 1].mv[0],
+                             r[-1][t->bx - 1].mv.mv[0],
                              &f->refp[r[-1][t->bx - 1].ref.ref[0] - 1],
                              r[-1][t->bx - 1].ref.ref[0] - 1,
                              f->frame_thread.pass != 2 ? t->tl_4x4_filter :
@@ -1647,7 +1647,7 @@
                 for (int pl = 0; pl < 2; pl++) {
                     res = mc(t, ((pixel *) f->cur.data[1 + pl]) + uvdstoff + v_off, NULL,
                              f->cur.stride[1], bw4, bh4, t->bx - 1,
-                             t->by, 1 + pl, r[0][t->bx - 1].mv[0],
+                             t->by, 1 + pl, r[0][t->bx - 1].mv.mv[0],
                              &f->refp[r[0][t->bx - 1].ref.ref[0] - 1],
                              r[0][t->bx - 1].ref.ref[0] - 1,
                              f->frame_thread.pass != 2 ? left_filter_2d :
@@ -1662,7 +1662,7 @@
                 for (int pl = 0; pl < 2; pl++) {
                     res = mc(t, ((pixel *) f->cur.data[1 + pl]) + uvdstoff + h_off, NULL,
                              f->cur.stride[1], bw4, bh4, t->bx, t->by - 1,
-                             1 + pl, r[-1][t->bx].mv[0],
+                             1 + pl, r[-1][t->bx].mv.mv[0],
                              &f->refp[r[-1][t->bx].ref.ref[0] - 1],
                              r[-1][t->bx].ref.ref[0] - 1,
                              f->frame_thread.pass != 2 ? top_filter_2d :
--- a/src/refmvs.c
+++ b/src/refmvs.c
@@ -43,17 +43,17 @@
                                   int *const have_newmv_match,
                                   int *const have_refmv_match)
 {
-    if (b->mv[0].n == INVALID_MV) return; // intra block, no intrabc
+    if (b->mv.mv[0].n == INVALID_MV) return; // intra block, no intrabc
 
     if (ref.ref[1] == -1) {
         for (int n = 0; n < 2; n++) {
             if (b->ref.ref[n] == ref.ref[0]) {
                 const mv cand_mv = ((b->mf & 1) && gmv[0].n != INVALID_MV) ?
-                                   gmv[0] : b->mv[n];
+                                   gmv[0] : b->mv.mv[n];
 
                 const int last = *cnt;
                 for (int m = 0; m < last; m++)
-                    if (mvstack[m].mv[0].n == cand_mv.n) {
+                    if (mvstack[m].mv.mv[0].n == cand_mv.n) {
                         mvstack[m].weight += weight;
                         *have_refmv_match = 1;
                         *have_newmv_match |= b->mf >> 1;
@@ -61,7 +61,7 @@
                     }
 
                 if (last < 8) {
-                    mvstack[last].mv[0] = cand_mv;
+                    mvstack[last].mv.mv[0] = cand_mv;
                     mvstack[last].weight = weight;
                     *cnt = last + 1;
                 }
@@ -71,16 +71,14 @@
             }
         }
     } else if (b->ref.pair == ref.pair) {
-        const mv cand_mv[2] = {
-            [0] = ((b->mf & 1) && gmv[0].n != INVALID_MV) ? gmv[0] : b->mv[0],
-            [1] = ((b->mf & 1) && gmv[1].n != INVALID_MV) ? gmv[1] : b->mv[1],
-        };
+        const refmvs_mvpair cand_mv = { .mv = {
+            [0] = ((b->mf & 1) && gmv[0].n != INVALID_MV) ? gmv[0] : b->mv.mv[0],
+            [1] = ((b->mf & 1) && gmv[1].n != INVALID_MV) ? gmv[1] : b->mv.mv[1],
+        }};
 
         const int last = *cnt;
         for (int n = 0; n < last; n++)
-            if (mvstack[n].mv[0].n == cand_mv[0].n &&
-                mvstack[n].mv[1].n == cand_mv[1].n)
-            {
+            if (mvstack[n].mv.n == cand_mv.n) {
                 mvstack[n].weight += weight;
                 *have_refmv_match = 1;
                 *have_newmv_match |= b->mf >> 1;
@@ -88,8 +86,7 @@
             }
 
         if (last < 8) {
-            mvstack[last].mv[0] = cand_mv[0];
-            mvstack[last].mv[1] = cand_mv[1];
+            mvstack[last].mv = cand_mv;
             mvstack[last].weight = weight;
             *cnt = last + 1;
         }
@@ -208,27 +205,29 @@
             *globalmv_ctx = (abs(mv.x - gmv[0].x) | abs(mv.y - gmv[0].y)) >= 16;
 
         for (int n = 0; n < last; n++)
-            if (mvstack[n].mv[0].n == mv.n) {
+            if (mvstack[n].mv.mv[0].n == mv.n) {
                 mvstack[n].weight += 2;
                 return;
             }
         if (last < 8) {
-            mvstack[last].mv[0] = mv;
+            mvstack[last].mv.mv[0] = mv;
             mvstack[last].weight = 2;
             *cnt = last + 1;
         }
     } else {
-        union mv mv2 = mv_projection(rb->mv, rf->pocdiff[ref.ref[1] - 1], rb->ref);
-        fix_mv_precision(rf->frm_hdr, &mv2);
+        refmvs_mvpair mvp = { .mv = {
+            [0] = mv,
+            [1] = mv_projection(rb->mv, rf->pocdiff[ref.ref[1] - 1], rb->ref),
+        }};
+        fix_mv_precision(rf->frm_hdr, &mvp.mv[1]);
 
         for (int n = 0; n < last; n++)
-            if (mvstack[n].mv[0].n == mv.n && mvstack[n].mv[1].n == mv2.n) {
+            if (mvstack[n].mv.n == mvp.n) {
                 mvstack[n].weight += 2;
                 return;
             }
         if (last < 8) {
-            mvstack[last].mv[0] = mv;
-            mvstack[last].mv[1] = mv2;
+            mvstack[last].mv = mvp;
             mvstack[last].weight = 2;
             *cnt = last + 1;
         }
@@ -250,26 +249,26 @@
 
         if (cand_ref <= 0) break;
 
-        mv cand_mv = cand_b->mv[n];
+        mv cand_mv = cand_b->mv.mv[n];
         if (cand_ref == ref.ref[0]) {
             if (same_count[0] < 2)
-                same[same_count[0]++].mv[0] = cand_mv;
+                same[same_count[0]++].mv.mv[0] = cand_mv;
             if (diff_count[1] < 2) {
                 if (sign1 ^ sign_bias[cand_ref - 1]) {
                     cand_mv.y = -cand_mv.y;
                     cand_mv.x = -cand_mv.x;
                 }
-                diff[diff_count[1]++].mv[1] = cand_mv;
+                diff[diff_count[1]++].mv.mv[1] = cand_mv;
             }
         } else if (cand_ref == ref.ref[1]) {
             if (same_count[1] < 2)
-                same[same_count[1]++].mv[1] = cand_mv;
+                same[same_count[1]++].mv.mv[1] = cand_mv;
             if (diff_count[0] < 2) {
                 if (sign0 ^ sign_bias[cand_ref - 1]) {
                     cand_mv.y = -cand_mv.y;
                     cand_mv.x = -cand_mv.x;
                 }
-                diff[diff_count[0]++].mv[0] = cand_mv;
+                diff[diff_count[0]++].mv.mv[0] = cand_mv;
             }
         } else {
             mv i_cand_mv = (union mv) {
@@ -278,13 +277,13 @@
             };
 
             if (diff_count[0] < 2) {
-                diff[diff_count[0]++].mv[0] =
+                diff[diff_count[0]++].mv.mv[0] =
                     sign0 ^ sign_bias[cand_ref - 1] ?
                     i_cand_mv : cand_mv;
             }
 
             if (diff_count[1] < 2) {
-                diff[diff_count[1]++].mv[1] =
+                diff[diff_count[1]++].mv.mv[1] =
                     sign1 ^ sign_bias[cand_ref - 1] ?
                     i_cand_mv : cand_mv;
             }
@@ -306,7 +305,7 @@
         // FIXME if scan_{row,col}() returned a mask for the nearest
         // edge, we could skip the appropriate ones here
 
-        mv cand_mv = cand_b->mv[n];
+        mv cand_mv = cand_b->mv.mv[n];
         if (sign ^ sign_bias[cand_ref - 1]) {
             cand_mv.y = -cand_mv.y;
             cand_mv.x = -cand_mv.x;
@@ -315,10 +314,10 @@
         int m;
         const int last = *cnt;
         for (m = 0; m < last; m++)
-            if (cand_mv.n == mvstack[m].mv[0].n)
+            if (cand_mv.n == mvstack[m].mv.mv[0].n)
                 break;
         if (m == last) {
-            mvstack[m].mv[0] = cand_mv;
+            mvstack[m].mv.mv[0] = cand_mv;
             mvstack[m].weight = 2; // "minimal"
             *cnt = last + 1;
         }
@@ -561,15 +560,15 @@
 
                 const int l = diff_count[n];
                 if (l) {
-                    same[m].mv[n] = diff[0].mv[n];
+                    same[m].mv.mv[n] = diff[0].mv.mv[n];
                     if (++m == 2) continue;
                     if (l == 2) {
-                        same[1].mv[n] = diff[1].mv[n];
+                        same[1].mv.mv[n] = diff[1].mv.mv[n];
                         continue;
                     }
                 }
                 do {
-                    same[m].mv[n] = tgmv[n];
+                    same[m].mv.mv[n] = tgmv[n];
                 } while (++m < 2);
             }
 
@@ -576,12 +575,8 @@
             // if the first extended was the same as the non-extended one,
             // then replace it with the second extended one
             int n = *cnt;
-            if (n == 1 && mvstack[0].mv[0].n == same[0].mv[0].n &&
-                mvstack[0].mv[1].n == same[0].mv[1].n)
-            {
-                mvstack[1].mv[0] = mvstack[2].mv[0];
-                mvstack[1].mv[1] = mvstack[2].mv[1];
-            }
+            if (n == 1 && mvstack[0].mv.n == same[0].mv.n)
+                mvstack[1].mv = mvstack[2].mv;
             do {
                 mvstack[n].weight = 2;
             } while (++n < 2);
@@ -597,10 +592,10 @@
         const int n_refmvs = *cnt;
         int n = 0;
         do {
-            mvstack[n].mv[0].x = iclip(mvstack[n].mv[0].x, left, right);
-            mvstack[n].mv[0].y = iclip(mvstack[n].mv[0].y, top, bottom);
-            mvstack[n].mv[1].x = iclip(mvstack[n].mv[1].x, left, right);
-            mvstack[n].mv[1].y = iclip(mvstack[n].mv[1].y, top, bottom);
+            mvstack[n].mv.mv[0].x = iclip(mvstack[n].mv.mv[0].x, left, right);
+            mvstack[n].mv.mv[0].y = iclip(mvstack[n].mv.mv[0].y, top, bottom);
+            mvstack[n].mv.mv[1].x = iclip(mvstack[n].mv.mv[1].x, left, right);
+            mvstack[n].mv.mv[1].y = iclip(mvstack[n].mv.mv[1].y, top, bottom);
         } while (++n < n_refmvs);
 
         switch (refmv_ctx >> 1) {
@@ -646,13 +641,13 @@
 
         int n = 0;
         do {
-            mvstack[n].mv[0].x = iclip(mvstack[n].mv[0].x, left, right);
-            mvstack[n].mv[0].y = iclip(mvstack[n].mv[0].y, top, bottom);
+            mvstack[n].mv.mv[0].x = iclip(mvstack[n].mv.mv[0].x, left, right);
+            mvstack[n].mv.mv[0].y = iclip(mvstack[n].mv.mv[0].y, top, bottom);
         } while (++n < n_refmvs);
     }
 
     for (int n = *cnt; n < 2; n++)
-        mvstack[n].mv[0] = tgmv[0];
+        mvstack[n].mv.mv[0] = tgmv[0];
 
     *ctx = (refmv_ctx << 4) | (globalmv_ctx << 3) | newmv_ctx;
 }
@@ -786,16 +781,16 @@
             const int bw8 = (dav1d_block_dimensions[cand_b->bs][0] + 1) >> 1;
 
             if (cand_b->ref.ref[1] > 0 && ref_sign[cand_b->ref.ref[1] - 1] &&
-                (abs(cand_b->mv[1].y) | abs(cand_b->mv[1].x)) < 4096)
+                (abs(cand_b->mv.mv[1].y) | abs(cand_b->mv.mv[1].x)) < 4096)
             {
                 for (int n = 0; n < bw8; n++, x++)
-                    rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv[1],
+                    rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv.mv[1],
                                                       .ref = cand_b->ref.ref[1] };
             } else if (cand_b->ref.ref[0] > 0 && ref_sign[cand_b->ref.ref[0] - 1] &&
-                       (abs(cand_b->mv[0].y) | abs(cand_b->mv[0].x)) < 4096)
+                       (abs(cand_b->mv.mv[0].y) | abs(cand_b->mv.mv[0].x)) < 4096)
             {
                 for (int n = 0; n < bw8; n++, x++)
-                    rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv[0],
+                    rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv.mv[0],
                                                       .ref = cand_b->ref.ref[0] };
             } else {
                 for (int n = 0; n < bw8; n++, x++)
--- a/src/refmvs.h
+++ b/src/refmvs.h
@@ -50,13 +50,13 @@
     uint16_t pair;
 } refmvs_refpair;
 
-// would be nice to have a mvpair also, so double mv comparisons in
-// add_{spatial,temporal}_candidate() can be done in a single comparison,
-// but that would extend the size of refmvs_block to 16 byte (from 12)
-// (on x86-64) which we probably don't want to do.
+typedef union refmvs_mvpair {
+    mv mv[2];
+    uint64_t n;
+} refmvs_mvpair;
 
 typedef struct refmvs_block {
-    mv mv[2];
+    refmvs_mvpair mv;
     refmvs_refpair ref;
     uint8_t bs, mf; // 1 = globalmv+affine, 2 = newmv
 } refmvs_block;
@@ -93,7 +93,7 @@
 } refmvs_tile;
 
 typedef struct refmvs_candidate {
-    mv mv[2];
+    refmvs_mvpair mv;
     int weight;
 } refmvs_candidate;
 
@@ -150,7 +150,7 @@
 
     const refmvs_block tmpl = (refmvs_block) {
         .ref.ref = { ref + 1, is_interintra ? 0 : -1 },
-        .mv = { mv },
+        .mv.mv[0] = mv,
         .bs = bs,
         .mf = (mode == GLOBALMV && imin(bw4, bh4) >= 2) | ((mode == NEWMV) * 2),
     };
@@ -171,7 +171,7 @@
 
     const refmvs_block tmpl = (refmvs_block) {
         .ref.ref = { 0, -1 },
-        .mv = { mv },
+        .mv.mv[0] = mv,
         .bs = bs,
         .mf = 0,
     };
@@ -187,8 +187,8 @@
                                    const int by4, const int bx4,
                                    const enum BlockSize bs,
                                    const enum CompInterPredMode mode,
-                                   const int ref1, const int ref2,
-                                   const mv mv[2])
+                                   const refmvs_refpair ref,
+                                   const refmvs_mvpair mv)
 {
     const int bw4 = dav1d_block_dimensions[bs][0];
     int bh4 = dav1d_block_dimensions[bs][1];
@@ -196,8 +196,8 @@
 
     assert(bw4 >= 2 && bh4 >= 2);
     const refmvs_block tmpl = (refmvs_block) {
-        .ref.ref = { ref1 + 1, ref2 + 1 },
-        .mv = { mv[0], mv[1] },
+        .ref.pair = ref.pair + 0x0101,
+        .mv = mv,
         .bs = bs,
         .mf = (mode == GLOBALMV_GLOBALMV) | !!((1 << mode) & (0xbc)) * 2,
     };
@@ -218,7 +218,7 @@
 
     const refmvs_block tmpl = (refmvs_block) {
         .ref.ref = { 0, -1 },
-        .mv = { [0] = { .n = INVALID_MV } },
+        .mv.mv[0].n = INVALID_MV,
         .bs = bs,
         .mf = 0,
     };