summaryrefslogtreecommitdiff
path: root/vpx_dsp
diff options
context:
space:
mode:
authorKonstantinos Margaritis <konma@vectorcamp.gr>2022-10-06 13:05:01 +0000
committerKonstantinos Margaritis <konma@vectorcamp.gr>2022-10-12 08:11:53 +0000
commitf538a022441bdb760c3b8ad835e209a71e31d8b9 (patch)
treea8f928c25712e5b77b4ff026be059eb909ecda3f /vpx_dsp
parent5c9d20cf4492da573a4aa797c382cdfd6ce175bd (diff)
downloadlibvpx-f538a022441bdb760c3b8ad835e209a71e31d8b9.tar
libvpx-f538a022441bdb760c3b8ad835e209a71e31d8b9.tar.gz
libvpx-f538a022441bdb760c3b8ad835e209a71e31d8b9.tar.bz2
libvpx-f538a022441bdb760c3b8ad835e209a71e31d8b9.zip
[NEON] Move helper functions for reuse
Move all butterfly functions to fdct_neon.h Slightly optimize load/scale/cross functions in fdct 16x16. These will be reused in highbd variants. Change-Id: I28b6e0cc240304bab6b94d9c3f33cca77b8cb073
Diffstat (limited to 'vpx_dsp')
-rw-r--r--vpx_dsp/arm/fdct16x16_neon.c12
-rw-r--r--vpx_dsp/arm/fdct16x16_neon.h135
-rw-r--r--vpx_dsp/arm/fdct32x32_neon.c231
-rw-r--r--vpx_dsp/arm/fdct4x4_neon.c (renamed from vpx_dsp/arm/fdct_neon.c)0
-rw-r--r--vpx_dsp/arm/fdct8x8_neon.c (renamed from vpx_dsp/arm/fwd_txfm_neon.c)0
-rw-r--r--vpx_dsp/arm/fdct_neon.h130
-rw-r--r--vpx_dsp/vpx_dsp.mk4
7 files changed, 231 insertions, 281 deletions
diff --git a/vpx_dsp/arm/fdct16x16_neon.c b/vpx_dsp/arm/fdct16x16_neon.c
index 5cccb6a64..0b0ce223d 100644
--- a/vpx_dsp/arm/fdct16x16_neon.c
+++ b/vpx_dsp/arm/fdct16x16_neon.c
@@ -35,13 +35,13 @@ void vpx_fdct16x16_neon(const int16_t *input, tran_low_t *output, int stride) {
int16x8_t temp3[16];
// Left half.
- load(input, stride, temp0);
- cross_input(temp0, temp1, 0);
+ load_cross(input, stride, temp0);
+ scale_input(temp0, temp1);
vpx_fdct16x16_body(temp1, temp0);
// Right half.
- load(input + 8, stride, temp1);
- cross_input(temp1, temp2, 0);
+ load_cross(input + 8, stride, temp1);
+ scale_input(temp1, temp2);
vpx_fdct16x16_body(temp2, temp1);
// Transpose top left and top right quarters into one contiguous location to
@@ -49,7 +49,7 @@ void vpx_fdct16x16_neon(const int16_t *input, tran_low_t *output, int stride) {
transpose_s16_8x8_new(&temp0[0], &temp2[0]);
transpose_s16_8x8_new(&temp1[0], &temp2[8]);
partial_round_shift(temp2);
- cross_input(temp2, temp3, 1);
+ cross_input(temp2, temp3);
vpx_fdct16x16_body(temp3, temp2);
transpose_s16_8x8(&temp2[0], &temp2[1], &temp2[2], &temp2[3], &temp2[4],
&temp2[5], &temp2[6], &temp2[7]);
@@ -65,7 +65,7 @@ void vpx_fdct16x16_neon(const int16_t *input, tran_low_t *output, int stride) {
transpose_s16_8x8(&temp1[8], &temp1[9], &temp1[10], &temp1[11], &temp1[12],
&temp1[13], &temp1[14], &temp1[15]);
partial_round_shift(temp1);
- cross_input(temp1, temp0, 1);
+ cross_input(temp1, temp0);
vpx_fdct16x16_body(temp0, temp1);
transpose_s16_8x8(&temp1[0], &temp1[1], &temp1[2], &temp1[3], &temp1[4],
&temp1[5], &temp1[6], &temp1[7]);
diff --git a/vpx_dsp/arm/fdct16x16_neon.h b/vpx_dsp/arm/fdct16x16_neon.h
index 5ce74cdf4..7fc2c6e7e 100644
--- a/vpx_dsp/arm/fdct16x16_neon.h
+++ b/vpx_dsp/arm/fdct16x16_neon.h
@@ -13,6 +13,8 @@
#include <arm_neon.h>
+#include "fdct_neon.h"
+
static INLINE void load(const int16_t *a, int stride, int16x8_t *b /*[16]*/) {
b[0] = vld1q_s16(a);
a += stride;
@@ -72,45 +74,67 @@ static INLINE void store(tran_low_t *a, const int16x8_t *b /*[8]*/) {
// To maybe reduce register usage this could be combined with the load() step to
// get the first 4 and last 4 values, cross those, then load the middle 8 values
// and cross them.
+static INLINE void scale_input(const int16x8_t *a /*[16]*/,
+ int16x8_t *b /*[16]*/) {
+ b[0] = vshlq_n_s16(a[0], 2);
+ b[1] = vshlq_n_s16(a[1], 2);
+ b[2] = vshlq_n_s16(a[2], 2);
+ b[3] = vshlq_n_s16(a[3], 2);
+ b[4] = vshlq_n_s16(a[4], 2);
+ b[5] = vshlq_n_s16(a[5], 2);
+ b[6] = vshlq_n_s16(a[6], 2);
+ b[7] = vshlq_n_s16(a[7], 2);
+
+ b[8] = vshlq_n_s16(a[8], 2);
+ b[9] = vshlq_n_s16(a[9], 2);
+ b[10] = vshlq_n_s16(a[10], 2);
+ b[11] = vshlq_n_s16(a[11], 2);
+ b[12] = vshlq_n_s16(a[12], 2);
+ b[13] = vshlq_n_s16(a[13], 2);
+ b[14] = vshlq_n_s16(a[14], 2);
+ b[15] = vshlq_n_s16(a[15], 2);
+}
+
static INLINE void cross_input(const int16x8_t *a /*[16]*/,
- int16x8_t *b /*[16]*/, const int pass) {
- if (pass == 0) {
- b[0] = vshlq_n_s16(vaddq_s16(a[0], a[15]), 2);
- b[1] = vshlq_n_s16(vaddq_s16(a[1], a[14]), 2);
- b[2] = vshlq_n_s16(vaddq_s16(a[2], a[13]), 2);
- b[3] = vshlq_n_s16(vaddq_s16(a[3], a[12]), 2);
- b[4] = vshlq_n_s16(vaddq_s16(a[4], a[11]), 2);
- b[5] = vshlq_n_s16(vaddq_s16(a[5], a[10]), 2);
- b[6] = vshlq_n_s16(vaddq_s16(a[6], a[9]), 2);
- b[7] = vshlq_n_s16(vaddq_s16(a[7], a[8]), 2);
+ int16x8_t *b /*[16]*/) {
+ b[0] = vaddq_s16(a[0], a[15]);
+ b[1] = vaddq_s16(a[1], a[14]);
+ b[2] = vaddq_s16(a[2], a[13]);
+ b[3] = vaddq_s16(a[3], a[12]);
+ b[4] = vaddq_s16(a[4], a[11]);
+ b[5] = vaddq_s16(a[5], a[10]);
+ b[6] = vaddq_s16(a[6], a[9]);
+ b[7] = vaddq_s16(a[7], a[8]);
+
+ b[8] = vsubq_s16(a[7], a[8]);
+ b[9] = vsubq_s16(a[6], a[9]);
+ b[10] = vsubq_s16(a[5], a[10]);
+ b[11] = vsubq_s16(a[4], a[11]);
+ b[12] = vsubq_s16(a[3], a[12]);
+ b[13] = vsubq_s16(a[2], a[13]);
+ b[14] = vsubq_s16(a[1], a[14]);
+ b[15] = vsubq_s16(a[0], a[15]);
+}
- b[8] = vshlq_n_s16(vsubq_s16(a[7], a[8]), 2);
- b[9] = vshlq_n_s16(vsubq_s16(a[6], a[9]), 2);
- b[10] = vshlq_n_s16(vsubq_s16(a[5], a[10]), 2);
- b[11] = vshlq_n_s16(vsubq_s16(a[4], a[11]), 2);
- b[12] = vshlq_n_s16(vsubq_s16(a[3], a[12]), 2);
- b[13] = vshlq_n_s16(vsubq_s16(a[2], a[13]), 2);
- b[14] = vshlq_n_s16(vsubq_s16(a[1], a[14]), 2);
- b[15] = vshlq_n_s16(vsubq_s16(a[0], a[15]), 2);
- } else {
- b[0] = vaddq_s16(a[0], a[15]);
- b[1] = vaddq_s16(a[1], a[14]);
- b[2] = vaddq_s16(a[2], a[13]);
- b[3] = vaddq_s16(a[3], a[12]);
- b[4] = vaddq_s16(a[4], a[11]);
- b[5] = vaddq_s16(a[5], a[10]);
- b[6] = vaddq_s16(a[6], a[9]);
- b[7] = vaddq_s16(a[7], a[8]);
+static INLINE void load_cross(const int16_t *a, int stride,
+ int16x8_t *b /*[16]*/) {
+ b[0] = vaddq_s16(vld1q_s16(a + 0 * stride), vld1q_s16(a + 15 * stride));
+ b[1] = vaddq_s16(vld1q_s16(a + 1 * stride), vld1q_s16(a + 14 * stride));
+ b[2] = vaddq_s16(vld1q_s16(a + 2 * stride), vld1q_s16(a + 13 * stride));
+ b[3] = vaddq_s16(vld1q_s16(a + 3 * stride), vld1q_s16(a + 12 * stride));
+ b[4] = vaddq_s16(vld1q_s16(a + 4 * stride), vld1q_s16(a + 11 * stride));
+ b[5] = vaddq_s16(vld1q_s16(a + 5 * stride), vld1q_s16(a + 10 * stride));
+ b[6] = vaddq_s16(vld1q_s16(a + 6 * stride), vld1q_s16(a + 9 * stride));
+ b[7] = vaddq_s16(vld1q_s16(a + 7 * stride), vld1q_s16(a + 8 * stride));
- b[8] = vsubq_s16(a[7], a[8]);
- b[9] = vsubq_s16(a[6], a[9]);
- b[10] = vsubq_s16(a[5], a[10]);
- b[11] = vsubq_s16(a[4], a[11]);
- b[12] = vsubq_s16(a[3], a[12]);
- b[13] = vsubq_s16(a[2], a[13]);
- b[14] = vsubq_s16(a[1], a[14]);
- b[15] = vsubq_s16(a[0], a[15]);
- }
+ b[8] = vsubq_s16(vld1q_s16(a + 7 * stride), vld1q_s16(a + 8 * stride));
+ b[9] = vsubq_s16(vld1q_s16(a + 6 * stride), vld1q_s16(a + 9 * stride));
+ b[10] = vsubq_s16(vld1q_s16(a + 5 * stride), vld1q_s16(a + 10 * stride));
+ b[11] = vsubq_s16(vld1q_s16(a + 4 * stride), vld1q_s16(a + 11 * stride));
+ b[12] = vsubq_s16(vld1q_s16(a + 3 * stride), vld1q_s16(a + 12 * stride));
+ b[13] = vsubq_s16(vld1q_s16(a + 2 * stride), vld1q_s16(a + 13 * stride));
+ b[14] = vsubq_s16(vld1q_s16(a + 1 * stride), vld1q_s16(a + 14 * stride));
+ b[15] = vsubq_s16(vld1q_s16(a + 0 * stride), vld1q_s16(a + 15 * stride));
}
// Quarter round at the beginning of the second pass. Can't use vrshr (rounding)
@@ -135,45 +159,6 @@ static INLINE void partial_round_shift(int16x8_t *a /*[16]*/) {
a[15] = vshrq_n_s16(vaddq_s16(a[15], one), 2);
}
-// fdct_round_shift((a +/- b) * c)
-static INLINE void butterfly_one_coeff(const int16x8_t a, const int16x8_t b,
- const tran_high_t c, int16x8_t *add,
- int16x8_t *sub) {
- const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), c);
- const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), c);
- const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), c);
- const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), c);
- const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), c);
- const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), c);
- const int16x4_t rounded0 = vqrshrn_n_s32(sum0, 14);
- const int16x4_t rounded1 = vqrshrn_n_s32(sum1, 14);
- const int16x4_t rounded2 = vqrshrn_n_s32(diff0, 14);
- const int16x4_t rounded3 = vqrshrn_n_s32(diff1, 14);
- *add = vcombine_s16(rounded0, rounded1);
- *sub = vcombine_s16(rounded2, rounded3);
-}
-
-// fdct_round_shift(a * c0 +/- b * c1)
-static INLINE void butterfly_two_coeff(const int16x8_t a, const int16x8_t b,
- const tran_coef_t c0,
- const tran_coef_t c1, int16x8_t *add,
- int16x8_t *sub) {
- const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), c0);
- const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), c0);
- const int32x4_t a2 = vmull_n_s16(vget_low_s16(a), c1);
- const int32x4_t a3 = vmull_n_s16(vget_high_s16(a), c1);
- const int32x4_t sum0 = vmlal_n_s16(a2, vget_low_s16(b), c0);
- const int32x4_t sum1 = vmlal_n_s16(a3, vget_high_s16(b), c0);
- const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), c1);
- const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), c1);
- const int16x4_t rounded0 = vqrshrn_n_s32(sum0, 14);
- const int16x4_t rounded1 = vqrshrn_n_s32(sum1, 14);
- const int16x4_t rounded2 = vqrshrn_n_s32(diff0, 14);
- const int16x4_t rounded3 = vqrshrn_n_s32(diff1, 14);
- *add = vcombine_s16(rounded0, rounded1);
- *sub = vcombine_s16(rounded2, rounded3);
-}
-
// Main body of fdct16x16.
static void vpx_fdct16x16_body(const int16x8_t *in /*[16]*/,
int16x8_t *out /*[16]*/) {
diff --git a/vpx_dsp/arm/fdct32x32_neon.c b/vpx_dsp/arm/fdct32x32_neon.c
index de74e6630..51d81bd08 100644
--- a/vpx_dsp/arm/fdct32x32_neon.c
+++ b/vpx_dsp/arm/fdct32x32_neon.c
@@ -15,6 +15,7 @@
#include "vpx_dsp/txfm_common.h"
#include "vpx_dsp/arm/mem_neon.h"
#include "vpx_dsp/arm/transpose_neon.h"
+#include "vpx_dsp/arm/fdct_neon.h"
// Most gcc 4.9 distributions outside of Android do not generate correct code
// for this function.
@@ -194,54 +195,6 @@ static INLINE void store(tran_low_t *a, const int16x8_t *b) {
#undef STORE_S16
-// fdct_round_shift((a +/- b) * c)
-static INLINE void butterfly_one_coeff(const int16x8_t a, const int16x8_t b,
- const tran_high_t constant,
- int16x8_t *add, int16x8_t *sub) {
- const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
- const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
- const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
- const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
- const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
- const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
- const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
- const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
- const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
- const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
- *add = vcombine_s16(rounded0, rounded1);
- *sub = vcombine_s16(rounded2, rounded3);
-}
-
-// fdct_round_shift(a * c0 +/- b * c1)
-static INLINE void butterfly_two_coeff(const int16x8_t a, const int16x8_t b,
- const tran_coef_t constant0,
- const tran_coef_t constant1,
- int16x8_t *add, int16x8_t *sub) {
- const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant0);
- const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant0);
- const int32x4_t a2 = vmull_n_s16(vget_low_s16(a), constant1);
- const int32x4_t a3 = vmull_n_s16(vget_high_s16(a), constant1);
- const int32x4_t sum0 = vmlal_n_s16(a2, vget_low_s16(b), constant0);
- const int32x4_t sum1 = vmlal_n_s16(a3, vget_high_s16(b), constant0);
- const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant1);
- const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant1);
- const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
- const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
- const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
- const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
- *add = vcombine_s16(rounded0, rounded1);
- *sub = vcombine_s16(rounded2, rounded3);
-}
-
-// Add 2 if positive, 1 if negative, and shift by 2.
-// In practice, subtract the sign bit, then shift with rounding.
-static INLINE int16x8_t sub_round_shift(const int16x8_t a) {
- const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
- const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
- const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
- return vrshrq_n_s16(vsubq_s16(a, a_sign_s16), 2);
-}
-
static void dct_body_first_pass(const int16x8_t *in, int16x8_t *out) {
int16x8_t a[32];
int16x8_t b[32];
@@ -562,23 +515,6 @@ static void dct_body_first_pass(const int16x8_t *in, int16x8_t *out) {
b##_hi[b_index] = vsubq_s32(a##_hi[left_index], a##_hi[right_index]); \
} while (0)
-// Like butterfly_one_coeff, but don't narrow results.
-static INLINE void butterfly_one_coeff_s16_s32(
- const int16x8_t a, const int16x8_t b, const tran_high_t constant,
- int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
- int32x4_t *sub_hi) {
- const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
- const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
- const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
- const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
- const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
- const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
- *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
- *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
- *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
- *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
-}
-
#define BUTTERFLY_ONE_S16_S32(a, left_index, right_index, constant, b, \
add_index, sub_index) \
do { \
@@ -587,23 +523,6 @@ static INLINE void butterfly_one_coeff_s16_s32(
&b##_lo[sub_index], &b##_hi[sub_index]); \
} while (0)
-// Like butterfly_one_coeff, but with s32.
-static INLINE void butterfly_one_coeff_s32(
- const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
- const int32x4_t b_hi, const int32_t constant, int32x4_t *add_lo,
- int32x4_t *add_hi, int32x4_t *sub_lo, int32x4_t *sub_hi) {
- const int32x4_t a_lo_0 = vmulq_n_s32(a_lo, constant);
- const int32x4_t a_hi_0 = vmulq_n_s32(a_hi, constant);
- const int32x4_t sum0 = vmlaq_n_s32(a_lo_0, b_lo, constant);
- const int32x4_t sum1 = vmlaq_n_s32(a_hi_0, b_hi, constant);
- const int32x4_t diff0 = vmlsq_n_s32(a_lo_0, b_lo, constant);
- const int32x4_t diff1 = vmlsq_n_s32(a_hi_0, b_hi, constant);
- *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
- *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
- *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
- *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
-}
-
#define BUTTERFLY_ONE_S32(a, left_index, right_index, constant, b, add_index, \
sub_index) \
do { \
@@ -613,26 +532,6 @@ static INLINE void butterfly_one_coeff_s32(
&b##_lo[sub_index], &b##_hi[sub_index]); \
} while (0)
-// Like butterfly_two_coeff, but with s32.
-static INLINE void butterfly_two_coeff_s32(
- const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
- const int32x4_t b_hi, const int32_t constant0, const int32_t constant1,
- int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
- int32x4_t *sub_hi) {
- const int32x4_t a0 = vmulq_n_s32(a_lo, constant0);
- const int32x4_t a1 = vmulq_n_s32(a_hi, constant0);
- const int32x4_t a2 = vmulq_n_s32(a_lo, constant1);
- const int32x4_t a3 = vmulq_n_s32(a_hi, constant1);
- const int32x4_t sum0 = vmlaq_n_s32(a2, b_lo, constant0);
- const int32x4_t sum1 = vmlaq_n_s32(a3, b_hi, constant0);
- const int32x4_t diff0 = vmlsq_n_s32(a0, b_lo, constant1);
- const int32x4_t diff1 = vmlsq_n_s32(a1, b_hi, constant1);
- *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
- *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
- *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
- *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
-}
-
#define BUTTERFLY_TWO_S32(a, left_index, right_index, left_constant, \
right_constant, b, add_index, sub_index) \
do { \
@@ -643,24 +542,6 @@ static INLINE void butterfly_two_coeff_s32(
&b##_hi[sub_index]); \
} while (0)
-// Add 1 if positive, 2 if negative, and shift by 2.
-// In practice, add 1, then add the sign bit, then shift without rounding.
-static INLINE int16x8_t add_round_shift_s32(const int32x4_t a_lo,
- const int32x4_t a_hi) {
- const int32x4_t one = vdupq_n_s32(1);
- const uint32x4_t a_lo_u32 = vreinterpretq_u32_s32(a_lo);
- const uint32x4_t a_lo_sign_u32 = vshrq_n_u32(a_lo_u32, 31);
- const int32x4_t a_lo_sign_s32 = vreinterpretq_s32_u32(a_lo_sign_u32);
- const int16x4_t b_lo =
- vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_lo, a_lo_sign_s32), one), 2);
- const uint32x4_t a_hi_u32 = vreinterpretq_u32_s32(a_hi);
- const uint32x4_t a_hi_sign_u32 = vshrq_n_u32(a_hi_u32, 31);
- const int32x4_t a_hi_sign_s32 = vreinterpretq_s32_u32(a_hi_sign_u32);
- const int16x4_t b_hi =
- vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_hi, a_hi_sign_s32), one), 2);
- return vcombine_s16(b_lo, b_hi);
-}
-
static void dct_body_second_pass(const int16x8_t *in, int16x8_t *out) {
int16x8_t a[32];
int16x8_t b[32];
@@ -967,16 +848,6 @@ static void dct_body_second_pass(const int16x8_t *in, int16x8_t *out) {
out[3] = add_round_shift_s32(d_lo[3], d_hi[3]);
}
-// Add 1 if positive, 2 if negative, and shift by 2.
-// In practice, add 1, then add the sign bit, then shift without rounding.
-static INLINE int16x8_t add_round_shift_s16(const int16x8_t a) {
- const int16x8_t one = vdupq_n_s16(1);
- const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
- const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
- const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
- return vshrq_n_s16(vaddq_s16(vaddq_s16(a, a_sign_s16), one), 2);
-}
-
static void dct_body_second_pass_rd(const int16x8_t *in, int16x8_t *out) {
int16x8_t a[32];
int16x8_t b[32];
@@ -1279,42 +1150,6 @@ static void dct_body_second_pass_rd(const int16x8_t *in, int16x8_t *out) {
#undef BUTTERFLY_ONE_S32
#undef BUTTERFLY_TWO_S32
-// Transpose 8x8 to a new location. Don't use transpose_neon.h because those
-// are all in-place.
-// TODO(johannkoenig): share with other fdcts.
-static INLINE void transpose_8x8(const int16x8_t *a, int16x8_t *b) {
- // Swap 16 bit elements.
- const int16x8x2_t c0 = vtrnq_s16(a[0], a[1]);
- const int16x8x2_t c1 = vtrnq_s16(a[2], a[3]);
- const int16x8x2_t c2 = vtrnq_s16(a[4], a[5]);
- const int16x8x2_t c3 = vtrnq_s16(a[6], a[7]);
-
- // Swap 32 bit elements.
- const int32x4x2_t d0 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[0]),
- vreinterpretq_s32_s16(c1.val[0]));
- const int32x4x2_t d1 = vtrnq_s32(vreinterpretq_s32_s16(c0.val[1]),
- vreinterpretq_s32_s16(c1.val[1]));
- const int32x4x2_t d2 = vtrnq_s32(vreinterpretq_s32_s16(c2.val[0]),
- vreinterpretq_s32_s16(c3.val[0]));
- const int32x4x2_t d3 = vtrnq_s32(vreinterpretq_s32_s16(c2.val[1]),
- vreinterpretq_s32_s16(c3.val[1]));
-
- // Swap 64 bit elements
- const int16x8x2_t e0 = vpx_vtrnq_s64_to_s16(d0.val[0], d2.val[0]);
- const int16x8x2_t e1 = vpx_vtrnq_s64_to_s16(d1.val[0], d3.val[0]);
- const int16x8x2_t e2 = vpx_vtrnq_s64_to_s16(d0.val[1], d2.val[1]);
- const int16x8x2_t e3 = vpx_vtrnq_s64_to_s16(d1.val[1], d3.val[1]);
-
- b[0] = e0.val[0];
- b[1] = e1.val[0];
- b[2] = e2.val[0];
- b[3] = e3.val[0];
- b[4] = e0.val[1];
- b[5] = e1.val[1];
- b[6] = e2.val[1];
- b[7] = e3.val[1];
-}
-
void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
int16x8_t temp0[32];
int16x8_t temp1[32];
@@ -1337,10 +1172,10 @@ void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
dct_body_first_pass(temp0, temp4);
// Generate the top row by munging the first set of 8 from each one together.
- transpose_8x8(&temp1[0], &temp0[0]);
- transpose_8x8(&temp2[0], &temp0[8]);
- transpose_8x8(&temp3[0], &temp0[16]);
- transpose_8x8(&temp4[0], &temp0[24]);
+ transpose_s16_8x8_new(&temp1[0], &temp0[0]);
+ transpose_s16_8x8_new(&temp2[0], &temp0[8]);
+ transpose_s16_8x8_new(&temp3[0], &temp0[16]);
+ transpose_s16_8x8_new(&temp4[0], &temp0[24]);
dct_body_second_pass(temp0, temp5);
@@ -1355,10 +1190,10 @@ void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
store(output, temp5);
// Second row of 8x32.
- transpose_8x8(&temp1[8], &temp0[0]);
- transpose_8x8(&temp2[8], &temp0[8]);
- transpose_8x8(&temp3[8], &temp0[16]);
- transpose_8x8(&temp4[8], &temp0[24]);
+ transpose_s16_8x8_new(&temp1[8], &temp0[0]);
+ transpose_s16_8x8_new(&temp2[8], &temp0[8]);
+ transpose_s16_8x8_new(&temp3[8], &temp0[16]);
+ transpose_s16_8x8_new(&temp4[8], &temp0[24]);
dct_body_second_pass(temp0, temp5);
@@ -1373,10 +1208,10 @@ void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
store(output + 8 * 32, temp5);
// Third row of 8x32
- transpose_8x8(&temp1[16], &temp0[0]);
- transpose_8x8(&temp2[16], &temp0[8]);
- transpose_8x8(&temp3[16], &temp0[16]);
- transpose_8x8(&temp4[16], &temp0[24]);
+ transpose_s16_8x8_new(&temp1[16], &temp0[0]);
+ transpose_s16_8x8_new(&temp2[16], &temp0[8]);
+ transpose_s16_8x8_new(&temp3[16], &temp0[16]);
+ transpose_s16_8x8_new(&temp4[16], &temp0[24]);
dct_body_second_pass(temp0, temp5);
@@ -1391,10 +1226,10 @@ void vpx_fdct32x32_neon(const int16_t *input, tran_low_t *output, int stride) {
store(output + 16 * 32, temp5);
// Final row of 8x32.
- transpose_8x8(&temp1[24], &temp0[0]);
- transpose_8x8(&temp2[24], &temp0[8]);
- transpose_8x8(&temp3[24], &temp0[16]);
- transpose_8x8(&temp4[24], &temp0[24]);
+ transpose_s16_8x8_new(&temp1[24], &temp0[0]);
+ transpose_s16_8x8_new(&temp2[24], &temp0[8]);
+ transpose_s16_8x8_new(&temp3[24], &temp0[16]);
+ transpose_s16_8x8_new(&temp4[24], &temp0[24]);
dct_body_second_pass(temp0, temp5);
@@ -1432,10 +1267,10 @@ void vpx_fdct32x32_rd_neon(const int16_t *input, tran_low_t *output,
dct_body_first_pass(temp0, temp4);
// Generate the top row by munging the first set of 8 from each one together.
- transpose_8x8(&temp1[0], &temp0[0]);
- transpose_8x8(&temp2[0], &temp0[8]);
- transpose_8x8(&temp3[0], &temp0[16]);
- transpose_8x8(&temp4[0], &temp0[24]);
+ transpose_s16_8x8_new(&temp1[0], &temp0[0]);
+ transpose_s16_8x8_new(&temp2[0], &temp0[8]);
+ transpose_s16_8x8_new(&temp3[0], &temp0[16]);
+ transpose_s16_8x8_new(&temp4[0], &temp0[24]);
dct_body_second_pass_rd(temp0, temp5);
@@ -1450,10 +1285,10 @@ void vpx_fdct32x32_rd_neon(const int16_t *input, tran_low_t *output,
store(output, temp5);
// Second row of 8x32.
- transpose_8x8(&temp1[8], &temp0[0]);
- transpose_8x8(&temp2[8], &temp0[8]);
- transpose_8x8(&temp3[8], &temp0[16]);
- transpose_8x8(&temp4[8], &temp0[24]);
+ transpose_s16_8x8_new(&temp1[8], &temp0[0]);
+ transpose_s16_8x8_new(&temp2[8], &temp0[8]);
+ transpose_s16_8x8_new(&temp3[8], &temp0[16]);
+ transpose_s16_8x8_new(&temp4[8], &temp0[24]);
dct_body_second_pass_rd(temp0, temp5);
@@ -1468,10 +1303,10 @@ void vpx_fdct32x32_rd_neon(const int16_t *input, tran_low_t *output,
store(output + 8 * 32, temp5);
// Third row of 8x32
- transpose_8x8(&temp1[16], &temp0[0]);
- transpose_8x8(&temp2[16], &temp0[8]);
- transpose_8x8(&temp3[16], &temp0[16]);
- transpose_8x8(&temp4[16], &temp0[24]);
+ transpose_s16_8x8_new(&temp1[16], &temp0[0]);
+ transpose_s16_8x8_new(&temp2[16], &temp0[8]);
+ transpose_s16_8x8_new(&temp3[16], &temp0[16]);
+ transpose_s16_8x8_new(&temp4[16], &temp0[24]);
dct_body_second_pass_rd(temp0, temp5);
@@ -1486,10 +1321,10 @@ void vpx_fdct32x32_rd_neon(const int16_t *input, tran_low_t *output,
store(output + 16 * 32, temp5);
// Final row of 8x32.
- transpose_8x8(&temp1[24], &temp0[0]);
- transpose_8x8(&temp2[24], &temp0[8]);
- transpose_8x8(&temp3[24], &temp0[16]);
- transpose_8x8(&temp4[24], &temp0[24]);
+ transpose_s16_8x8_new(&temp1[24], &temp0[0]);
+ transpose_s16_8x8_new(&temp2[24], &temp0[8]);
+ transpose_s16_8x8_new(&temp3[24], &temp0[16]);
+ transpose_s16_8x8_new(&temp4[24], &temp0[24]);
dct_body_second_pass_rd(temp0, temp5);
diff --git a/vpx_dsp/arm/fdct_neon.c b/vpx_dsp/arm/fdct4x4_neon.c
index 2827791f1..2827791f1 100644
--- a/vpx_dsp/arm/fdct_neon.c
+++ b/vpx_dsp/arm/fdct4x4_neon.c
diff --git a/vpx_dsp/arm/fwd_txfm_neon.c b/vpx_dsp/arm/fdct8x8_neon.c
index d9161c6d3..d9161c6d3 100644
--- a/vpx_dsp/arm/fwd_txfm_neon.c
+++ b/vpx_dsp/arm/fdct8x8_neon.c
diff --git a/vpx_dsp/arm/fdct_neon.h b/vpx_dsp/arm/fdct_neon.h
index 28d7d86bf..056cae408 100644
--- a/vpx_dsp/arm/fdct_neon.h
+++ b/vpx_dsp/arm/fdct_neon.h
@@ -13,6 +13,136 @@
#include <arm_neon.h>
+// fdct_round_shift((a +/- b) * c)
+static INLINE void butterfly_one_coeff(const int16x8_t a, const int16x8_t b,
+ const tran_high_t constant,
+ int16x8_t *add, int16x8_t *sub) {
+ const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
+ const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
+ const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
+ const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
+ const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
+ const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
+ const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
+ const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
+ const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
+ const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
+ *add = vcombine_s16(rounded0, rounded1);
+ *sub = vcombine_s16(rounded2, rounded3);
+}
+
+// fdct_round_shift(a * c0 +/- b * c1)
+static INLINE void butterfly_two_coeff(const int16x8_t a, const int16x8_t b,
+ const tran_coef_t constant0,
+ const tran_coef_t constant1,
+ int16x8_t *add, int16x8_t *sub) {
+ const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant0);
+ const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant0);
+ const int32x4_t a2 = vmull_n_s16(vget_low_s16(a), constant1);
+ const int32x4_t a3 = vmull_n_s16(vget_high_s16(a), constant1);
+ const int32x4_t sum0 = vmlal_n_s16(a2, vget_low_s16(b), constant0);
+ const int32x4_t sum1 = vmlal_n_s16(a3, vget_high_s16(b), constant0);
+ const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant1);
+ const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant1);
+ const int16x4_t rounded0 = vqrshrn_n_s32(sum0, DCT_CONST_BITS);
+ const int16x4_t rounded1 = vqrshrn_n_s32(sum1, DCT_CONST_BITS);
+ const int16x4_t rounded2 = vqrshrn_n_s32(diff0, DCT_CONST_BITS);
+ const int16x4_t rounded3 = vqrshrn_n_s32(diff1, DCT_CONST_BITS);
+ *add = vcombine_s16(rounded0, rounded1);
+ *sub = vcombine_s16(rounded2, rounded3);
+}
+
+// Add 2 if positive, 1 if negative, and shift by 2.
+// In practice, subtract the sign bit, then shift with rounding.
+static INLINE int16x8_t sub_round_shift(const int16x8_t a) {
+ const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
+ const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
+ const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
+ return vrshrq_n_s16(vsubq_s16(a, a_sign_s16), 2);
+}
+
+// Like butterfly_one_coeff, but don't narrow results.
+static INLINE void butterfly_one_coeff_s16_s32(
+ const int16x8_t a, const int16x8_t b, const tran_high_t constant,
+ int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
+ int32x4_t *sub_hi) {
+ const int32x4_t a0 = vmull_n_s16(vget_low_s16(a), constant);
+ const int32x4_t a1 = vmull_n_s16(vget_high_s16(a), constant);
+ const int32x4_t sum0 = vmlal_n_s16(a0, vget_low_s16(b), constant);
+ const int32x4_t sum1 = vmlal_n_s16(a1, vget_high_s16(b), constant);
+ const int32x4_t diff0 = vmlsl_n_s16(a0, vget_low_s16(b), constant);
+ const int32x4_t diff1 = vmlsl_n_s16(a1, vget_high_s16(b), constant);
+ *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
+ *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
+ *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
+ *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
+}
+
+// Like butterfly_one_coeff, but with s32.
+static INLINE void butterfly_one_coeff_s32(
+ const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
+ const int32x4_t b_hi, const int32_t constant, int32x4_t *add_lo,
+ int32x4_t *add_hi, int32x4_t *sub_lo, int32x4_t *sub_hi) {
+ const int32x4_t a_lo_0 = vmulq_n_s32(a_lo, constant);
+ const int32x4_t a_hi_0 = vmulq_n_s32(a_hi, constant);
+ const int32x4_t sum0 = vmlaq_n_s32(a_lo_0, b_lo, constant);
+ const int32x4_t sum1 = vmlaq_n_s32(a_hi_0, b_hi, constant);
+ const int32x4_t diff0 = vmlsq_n_s32(a_lo_0, b_lo, constant);
+ const int32x4_t diff1 = vmlsq_n_s32(a_hi_0, b_hi, constant);
+ *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
+ *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
+ *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
+ *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
+}
+
+// Like butterfly_two_coeff, but with s32.
+static INLINE void butterfly_two_coeff_s32(
+ const int32x4_t a_lo, const int32x4_t a_hi, const int32x4_t b_lo,
+ const int32x4_t b_hi, const int32_t constant0, const int32_t constant1,
+ int32x4_t *add_lo, int32x4_t *add_hi, int32x4_t *sub_lo,
+ int32x4_t *sub_hi) {
+ const int32x4_t a0 = vmulq_n_s32(a_lo, constant0);
+ const int32x4_t a1 = vmulq_n_s32(a_hi, constant0);
+ const int32x4_t a2 = vmulq_n_s32(a_lo, constant1);
+ const int32x4_t a3 = vmulq_n_s32(a_hi, constant1);
+ const int32x4_t sum0 = vmlaq_n_s32(a2, b_lo, constant0);
+ const int32x4_t sum1 = vmlaq_n_s32(a3, b_hi, constant0);
+ const int32x4_t diff0 = vmlsq_n_s32(a0, b_lo, constant1);
+ const int32x4_t diff1 = vmlsq_n_s32(a1, b_hi, constant1);
+ *add_lo = vrshrq_n_s32(sum0, DCT_CONST_BITS);
+ *add_hi = vrshrq_n_s32(sum1, DCT_CONST_BITS);
+ *sub_lo = vrshrq_n_s32(diff0, DCT_CONST_BITS);
+ *sub_hi = vrshrq_n_s32(diff1, DCT_CONST_BITS);
+}
+
+// Add 1 if positive, 2 if negative, and shift by 2.
+// In practice, add 1, then add the sign bit, then shift without rounding.
+static INLINE int16x8_t add_round_shift_s16(const int16x8_t a) {
+ const int16x8_t one = vdupq_n_s16(1);
+ const uint16x8_t a_u16 = vreinterpretq_u16_s16(a);
+ const uint16x8_t a_sign_u16 = vshrq_n_u16(a_u16, 15);
+ const int16x8_t a_sign_s16 = vreinterpretq_s16_u16(a_sign_u16);
+ return vshrq_n_s16(vaddq_s16(vaddq_s16(a, a_sign_s16), one), 2);
+}
+
+// Add 1 if positive, 2 if negative, and shift by 2.
+// In practice, add 1, then add the sign bit, then shift without rounding.
+static INLINE int16x8_t add_round_shift_s32(const int32x4_t a_lo,
+ const int32x4_t a_hi) {
+ const int32x4_t one = vdupq_n_s32(1);
+ const uint32x4_t a_lo_u32 = vreinterpretq_u32_s32(a_lo);
+ const uint32x4_t a_lo_sign_u32 = vshrq_n_u32(a_lo_u32, 31);
+ const int32x4_t a_lo_sign_s32 = vreinterpretq_s32_u32(a_lo_sign_u32);
+ const int16x4_t b_lo =
+ vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_lo, a_lo_sign_s32), one), 2);
+ const uint32x4_t a_hi_u32 = vreinterpretq_u32_s32(a_hi);
+ const uint32x4_t a_hi_sign_u32 = vshrq_n_u32(a_hi_u32, 31);
+ const int32x4_t a_hi_sign_s32 = vreinterpretq_s32_u32(a_hi_sign_u32);
+ const int16x4_t b_hi =
+ vshrn_n_s32(vqaddq_s32(vqaddq_s32(a_hi, a_hi_sign_s32), one), 2);
+ return vcombine_s16(b_lo, b_hi);
+}
+
static INLINE void vpx_fdct4x4_pass1_neon(int16x4_t *in) {
const int16x8_t input_01 = vcombine_s16(in[0], in[1]);
const int16x8_t input_32 = vcombine_s16(in[3], in[2]);
diff --git a/vpx_dsp/vpx_dsp.mk b/vpx_dsp/vpx_dsp.mk
index 32d21e03f..1fd9495cf 100644
--- a/vpx_dsp/vpx_dsp.mk
+++ b/vpx_dsp/vpx_dsp.mk
@@ -227,11 +227,11 @@ ifeq ($(VPX_ARCH_X86_64),yes)
DSP_SRCS-$(HAVE_SSSE3) += x86/fwd_txfm_ssse3_x86_64.asm
endif
DSP_SRCS-$(HAVE_AVX2) += x86/fwd_dct32x32_impl_avx2.h
-DSP_SRCS-$(HAVE_NEON) += arm/fdct_neon.c
+DSP_SRCS-$(HAVE_NEON) += arm/fdct4x4_neon.c
+DSP_SRCS-$(HAVE_NEON) += arm/fdct8x8_neon.c
DSP_SRCS-$(HAVE_NEON) += arm/fdct16x16_neon.c
DSP_SRCS-$(HAVE_NEON) += arm/fdct32x32_neon.c
DSP_SRCS-$(HAVE_NEON) += arm/fdct_partial_neon.c
-DSP_SRCS-$(HAVE_NEON) += arm/fwd_txfm_neon.c
DSP_SRCS-$(HAVE_MSA) += mips/fwd_txfm_msa.h
DSP_SRCS-$(HAVE_MSA) += mips/fwd_txfm_msa.c
DSP_SRCS-$(HAVE_LSX) += loongarch/fwd_txfm_lsx.h