diff --git a/faiss/cppcontrib/amx/distances_dnnl.h b/faiss/cppcontrib/amx/distances_dnnl.h index e767c9070f..a3004d0ffd 100644 --- a/faiss/cppcontrib/amx/distances_dnnl.h +++ b/faiss/cppcontrib/amx/distances_dnnl.h @@ -24,7 +24,7 @@ namespace faiss { // block sizes for oneDNN/AMX distance computations FAISS_API int distance_compute_dnnl_query_bs = 10240; FAISS_API int distance_compute_dnnl_database_bs = 10240; - + /* Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */ template void exhaustive_inner_product_seq_dnnl( @@ -33,11 +33,11 @@ void exhaustive_inner_product_seq_dnnl( size_t d, size_t nx, size_t ny, - BlockResultHandler& res) { + BlockResultHandler& res) { using SingleResultHandler = typename BlockResultHandler::SingleResultHandler; [[maybe_unused]] int nt = std::min(int(nx), omp_get_max_threads()); - + std::unique_ptr res_arr(new float[nx * ny]); comput_f32bf16f32_inner_product( @@ -50,21 +50,21 @@ void exhaustive_inner_product_seq_dnnl( res_arr.get()); #pragma omp parallel num_threads(nt) - { - SingleResultHandler resi(res); + { + SingleResultHandler resi(res); #pragma omp for - for (size_t i = 0; i < nx; i++) { - resi.begin(i); - for (size_t j = 0; j < ny; j++) { - float ip = res_arr[i * ny + j]; - resi.add_result(ip, j); - } - resi.end(); + for (size_t i = 0; i < nx; i++) { + resi.begin(i); + for (size_t j = 0; j < ny; j++) { + float ip = res_arr[i * ny + j]; + resi.add_result(ip, j); } + resi.end(); } + } } -/** Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */ +/* Find the nearest neighbors for nx queries in a set of ny vectors using oneDNN/AMX */ template void exhaustive_inner_product_blas_dnnl( const float* x, @@ -72,8 +72,8 @@ void exhaustive_inner_product_blas_dnnl( size_t d, size_t nx, size_t ny, - BlockResultHandler& res) { - /* block sizes */ + BlockResultHandler& res) { + /* block sizes */ const size_t bs_x = distance_compute_dnnl_query_bs; const size_t bs_y = distance_compute_dnnl_database_bs; std::unique_ptr ip_block(new float[bs_x * bs_y]); @@ -100,7 +100,6 @@ void exhaustive_inner_product_blas_dnnl( const_cast(y + j0 * d), ip_block.get()); - res.add_results(j0, j1, ip_block.get()); } res.end_multiple(); @@ -108,4 +107,4 @@ void exhaustive_inner_product_blas_dnnl( } } -}// namespace faiss \ No newline at end of file +} // namespace faiss \ No newline at end of file