From 3e1e38d1176c34f71a87f8402c07cdcc2e20083e Mon Sep 17 00:00:00 2001 From: Jonathan Wright Date: Thu, 4 May 2023 16:33:38 +0100 Subject: Add 2D-specific Neon horizontal convolution functions 2D 8-tap convolution filtering is performed in two passes - horizontal and vertical. The horizontal pass must produce enough input data for the subsequent vertical pass - 3 rows above and 4 rows below, in addition to the actual block height. At present, all Neon horizontal convolution algorithms process 4 rows at a time, but this means we end up doing at least 1 row too much work in the 2D first pass case where we need h + 7, not h + 8 rows of output. This patch adds additional dot-product (SDOT and USDOT) Neon paths that process h + 7 rows of data exactly, saving the work of the unnecessary extra row. It is impractical to take a similar approach for the Armv8.0 MLA paths since we have to transpose the data block both before and after calling the convolution helper functions. vpx_convolve_neon performance impact: we observe a speedup of ~9% for smaller (and wider) blocks, and a speedup of 0-3% for larger blocks. This is to be expected since the proportion of redundant work decreases as the block height increases. Change-Id: Ie77ad1848707d2d48bb8851345a469aae9d097e1 --- vpx_dsp/arm/mem_neon.h | 20 ++++ vpx_dsp/arm/vpx_convolve8_neon.c | 221 ++++++++++++++++++++++++++++++++++++++- vpx_dsp/arm/vpx_convolve8_neon.h | 9 ++ vpx_dsp/arm/vpx_convolve_neon.c | 55 ++++++++++ 4 files changed, 301 insertions(+), 4 deletions(-) diff --git a/vpx_dsp/arm/mem_neon.h b/vpx_dsp/arm/mem_neon.h index 1a20da70e..586bfb85a 100644 --- a/vpx_dsp/arm/mem_neon.h +++ b/vpx_dsp/arm/mem_neon.h @@ -263,6 +263,16 @@ static INLINE void store_u8(uint8_t *buf, ptrdiff_t stride, const uint8x8_t a) { vst1_lane_u32((uint32_t *)buf, a_u32, 1); } +static INLINE void store_u8_8x3(uint8_t *s, const ptrdiff_t p, + const uint8x8_t s0, const uint8x8_t s1, + const uint8x8_t s2) { + vst1_u8(s, s0); + s += p; + vst1_u8(s, s1); + s += p; + vst1_u8(s, s2); +} + static INLINE void load_u8_8x4(const uint8_t *s, const ptrdiff_t p, uint8x8_t *const s0, uint8x8_t *const s1, uint8x8_t *const s2, uint8x8_t *const s3) { @@ -287,6 +297,16 @@ static INLINE void store_u8_8x4(uint8_t *s, const ptrdiff_t p, vst1_u8(s, s3); } +static INLINE void load_u8_16x3(const uint8_t *s, const ptrdiff_t p, + uint8x16_t *const s0, uint8x16_t *const s1, + uint8x16_t *const s2) { + *s0 = vld1q_u8(s); + s += p; + *s1 = vld1q_u8(s); + s += p; + *s2 = vld1q_u8(s); +} + static INLINE void load_u8_16x4(const uint8_t *s, const ptrdiff_t p, uint8x16_t *const s0, uint8x16_t *const s1, uint8x16_t *const s2, uint8x16_t *const s3) { diff --git a/vpx_dsp/arm/vpx_convolve8_neon.c b/vpx_dsp/arm/vpx_convolve8_neon.c index f217a3f35..505d0672f 100644 --- a/vpx_dsp/arm/vpx_convolve8_neon.c +++ b/vpx_dsp/arm/vpx_convolve8_neon.c @@ -57,6 +57,111 @@ DECLARE_ALIGNED(16, static const uint8_t, dot_prod_merge_block_tbl[48]) = { #if defined(__ARM_FEATURE_MATMUL_INT8) +void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const InterpKernel *filter, int x0_q4, + int x_step_q4, int y0_q4, int y_step_q4, int w, + int h) { + const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4])); + uint8x16_t s0, s1, s2, s3; + + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + assert(x_step_q4 == 16); + assert(h % 4 == 3); + + (void)x_step_q4; + (void)y0_q4; + (void)y_step_q4; + + src -= 3; + + if (w == 4) { + const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl); + int16x4_t d0, d1, d2, d3; + uint8x8_t d01, d23; + + do { + load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); + + d0 = convolve8_4_usdot(s0, filters, perm_tbl); + d1 = convolve8_4_usdot(s1, filters, perm_tbl); + d2 = convolve8_4_usdot(s2, filters, perm_tbl); + d3 = convolve8_4_usdot(s3, filters, perm_tbl); + d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); + + store_u8(dst + 0 * dst_stride, dst_stride, d01); + store_u8(dst + 2 * dst_stride, dst_stride, d23); + + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h > 3); + + /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for + * further details on possible values of block height. */ + load_u8_16x3(src, src_stride, &s0, &s1, &s2); + + d0 = convolve8_4_usdot(s0, filters, perm_tbl); + d1 = convolve8_4_usdot(s1, filters, perm_tbl); + d2 = convolve8_4_usdot(s2, filters, perm_tbl); + d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(d2, vdup_n_s16(0)), FILTER_BITS); + + store_u8(dst + 0 * dst_stride, dst_stride, d01); + store_u8_4x1(dst + 2 * dst_stride, d23); + } else { + const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl); + const uint8_t *s; + uint8_t *d; + int width; + uint8x8_t d0, d1, d2, d3; + + do { + width = w; + s = src; + d = dst; + do { + load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); + + d0 = convolve8_8_usdot(s0, filters, perm_tbl); + d1 = convolve8_8_usdot(s1, filters, perm_tbl); + d2 = convolve8_8_usdot(s2, filters, perm_tbl); + d3 = convolve8_8_usdot(s3, filters, perm_tbl); + + store_u8_8x4(d, dst_stride, d0, d1, d2, d3); + + s += 8; + d += 8; + width -= 8; + } while (width > 0); + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h > 3); + + /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for + * further details on possible values of block height. */ + width = w; + s = src; + d = dst; + do { + load_u8_16x3(s, src_stride, &s0, &s1, &s2); + + d0 = convolve8_8_usdot(s0, filters, perm_tbl); + d1 = convolve8_8_usdot(s1, filters, perm_tbl); + d2 = convolve8_8_usdot(s2, filters, perm_tbl); + + store_u8_8x3(d, dst_stride, d0, d1, d2); + + s += 8; + d += 8; + width -= 8; + } while (width > 0); + } +} + void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const InterpKernel *filter, int x0_q4, @@ -96,7 +201,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, src += 4 * src_stride; dst += 4 * dst_stride; h -= 4; - } while (h > 0); + } while (h != 0); } else { const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl); const uint8_t *s; @@ -125,7 +230,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, src += 4 * src_stride; dst += 4 * dst_stride; h -= 4; - } while (h > 0); + } while (h != 0); } } @@ -611,6 +716,114 @@ void vpx_convolve8_avg_vert_neon(const uint8_t *src, ptrdiff_t src_stride, #else // !defined(__ARM_FEATURE_MATMUL_INT8) +void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const InterpKernel *filter, int x0_q4, + int x_step_q4, int y0_q4, int y_step_q4, int w, + int h) { + const int8x8_t filters = vmovn_s16(vld1q_s16(filter[x0_q4])); + const int16x8_t correct_tmp = vmulq_n_s16(vld1q_s16(filter[x0_q4]), 128); + const int32x4_t correction = vdupq_n_s32((int32_t)vaddvq_s16(correct_tmp)); + const uint8x16_t range_limit = vdupq_n_u8(128); + uint8x16_t s0, s1, s2, s3; + + assert((intptr_t)dst % 4 == 0); + assert(dst_stride % 4 == 0); + assert(x_step_q4 == 16); + assert(h % 4 == 3); + + (void)x_step_q4; + (void)y0_q4; + (void)y_step_q4; + + src -= 3; + + if (w == 4) { + const uint8x16x2_t perm_tbl = vld1q_u8_x2(dot_prod_permute_tbl); + int16x4_t d0, d1, d2, d3; + uint8x8_t d01, d23; + + do { + load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3); + + d0 = convolve8_4_sdot(s0, filters, correction, range_limit, perm_tbl); + d1 = convolve8_4_sdot(s1, filters, correction, range_limit, perm_tbl); + d2 = convolve8_4_sdot(s2, filters, correction, range_limit, perm_tbl); + d3 = convolve8_4_sdot(s3, filters, correction, range_limit, perm_tbl); + d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS); + + store_u8(dst + 0 * dst_stride, dst_stride, d01); + store_u8(dst + 2 * dst_stride, dst_stride, d23); + + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h > 3); + + /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for + * further details on possible values of block height. */ + load_u8_16x3(src, src_stride, &s0, &s1, &s2); + + d0 = convolve8_4_sdot(s0, filters, correction, range_limit, perm_tbl); + d1 = convolve8_4_sdot(s1, filters, correction, range_limit, perm_tbl); + d2 = convolve8_4_sdot(s2, filters, correction, range_limit, perm_tbl); + d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS); + d23 = vqrshrun_n_s16(vcombine_s16(d2, vdup_n_s16(0)), FILTER_BITS); + + store_u8(dst + 0 * dst_stride, dst_stride, d01); + store_u8_4x1(dst + 2 * dst_stride, d23); + } else { + const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl); + const uint8_t *s; + uint8_t *d; + int width; + uint8x8_t d0, d1, d2, d3; + + do { + width = w; + s = src; + d = dst; + do { + load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3); + + d0 = convolve8_8_sdot(s0, filters, correction, range_limit, perm_tbl); + d1 = convolve8_8_sdot(s1, filters, correction, range_limit, perm_tbl); + d2 = convolve8_8_sdot(s2, filters, correction, range_limit, perm_tbl); + d3 = convolve8_8_sdot(s3, filters, correction, range_limit, perm_tbl); + + store_u8_8x4(d, dst_stride, d0, d1, d2, d3); + + s += 8; + d += 8; + width -= 8; + } while (width != 0); + src += 4 * src_stride; + dst += 4 * dst_stride; + h -= 4; + } while (h > 3); + + /* Process final three rows (h % 4 == 3). See vpx_convolve_neon.c for + * further details on possible values of block height. */ + width = w; + s = src; + d = dst; + do { + load_u8_16x3(s, src_stride, &s0, &s1, &s2); + + d0 = convolve8_8_sdot(s0, filters, correction, range_limit, perm_tbl); + d1 = convolve8_8_sdot(s1, filters, correction, range_limit, perm_tbl); + d2 = convolve8_8_sdot(s2, filters, correction, range_limit, perm_tbl); + + store_u8_8x3(d, dst_stride, d0, d1, d2); + + s += 8; + d += 8; + width -= 8; + } while (width != 0); + } +} + void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const InterpKernel *filter, int x0_q4, @@ -653,7 +866,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, src += 4 * src_stride; dst += 4 * dst_stride; h -= 4; - } while (h > 0); + } while (h != 0); } else { const uint8x16x3_t perm_tbl = vld1q_u8_x3(dot_prod_permute_tbl); const uint8_t *s; @@ -682,7 +895,7 @@ void vpx_convolve8_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, src += 4 * src_stride; dst += 4 * dst_stride; h -= 4; - } while (h > 0); + } while (h != 0); } } diff --git a/vpx_dsp/arm/vpx_convolve8_neon.h b/vpx_dsp/arm/vpx_convolve8_neon.h index c838d4047..2f78583af 100644 --- a/vpx_dsp/arm/vpx_convolve8_neon.h +++ b/vpx_dsp/arm/vpx_convolve8_neon.h @@ -17,6 +17,15 @@ #include "./vpx_dsp_rtcd.h" #include "vpx_dsp/vpx_filter.h" +#if VPX_ARCH_AARCH64 && \ + (defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)) +void vpx_convolve8_2d_horiz_neon(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const InterpKernel *filter, int x0_q4, + int x_step_q4, int y0_q4, int y_step_q4, int w, + int h); +#endif + #if VPX_ARCH_AARCH64 && defined(__ARM_FEATURE_DOTPROD) static INLINE int16x4_t convolve8_4_sdot_partial(const int8x16_t samples_lo, diff --git a/vpx_dsp/arm/vpx_convolve_neon.c b/vpx_dsp/arm/vpx_convolve_neon.c index 830f3176d..f7db3e6a9 100644 --- a/vpx_dsp/arm/vpx_convolve_neon.c +++ b/vpx_dsp/arm/vpx_convolve_neon.c @@ -14,6 +14,57 @@ #include "vpx_dsp/vpx_dsp_common.h" #include "vpx_ports/mem.h" +#if VPX_ARCH_AARCH64 && \ + (defined(__ARM_FEATURE_DOTPROD) || defined(__ARM_FEATURE_MATMUL_INT8)) +#include "vpx_dsp/arm/vpx_convolve8_neon.h" + +void vpx_convolve8_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, + ptrdiff_t dst_stride, const InterpKernel *filter, + int x0_q4, int x_step_q4, int y0_q4, int y_step_q4, + int w, int h) { + /* Given our constraints: w <= 64, h <= 64, taps == 8 we can reduce the + * maximum buffer size to 64 * (64 + 7). */ + uint8_t temp[64 * 71]; + + /* Account for the vertical phase needing 3 lines prior and 4 lines post. */ + const int intermediate_height = h + 7; + + assert(y_step_q4 == 16); + assert(x_step_q4 == 16); + + /* Filter starting 3 lines back. */ + vpx_convolve8_2d_horiz_neon(src - src_stride * 3, src_stride, temp, w, filter, + x0_q4, x_step_q4, y0_q4, y_step_q4, w, + intermediate_height); + + /* Step into the temp buffer 3 lines to get the actual frame data */ + vpx_convolve8_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4, + x_step_q4, y0_q4, y_step_q4, w, h); +} + +void vpx_convolve8_avg_neon(const uint8_t *src, ptrdiff_t src_stride, + uint8_t *dst, ptrdiff_t dst_stride, + const InterpKernel *filter, int x0_q4, + int x_step_q4, int y0_q4, int y_step_q4, int w, + int h) { + uint8_t temp[64 * 71]; + const int intermediate_height = h + 7; + + assert(y_step_q4 == 16); + assert(x_step_q4 == 16); + + vpx_convolve8_2d_horiz_neon(src - src_stride * 3, src_stride, temp, w, filter, + x0_q4, x_step_q4, y0_q4, y_step_q4, w, + intermediate_height); + + vpx_convolve8_avg_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4, + x_step_q4, y0_q4, y_step_q4, w, h); +} + +#else // !(VPX_ARCH_AARCH64 && + // (defined(__ARM_FEATURE_DOTPROD) || + // defined(__ARM_FEATURE_MATMUL_INT8))) + void vpx_convolve8_neon(const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst, ptrdiff_t dst_stride, const InterpKernel *filter, int x0_q4, int x_step_q4, int y0_q4, int y_step_q4, @@ -63,3 +114,7 @@ void vpx_convolve8_avg_neon(const uint8_t *src, ptrdiff_t src_stride, vpx_convolve8_avg_vert_neon(temp + w * 3, w, dst, dst_stride, filter, x0_q4, x_step_q4, y0_q4, y_step_q4, w, h); } + +#endif // #if VPX_ARCH_AARCH64 && + // (defined(__ARM_FEATURE_DOTPROD) || + // defined(__ARM_FEATURE_MATMUL_INT8)) -- cgit v1.2.3