Skip to content

Commit

Permalink
Add oneDNN/AMX optimization for distance calculation using Blas for I…
Browse files Browse the repository at this point in the history
…ndexFlatIP
  • Loading branch information
guangzegu committed Mar 19, 2024
1 parent 781f178 commit 2f3fdf9
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 3 deletions.
8 changes: 8 additions & 0 deletions c_api/utils/distances_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,11 @@ void faiss_set_distance_compute_min_k_reservoir(int value) {
int faiss_get_distance_compute_min_k_reservoir() {
return faiss::distance_compute_min_k_reservoir;
}

void faiss_set_distance_compute_dnnl_query_bs(int value) {
faiss::distance_compute_dnnl_query_bs = value;
}

int faiss_get_distance_compute_dnnl_query_bs() {
return faiss::distance_compute_dnnl_query_bs;
}
12 changes: 12 additions & 0 deletions c_api/utils/distances_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,18 @@ void faiss_set_distance_compute_min_k_reservoir(int value);
/// rather than a heap
int faiss_get_distance_compute_min_k_reservoir();

/// Setter of block sizes value for oneDNN/AMX distance computations
void faiss_set_distance_compute_dnnl_query_bs(int value);

/// Getter of block sizes value for oneDNN/AMX distance computations
int faiss_get_distance_compute_dnnl_query_bs();

/// Setter of block sizes value for oneDNN/AMX distance computations
void faiss_set_distance_compute_dnnl_database_bs(int value);

/// Getter of block sizes value for oneDNN/AMX distance computations
int faiss_get_distance_compute_dnnl_database_bs();

#ifdef __cplusplus
}
#endif
Expand Down
29 changes: 26 additions & 3 deletions faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,16 @@ void exhaustive_inner_product_blas(
return;

/* block sizes */
const size_t bs_x = distance_compute_blas_query_bs;
const size_t bs_y = distance_compute_blas_database_bs;
size_t prov_bs_x = distance_compute_blas_query_bs;
size_t prov_bs_y = distance_compute_blas_database_bs;
#ifdef ENABLE_DNNL
if (is_amxbf16_supported()) {
prov_bs_x = distance_compute_dnnl_query_bs;
prov_bs_y = distance_compute_dnnl_database_bs;
}
#endif
const size_t bs_x = prov_bs_x;
const size_t bs_y = prov_bs_y;
std::unique_ptr<float[]> ip_block(new float[bs_x * bs_y]);

for (size_t i0 = 0; i0 < nx; i0 += bs_x) {
Expand All @@ -269,7 +277,20 @@ void exhaustive_inner_product_blas(
size_t j1 = j0 + bs_y;
if (j1 > ny)
j1 = ny;
/* compute the actual dot products */
/* compute the actual dot products */
#ifdef ENABLE_DNNL
if (is_amxbf16_supported()) {
FINTEGER nyi = j1 - j0, nxi = i1 - i0;
comput_f32bf16f32_inner_product(
nxi,
d,
nyi,
d,
const_cast<float*>(x + i0 * d),
const_cast<float*>(y + j0 * d),
ip_block.get());
} else
#endif
{
float one = 1, zero = 0;
FINTEGER nyi = j1 - j0, nxi = i1 - i0, di = d;
Expand Down Expand Up @@ -688,6 +709,8 @@ int distance_compute_blas_threshold = 20;
int distance_compute_blas_query_bs = 4096;
int distance_compute_blas_database_bs = 1024;
int distance_compute_min_k_reservoir = 100;
int distance_compute_dnnl_query_bs = 10240;
int distance_compute_dnnl_database_bs = 10240;

void knn_inner_product(
const float* x,
Expand Down
4 changes: 4 additions & 0 deletions faiss/utils/distances.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ FAISS_API extern int distance_compute_blas_threshold;
FAISS_API extern int distance_compute_blas_query_bs;
FAISS_API extern int distance_compute_blas_database_bs;

// block sizes for oneDNN/AMX distance computations
FAISS_API extern int distance_compute_dnnl_query_bs;
FAISS_API extern int distance_compute_dnnl_database_bs;

// above this number of results we switch to a reservoir to collect results
// rather than a heap
FAISS_API extern int distance_compute_min_k_reservoir;
Expand Down

0 comments on commit 2f3fdf9

Please sign in to comment.