Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for masked loads & stores #399

Merged
merged 1 commit into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/core_simd/src/masks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ mod sealed {
fn eq(self, other: Self) -> bool;

fn to_usize(self) -> usize;
fn max_unsigned() -> u64;

type Unsigned: SimdElement;

Expand Down Expand Up @@ -78,6 +79,11 @@ macro_rules! impl_element {
self as usize
}

#[inline]
fn max_unsigned() -> u64 {
<$unsigned>::MAX as u64
}

type Unsigned = $unsigned;

const TRUE: Self = -1;
Expand Down
244 changes: 244 additions & 0 deletions crates/core_simd/src/vector.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::simd::{
cmp::SimdPartialOrd,
num::SimdUint,
ptr::{SimdConstPtr, SimdMutPtr},
LaneCount, Mask, MaskElement, SupportedLaneCount, Swizzle,
};
Expand Down Expand Up @@ -261,6 +262,7 @@ where
/// # Panics
///
/// Panics if the slice's length is less than the vector's `Simd::N`.
/// Use `load_or_default` for an alternative that does not panic.
///
/// # Example
///
Expand Down Expand Up @@ -314,6 +316,143 @@ where
unsafe { self.store(slice.as_mut_ptr().cast()) }
}

/// Reads contiguous elements from `slice`. Elements are read so long as they're in-bounds for
/// the `slice`. Otherwise, the default value for the element type is returned.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{Simd, Mask};
/// let vec: Vec<i32> = vec![10, 11];
///
/// let result = Simd::<i32, 4>::load_or_default(&vec);
/// assert_eq!(result, Simd::from_array([10, 11, 0, 0]));
/// ```
#[must_use]
#[inline]
pub fn load_or_default(slice: &[T]) -> Self
where
T: Default,
{
Self::load_or(slice, Default::default())
}

/// Reads contiguous elements from `slice`. Elements are read so long as they're in-bounds for
/// the `slice`. Otherwise, the corresponding value from `or` is passed through.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{Simd, Mask};
/// let vec: Vec<i32> = vec![10, 11];
/// let or = Simd::from_array([-5, -4, -3, -2]);
///
/// let result = Simd::load_or(&vec, or);
/// assert_eq!(result, Simd::from_array([10, 11, -3, -2]));
/// ```
#[must_use]
#[inline]
pub fn load_or(slice: &[T], or: Self) -> Self {
Self::load_select(slice, Mask::splat(true), or)
}

/// Reads contiguous elements from `slice`. Each element is read from memory if its
/// corresponding element in `enable` is `true`.
///
/// When the element is disabled or out of bounds for the slice, that memory location
/// is not accessed and the corresponding value from `or` is passed through.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{Simd, Mask};
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
/// let enable = Mask::from_array([true, true, false, true]);
/// let or = Simd::from_array([-5, -4, -3, -2]);
///
/// let result = Simd::load_select(&vec, enable, or);
/// assert_eq!(result, Simd::from_array([10, 11, -3, 13]));
/// ```
#[must_use]
#[inline]
pub fn load_select_or_default(slice: &[T], enable: Mask<<T as SimdElement>::Mask, N>) -> Self
where
T: Default,
{
Self::load_select(slice, enable, Default::default())
}

/// Reads contiguous elements from `slice`. Each element is read from memory if its
/// corresponding element in `enable` is `true`.
///
/// When the element is disabled or out of bounds for the slice, that memory location
/// is not accessed and the corresponding value from `or` is passed through.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{Simd, Mask};
/// let vec: Vec<i32> = vec![10, 11, 12, 13, 14, 15, 16, 17, 18];
/// let enable = Mask::from_array([true, true, false, true]);
/// let or = Simd::from_array([-5, -4, -3, -2]);
///
/// let result = Simd::load_select(&vec, enable, or);
/// assert_eq!(result, Simd::from_array([10, 11, -3, 13]));
/// ```
#[must_use]
#[inline]
pub fn load_select(
slice: &[T],
mut enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
enable &= mask_up_to(slice.len());
// SAFETY: We performed the bounds check by updating the mask. &[T] is properly aligned to
// the element.
unsafe { Self::load_select_ptr(slice.as_ptr(), enable, or) }
}

/// Reads contiguous elements from `slice`. Each element is read from memory if its
/// corresponding element in `enable` is `true`.
///
/// When the element is disabled, that memory location is not accessed and the corresponding
/// value from `or` is passed through.
#[must_use]
#[inline]
pub unsafe fn load_select_unchecked(
slice: &[T],
enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
let ptr = slice.as_ptr();
// SAFETY: The safety of reading elements from `slice` is ensured by the caller.
unsafe { Self::load_select_ptr(ptr, enable, or) }
}

/// Reads contiguous elements starting at `ptr`. Each element is read from memory if its
/// corresponding element in `enable` is `true`.
///
/// When the element is disabled, that memory location is not accessed and the corresponding
/// value from `or` is passed through.
#[must_use]
#[inline]
pub unsafe fn load_select_ptr(
ptr: *const T,
enable: Mask<<T as SimdElement>::Mask, N>,
or: Self,
) -> Self {
// SAFETY: The safety of reading elements through `ptr` is ensured by the caller.
unsafe { core::intrinsics::simd::simd_masked_load(enable.to_int(), ptr, or) }
}

/// Reads from potentially discontiguous indices in `slice` to construct a SIMD vector.
/// If an index is out-of-bounds, the element is instead selected from the `or` vector.
///
Expand Down Expand Up @@ -492,6 +631,77 @@ where
unsafe { core::intrinsics::simd::simd_gather(or, source, enable.to_int()) }
}

/// Conditionally write contiguous elements to `slice`. The `enable` mask controls
/// which elements are written, as long as they're in-bounds of the `slice`.
/// If the element is disabled or out of bounds, no memory access to that location
/// is made.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{Simd, Mask};
/// let mut arr = [0i32; 4];
/// let write = Simd::from_array([-5, -4, -3, -2]);
/// let enable = Mask::from_array([false, true, true, true]);
///
/// write.store_select(&mut arr[..3], enable);
/// assert_eq!(arr, [0, -4, -3, 0]);
/// ```
#[inline]
pub fn store_select(self, slice: &mut [T], mut enable: Mask<<T as SimdElement>::Mask, N>) {
enable &= mask_up_to(slice.len());
// SAFETY: We performed the bounds check by updating the mask. &[T] is properly aligned to
// the element.
unsafe { self.store_select_ptr(slice.as_mut_ptr(), enable) }
}

/// Conditionally write contiguous elements to `slice`. The `enable` mask controls
/// which elements are written.
///
/// # Safety
///
/// Every enabled element must be in bounds for the `slice`.
///
/// # Examples
/// ```
/// # #![feature(portable_simd)]
/// # #[cfg(feature = "as_crate")] use core_simd::simd;
/// # #[cfg(not(feature = "as_crate"))] use core::simd;
/// # use simd::{Simd, Mask};
/// let mut arr = [0i32; 4];
/// let write = Simd::from_array([-5, -4, -3, -2]);
/// let enable = Mask::from_array([false, true, true, true]);
///
/// unsafe { write.store_select_unchecked(&mut arr, enable) };
/// assert_eq!(arr, [0, -4, -3, -2]);
/// ```
#[inline]
pub unsafe fn store_select_unchecked(
self,
slice: &mut [T],
enable: Mask<<T as SimdElement>::Mask, N>,
) {
let ptr = slice.as_mut_ptr();
// SAFETY: The safety of writing elements in `slice` is ensured by the caller.
unsafe { self.store_select_ptr(ptr, enable) }
}

/// Conditionally write contiguous elements starting from `ptr`.
/// The `enable` mask controls which elements are written.
/// When disabled, the memory location corresponding to that element is not accessed.
///
/// # Safety
///
/// Memory addresses for element are calculated [`core::ptr::wrapping_offset`] and
/// each enabled element must satisfy the same conditions as [`core::ptr::write`].
#[inline]
pub unsafe fn store_select_ptr(self, ptr: *mut T, enable: Mask<<T as SimdElement>::Mask, N>) {
// SAFETY: The safety of writing elements through `ptr` is ensured by the caller.
unsafe { core::intrinsics::simd::simd_masked_store(enable.to_int(), ptr, self) }
}

/// Writes the values in a SIMD vector to potentially discontiguous indices in `slice`.
/// If an index is out-of-bounds, the write is suppressed without panicking.
/// If two elements in the scattered vector would write to the same index
Expand Down Expand Up @@ -979,3 +1189,37 @@ where
{
type Mask = isize;
}

#[inline]
fn lane_indices<const N: usize>() -> Simd<usize, N>
where
LaneCount<N>: SupportedLaneCount,
{
let mut index = [0; N];
for i in 0..N {
index[i] = i;
}
Simd::from_array(index)
}

#[inline]
fn mask_up_to<M, const N: usize>(len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement,
{
let index = lane_indices::<N>();
let max_value: u64 = M::max_unsigned();
macro_rules! case {
($ty:ty) => {
if N < <$ty>::MAX as usize && max_value as $ty as u64 == max_value {
return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast();
}
};
}
case!(u8);
case!(u16);
case!(u32);
case!(u64);
index.simd_lt(Simd::splat(len)).cast()
}
Comment on lines +1193 to +1225
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a from_usize and doing something like this to avoid casts?

Suggested change
#[inline]
fn lane_indices<const N: usize>() -> Simd<usize, N>
where
LaneCount<N>: SupportedLaneCount,
{
let mut index = [0; N];
for i in 0..N {
index[i] = i;
}
Simd::from_array(index)
}
#[inline]
fn mask_up_to<M, const N: usize>(len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement,
{
let index = lane_indices::<N>();
let max_value: u64 = M::max_unsigned();
macro_rules! case {
($ty:ty) => {
if N < <$ty>::MAX as usize && max_value as $ty as u64 == max_value {
return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast();
}
};
}
case!(u8);
case!(u16);
case!(u32);
case!(u64);
index.simd_lt(Simd::splat(len)).cast()
}
#[inline]
fn lane_indices<M, const N: usize>() -> Simd<M, N>
where
M: MaskElement,
LaneCount<N>: SupportedLaneCount,
{
Simd::from_array(core::array::from_fn(M::from_usize))
}
#[inline]
fn mask_up_to<M, const N: usize>(len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement,
{
if N as u64 < M::max_unsigned() {
lane_indices::<M, N>().simd_lt(Simd::splat(M::from_usize(len)))
} else {
lane_indices::<usize, N>().simd_lt(Simd::splat(len)).cast()
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part of the case! logic is picking the minimum viable element size, since that's usually more efficient, so it would pick 16-bit elements for Simd<u8, 256> and 32-bit elements for Simd<u8, 65536> (yes, afaict that is theoretically possible in a single vector on RISC-V V).

your suggestion would just pick the mask's element size and if that isn't big enough, give up and use full usize elements even if 16-bit elements would suffice for all realistic vector types. Unfortunately LLVM doesn't seem to be able to optimize vectors to use a smaller element size unless you explicitly cast to that size before using simd_lt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also lane_indices::<M, N>().simd_lt(Simd::splat(M::from_usize(len))) is just plain incorrect, if M is i8 and len is 256 then it will wrap around (or panic if from_usize checks) and act as if len == 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the usual case (vectors with well under 256 elements) codegen is likely to be better without a cast? How about keeping as is, but trying to use M first if it fits?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe? the casts optimized away completely when I tried...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I've had bad luck with casts, but that's usually with more complex operations. If we think the casts won't matter it's fine with me

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part of the case! logic is picking the minimum viable element size, since that's usually more efficient, so it would pick 16-bit elements for Simd<u8, 256> and 32-bit elements for Simd<u8, 65536> (yes, afaict that is theoretically possible in a single vector on RISC-V V).

Would it be sufficient to consider only value of N when determine the element size? i.e. use 8-bit elements for Simd<u64,16>

e.g.

    macro_rules! case {
        ($ty:ty) => {
            if N < <$ty>::MAX as usize {
                return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast();
            }
        };
    }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be sufficient to consider only value of N when determine the element size? i.e. use 8-bit elements for Simd<u64,16>

no, because iirc LLVM doesn't optimize-out the element-size conversions then.

35 changes: 35 additions & 0 deletions crates/core_simd/tests/masked_load_store.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#![feature(portable_simd)]
use core_simd::simd::prelude::*;

#[cfg(target_arch = "wasm32")]
use wasm_bindgen_test::*;

#[cfg(target_arch = "wasm32")]
wasm_bindgen_test_configure!(run_in_browser);

#[test]
#[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
fn masked_load_store() {
let mut arr = [u8::MAX; 7];

u8x4::splat(0).store_select(&mut arr[5..], Mask::from_array([false, true, false, true]));
// write to index 8 is OOB and dropped
assert_eq!(arr, [255u8, 255, 255, 255, 255, 255, 0]);

u8x4::from_array([0, 1, 2, 3]).store_select(&mut arr[1..], Mask::splat(true));
assert_eq!(arr, [255u8, 0, 1, 2, 3, 255, 0]);

// read from index 8 is OOB and dropped
assert_eq!(
u8x4::load_or(&arr[4..], u8x4::splat(42)),
u8x4::from_array([3, 255, 0, 42])
);
assert_eq!(
u8x4::load_select(
&arr[4..],
Mask::from_array([true, false, true, true]),
u8x4::splat(42)
),
u8x4::from_array([3, 42, 0, 42])
);
}
Loading