shithub: libvpx

Download patch

ref: 0f563e5fadbccb10fabd6ac80c256a4321401e22
parent: f7364c05748b70a1e0fd57849665a9d9f0990803
author: Jonathan Wright <jonathan.wright@arm.com>
date: Fri May 7 09:25:51 EDT 2021

Optimize Neon reductions in sum_neon.h using ADDV instruction

Use the AArch64-only ADDV and ADDLV instructions to accelerate
reductions that add across a Neon vector in sum_neon.h. This commit
also refactors the inline functions to return a scalar instead of a
vector - allowing for optimization of the surrounding code at each
call site.

Bug: b/181236880
Change-Id: Ieed2a2dd3c74f8a52957bf404141ffc044bd5d79

--- a/vpx_dsp/arm/avg_neon.c
+++ b/vpx_dsp/arm/avg_neon.c
@@ -22,8 +22,7 @@
 uint32_t vpx_avg_4x4_neon(const uint8_t *a, int a_stride) {
   const uint8x16_t b = load_unaligned_u8q(a, a_stride);
   const uint16x8_t c = vaddl_u8(vget_low_u8(b), vget_high_u8(b));
-  const uint32x2_t d = horizontal_add_uint16x8(c);
-  return vget_lane_u32(vrshr_n_u32(d, 4), 0);
+  return (horizontal_add_uint16x8(c) + (1 << 3)) >> 4;
 }
 
 uint32_t vpx_avg_8x8_neon(const uint8_t *a, int a_stride) {
@@ -30,7 +29,6 @@
   int i;
   uint8x8_t b, c;
   uint16x8_t sum;
-  uint32x2_t d;
   b = vld1_u8(a);
   a += a_stride;
   c = vld1_u8(a);
@@ -43,9 +41,7 @@
     sum = vaddw_u8(sum, d);
   }
 
-  d = horizontal_add_uint16x8(sum);
-
-  return vget_lane_u32(vrshr_n_u32(d, 6), 0);
+  return (horizontal_add_uint16x8(sum) + (1 << 5)) >> 6;
 }
 
 // coeff: 16 bits, dynamic range [-32640, 32640].
@@ -139,8 +135,7 @@
     ref += 16;
   }
 
-  return vget_lane_s16(vreinterpret_s16_u32(horizontal_add_uint16x8(vec_sum)),
-                       0);
+  return (int16_t)horizontal_add_uint16x8(vec_sum);
 }
 
 // ref, src = [0, 510] - max diff = 16-bits
--- a/vpx_dsp/arm/fdct_partial_neon.c
+++ b/vpx_dsp/arm/fdct_partial_neon.c
@@ -15,19 +15,10 @@
 #include "vpx_dsp/arm/mem_neon.h"
 #include "vpx_dsp/arm/sum_neon.h"
 
-static INLINE tran_low_t get_lane(const int32x2_t a) {
-#if CONFIG_VP9_HIGHBITDEPTH
-  return vget_lane_s32(a, 0);
-#else
-  return vget_lane_s16(vreinterpret_s16_s32(a), 0);
-#endif  // CONFIG_VP9_HIGHBITDETPH
-}
-
 void vpx_fdct4x4_1_neon(const int16_t *input, tran_low_t *output, int stride) {
   int16x4_t a0, a1, a2, a3;
   int16x8_t b0, b1;
   int16x8_t c;
-  int32x2_t d;
 
   a0 = vld1_s16(input);
   input += stride;
@@ -42,9 +33,7 @@
 
   c = vaddq_s16(b0, b1);
 
-  d = horizontal_add_int16x8(c);
-
-  output[0] = get_lane(vshl_n_s32(d, 1));
+  output[0] = (tran_low_t)(horizontal_add_int16x8(c) << 1);
   output[1] = 0;
 }
 
@@ -57,7 +46,7 @@
     sum = vaddq_s16(sum, input_00);
   }
 
-  output[0] = get_lane(horizontal_add_int16x8(sum));
+  output[0] = (tran_low_t)horizontal_add_int16x8(sum);
   output[1] = 0;
 }
 
@@ -66,7 +55,7 @@
   int r;
   int16x8_t left = vld1q_s16(input);
   int16x8_t right = vld1q_s16(input + 8);
-  int32x2_t sum;
+  int32_t sum;
   input += stride;
 
   for (r = 1; r < 16; ++r) {
@@ -77,9 +66,9 @@
     right = vaddq_s16(right, b);
   }
 
-  sum = vadd_s32(horizontal_add_int16x8(left), horizontal_add_int16x8(right));
+  sum = horizontal_add_int16x8(left) + horizontal_add_int16x8(right);
 
-  output[0] = get_lane(vshr_n_s32(sum, 1));
+  output[0] = (tran_low_t)(sum >> 1);
   output[1] = 0;
 }
 
@@ -90,7 +79,7 @@
   int16x8_t a1 = vld1q_s16(input + 8);
   int16x8_t a2 = vld1q_s16(input + 16);
   int16x8_t a3 = vld1q_s16(input + 24);
-  int32x2_t sum;
+  int32_t sum;
   input += stride;
 
   for (r = 1; r < 32; ++r) {
@@ -105,9 +94,10 @@
     a3 = vaddq_s16(a3, b3);
   }
 
-  sum = vadd_s32(horizontal_add_int16x8(a0), horizontal_add_int16x8(a1));
-  sum = vadd_s32(sum, horizontal_add_int16x8(a2));
-  sum = vadd_s32(sum, horizontal_add_int16x8(a3));
-  output[0] = get_lane(vshr_n_s32(sum, 3));
+  sum = horizontal_add_int16x8(a0);
+  sum += horizontal_add_int16x8(a1);
+  sum += horizontal_add_int16x8(a2);
+  sum += horizontal_add_int16x8(a3);
+  output[0] = (tran_low_t)(sum >> 3);
   output[1] = 0;
 }
--- a/vpx_dsp/arm/sad_neon.c
+++ b/vpx_dsp/arm/sad_neon.c
@@ -23,7 +23,7 @@
   const uint8x16_t ref_u8 = load_unaligned_u8q(ref_ptr, ref_stride);
   uint16x8_t abs = vabdl_u8(vget_low_u8(src_u8), vget_low_u8(ref_u8));
   abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(ref_u8));
-  return vget_lane_u32(horizontal_add_uint16x8(abs), 0);
+  return horizontal_add_uint16x8(abs);
 }
 
 uint32_t vpx_sad4x4_avg_neon(const uint8_t *src_ptr, int src_stride,
@@ -35,7 +35,7 @@
   const uint8x16_t avg = vrhaddq_u8(ref_u8, second_pred_u8);
   uint16x8_t abs = vabdl_u8(vget_low_u8(src_u8), vget_low_u8(avg));
   abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(avg));
-  return vget_lane_u32(horizontal_add_uint16x8(abs), 0);
+  return horizontal_add_uint16x8(abs);
 }
 
 uint32_t vpx_sad4x8_neon(const uint8_t *src_ptr, int src_stride,
@@ -51,7 +51,7 @@
     abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(ref_u8));
   }
 
-  return vget_lane_u32(horizontal_add_uint16x8(abs), 0);
+  return horizontal_add_uint16x8(abs);
 }
 
 uint32_t vpx_sad4x8_avg_neon(const uint8_t *src_ptr, int src_stride,
@@ -71,7 +71,7 @@
     abs = vabal_u8(abs, vget_high_u8(src_u8), vget_high_u8(avg));
   }
 
-  return vget_lane_u32(horizontal_add_uint16x8(abs), 0);
+  return horizontal_add_uint16x8(abs);
 }
 
 static INLINE uint16x8_t sad8x(const uint8_t *src_ptr, int src_stride,
@@ -114,7 +114,7 @@
   uint32_t vpx_sad8x##n##_neon(const uint8_t *src_ptr, int src_stride,         \
                                const uint8_t *ref_ptr, int ref_stride) {       \
     const uint16x8_t abs = sad8x(src_ptr, src_stride, ref_ptr, ref_stride, n); \
-    return vget_lane_u32(horizontal_add_uint16x8(abs), 0);                     \
+    return horizontal_add_uint16x8(abs);                                       \
   }                                                                            \
                                                                                \
   uint32_t vpx_sad8x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,     \
@@ -122,7 +122,7 @@
                                    const uint8_t *second_pred) {               \
     const uint16x8_t abs =                                                     \
         sad8x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n);   \
-    return vget_lane_u32(horizontal_add_uint16x8(abs), 0);                     \
+    return horizontal_add_uint16x8(abs);                                       \
   }
 
 sad8xN(4);
@@ -172,7 +172,7 @@
                                 const uint8_t *ref_ptr, int ref_stride) {     \
     const uint16x8_t abs =                                                    \
         sad16x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return vget_lane_u32(horizontal_add_uint16x8(abs), 0);                    \
+    return horizontal_add_uint16x8(abs);                                      \
   }                                                                           \
                                                                               \
   uint32_t vpx_sad16x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
@@ -180,7 +180,7 @@
                                     const uint8_t *second_pred) {             \
     const uint16x8_t abs =                                                    \
         sad16x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return vget_lane_u32(horizontal_add_uint16x8(abs), 0);                    \
+    return horizontal_add_uint16x8(abs);                                      \
   }
 
 sad16xN(8);
@@ -240,7 +240,7 @@
                                 const uint8_t *ref_ptr, int ref_stride) {     \
     const uint16x8_t abs =                                                    \
         sad32x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return vget_lane_u32(horizontal_add_uint16x8(abs), 0);                    \
+    return horizontal_add_uint16x8(abs);                                      \
   }                                                                           \
                                                                               \
   uint32_t vpx_sad32x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
@@ -248,7 +248,7 @@
                                     const uint8_t *second_pred) {             \
     const uint16x8_t abs =                                                    \
         sad32x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return vget_lane_u32(horizontal_add_uint16x8(abs), 0);                    \
+    return horizontal_add_uint16x8(abs);                                      \
   }
 
 sad32xN(16);
@@ -338,7 +338,7 @@
                                 const uint8_t *ref_ptr, int ref_stride) {     \
     const uint32x4_t abs =                                                    \
         sad64x(src_ptr, src_stride, ref_ptr, ref_stride, n);                  \
-    return vget_lane_u32(horizontal_add_uint32x4(abs), 0);                    \
+    return horizontal_add_uint32x4(abs);                                      \
   }                                                                           \
                                                                               \
   uint32_t vpx_sad64x##n##_avg_neon(const uint8_t *src_ptr, int src_stride,   \
@@ -346,7 +346,7 @@
                                     const uint8_t *second_pred) {             \
     const uint32x4_t abs =                                                    \
         sad64x_avg(src_ptr, src_stride, ref_ptr, ref_stride, second_pred, n); \
-    return vget_lane_u32(horizontal_add_uint32x4(abs), 0);                    \
+    return horizontal_add_uint32x4(abs);                                      \
   }
 
 sad64xN(32);
--- a/vpx_dsp/arm/sum_neon.h
+++ b/vpx_dsp/arm/sum_neon.h
@@ -16,23 +16,38 @@
 #include "./vpx_config.h"
 #include "vpx/vpx_integer.h"
 
-static INLINE int32x2_t horizontal_add_int16x8(const int16x8_t a) {
+static INLINE int32_t horizontal_add_int16x8(const int16x8_t a) {
+#if defined(__aarch64__)
+  return vaddlvq_s16(a);
+#else
   const int32x4_t b = vpaddlq_s16(a);
   const int64x2_t c = vpaddlq_s32(b);
-  return vadd_s32(vreinterpret_s32_s64(vget_low_s64(c)),
-                  vreinterpret_s32_s64(vget_high_s64(c)));
+  const int32x2_t d = vadd_s32(vreinterpret_s32_s64(vget_low_s64(c)),
+                               vreinterpret_s32_s64(vget_high_s64(c)));
+  return vget_lane_s32(d, 0);
+#endif
 }
 
-static INLINE uint32x2_t horizontal_add_uint16x8(const uint16x8_t a) {
+static INLINE uint32_t horizontal_add_uint16x8(const uint16x8_t a) {
+#if defined(__aarch64__)
+  return vaddlvq_u16(a);
+#else
   const uint32x4_t b = vpaddlq_u16(a);
   const uint64x2_t c = vpaddlq_u32(b);
-  return vadd_u32(vreinterpret_u32_u64(vget_low_u64(c)),
-                  vreinterpret_u32_u64(vget_high_u64(c)));
+  const uint32x2_t d = vadd_u32(vreinterpret_u32_u64(vget_low_u64(c)),
+                                vreinterpret_u32_u64(vget_high_u64(c)));
+  return vget_lane_u32(d, 0);
+#endif
 }
 
-static INLINE uint32x2_t horizontal_add_uint32x4(const uint32x4_t a) {
+static INLINE uint32_t horizontal_add_uint32x4(const uint32x4_t a) {
+#if defined(__aarch64__)
+  return vaddvq_u32(a);
+#else
   const uint64x2_t b = vpaddlq_u32(a);
-  return vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)),
-                  vreinterpret_u32_u64(vget_high_u64(b)));
+  const uint32x2_t c = vadd_u32(vreinterpret_u32_u64(vget_low_u64(b)),
+                                vreinterpret_u32_u64(vget_high_u64(b)));
+  return vget_lane_u32(c, 0);
+#endif
 }
 #endif  // VPX_VPX_DSP_ARM_SUM_NEON_H_
--- a/vpx_dsp/arm/variance_neon.c
+++ b/vpx_dsp/arm/variance_neon.c
@@ -66,10 +66,9 @@
     ref_ptr += 4 * ref_stride;
   }
 
-  *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0);
-  *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32(
-                           vaddq_s32(sse_lo_s32, sse_hi_s32))),
-                       0);
+  *sum = horizontal_add_int16x8(sum_s16);
+  *sse = horizontal_add_uint32x4(
+      vreinterpretq_u32_s32(vaddq_s32(sse_lo_s32, sse_hi_s32)));
 }
 
 // Process a block of any size where the width is divisible by 16.
@@ -115,10 +114,9 @@
     ref_ptr += ref_stride;
   }
 
-  *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0);
-  *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32(
-                           vaddq_s32(sse_lo_s32, sse_hi_s32))),
-                       0);
+  *sum = horizontal_add_int16x8(sum_s16);
+  *sse = horizontal_add_uint32x4(
+      vreinterpretq_u32_s32(vaddq_s32(sse_lo_s32, sse_hi_s32)));
 }
 
 // Process a block of width 8 two rows at a time.
@@ -157,10 +155,9 @@
     i += 2;
   } while (i < h);
 
-  *sum = vget_lane_s32(horizontal_add_int16x8(sum_s16), 0);
-  *sse = vget_lane_u32(horizontal_add_uint32x4(vreinterpretq_u32_s32(
-                           vaddq_s32(sse_lo_s32, sse_hi_s32))),
-                       0);
+  *sum = horizontal_add_int16x8(sum_s16);
+  *sse = horizontal_add_uint32x4(
+      vreinterpretq_u32_s32(vaddq_s32(sse_lo_s32, sse_hi_s32)));
 }
 
 void vpx_get8x8var_neon(const uint8_t *src_ptr, int src_stride,