12
12
// See the License for the specific language governing permissions and
13
13
// limitations under the License.
14
14
15
+ #include < immintrin.h>
15
16
module;
16
17
17
- #include < cmath>
18
18
#include " ../header.h"
19
+ #include < cmath>
20
+ #include < cstdint>
19
21
20
22
import stl;
21
23
@@ -91,9 +93,7 @@ export float F32CosAVX512(const float *pv1, const float *pv2, size_t dim) {
91
93
return mul_res != 0 ? mul_res / sqrt (v1_res * v2_res) : 0 ;
92
94
}
93
95
94
- export float F32CosAVX512Residual (const float *pv1, const float *pv2, size_t dim) {
95
- return F32CosAVX512 (pv1, pv2, dim);
96
- }
96
+ export float F32CosAVX512Residual (const float *pv1, const float *pv2, size_t dim) { return F32CosAVX512 (pv1, pv2, dim); }
97
97
98
98
#endif
99
99
@@ -151,9 +151,7 @@ export float F32CosAVX(const float *pv1, const float *pv2, size_t dim) {
151
151
return mul_res != 0 ? mul_res / sqrt (v1_res * v2_res) : 0 ;
152
152
}
153
153
154
- export float F32CosAVXResidual (const float *pv1, const float *pv2, size_t dim) {
155
- return F32CosAVX (pv1, pv2, dim);
156
- }
154
+ export float F32CosAVXResidual (const float *pv1, const float *pv2, size_t dim) { return F32CosAVX (pv1, pv2, dim); }
157
155
158
156
#endif
159
157
@@ -211,7 +209,7 @@ export float F32CosSSE(const float *pv1, const float *pv2, size_t dim) {
211
209
_mm_store_ps (V1TmpRes, norm_v1);
212
210
_mm_store_ps (V2TmpRes, norm_v2);
213
211
214
- float mul_res = MulTmpRes[0 ] + MulTmpRes[1 ] + MulTmpRes[2 ] + MulTmpRes[3 ];
212
+ float mul_res = MulTmpRes[0 ] + MulTmpRes[1 ] + MulTmpRes[2 ] + MulTmpRes[3 ];
215
213
float v1_res = V1TmpRes[0 ] + V1TmpRes[1 ] + V1TmpRes[2 ] + V1TmpRes[3 ];
216
214
float v2_res = V2TmpRes[0 ] + V2TmpRes[1 ] + V2TmpRes[2 ] + V2TmpRes[3 ];
217
215
@@ -227,9 +225,7 @@ export float F32CosSSE(const float *pv1, const float *pv2, size_t dim) {
227
225
return mul_res != 0 ? mul_res / sqrt (v1_res * v2_res) : 0 ;
228
226
}
229
227
230
- export float F32CosSSEResidual (const float *pv1, const float *pv2, size_t dim) {
231
- return F32CosSSE (pv1, pv2, dim);
232
- }
228
+ export float F32CosSSEResidual (const float *pv1, const float *pv2, size_t dim) { return F32CosSSE (pv1, pv2, dim); }
233
229
234
230
#endif
235
231
@@ -354,6 +350,148 @@ export int32_t I8IPSSEResidual(const int8_t *pv1, const int8_t *pv2, size_t dim)
354
350
355
351
// ------------------------------//------------------------------//------------------------------
356
352
353
+ export signed char I8L2BF (const int8_t *pv1, const int8_t *pv2, size_t dim) {
354
+ int32_t res = 0 ;
355
+ for (size_t i = 0 ; i < dim; i++) {
356
+ int32_t t = pv1[i] - pv2[i];
357
+ res += t * t;
358
+ }
359
+ return (signed char )res;
360
+ }
361
+
362
+ #if defined(USE_AVX512)
363
+
364
+ export signed char I8L2AVX512 (const int8_t *pv1, const int8_t *pv2, size_t dim) {
365
+ int8_t PORTABLE_ALIGN64 TmpRes[64 ];
366
+ size_t dim16 = dim >> 4 ;
367
+
368
+ const int8_t *pEnd1 = pv1 + (dim16 << 4 );
369
+
370
+ __m512i diff, v1, v2;
371
+ __m512i sum = __mm512_set1_ps (0 );
372
+
373
+ while (pv1 < pEnd1) {
374
+ v1 = _mm512_loadu_si512 (pv1);
375
+ pv1 += 16 ;
376
+ v2 = _mm512_loadu_si512 (pv2);
377
+ pv2 += 16 ;
378
+ diff = _mm512_sub_epi8 (v1, v2);
379
+ sum = _mm512_add_epi8 (sum, _mm512_mul_epi8 (diff, diff));
380
+ }
381
+
382
+ _mm512_store_epi8 (TmpRes, sum);
383
+ int32_t res = 0 ;
384
+ for (size_t i = 0 ; i < 64 ; i++) {
385
+ res += TmpRes[i];
386
+ }
387
+
388
+ return (signed char )res;
389
+ }
390
+
391
+ export signed char I8L2AVX512Residual (const int8_t *pv1, const int8_t *pv2, size_t dim) {
392
+ return I8L2AVX512 (pv1, pv2, dim) + I8L2BF (pv1 + (dim & ~63 ), pv2 + (dim & ~63 ), dim & 63 );
393
+ }
394
+ #endif
395
+
396
+ #if defined(USE_AVX)
397
+
398
+ export signed char I8L2AVX (const int8_t *pv1, const int8_t *pv2, size_t dim) {
399
+ int8_t PORTABLE_ALIGN32 TmpRes[32 ];
400
+ size_t dim16 = dim >> 4 ;
401
+
402
+ const int8_t *pEnd1 = pv1 + (dim16 << 4 );
403
+
404
+ __m256i diff, v1, v2;
405
+ __m256i sum = _mm256_set1_epi8 (0 );
406
+ __m512i diff16, lo, hi;
407
+
408
+ while (pv1 < pEnd1) {
409
+ v1 = _mm256_loadu_epi8 (pv1);
410
+ pv1 += 8 ;
411
+ v2 = _mm256_loadu_epi8 (pv2);
412
+ pv2 += 8 ;
413
+ diff = _mm256_sub_epi8 (v1, v2);
414
+ diff16 = _mm512_cvtepi8_epi16 (diff);
415
+ lo = _mm512_extracti64x4_epi64 (diff16, 0 );
416
+ hi = _mm512_extracti64x4_epi64 (diff16, 1 );
417
+ sum = _mm256_add_epi8 (sum, _mm512_mullo_epi16 (diff16, diff16));
418
+
419
+ v1 = _mm256_loadu_epi8 (pv1);
420
+ pv1 += 8 ;
421
+ v2 = _mm256_loadu_epi8 (pv2);
422
+ pv2 += 8 ;
423
+ diff = _mm256_sub_epi8 (v1, v2);
424
+ sum = _mm256_add_epi8 (sum, _mm256_mul_epi8 (diff, diff));
425
+ }
426
+
427
+ _mm256_storeu_epi8 (TmpRes, sum);
428
+ int32_t res = 0 ;
429
+ for (size_t i = 0 ; i < 32 ; i++) {
430
+ res += TmpRes[i];
431
+ }
432
+ return (signed char )res;
433
+ }
434
+
435
+ export signed char I8L2AVXResidual (const int8_t *pv1, const int8_t *pv2, size_t dim) {
436
+ return I8L2AVX (pv1, pv2, dim) + I8L2BF (pv1 + (dim & ~31 ), pv2 + (dim & ~31 ), dim & 31 );
437
+ }
438
+
439
+ #endif
440
+
441
+ #if defined(USE_SSE)
442
+
443
+ export signed char I8L2SSE (const int8_t *pv1, const int8_t *pv2, size_t dim) {
444
+ alignas (16 ) int32_t TmpRes[4 ];
445
+ size_t dim16 = dim >> 4 ;
446
+
447
+ const int8_t *pEnd1 = pv1 + (dim16 << 4 );
448
+
449
+ __m128i diff, v1, v2;
450
+ __m128i sum = _mm_set1_ps (0 );
451
+
452
+ while (pv1 < pEnd1) {
453
+ v1 = _mm_loadu_epi8 (pv1);
454
+ pv1 += 4 ;
455
+ v2 = _mm_loadu_epi8 (pv2);
456
+ pv2 += 4 ;
457
+ diff = _mm_sub_epi8 (v1, v2);
458
+ sum = _mm_add_epi8 (sum, _mm_mul_epi8 (diff, diff));
459
+
460
+ v1 = _mm_loadu_epi8 (pv1);
461
+ pv1 += 4 ;
462
+ v2 = _mm_loadu_epi8 (pv2);
463
+ pv2 += 4 ;
464
+ diff = _mm_sub_epi8 (v1, v2);
465
+ sum = _mm_add_epi8 (sum, _mm_mul_epi8 (diff, diff));
466
+
467
+ v1 = _mm_loadu_epi8 (pv1);
468
+ pv1 += 4 ;
469
+ v2 = _mm_loadu_epi8 (pv2);
470
+ pv2 += 4 ;
471
+ diff = _mm_sub_epi8 (v1, v2);
472
+ sum = _mm_add_epi8 (sum, _mm_mul_epi8 (diff, diff));
473
+
474
+ v1 = _mm_loadu_epi8 (pv1);
475
+ pv1 += 4 ;
476
+ v2 = _mm_loadu_epi8 (pv2);
477
+ pv2 += 4 ;
478
+ diff = _mm_sub_epi8 (v1, v2);
479
+ sum = _mm_add_epi8 (sum, _mm_mul_epi8 (diff, diff));
480
+ }
481
+
482
+ _mm_storeu_epi8 (TmpRes, sum);
483
+ int32_t res = TmpRes[0 ] + TmpRes[1 ] + TmpRes[2 ] + TmpRes[3 ];
484
+ return (signed char )res;
485
+ }
486
+
487
+ export signed char I8L2SSEResidual (const int8_t *pv1, const int8_t *pv2, size_t dim) {
488
+ return I8L2SSE (pv1, pv2, dim) + I8L2BF (pv1 + (dim & ~15 ), pv2 + (dim & ~15 ), dim & 15 );
489
+ }
490
+
491
+ #endif
492
+
493
+ // ------------------------------//------------------------------//------------------------------
494
+
357
495
export float F32L2BF (const float *pv1, const float *pv2, size_t dim) {
358
496
float res = 0 ;
359
497
for (size_t i = 0 ; i < dim; i++) {
0 commit comments