diff options
author | Linfeng Zhang <linfengz@google.com> | 2018-04-17 13:43:16 -0700 |
---|---|---|
committer | Linfeng Zhang <linfengz@google.com> | 2018-04-17 13:43:16 -0700 |
commit | 78ba83bb917484336ee782f03adde4a6e5afb626 (patch) | |
tree | cf9bc37260155c924d4e1fd27807038803c5309e /vpx_dsp | |
parent | 55ca875e6bc2da9539fac2839611b3945f262e49 (diff) | |
download | libvpx-78ba83bb917484336ee782f03adde4a6e5afb626.tar libvpx-78ba83bb917484336ee782f03adde4a6e5afb626.tar.gz libvpx-78ba83bb917484336ee782f03adde4a6e5afb626.tar.bz2 libvpx-78ba83bb917484336ee782f03adde4a6e5afb626.zip |
Update variance avx2 functions
Old vs New
Variance 64x64 time: 1145 ms 797 ms
Variance 64x32 time: 1200 ms 831 ms
Variance 32x32 time: 1228 ms 1135 ms
Variance 32x16 time: 1374 ms 1491 ms
Variance 16x16 time: 1688 ms 1571 ms
sse2 vs avx2
Variance 32x64 time: 1645 ms 957 ms
Variance 16x32 time: 2031 ms 1243 ms
Variance 16x8 time: 3071 ms 2275 ms
Change-Id: I0202a556e4629977d647e219c2e897e1ab6accb2
Diffstat (limited to 'vpx_dsp')
-rw-r--r-- | vpx_dsp/vpx_dsp_rtcd_defs.pl | 8 | ||||
-rw-r--r-- | vpx_dsp/x86/variance_avx2.c | 363 |
2 files changed, 216 insertions, 155 deletions
diff --git a/vpx_dsp/vpx_dsp_rtcd_defs.pl b/vpx_dsp/vpx_dsp_rtcd_defs.pl index a51761cd3..4dbee088b 100644 --- a/vpx_dsp/vpx_dsp_rtcd_defs.pl +++ b/vpx_dsp/vpx_dsp_rtcd_defs.pl @@ -1088,7 +1088,7 @@ add_proto qw/unsigned int vpx_variance64x32/, "const uint8_t *src_ptr, int sourc specialize qw/vpx_variance64x32 sse2 avx2 neon msa mmi/; add_proto qw/unsigned int vpx_variance32x64/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse"; - specialize qw/vpx_variance32x64 sse2 neon msa mmi/; + specialize qw/vpx_variance32x64 sse2 avx2 neon msa mmi/; add_proto qw/unsigned int vpx_variance32x32/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse"; specialize qw/vpx_variance32x32 sse2 avx2 neon msa mmi/; @@ -1097,13 +1097,13 @@ add_proto qw/unsigned int vpx_variance32x16/, "const uint8_t *src_ptr, int sourc specialize qw/vpx_variance32x16 sse2 avx2 neon msa mmi/; add_proto qw/unsigned int vpx_variance16x32/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse"; - specialize qw/vpx_variance16x32 sse2 neon msa mmi/; + specialize qw/vpx_variance16x32 sse2 avx2 neon msa mmi/; add_proto qw/unsigned int vpx_variance16x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse"; specialize qw/vpx_variance16x16 sse2 avx2 neon msa mmi/; add_proto qw/unsigned int vpx_variance16x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse"; - specialize qw/vpx_variance16x8 sse2 neon msa mmi/; + specialize qw/vpx_variance16x8 sse2 avx2 neon msa mmi/; add_proto qw/unsigned int vpx_variance8x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int ref_stride, unsigned int *sse"; specialize qw/vpx_variance8x16 sse2 neon msa mmi/; @@ -1133,7 +1133,7 @@ add_proto qw/unsigned int vpx_mse16x16/, "const uint8_t *src_ptr, int source_st specialize qw/vpx_mse16x16 sse2 avx2 neon msa mmi/; add_proto qw/unsigned int vpx_mse16x8/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse"; - specialize qw/vpx_mse16x8 sse2 msa mmi/; + specialize qw/vpx_mse16x8 sse2 avx2 msa mmi/; add_proto qw/unsigned int vpx_mse8x16/, "const uint8_t *src_ptr, int source_stride, const uint8_t *ref_ptr, int recon_stride, unsigned int *sse"; specialize qw/vpx_mse8x16 sse2 msa mmi/; diff --git a/vpx_dsp/x86/variance_avx2.c b/vpx_dsp/x86/variance_avx2.c index d15a89c74..d938b81ea 100644 --- a/vpx_dsp/x86/variance_avx2.c +++ b/vpx_dsp/x86/variance_avx2.c @@ -38,130 +38,140 @@ DECLARE_ALIGNED(32, static const int8_t, adjacent_sub_avx2[32]) = { }; /* clang-format on */ -void vpx_get16x16var_avx2(const unsigned char *src_ptr, int source_stride, - const unsigned char *ref_ptr, int recon_stride, - unsigned int *sse, int *sum) { - unsigned int i, src_2strides, ref_2strides; - __m256i sum_reg = _mm256_setzero_si256(); - __m256i sse_reg = _mm256_setzero_si256(); - // process two 16 byte locations in a 256 bit register - src_2strides = source_stride << 1; - ref_2strides = recon_stride << 1; - for (i = 0; i < 8; ++i) { - // convert up values in 128 bit registers across lanes - const __m256i src0 = - _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i const *)(src_ptr))); - const __m256i src1 = _mm256_cvtepu8_epi16( - _mm_loadu_si128((__m128i const *)(src_ptr + source_stride))); - const __m256i ref0 = - _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i const *)(ref_ptr))); - const __m256i ref1 = _mm256_cvtepu8_epi16( - _mm_loadu_si128((__m128i const *)(ref_ptr + recon_stride))); - const __m256i diff0 = _mm256_sub_epi16(src0, ref0); - const __m256i diff1 = _mm256_sub_epi16(src1, ref1); - const __m256i madd0 = _mm256_madd_epi16(diff0, diff0); - const __m256i madd1 = _mm256_madd_epi16(diff1, diff1); - - // add to the running totals - sum_reg = _mm256_add_epi16(sum_reg, _mm256_add_epi16(diff0, diff1)); - sse_reg = _mm256_add_epi32(sse_reg, _mm256_add_epi32(madd0, madd1)); - - src_ptr += src_2strides; - ref_ptr += ref_2strides; - } - { - // extract the low lane and add it to the high lane - const __m128i sum_reg_128 = _mm_add_epi16( - _mm256_castsi256_si128(sum_reg), _mm256_extractf128_si256(sum_reg, 1)); - const __m128i sse_reg_128 = _mm_add_epi32( - _mm256_castsi256_si128(sse_reg), _mm256_extractf128_si256(sse_reg, 1)); - - // sum upper and lower 64 bits together and convert up to 32 bit values - const __m128i sum_reg_64 = - _mm_add_epi16(sum_reg_128, _mm_srli_si128(sum_reg_128, 8)); - const __m128i sum_int32 = _mm_cvtepi16_epi32(sum_reg_64); - - // unpack sse and sum registers and add - const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, sum_int32); - const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, sum_int32); - const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi); - - // perform the final summation and extract the results - const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8)); - *((int *)sse) = _mm_cvtsi128_si32(res); - *((int *)sum) = _mm_extract_epi32(res, 1); +static INLINE void variance_kernel_avx2(const __m256i src, const __m256i ref, + __m256i *const sse, + __m256i *const sum) { + const __m256i adj_sub = _mm256_load_si256((__m256i const *)adjacent_sub_avx2); + + // unpack into pairs of source and reference values + const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref); + const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref); + + // subtract adjacent elements using src*1 + ref*-1 + const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub); + const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub); + const __m256i madd0 = _mm256_madd_epi16(diff0, diff0); + const __m256i madd1 = _mm256_madd_epi16(diff1, diff1); + + // add to the running totals + *sum = _mm256_add_epi16(*sum, _mm256_add_epi16(diff0, diff1)); + *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1)); +} + +static INLINE void variance_final_from_32bit_sum_avx2(__m256i vsse, + __m128i vsum, + unsigned int *const sse, + int *const sum) { + // extract the low lane and add it to the high lane + const __m128i sse_reg_128 = _mm_add_epi32(_mm256_castsi256_si128(vsse), + _mm256_extractf128_si256(vsse, 1)); + + // unpack sse and sum registers and add + const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, vsum); + const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, vsum); + const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi); + + // perform the final summation and extract the results + const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8)); + *((int *)sse) = _mm_cvtsi128_si32(res); + *((int *)sum) = _mm_extract_epi32(res, 1); +} + +static INLINE void variance_final_from_16bit_sum_avx2(__m256i vsse, + __m256i vsum, + unsigned int *const sse, + int *const sum) { + // extract the low lane and add it to the high lane + const __m128i sum_reg_128 = _mm_add_epi16(_mm256_castsi256_si128(vsum), + _mm256_extractf128_si256(vsum, 1)); + const __m128i sum_reg_64 = + _mm_add_epi16(sum_reg_128, _mm_srli_si128(sum_reg_128, 8)); + const __m128i sum_int32 = _mm_cvtepi16_epi32(sum_reg_64); + + variance_final_from_32bit_sum_avx2(vsse, sum_int32, sse, sum); +} + +static INLINE __m256i sum_to_32bit_avx2(const __m256i sum) { + const __m256i sum_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum)); + const __m256i sum_hi = + _mm256_cvtepi16_epi32(_mm256_extractf128_si256(sum, 1)); + return _mm256_add_epi32(sum_lo, sum_hi); +} + +static INLINE void variance16_kernel_avx2( + const uint8_t *const src, const int src_stride, const uint8_t *const ref, + const int ref_stride, __m256i *const sse, __m256i *const sum) { + const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride)); + const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride)); + const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride)); + const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride)); + const __m256i s = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1); + const __m256i r = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1); + variance_kernel_avx2(s, r, sse, sum); +} + +static INLINE void variance32_kernel_avx2(const uint8_t *const src, + const uint8_t *const ref, + __m256i *const sse, + __m256i *const sum) { + const __m256i s = _mm256_loadu_si256((__m256i const *)(src)); + const __m256i r = _mm256_loadu_si256((__m256i const *)(ref)); + variance_kernel_avx2(s, r, sse, sum); +} + +static INLINE void variance16_avx2(const uint8_t *src, const int src_stride, + const uint8_t *ref, const int ref_stride, + const int h, __m256i *const vsse, + __m256i *const vsum) { + int i; + *vsum = _mm256_setzero_si256(); + *vsse = _mm256_setzero_si256(); + + for (i = 0; i < h; i += 2) { + variance16_kernel_avx2(src, src_stride, ref, ref_stride, vsse, vsum); + src += 2 * src_stride; + ref += 2 * ref_stride; } } -static void get32x16var_avx2(const unsigned char *src_ptr, int source_stride, - const unsigned char *ref_ptr, int recon_stride, - unsigned int *sse, int *sum) { - unsigned int i, src_2strides, ref_2strides; - const __m256i adj_sub = _mm256_load_si256((__m256i const *)adjacent_sub_avx2); - __m256i sum_reg = _mm256_setzero_si256(); - __m256i sse_reg = _mm256_setzero_si256(); +static INLINE void variance32_avx2(const uint8_t *src, const int src_stride, + const uint8_t *ref, const int ref_stride, + const int h, __m256i *const vsse, + __m256i *const vsum) { + int i; + *vsum = _mm256_setzero_si256(); + *vsse = _mm256_setzero_si256(); - // process 64 elements in an iteration - src_2strides = source_stride << 1; - ref_2strides = recon_stride << 1; - for (i = 0; i < 8; i++) { - const __m256i src0 = _mm256_loadu_si256((__m256i const *)(src_ptr)); - const __m256i src1 = - _mm256_loadu_si256((__m256i const *)(src_ptr + source_stride)); - const __m256i ref0 = _mm256_loadu_si256((__m256i const *)(ref_ptr)); - const __m256i ref1 = - _mm256_loadu_si256((__m256i const *)(ref_ptr + recon_stride)); - - // unpack into pairs of source and reference values - const __m256i src_ref0 = _mm256_unpacklo_epi8(src0, ref0); - const __m256i src_ref1 = _mm256_unpackhi_epi8(src0, ref0); - const __m256i src_ref2 = _mm256_unpacklo_epi8(src1, ref1); - const __m256i src_ref3 = _mm256_unpackhi_epi8(src1, ref1); - - // subtract adjacent elements using src*1 + ref*-1 - const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub); - const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub); - const __m256i diff2 = _mm256_maddubs_epi16(src_ref2, adj_sub); - const __m256i diff3 = _mm256_maddubs_epi16(src_ref3, adj_sub); - const __m256i madd0 = _mm256_madd_epi16(diff0, diff0); - const __m256i madd1 = _mm256_madd_epi16(diff1, diff1); - const __m256i madd2 = _mm256_madd_epi16(diff2, diff2); - const __m256i madd3 = _mm256_madd_epi16(diff3, diff3); - - // add to the running totals - sum_reg = _mm256_add_epi16(sum_reg, _mm256_add_epi16(diff0, diff1)); - sum_reg = _mm256_add_epi16(sum_reg, _mm256_add_epi16(diff2, diff3)); - sse_reg = _mm256_add_epi32(sse_reg, _mm256_add_epi32(madd0, madd1)); - sse_reg = _mm256_add_epi32(sse_reg, _mm256_add_epi32(madd2, madd3)); - - src_ptr += src_2strides; - ref_ptr += ref_2strides; + for (i = 0; i < h; i++) { + variance32_kernel_avx2(src, ref, vsse, vsum); + src += src_stride; + ref += ref_stride; } +} + +static INLINE void variance64_avx2(const uint8_t *src, const int src_stride, + const uint8_t *ref, const int ref_stride, + const int h, __m256i *const vsse, + __m256i *const vsum) { + int i; + *vsum = _mm256_setzero_si256(); - { - // extract the low lane and add it to the high lane - const __m128i sum_reg_128 = _mm_add_epi16( - _mm256_castsi256_si128(sum_reg), _mm256_extractf128_si256(sum_reg, 1)); - const __m128i sse_reg_128 = _mm_add_epi32( - _mm256_castsi256_si128(sse_reg), _mm256_extractf128_si256(sse_reg, 1)); - - // sum upper and lower 64 bits together and convert up to 32 bit values - const __m128i sum_reg_64 = - _mm_add_epi16(sum_reg_128, _mm_srli_si128(sum_reg_128, 8)); - const __m128i sum_int32 = _mm_cvtepi16_epi32(sum_reg_64); - - // unpack sse and sum registers and add - const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, sum_int32); - const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, sum_int32); - const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi); - - // perform the final summation and extract the results - const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8)); - *((int *)sse) = _mm_cvtsi128_si32(res); - *((int *)sum) = _mm_extract_epi32(res, 1); + for (i = 0; i < h; i++) { + variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum); + variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum); + src += src_stride; + ref += ref_stride; } } +void vpx_get16x16var_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, unsigned int *sse, + int *sum) { + __m256i vsse, vsum; + variance16_avx2(src, src_stride, ref, ref_stride, 16, &vsse, &vsum); + variance_final_from_16bit_sum_avx2(vsse, vsum, sse, sum); +} + #define FILTER_SRC(filter) \ /* filter the source */ \ exp_src_lo = _mm256_maddubs_epi16(exp_src_lo, filter); \ @@ -593,50 +603,43 @@ typedef void (*get_var_avx2)(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, unsigned int *sse, int *sum); -static void variance_avx2(const uint8_t *src, int src_stride, - const uint8_t *ref, int ref_stride, int w, int h, - unsigned int *sse, int *sum, get_var_avx2 var_fn, - int block_size) { - int i, j; - - *sse = 0; - *sum = 0; - - for (i = 0; i < h; i += 16) { - for (j = 0; j < w; j += block_size) { - unsigned int sse0; - int sum0; - var_fn(&src[src_stride * i + j], src_stride, &ref[ref_stride * i + j], - ref_stride, &sse0, &sum0); - *sse += sse0; - *sum += sum0; - } - } +unsigned int vpx_variance16x8_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + unsigned int *sse) { + int sum; + __m256i vsse, vsum; + variance16_avx2(src, src_stride, ref, ref_stride, 8, &vsse, &vsum); + variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum); + return *sse - (uint32_t)(((int64_t)sum * sum) >> 7); } unsigned int vpx_variance16x16_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, unsigned int *sse) { int sum; - variance_avx2(src, src_stride, ref, ref_stride, 16, 16, sse, &sum, - vpx_get16x16var_avx2, 16); + __m256i vsse, vsum; + variance16_avx2(src, src_stride, ref, ref_stride, 16, &vsse, &vsum); + variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum); return *sse - (uint32_t)(((int64_t)sum * sum) >> 8); } -unsigned int vpx_mse16x16_avx2(const uint8_t *src, int src_stride, - const uint8_t *ref, int ref_stride, - unsigned int *sse) { +unsigned int vpx_variance16x32_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + unsigned int *sse) { int sum; - vpx_get16x16var_avx2(src, src_stride, ref, ref_stride, sse, &sum); - return *sse; + __m256i vsse, vsum; + variance16_avx2(src, src_stride, ref, ref_stride, 32, &vsse, &vsum); + variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum); + return *sse - (uint32_t)(((int64_t)sum * sum) >> 9); } unsigned int vpx_variance32x16_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, unsigned int *sse) { int sum; - variance_avx2(src, src_stride, ref, ref_stride, 32, 16, sse, &sum, - get32x16var_avx2, 32); + __m256i vsse, vsum; + variance32_avx2(src, src_stride, ref, ref_stride, 16, &vsse, &vsum); + variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum); return *sse - (uint32_t)(((int64_t)sum * sum) >> 9); } @@ -644,29 +647,87 @@ unsigned int vpx_variance32x32_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, unsigned int *sse) { int sum; - variance_avx2(src, src_stride, ref, ref_stride, 32, 32, sse, &sum, - get32x16var_avx2, 32); + __m256i vsse, vsum; + __m128i vsum_128; + variance32_avx2(src, src_stride, ref, ref_stride, 32, &vsse, &vsum); + vsum_128 = _mm_add_epi16(_mm256_castsi256_si128(vsum), + _mm256_extractf128_si256(vsum, 1)); + vsum_128 = _mm_add_epi32(_mm_cvtepi16_epi32(vsum_128), + _mm_cvtepi16_epi32(_mm_srli_si128(vsum_128, 8))); + variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum); return *sse - (uint32_t)(((int64_t)sum * sum) >> 10); } -unsigned int vpx_variance64x64_avx2(const uint8_t *src, int src_stride, +unsigned int vpx_variance32x64_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, unsigned int *sse) { int sum; - variance_avx2(src, src_stride, ref, ref_stride, 64, 64, sse, &sum, - get32x16var_avx2, 32); - return *sse - (uint32_t)(((int64_t)sum * sum) >> 12); + __m256i vsse, vsum; + __m128i vsum_128; + variance32_avx2(src, src_stride, ref, ref_stride, 64, &vsse, &vsum); + vsum = sum_to_32bit_avx2(vsum); + vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum), + _mm256_extractf128_si256(vsum, 1)); + variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum); + return *sse - (uint32_t)(((int64_t)sum * sum) >> 11); } unsigned int vpx_variance64x32_avx2(const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, unsigned int *sse) { + __m256i vsse = _mm256_setzero_si256(); + __m256i vsum = _mm256_setzero_si256(); + __m128i vsum_128; int sum; - variance_avx2(src, src_stride, ref, ref_stride, 64, 32, sse, &sum, - get32x16var_avx2, 32); + variance64_avx2(src, src_stride, ref, ref_stride, 32, &vsse, &vsum); + vsum = sum_to_32bit_avx2(vsum); + vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum), + _mm256_extractf128_si256(vsum, 1)); + variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum); return *sse - (uint32_t)(((int64_t)sum * sum) >> 11); } +unsigned int vpx_variance64x64_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + unsigned int *sse) { + __m256i vsse = _mm256_setzero_si256(); + __m256i vsum = _mm256_setzero_si256(); + __m128i vsum_128; + int sum; + int i = 0; + + for (i = 0; i < 2; i++) { + __m256i vsum16; + variance64_avx2(src + 32 * i * src_stride, src_stride, + ref + 32 * i * ref_stride, ref_stride, 32, &vsse, &vsum16); + vsum = _mm256_add_epi32(vsum, sum_to_32bit_avx2(vsum16)); + } + vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum), + _mm256_extractf128_si256(vsum, 1)); + variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum); + return *sse - (unsigned int)(((int64_t)sum * sum) >> 12); +} + +unsigned int vpx_mse16x8_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + unsigned int *sse) { + int sum; + __m256i vsse, vsum; + variance16_avx2(src, src_stride, ref, ref_stride, 8, &vsse, &vsum); + variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum); + return *sse; +} + +unsigned int vpx_mse16x16_avx2(const uint8_t *src, int src_stride, + const uint8_t *ref, int ref_stride, + unsigned int *sse) { + int sum; + __m256i vsse, vsum; + variance16_avx2(src, src_stride, ref, ref_stride, 16, &vsse, &vsum); + variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum); + return *sse; +} + unsigned int vpx_sub_pixel_variance64x64_avx2(const uint8_t *src, int src_stride, int x_offset, int y_offset, const uint8_t *dst, |