ref: c8b0432505d32820af0c42a94b219aa83eed5db9
parent: 2db85c269bc5479e48ea7cd4fde85236ee0bc347
author: Jonathan Wright <jonathan.wright@arm.com>
date: Tue May 11 09:17:44 EDT 2021
Implement Neon variance functions using UDOT instruction Accelerate Neon variance functions by implementing the sum of squares calculation using the Armv8.4-A UDOT instruction instead of 4 MLAs. The previous implementation is retained for use on CPUs that do not implement the Armv8.4-A dot product instructions. Bug: b/181236880 Change-Id: I9ab3d52634278b9b6f0011f39390a1195210bc75
--- a/vpx_dsp/arm/sum_neon.h
+++ b/vpx_dsp/arm/sum_neon.h
@@ -40,6 +40,33 @@
#endif
}
+static INLINE int32_t horizontal_add_int32x2(const int32x2_t a) {
+#if defined(__aarch64__)
+ return vaddv_s32(a);
+#else
+ return vget_lane_s32(a, 0) + vget_lane_s32(a, 1);
+#endif
+}
+
+static INLINE uint32_t horizontal_add_uint32x2(const uint32x2_t a) {
+#if defined(__aarch64__)
+ return vaddv_u32(a);
+#else
+ return vget_lane_u32(a, 0) + vget_lane_u32(a, 1);
+#endif
+}
+
+static INLINE int32_t horizontal_add_int32x4(const int32x4_t a) {
+#if defined(__aarch64__)
+ return vaddvq_s32(a);
+#else
+ const int64x2_t b = vpaddlq_s32(a);
+ const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
+ vreinterpret_s32_s64(vget_high_s64(b)));
+ return vget_lane_s32(c, 0);
+#endif
+}
+
static INLINE uint32_t horizontal_add_uint32x4(const uint32x4_t a) {
#if defined(__aarch64__)
return vaddvq_u32(a);
--- a/vpx_dsp/arm/variance_neon.c
+++ b/vpx_dsp/arm/variance_neon.c
@@ -19,6 +19,100 @@
#include "vpx_dsp/arm/sum_neon.h"
#include "vpx_ports/mem.h"
+#if defined(__ARM_FEATURE_DOTPROD) && (__ARM_FEATURE_DOTPROD == 1)
+
+// Process a block of width 4 four rows at a time.
+static void variance_neon_w4x4(const uint8_t *src_ptr, int src_stride,
+ const uint8_t *ref_ptr, int ref_stride, int h,
+ uint32_t *sse, int *sum) {
+ int i;
+ uint32x4_t sum_a = vdupq_n_u32(0);
+ uint32x4_t sum_b = vdupq_n_u32(0);
+ uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+ for (i = 0; i < h; i += 4) {
+ const uint8x16_t a = load_unaligned_u8q(src_ptr, src_stride);
+ const uint8x16_t b = load_unaligned_u8q(ref_ptr, ref_stride);
+
+ const uint8x16_t abs_diff = vabdq_u8(a, b);
+ sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
+
+ sum_a = vdotq_u32(sum_a, a, vdupq_n_u8(1));
+ sum_b = vdotq_u32(sum_b, b, vdupq_n_u8(1));
+
+ src_ptr += 4 * src_stride;
+ ref_ptr += 4 * ref_stride;
+ }
+
+ *sum = horizontal_add_int32x4(vreinterpretq_s32_u32(vsubq_u32(sum_a, sum_b)));
+ *sse = horizontal_add_uint32x4(sse_u32);
+}
+
+// Process a block of any size where the width is divisible by 16.
+static void variance_neon_w16(const uint8_t *src_ptr, int src_stride,
+ const uint8_t *ref_ptr, int ref_stride, int w,
+ int h, uint32_t *sse, int *sum) {
+ int i, j;
+ uint32x4_t sum_a = vdupq_n_u32(0);
+ uint32x4_t sum_b = vdupq_n_u32(0);
+ uint32x4_t sse_u32 = vdupq_n_u32(0);
+
+ for (i = 0; i < h; ++i) {
+ for (j = 0; j < w; j += 16) {
+ const uint8x16_t a = vld1q_u8(src_ptr + j);
+ const uint8x16_t b = vld1q_u8(ref_ptr + j);
+
+ const uint8x16_t abs_diff = vabdq_u8(a, b);
+ sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
+
+ sum_a = vdotq_u32(sum_a, a, vdupq_n_u8(1));
+ sum_b = vdotq_u32(sum_b, b, vdupq_n_u8(1));
+ }
+ src_ptr += src_stride;
+ ref_ptr += ref_stride;
+ }
+
+ *sum = horizontal_add_int32x4(vreinterpretq_s32_u32(vsubq_u32(sum_a, sum_b)));
+ *sse = horizontal_add_uint32x4(sse_u32);
+}
+
+// Process a block of width 8 two rows at a time.
+static void variance_neon_w8x2(const uint8_t *src_ptr, int src_stride,
+ const uint8_t *ref_ptr, int ref_stride, int h,
+ uint32_t *sse, int *sum) {
+ int i = 0;
+ uint32x2_t sum_a = vdup_n_u32(0);
+ uint32x2_t sum_b = vdup_n_u32(0);
+ uint32x2_t sse_lo_u32 = vdup_n_u32(0);
+ uint32x2_t sse_hi_u32 = vdup_n_u32(0);
+
+ do {
+ const uint8x8_t a_0 = vld1_u8(src_ptr);
+ const uint8x8_t a_1 = vld1_u8(src_ptr + src_stride);
+ const uint8x8_t b_0 = vld1_u8(ref_ptr);
+ const uint8x8_t b_1 = vld1_u8(ref_ptr + ref_stride);
+
+ const uint8x8_t abs_diff_0 = vabd_u8(a_0, b_0);
+ const uint8x8_t abs_diff_1 = vabd_u8(a_1, b_1);
+ sse_lo_u32 = vdot_u32(sse_lo_u32, abs_diff_0, abs_diff_0);
+ sse_hi_u32 = vdot_u32(sse_hi_u32, abs_diff_1, abs_diff_1);
+
+ sum_a = vdot_u32(sum_a, a_0, vdup_n_u8(1));
+ sum_b = vdot_u32(sum_b, b_0, vdup_n_u8(1));
+ sum_a = vdot_u32(sum_a, a_1, vdup_n_u8(1));
+ sum_b = vdot_u32(sum_b, b_1, vdup_n_u8(1));
+
+ src_ptr += src_stride + src_stride;
+ ref_ptr += ref_stride + ref_stride;
+ i += 2;
+ } while (i < h);
+
+ *sum = horizontal_add_int32x2(vreinterpret_s32_u32(vsub_u32(sum_a, sum_b)));
+ *sse = horizontal_add_uint32x2(vadd_u32(sse_lo_u32, sse_hi_u32));
+}
+
+#else
+
// The variance helper functions use int16_t for sum. 8 values are accumulated
// and then added (at which point they expand up to int32_t). To avoid overflow,
// there can be no more than 32767 / 255 ~= 128 values accumulated in each
@@ -159,6 +253,8 @@
*sse = horizontal_add_uint32x4(
vreinterpretq_u32_s32(vaddq_s32(sse_lo_s32, sse_hi_s32)));
}
+
+#endif
void vpx_get8x8var_neon(const uint8_t *src_ptr, int src_stride,
const uint8_t *ref_ptr, int ref_stride,