14
14
15
15
module;
16
16
17
+ #include < cmath>
18
+
17
19
#if defined(__GNUC__) && (defined(__x86_64__) || defined(__i386__))
18
20
#include < immintrin.h>
19
21
#elif defined(__GNUC__) && defined(__aarch64__)
@@ -32,105 +34,158 @@ namespace infinity {
32
34
33
35
// x = ( x7, x6, x5, x4, x3, x2, x1, x0 )
34
36
float calc_256_sum_8 (__m256 x) {
35
- // high_quad = ( x7, x6, x5, x4 )
36
- const __m128 high_quad = _mm256_extractf128_ps (x, 1 );
37
- // low_quad = ( x3, x2, x1, x0 )
38
- const __m128 low_quad = _mm256_castps256_ps128 (x);
39
- // sum_quad = ( x3 + x7, x2 + x6, x1 + x5, x0 + x4 )
40
- const __m128 sum_quad = _mm_add_ps (low_quad, high_quad);
41
- // low_dual = ( -, -, x1 + x5, x0 + x4 )
42
- const __m128 low_dual = sum_quad;
43
- // high_dual = ( -, -, x3 + x7, x2 + x6 )
44
- const __m128 high_dual = _mm_movehl_ps (sum_quad, sum_quad);
45
- // sum_dual = ( -, -, x1 + x3 + x5 + x7, x0 + x2 + x4 + x6 )
46
- const __m128 sum_dual = _mm_add_ps (low_dual, high_dual);
47
- // low = ( -, -, -, x0 + x2 + x4 + x6 )
48
- const __m128 low = sum_dual;
49
- // high = ( -, -, -, x1 + x3 + x5 + x7 )
50
- const __m128 high = _mm_shuffle_ps (sum_dual, sum_dual, 0x1 );
51
- // sum = ( -, -, -, x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 )
52
- const __m128 sum = _mm_add_ss (low, high);
53
- return _mm_cvtss_f32 (sum);
37
+ // high_quad = ( x7, x6, x5, x4 )
38
+ const __m128 high_quad = _mm256_extractf128_ps (x, 1 );
39
+ // low_quad = ( x3, x2, x1, x0 )
40
+ const __m128 low_quad = _mm256_castps256_ps128 (x);
41
+ // sum_quad = ( x3 + x7, x2 + x6, x1 + x5, x0 + x4 )
42
+ const __m128 sum_quad = _mm_add_ps (low_quad, high_quad);
43
+ // low_dual = ( -, -, x1 + x5, x0 + x4 )
44
+ const __m128 low_dual = sum_quad;
45
+ // high_dual = ( -, -, x3 + x7, x2 + x6 )
46
+ const __m128 high_dual = _mm_movehl_ps (sum_quad, sum_quad);
47
+ // sum_dual = ( -, -, x1 + x3 + x5 + x7, x0 + x2 + x4 + x6 )
48
+ const __m128 sum_dual = _mm_add_ps (low_dual, high_dual);
49
+ // low = ( -, -, -, x0 + x2 + x4 + x6 )
50
+ const __m128 low = sum_dual;
51
+ // high = ( -, -, -, x1 + x3 + x5 + x7 )
52
+ const __m128 high = _mm_shuffle_ps (sum_dual, sum_dual, 0x1 );
53
+ // sum = ( -, -, -, x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 )
54
+ const __m128 sum = _mm_add_ss (low, high);
55
+ return _mm_cvtss_f32 (sum);
54
56
}
55
57
56
58
#endif
57
59
58
60
#if defined(__AVX2__)
59
61
60
62
export f32 L2Distance_simd (const f32 *vector1, const f32 *vector2, u32 dimension) {
61
- u32 i = 0 ;
62
- __m256 sum_1 = _mm256_setzero_ps ();
63
- __m256 sum_2 = _mm256_setzero_ps ();
64
- _mm_prefetch (vector1, _MM_HINT_NTA);
65
- _mm_prefetch (vector2, _MM_HINT_NTA);
66
- for (; i + 16 <= dimension; i += 16 ) {
67
- _mm_prefetch (vector1 + i + 16 , _MM_HINT_NTA);
68
- _mm_prefetch (vector2 + i + 16 , _MM_HINT_NTA);
69
- auto diff_1 = _mm256_sub_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
70
- auto diff_2 = _mm256_sub_ps (_mm256_loadu_ps (vector1 + i + 8 ), _mm256_loadu_ps (vector2 + i + 8 ));
71
- auto mul_1 = _mm256_mul_ps (diff_1, diff_1);
72
- auto mul_2 = _mm256_mul_ps (diff_2, diff_2);
73
- // add mul to sum
74
- sum_1 = _mm256_add_ps (sum_1, mul_1);
75
- sum_2 = _mm256_add_ps (sum_2, mul_2);
76
- }
77
- if (i + 8 <= dimension) {
78
- auto diff = _mm256_sub_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
79
- auto mul = _mm256_mul_ps (diff, diff);
80
- sum_1 = _mm256_add_ps (sum_1, mul);
81
- i += 8 ;
82
- }
83
- f32 distance = calc_256_sum_8 (sum_1) + calc_256_sum_8 (sum_2);
84
- for (; i < dimension; ++i) {
85
- auto diff = vector1[i] - vector2[i];
86
- distance += diff * diff;
87
- }
88
- return distance;
63
+ u32 i = 0 ;
64
+ __m256 sum_1 = _mm256_setzero_ps ();
65
+ __m256 sum_2 = _mm256_setzero_ps ();
66
+ _mm_prefetch (vector1, _MM_HINT_NTA);
67
+ _mm_prefetch (vector2, _MM_HINT_NTA);
68
+ for (; i + 16 <= dimension; i += 16 ) {
69
+ _mm_prefetch (vector1 + i + 16 , _MM_HINT_NTA);
70
+ _mm_prefetch (vector2 + i + 16 , _MM_HINT_NTA);
71
+ auto diff_1 = _mm256_sub_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
72
+ auto diff_2 = _mm256_sub_ps (_mm256_loadu_ps (vector1 + i + 8 ), _mm256_loadu_ps (vector2 + i + 8 ));
73
+ auto mul_1 = _mm256_mul_ps (diff_1, diff_1);
74
+ auto mul_2 = _mm256_mul_ps (diff_2, diff_2);
75
+ // add mul to sum
76
+ sum_1 = _mm256_add_ps (sum_1, mul_1);
77
+ sum_2 = _mm256_add_ps (sum_2, mul_2);
78
+ }
79
+ if (i + 8 <= dimension) {
80
+ auto diff = _mm256_sub_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
81
+ auto mul = _mm256_mul_ps (diff, diff);
82
+ sum_1 = _mm256_add_ps (sum_1, mul);
83
+ i += 8 ;
84
+ }
85
+ f32 distance = calc_256_sum_8 (sum_1) + calc_256_sum_8 (sum_2);
86
+ for (; i < dimension; ++i) {
87
+ auto diff = vector1[i] - vector2[i];
88
+ distance += diff * diff;
89
+ }
90
+ return distance;
89
91
}
90
92
91
- #elif defined(__SSE__)
93
+ #elif defined(__SSE__)
92
94
93
- export f32 L2Distance_simd (const f32 *vector1, const f32 *vector2, u32 dimension) {
94
- return F32L2SSEResidual (vector1, vector2, dimension);
95
- }
95
+ export f32 L2Distance_simd (const f32 *vector1, const f32 *vector2, u32 dimension) { return F32L2SSEResidual (vector1, vector2, dimension); }
96
96
97
97
#endif
98
98
99
99
#if defined(__AVX2__)
100
100
101
- export f32 IPDistance_simd (const f32 *vector1, const f32 *vector2, u32 dimension) {
102
- u32 i = 0 ;
103
- __m256 sum_1 = _mm256_setzero_ps ();
104
- __m256 sum_2 = _mm256_setzero_ps ();
105
- _mm_prefetch (vector1, _MM_HINT_NTA);
106
- _mm_prefetch (vector2, _MM_HINT_NTA);
107
- for (; i + 16 <= dimension; i += 16 ) {
108
- _mm_prefetch (vector1 + i + 16 , _MM_HINT_NTA);
109
- _mm_prefetch (vector2 + i + 16 , _MM_HINT_NTA);
110
- auto mul_1 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
111
- auto mul_2 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i + 8 ), _mm256_loadu_ps (vector2 + i + 8 ));
112
- // add mul to sum
113
- sum_1 = _mm256_add_ps (sum_1, mul_1);
114
- sum_2 = _mm256_add_ps (sum_2, mul_2);
115
- }
116
- if (i + 8 <= dimension) {
117
- auto mul = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
118
- sum_1 = _mm256_add_ps (sum_1, mul);
119
- i += 8 ;
120
- }
121
- f32 distance = calc_256_sum_8 (sum_1) + calc_256_sum_8 (sum_2);
122
- for (; i < dimension; ++i) {
123
- distance += vector1[i] * vector2[i];
124
- }
125
- return distance;
101
+ export f32 CosineDistance_simd (const f32 *vector1, const f32 *vector2, u32 dimension) {
102
+ u32 i = 0 ;
103
+ __m256 dot_sum_1 = _mm256_setzero_ps ();
104
+ __m256 dot_sum_2 = _mm256_setzero_ps ();
105
+ __m256 norm_v1_1 = _mm256_setzero_ps ();
106
+ __m256 norm_v1_2 = _mm256_setzero_ps ();
107
+ __m256 norm_v2_1 = _mm256_setzero_ps ();
108
+ __m256 norm_v2_2 = _mm256_setzero_ps ();
109
+ _mm_prefetch (vector1, _MM_HINT_NTA);
110
+ _mm_prefetch (vector2, _MM_HINT_NTA);
111
+ for (; i + 16 <= dimension; i += 16 ) {
112
+ _mm_prefetch (vector1 + i + 16 , _MM_HINT_NTA);
113
+ _mm_prefetch (vector2 + i + 16 , _MM_HINT_NTA);
114
+ auto dot_mul_1 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
115
+ auto dot_mul_2 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i + 8 ), _mm256_loadu_ps (vector2 + i + 8 ));
116
+ auto norm_mul_v1_1 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector1 + i));
117
+ auto norm_mul_v1_2 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i + 8 ), _mm256_loadu_ps (vector1 + i + 8 ));
118
+ auto norm_mul_v2_1 = _mm256_mul_ps (_mm256_loadu_ps (vector2 + i), _mm256_loadu_ps (vector2 + i));
119
+ auto norm_mul_v2_2 = _mm256_mul_ps (_mm256_loadu_ps (vector2 + i + 8 ), _mm256_loadu_ps (vector2 + i + 8 ));
120
+ // add mul to sum
121
+ dot_sum_1 = _mm256_add_ps (dot_sum_1, dot_mul_1);
122
+ dot_sum_2 = _mm256_add_ps (dot_sum_2, dot_mul_2);
123
+ norm_v1_1 = _mm256_add_ps (norm_v1_1, norm_mul_v1_1);
124
+ norm_v1_2 = _mm256_add_ps (norm_v1_2, norm_mul_v1_2);
125
+ norm_v2_1 = _mm256_add_ps (norm_v2_1, norm_mul_v2_1);
126
+ norm_v2_2 = _mm256_add_ps (norm_v2_2, norm_mul_v2_2);
127
+ }
128
+ if (i + 8 <= dimension) {
129
+ auto dot_mul = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
130
+ auto norm_mul_v1 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector1 + i));
131
+ auto norm_mul_v2 = _mm256_mul_ps (_mm256_loadu_ps (vector2 + i), _mm256_loadu_ps (vector2 + i));
132
+
133
+ dot_sum_1 = _mm256_add_ps (dot_sum_1, dot_mul);
134
+ norm_v1_1 = _mm256_add_ps (norm_v1_1, norm_mul_v1);
135
+ norm_v2_1 = _mm256_add_ps (norm_v2_1, norm_mul_v2);
136
+ i += 8 ;
137
+ }
138
+
139
+ f32 dot = calc_256_sum_8 (dot_sum_1) + calc_256_sum_8 (dot_sum_2);
140
+ f32 norm_v1 = calc_256_sum_8 (norm_v1_1) + calc_256_sum_8 (norm_v1_2);
141
+ f32 norm_v2 = calc_256_sum_8 (norm_v2_1) + calc_256_sum_8 (norm_v2_2);
142
+ for (; i < dimension; ++i) {
143
+ dot += vector1[i] * vector2[i];
144
+ norm_v1 += vector1[i] * vector1[i];
145
+ norm_v2 += vector2[i] * vector2[i];
146
+ }
147
+ return dot != 0 ? dot / sqrt (norm_v1 * norm_v2) : 0 ;
126
148
}
127
149
128
- #elif defined(__SSE__)
150
+ #elif defined(__SSE__)
151
+
152
+ export f32 CosineDistance_simd (const f32 *vector1, const f32 *vector2, u32 dimension) { return F32CosSSEResidual (vector1, vector2, dimension); }
153
+
154
+ #endif
155
+
156
+ #if defined(__AVX2__)
129
157
130
158
export f32 IPDistance_simd (const f32 *vector1, const f32 *vector2, u32 dimension) {
131
- return F32IPSSEResidual (vector1, vector2, dimension);
159
+ u32 i = 0 ;
160
+ __m256 sum_1 = _mm256_setzero_ps ();
161
+ __m256 sum_2 = _mm256_setzero_ps ();
162
+ _mm_prefetch (vector1, _MM_HINT_NTA);
163
+ _mm_prefetch (vector2, _MM_HINT_NTA);
164
+ for (; i + 16 <= dimension; i += 16 ) {
165
+ _mm_prefetch (vector1 + i + 16 , _MM_HINT_NTA);
166
+ _mm_prefetch (vector2 + i + 16 , _MM_HINT_NTA);
167
+ auto mul_1 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
168
+ auto mul_2 = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i + 8 ), _mm256_loadu_ps (vector2 + i + 8 ));
169
+ // add mul to sum
170
+ sum_1 = _mm256_add_ps (sum_1, mul_1);
171
+ sum_2 = _mm256_add_ps (sum_2, mul_2);
172
+ }
173
+ if (i + 8 <= dimension) {
174
+ auto mul = _mm256_mul_ps (_mm256_loadu_ps (vector1 + i), _mm256_loadu_ps (vector2 + i));
175
+ sum_1 = _mm256_add_ps (sum_1, mul);
176
+ i += 8 ;
177
+ }
178
+ f32 distance = calc_256_sum_8 (sum_1) + calc_256_sum_8 (sum_2);
179
+ for (; i < dimension; ++i) {
180
+ distance += vector1[i] * vector2[i];
181
+ }
182
+ return distance;
132
183
}
133
184
185
+ #elif defined(__SSE__)
186
+
187
+ export f32 IPDistance_simd (const f32 *vector1, const f32 *vector2, u32 dimension) { return F32IPSSEResidual (vector1, vector2, dimension); }
188
+
134
189
#endif
135
190
136
191
} // namespace infinity
0 commit comments