Skip to content

Commit

Permalink
impl portable pulp simd variants and make the other ones (that are st…
Browse files Browse the repository at this point in the history
…ill used) safe (#16)
  • Loading branch information
sarah-quinones authored Dec 16, 2024
1 parent d3e17da commit 788408e
Show file tree
Hide file tree
Showing 8 changed files with 599 additions and 228 deletions.
19 changes: 17 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@ categories = ["algorithms", "science"]

[dependencies]
argh = "0.1.12"
bytemuck = "1.20.0"
env_logger = "0.11.5"
faer = { version = "0.19.4", default-features = false, features = ["std"] }
log = "0.4.22"
num-traits = "0.2.19"
pulp = "0.21.0"
rand = "0.8.5"
rand_distr = "0.4.3"
rayon = "1.10.0"
Expand Down
23 changes: 22 additions & 1 deletion benches/bench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use gathers::distance::{
native_argmin, native_dot_produce, native_l2_norm, native_squared_euclidean,
};
use gathers::simd::{argmin, dot_product, l2_norm, l2_squared_distance};
use gathers::simd::{self, argmin, dot_product, l2_norm, l2_squared_distance};
use pulp::x86::V3;
use rand::{thread_rng, Rng};

pub fn l2_norm_benchmark(c: &mut Criterion) {
Expand All @@ -18,6 +19,11 @@ pub fn l2_norm_benchmark(c: &mut Criterion) {
group.bench_with_input(BenchmarkId::new("simd", dim), &x, |b, input| {
b.iter(|| unsafe { l2_norm(&input) })
});
if let Some(simd) = V3::try_new() {
group.bench_with_input(BenchmarkId::new("pulp", dim), &x, |b, input| {
b.iter(|| simd::pulp::l2_norm(simd, &input))
});
}
}
group.finish();
}
Expand All @@ -35,6 +41,11 @@ pub fn argmin_benchmark(c: &mut Criterion) {
group.bench_with_input(BenchmarkId::new("simd", dim), &x, |b, input| {
b.iter(|| unsafe { argmin(&input) })
});
if let Some(simd) = V3::try_new() {
group.bench_with_input(BenchmarkId::new("pulp", dim), &x, |b, input| {
b.iter(|| simd::pulp::argmin(simd, &input))
});
}
}
}

Expand All @@ -54,6 +65,11 @@ pub fn l2_distance_benchmark(c: &mut Criterion) {
group.bench_with_input(BenchmarkId::new("simd", dim), &(&lhs, &rhs), |b, input| {
b.iter(|| unsafe { l2_squared_distance(&input.0, &input.1) })
});
if let Some(simd) = V3::try_new() {
group.bench_with_input(BenchmarkId::new("pulp", dim), &(&lhs, &rhs), |b, input| {
b.iter(|| simd::pulp::l2_squared_distance(simd, &input.0, &input.1))
});
}
}
group.finish();
}
Expand All @@ -74,6 +90,11 @@ pub fn ip_distance_benchmark(c: &mut Criterion) {
group.bench_with_input(BenchmarkId::new("simd", dim), &(&lhs, &rhs), |b, input| {
b.iter(|| unsafe { dot_product(&input.0, &input.1) })
});
if let Some(simd) = V3::try_new() {
group.bench_with_input(BenchmarkId::new("pulp", dim), &(&lhs, &rhs), |b, input| {
b.iter(|| simd::pulp::dot_product(simd, &input.0, &input.1))
});
}
}
group.finish();
}
Expand Down
21 changes: 18 additions & 3 deletions python/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

94 changes: 54 additions & 40 deletions src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,21 @@ pub fn native_l2_norm(vec: &[f32]) -> f32 {
/// Compute the L2 norm of the vector.
#[inline]
pub fn l2_norm(vec: &[f32]) -> f32 {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { crate::simd::l2_norm(vec) }
} else {
native_l2_norm(vec)
}
struct Impl<'a> {
vec: &'a [f32],
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
{
native_l2_norm(vec)

impl pulp::WithSimd for Impl<'_> {
type Output = f32;

#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let Self { vec } = self;
crate::simd::pulp::l2_norm(simd, vec)
}
}

pulp::Arch::new().dispatch(Impl { vec })
}

/// Native implementation of squared euclidean distance.
Expand All @@ -46,18 +49,22 @@ pub fn native_squared_euclidean(lhs: &[f32], rhs: &[f32]) -> f32 {
/// Compute the squared Euclidean distance between two vectors.
#[inline]
pub fn squared_euclidean(lhs: &[f32], rhs: &[f32]) -> f32 {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { crate::simd::l2_squared_distance(lhs, rhs) }
} else {
native_squared_euclidean(lhs, rhs)
}
struct Impl<'a> {
lhs: &'a [f32],
rhs: &'a [f32],
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
{
native_squared_euclidean(lhs, rhs)

impl pulp::WithSimd for Impl<'_> {
type Output = f32;

#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let Self { lhs, rhs } = self;
crate::simd::pulp::l2_squared_distance(simd, lhs, rhs)
}
}

pulp::Arch::new().dispatch(Impl { lhs, rhs })
}

/// Native implementation of negative dot product.
Expand All @@ -72,18 +79,22 @@ pub fn native_dot_produce(lhs: &[f32], rhs: &[f32]) -> f32 {
/// Compute the negative dot product between two vectors.
#[inline]
pub fn neg_dot_product(lhs: &[f32], rhs: &[f32]) -> f32 {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { -crate::simd::dot_product(lhs, rhs) }
} else {
-native_dot_produce(lhs, rhs)
}
struct Impl<'a> {
lhs: &'a [f32],
rhs: &'a [f32],
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
{
-native_dot_produce(lhs, rhs)

impl pulp::WithSimd for Impl<'_> {
type Output = f32;

#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let Self { lhs, rhs } = self;
-crate::simd::pulp::dot_product(simd, lhs, rhs)
}
}

pulp::Arch::new().dispatch(Impl { lhs, rhs })
}

/// Native implementation of argmin.
Expand All @@ -103,18 +114,21 @@ pub fn native_argmin(vec: &[f32]) -> usize {
/// Find the index of the minimum value in the vector.
#[inline]
pub fn argmin(vec: &[f32]) -> usize {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
{
if is_x86_feature_detected!("avx2") {
unsafe { crate::simd::argmin(vec) }
} else {
native_argmin(vec)
}
struct Impl<'a> {
vec: &'a [f32],
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
{
native_argmin(vec)

impl pulp::WithSimd for Impl<'_> {
type Output = usize;

#[inline(always)]
fn with_simd<S: pulp::Simd>(self, simd: S) -> Self::Output {
let Self { vec } = self;
crate::simd::pulp::argmin(simd, vec)
}
}

pulp::Arch::new().dispatch(Impl { vec })
}

#[cfg(test)]
Expand Down
Loading

0 comments on commit 788408e

Please sign in to comment.