Skip to content

Commit

Permalink
sse2 requantize pack8
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Feb 5, 2025
1 parent 20a9cdb commit 7bf5e3c
Showing 1 changed file with 48 additions and 27 deletions.
75 changes: 48 additions & 27 deletions src/layer/x86/requantize_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,17 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_

float scale_in = scale_in_data[0];
#if __SSE2__
__m128 _scale_in = _mm_set1_ps(scale_in);
__m128 _scale_in0 = _mm_set1_ps(scale_in);
#if __AVX__
__m256 _scale_in_avx = _mm256_set1_ps(scale_in);
#if __AVX512F__
__m512 _scale_in_avx512 = _mm512_set1_ps(scale_in);
#endif // __AVX512F__
#else // __AVX__
__m128 _scale_in1 = _scale_in0;
#endif // __AVX__
if (scale_in_data_size > 1)
{
#if __AVX__
#if __AVX512F__
if (elempack == 16)
{
Expand All @@ -62,37 +63,44 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_
#endif // __AVX512F__
if (elempack == 8)
{
#if __AVX__
_scale_in_avx = _mm256_loadu_ps((const float*)scale_in_data);
#if __AVX512F__
_scale_in_avx512 = combine8x2_ps(_scale_in_avx, _scale_in_avx);
#endif // __AVX512F__
}
#else // __AVX__
_scale_in0 = _mm_loadu_ps((const float*)scale_in_data);
_scale_in1 = _mm_loadu_ps((const float*)scale_in_data + 4);
#endif // __AVX__
}
if (elempack == 4)
{
_scale_in = _mm_loadu_ps((const float*)scale_in_data);
_scale_in0 = _mm_loadu_ps((const float*)scale_in_data);
#if __AVX__
_scale_in_avx = combine4x2_ps(_scale_in, _scale_in);
_scale_in_avx = combine4x2_ps(_scale_in0, _scale_in0);
#if __AVX512F__
_scale_in_avx512 = combine8x2_ps(_scale_in_avx, _scale_in_avx);
#endif // __AVX512F__
#else // __AVX__
_scale_in1 = _scale_in0;
#endif // __AVX__
}
}
#endif // __SSE2__

float scale_out = scale_out_data[0];
#if __SSE2__
__m128 _scale_out = _mm_set1_ps(scale_out);
__m128 _scale_out0 = _mm_set1_ps(scale_out);
#if __AVX__
__m256 _scale_out_avx = _mm256_set1_ps(scale_out);
#if __AVX512F__
__m512 _scale_out_avx512 = _mm512_set1_ps(scale_out);
#endif // __AVX512F__
#else // __AVX__
__m128 _scale_out1 = _scale_out0;
#endif // __AVX__
if (scale_out_data_size > 1)
{
#if __AVX__
#if __AVX512F__
if (elempack == 16)
{
Expand All @@ -101,20 +109,26 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_
#endif // __AVX512F__
if (elempack == 8)
{
#if __AVX__
_scale_out_avx = _mm256_loadu_ps((const float*)scale_out_data);
#if __AVX512F__
_scale_out_avx512 = combine8x2_ps(_scale_out_avx, _scale_out_avx);
#endif // __AVX512F__
}
#else // __AVX__
_scale_out0 = _mm_loadu_ps((const float*)scale_out_data);
_scale_out1 = _mm_loadu_ps((const float*)scale_out_data + 4);
#endif // __AVX__
}
if (elempack == 4)
{
_scale_out = _mm_loadu_ps((const float*)scale_out_data);
_scale_out0 = _mm_loadu_ps((const float*)scale_out_data);
#if __AVX__
_scale_out_avx = combine4x2_ps(_scale_out, _scale_out);
_scale_out_avx = combine4x2_ps(_scale_out0, _scale_out0);
#if __AVX512F__
_scale_out_avx512 = combine8x2_ps(_scale_out_avx, _scale_out_avx);
#endif // __AVX512F__
#else // __AVX__
_scale_out1 = _scale_out0;
#endif // __AVX__
}
}
Expand Down Expand Up @@ -159,12 +173,12 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_
#else // __AVX__
__m128 _v0 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr));
__m128 _v1 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)(intptr + 4)));
_v0 = _mm_mul_ps(_v0, _scale_in);
_v1 = _mm_mul_ps(_v1, _scale_in);
_v0 = _mm_mul_ps(_v0, _scale_in0);
_v1 = _mm_mul_ps(_v1, _scale_in1);
_v0 = activation_sse(_v0, activation_type, activation_params);
_v1 = activation_sse(_v1, activation_type, activation_params);
_v0 = _mm_mul_ps(_v0, _scale_out);
_v1 = _mm_mul_ps(_v1, _scale_out);
_v0 = _mm_mul_ps(_v0, _scale_out0);
_v1 = _mm_mul_ps(_v1, _scale_out1);
*(int64_t*)ptr = float2int8_sse(_v0, _v1);
#endif // __AVX__
intptr += 8;
Expand All @@ -173,9 +187,9 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_
for (; i + 3 < size; i += 4)
{
__m128 _v = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr));
_v = _mm_mul_ps(_v, _scale_in);
_v = _mm_mul_ps(_v, _scale_in0);
_v = activation_sse(_v, activation_type, activation_params);
_v = _mm_mul_ps(_v, _scale_out);
_v = _mm_mul_ps(_v, _scale_out0);
int32_t v = float2int8_sse(_v);
ptr[0] = (v >> 0) & 0xff;
ptr[1] = (v >> 8) & 0xff;
Expand All @@ -198,16 +212,17 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_
{
float bias = bias_data[0];
#if __SSE2__
__m128 _bias = _mm_set1_ps(bias);
__m128 _bias0 = _mm_set1_ps(bias);
#if __AVX__
__m256 _bias_avx = _mm256_set1_ps(bias);
#if __AVX512F__
__m512 _bias_avx512 = _mm512_set1_ps(bias);
#endif // __AVX512F__
#else // __AVX__
__m128 _bias1 = _bias0;
#endif // __AVX__
if (bias_data_size > 1)
{
#if __AVX__
#if __AVX512F__
if (elempack == 16)
{
Expand All @@ -216,20 +231,26 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_
#endif // __AVX512F__
if (elempack == 8)
{
#if __AVX__
_bias_avx = _mm256_loadu_ps((const float*)bias_data);
#if __AVX512F__
_bias_avx512 = combine8x2_ps(_bias_avx, _bias_avx);
#endif // __AVX512F__
}
#else // __AVX__
_bias0 = _mm_loadu_ps((const float*)bias_data);
_bias1 = _mm_loadu_ps((const float*)bias_data + 4);
#endif // __AVX__
}
if (elempack == 4)
{
_bias = _mm_loadu_ps((const float*)bias_data);
_bias0 = _mm_loadu_ps((const float*)bias_data);
#if __AVX__
_bias_avx = combine4x2_ps(_bias, _bias);
_bias_avx = combine4x2_ps(_bias0, _bias0);
#if __AVX512F__
_bias_avx512 = combine8x2_ps(_bias_avx, _bias_avx);
#endif // __AVX512F__
#else // __AVX__
_bias1 = _bias0;
#endif // __AVX__
}
}
Expand Down Expand Up @@ -272,12 +293,12 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_
#else // __AVX__
__m128 _v0 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr));
__m128 _v1 = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)(intptr + 4)));
_v0 = _mm_comp_fmadd_ps(_v0, _scale_in, _bias);
_v1 = _mm_comp_fmadd_ps(_v1, _scale_in, _bias);
_v0 = _mm_comp_fmadd_ps(_v0, _scale_in0, _bias0);
_v1 = _mm_comp_fmadd_ps(_v1, _scale_in1, _bias1);
_v0 = activation_sse(_v0, activation_type, activation_params);
_v1 = activation_sse(_v1, activation_type, activation_params);
_v0 = _mm_mul_ps(_v0, _scale_out);
_v1 = _mm_mul_ps(_v1, _scale_out);
_v0 = _mm_mul_ps(_v0, _scale_out0);
_v1 = _mm_mul_ps(_v1, _scale_out1);
*(int64_t*)ptr = float2int8_sse(_v0, _v1);
#endif // __AVX__
intptr += 8;
Expand All @@ -286,9 +307,9 @@ static void requantize(const int* intptr, signed char* ptr, const Mat& scale_in_
for (; i + 3 < size; i += 4)
{
__m128 _v = _mm_cvtepi32_ps(_mm_loadu_si128((const __m128i*)intptr));
_v = _mm_comp_fmadd_ps(_v, _scale_in, _bias);
_v = _mm_comp_fmadd_ps(_v, _scale_in0, _bias0);
_v = activation_sse(_v, activation_type, activation_params);
_v = _mm_mul_ps(_v, _scale_out);
_v = _mm_mul_ps(_v, _scale_out0);
int32_t v = float2int8_sse(_v);
ptr[0] = (v >> 0) & 0xff;
ptr[1] = (v >> 8) & 0xff;
Expand Down

0 comments on commit 7bf5e3c

Please sign in to comment.