From 7bf5e3cc7545fe261b79c18d2e09ff7df1ff117d Mon Sep 17 00:00:00 2001 From: nihuini Date: Wed, 5 Feb 2025 17:19:43 +0800 Subject: [PATCH] sse2 requantize pack8 --- src/layer/x86/requantize_x86.cpp | 75 ++++++++++++++++++++------------ 1 file changed, 48 insertions(+), 27 deletions(-) diff --git a/src/layer/x86/requantize_x86.cpp b/src/layer/x86/requantize_x86.cpp index 996681e5e42..6b64f86967d 100644 --- a/src/layer/x86/requantize_x86.cpp +++ b/src/layer/x86/requantize_x86.cpp @@ -44,16 +44,17 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ float scale_in = scale_in_data[0]; #if __SSE2__ - __m128 _scale_in = _mm_set1_ps(scale_in); + __m128 _scale_in0 = _mm_set1_ps(scale_in); #if __AVX__ __m256 _scale_in_avx = _mm256_set1_ps(scale_in); #if __AVX512F__ __m512 _scale_in_avx512 = _mm512_set1_ps(scale_in); #endif // __AVX512F__ +#else // __AVX__ + __m128 _scale_in1 = _scale_in0; #endif // __AVX__ if (scale_in_data_size > 1) { -#if __AVX__ #if __AVX512F__ if (elempack == 16) { @@ -62,20 +63,26 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ #endif // __AVX512F__ if (elempack == 8) { +#if __AVX__ _scale_in_avx = _mm256_loadu_ps((const float*)scale_in_data); #if __AVX512F__ _scale_in_avx512 = combine8x2_ps(_scale_in_avx, _scale_in_avx); #endif // __AVX512F__ - } +#else // __AVX__ + _scale_in0 = _mm_loadu_ps((const float*)scale_in_data); + _scale_in1 = _mm_loadu_ps((const float*)scale_in_data + 4); #endif // __AVX__ + } if (elempack == 4) { - _scale_in = _mm_loadu_ps((const float*)scale_in_data); + _scale_in0 = _mm_loadu_ps((const float*)scale_in_data); #if __AVX__ - _scale_in_avx = combine4x2_ps(_scale_in, _scale_in); + _scale_in_avx = combine4x2_ps(_scale_in0, _scale_in0); #if __AVX512F__ _scale_in_avx512 = combine8x2_ps(_scale_in_avx, _scale_in_avx); #endif // __AVX512F__ +#else // __AVX__ + _scale_in1 = _scale_in0; #endif // __AVX__ } } @@ -83,16 +90,17 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ float scale_out = scale_out_data[0]; #if __SSE2__ - __m128 _scale_out = _mm_set1_ps(scale_out); + __m128 _scale_out0 = _mm_set1_ps(scale_out); #if __AVX__ __m256 _scale_out_avx = _mm256_set1_ps(scale_out); #if __AVX512F__ __m512 _scale_out_avx512 = _mm512_set1_ps(scale_out); #endif // __AVX512F__ +#else // __AVX__ + __m128 _scale_out1 = _scale_out0; #endif // __AVX__ if (scale_out_data_size > 1) { -#if __AVX__ #if __AVX512F__ if (elempack == 16) { @@ -101,20 +109,26 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ #endif // __AVX512F__ if (elempack == 8) { +#if __AVX__ _scale_out_avx = _mm256_loadu_ps((const float*)scale_out_data); #if __AVX512F__ _scale_out_avx512 = combine8x2_ps(_scale_out_avx, _scale_out_avx); #endif // __AVX512F__ - } +#else // __AVX__ + _scale_out0 = _mm_loadu_ps((const float*)scale_out_data); + _scale_out1 = _mm_loadu_ps((const float*)scale_out_data + 4); #endif // __AVX__ + } if (elempack == 4) { - _scale_out = _mm_loadu_ps((const float*)scale_out_data); + _scale_out0 = _mm_loadu_ps((const float*)scale_out_data); #if __AVX__ - _scale_out_avx = combine4x2_ps(_scale_out, _scale_out); + _scale_out_avx = combine4x2_ps(_scale_out0, _scale_out0); #if __AVX512F__ _scale_out_avx512 = combine8x2_ps(_scale_out_avx, _scale_out_avx); #endif // __AVX512F__ +#else // __AVX__ + _scale_out1 = _scale_out0; #endif // __AVX__ } } @@ -159,12 +173,12 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ #else // __AVX__ __m128 _v0 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr)); __m128 _v1 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)(intptr + 4))); - _v0 = _mm_mul_ps(_v0, _scale_in); - _v1 = _mm_mul_ps(_v1, _scale_in); + _v0 = _mm_mul_ps(_v0, _scale_in0); + _v1 = _mm_mul_ps(_v1, _scale_in1); _v0 = activation_sse(_v0, activation_type, activation_params); _v1 = activation_sse(_v1, activation_type, activation_params); - _v0 = _mm_mul_ps(_v0, _scale_out); - _v1 = _mm_mul_ps(_v1, _scale_out); + _v0 = _mm_mul_ps(_v0, _scale_out0); + _v1 = _mm_mul_ps(_v1, _scale_out1); *(int64_t*)ptr = float2int8_sse(_v0, _v1); #endif // __AVX__ intptr += 8; @@ -173,9 +187,9 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ for (; i + 3 < size; i += 4) { __m128 _v = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr)); - _v = _mm_mul_ps(_v, _scale_in); + _v = _mm_mul_ps(_v, _scale_in0); _v = activation_sse(_v, activation_type, activation_params); - _v = _mm_mul_ps(_v, _scale_out); + _v = _mm_mul_ps(_v, _scale_out0); int32_t v = float2int8_sse(_v); ptr[0] = (v >> 0) & 0xff; ptr[1] = (v >> 8) & 0xff; @@ -198,16 +212,17 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ { float bias = bias_data[0]; #if __SSE2__ - __m128 _bias = _mm_set1_ps(bias); + __m128 _bias0 = _mm_set1_ps(bias); #if __AVX__ __m256 _bias_avx = _mm256_set1_ps(bias); #if __AVX512F__ __m512 _bias_avx512 = _mm512_set1_ps(bias); #endif // __AVX512F__ +#else // __AVX__ + __m128 _bias1 = _bias0; #endif // __AVX__ if (bias_data_size > 1) { -#if __AVX__ #if __AVX512F__ if (elempack == 16) { @@ -216,20 +231,26 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ #endif // __AVX512F__ if (elempack == 8) { +#if __AVX__ _bias_avx = _mm256_loadu_ps((const float*)bias_data); #if __AVX512F__ _bias_avx512 = combine8x2_ps(_bias_avx, _bias_avx); #endif // __AVX512F__ - } +#else // __AVX__ + _bias0 = _mm_loadu_ps((const float*)bias_data); + _bias1 = _mm_loadu_ps((const float*)bias_data + 4); #endif // __AVX__ + } if (elempack == 4) { - _bias = _mm_loadu_ps((const float*)bias_data); + _bias0 = _mm_loadu_ps((const float*)bias_data); #if __AVX__ - _bias_avx = combine4x2_ps(_bias, _bias); + _bias_avx = combine4x2_ps(_bias0, _bias0); #if __AVX512F__ _bias_avx512 = combine8x2_ps(_bias_avx, _bias_avx); #endif // __AVX512F__ +#else // __AVX__ + _bias1 = _bias0; #endif // __AVX__ } } @@ -272,12 +293,12 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ #else // __AVX__ __m128 _v0 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr)); __m128 _v1 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)(intptr + 4))); - _v0 = _mm_comp_fmadd_ps(_v0, _scale_in, _bias); - _v1 = _mm_comp_fmadd_ps(_v1, _scale_in, _bias); + _v0 = _mm_comp_fmadd_ps(_v0, _scale_in0, _bias0); + _v1 = _mm_comp_fmadd_ps(_v1, _scale_in1, _bias1); _v0 = activation_sse(_v0, activation_type, activation_params); _v1 = activation_sse(_v1, activation_type, activation_params); - _v0 = _mm_mul_ps(_v0, _scale_out); - _v1 = _mm_mul_ps(_v1, _scale_out); + _v0 = _mm_mul_ps(_v0, _scale_out0); + _v1 = _mm_mul_ps(_v1, _scale_out1); *(int64_t*)ptr = float2int8_sse(_v0, _v1); #endif // __AVX__ intptr += 8; @@ -286,9 +307,9 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_ for (; i + 3 < size; i += 4) { __m128 _v = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr)); - _v = _mm_comp_fmadd_ps(_v, _scale_in, _bias); + _v = _mm_comp_fmadd_ps(_v, _scale_in0, _bias0); _v = activation_sse(_v, activation_type, activation_params); - _v = _mm_mul_ps(_v, _scale_out); + _v = _mm_mul_ps(_v, _scale_out0); int32_t v = float2int8_sse(_v); ptr[0] = (v >> 0) & 0xff; ptr[1] = (v >> 8) & 0xff;