Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for masked loads & stores #399

Merged
merged 1 commit into from
Mar 14, 2024

Conversation

farnoy
Copy link
Contributor

@farnoy farnoy commented Feb 29, 2024

Continuation of #374, but simplified.

  1. Got rid of the optimization variants USE_BRANCH, USE_BITMASK
    1. We don't do a branch to check if all lanes are enabled and we don't calculate bitmasks with integer code
    2. This won't be the most optimal approach until LLVM starts to optimize these further
  2. Added documentation for load operations. Let me know if this complies with established norms and I'll document stores as well
  3. Simplified public bounds to bare essentials

@farnoy farnoy force-pushed the masked-load-store-simple branch from eccbf0f to 3692383 Compare February 29, 2024 16:33
@calebzulawski
Copy link
Member

Can you add to the documentation for from_slice that you can use load_or or load_or_default for a non-panicking version?

@farnoy
Copy link
Contributor Author

farnoy commented Mar 1, 2024

Not sure why the new doctests are failing. The errors about store functions being undocumented are valid and I've yet to add them.

{
let index = lane_indices::<i8, N>();
let lt = index.simd_lt(Simd::splat(i8::try_from(len).unwrap_or(i8::MAX)));
Mask::<M, N>::from_bitmask_vector(lt.to_bitmask_vector())
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it has been considered. any idea to avoid calling lane_indices (and no need to run the for loop inside the function each time the mask_up_to is called), by :

fn mask_up_to<...>() -> Mask<M,N> ... {
    Mask::<M, N>::from_bitmask((1_u64 << len) - 1)
}

pls let me know if I misunderstood the function :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The previous version of this PR had the option to use this exact bitmask and indeed, it tends to produce better performing code #374 (comment)

Plus, it's just simpler, so I'm in favor. Any objections @calebzulawski? This would not require adding any extra trait bounds ever since you removed ToBitMask.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and it would have to be this to prevent an overflow on the shift:

fn mask_up_to<...>() -> Mask<M,N> ... {
    if len >= 64 {
        Mask::<M, N>::splat(true)
    } else {
        Mask::<M, N>::from_bitmask((1_u64 << len) - 1)
    }
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be...if desirable to avoid the literal "64"

 if len >= u64::BITS as usize { ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to condition that on x86, because that will be worse on other architectures. Also, you should condition that on the lane count being no more than 64, because we will support larger in the future. Alternatively, use a bitmask vector which works for larger than 64.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any comment if giving shortcuts when len not less than lane count?

fn mask_up_to<...>() -> Mask<M,N> ... {
    if len >= N {
        Mask::<M,N>::splat(true)
    }else{
        let index = lane_indices::<i8, N>();
        let lt = index.simd_lt(Simd::splat(i8::try_from(len).unwrap_or(i8::MAX)));
        // .... Mask::<M, N>::from_bitmask_vector(lt.to_bitmask_vector())
    }
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the comparison based version also generates better code on non-avx512 machines: https://rust.godbolt.org/z/6edK1rx9M

Regarding the initial question, since the length parameter is a const generic we would expect llvm to always precompute the mask vector, so at runtime it is a single vector load.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're at risk of having the same conversation as in the previous PR, which we wanted to avoid with this spin-off PR.

Let's go with the bitmask vector implementation, which is more generic. We can revisit this later, or LLVM can learn to optimize this better, or users who are performance-conscious today can use the unsafe functions to provide their own mask, calculated in a way that best fits their architecture and usecase.

@farnoy farnoy force-pushed the masked-load-store-simple branch 2 times, most recently from 83c4eea to 280aa8d Compare March 9, 2024 23:41
@farnoy
Copy link
Contributor Author

farnoy commented Mar 9, 2024

Should be green now. Please review

Copy link
Member

@calebzulawski calebzulawski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me other than these nits

@farnoy farnoy force-pushed the masked-load-store-simple branch from 280aa8d to f204737 Compare March 10, 2024 10:33
LaneCount<N>: SupportedLaneCount,
M: MaskElement,
{
let index = lane_indices::<i8, N>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be isize instead of i8, since we will eventually have types with more than 127 elements.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i expect LLVM to be able to optimize it back to the 8-bit per element version if you use len.min(N) so LLVM knows the splatted value is <= N

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, apparently LLVM is more terrible than I thought: https://rust.godbolt.org/z/vj75dsT51

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about using M directly? If type inference doesn't play nice, using intrinsics directly can probably get past it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doing it on M would require M: core::ops::TryFrom<usize> which would appear on the public interface, and then there's still no way to convert len: usize to M safely, nor is there a core::ops::Max that would give you a generic ::MAX constant. Am I missing anything?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@programmerjake I don't think I can add Sealed::Unsigned: Into<u64> because usize does not implement that trait.

error[E0277]: the trait bound `u64: From<usize>` is not satisfied
  --> crates\core_simd\src\masks.rs:96:24
   |
96 | impl_element! { isize, usize }
   |                        ^^^^^ the trait `From<usize>` is not implemented for `u64`, which is required by `usize: Into<u64>`
   |

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think where bounds are picked up when you use the trait, they only constrain it. I would just add functions to Sealed anyway, to avoid leaking bounds (that we can't remove in the future if we want to)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had that as u128 instead of u64, though like caleb suggests, a Sealed method may be best.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think where bounds are picked up when you use the trait, they only constrain it.

associated type bounds seem to be picked up: https://play.rust-lang.org/?version=stable&mode=debug&edition=2021&gist=4d88c89dd7ee65b133b92a3ccedca39d

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed the change. I barely understand it but I think it should work?

@farnoy farnoy force-pushed the masked-load-store-simple branch from 3ca5682 to 4f0ba1a Compare March 11, 2024 17:54
@farnoy
Copy link
Contributor Author

farnoy commented Mar 11, 2024

Squashed to clean up the history

Comment on lines +1193 to +1225
#[inline]
fn lane_indices<const N: usize>() -> Simd<usize, N>
where
LaneCount<N>: SupportedLaneCount,
{
let mut index = [0; N];
for i in 0..N {
index[i] = i;
}
Simd::from_array(index)
}

#[inline]
fn mask_up_to<M, const N: usize>(len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement,
{
let index = lane_indices::<N>();
let max_value: u64 = M::max_unsigned();
macro_rules! case {
($ty:ty) => {
if N < <$ty>::MAX as usize && max_value as $ty as u64 == max_value {
return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast();
}
};
}
case!(u8);
case!(u16);
case!(u32);
case!(u64);
index.simd_lt(Simd::splat(len)).cast()
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about adding a from_usize and doing something like this to avoid casts?

Suggested change
#[inline]
fn lane_indices<const N: usize>() -> Simd<usize, N>
where
LaneCount<N>: SupportedLaneCount,
{
let mut index = [0; N];
for i in 0..N {
index[i] = i;
}
Simd::from_array(index)
}
#[inline]
fn mask_up_to<M, const N: usize>(len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement,
{
let index = lane_indices::<N>();
let max_value: u64 = M::max_unsigned();
macro_rules! case {
($ty:ty) => {
if N < <$ty>::MAX as usize && max_value as $ty as u64 == max_value {
return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast();
}
};
}
case!(u8);
case!(u16);
case!(u32);
case!(u64);
index.simd_lt(Simd::splat(len)).cast()
}
#[inline]
fn lane_indices<M, const N: usize>() -> Simd<M, N>
where
M: MaskElement,
LaneCount<N>: SupportedLaneCount,
{
Simd::from_array(core::array::from_fn(M::from_usize))
}
#[inline]
fn mask_up_to<M, const N: usize>(len: usize) -> Mask<M, N>
where
LaneCount<N>: SupportedLaneCount,
M: MaskElement,
{
if N as u64 < M::max_unsigned() {
lane_indices::<M, N>().simd_lt(Simd::splat(M::from_usize(len)))
} else {
lane_indices::<usize, N>().simd_lt(Simd::splat(len)).cast()
}
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part of the case! logic is picking the minimum viable element size, since that's usually more efficient, so it would pick 16-bit elements for Simd<u8, 256> and 32-bit elements for Simd<u8, 65536> (yes, afaict that is theoretically possible in a single vector on RISC-V V).

your suggestion would just pick the mask's element size and if that isn't big enough, give up and use full usize elements even if 16-bit elements would suffice for all realistic vector types. Unfortunately LLVM doesn't seem to be able to optimize vectors to use a smaller element size unless you explicitly cast to that size before using simd_lt

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also lane_indices::<M, N>().simd_lt(Simd::splat(M::from_usize(len))) is just plain incorrect, if M is i8 and len is 256 then it will wrap around (or panic if from_usize checks) and act as if len == 0

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think for the usual case (vectors with well under 256 elements) codegen is likely to be better without a cast? How about keeping as is, but trying to use M first if it fits?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe? the casts optimized away completely when I tried...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, I've had bad luck with casts, but that's usually with more complex operations. If we think the casts won't matter it's fine with me

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

part of the case! logic is picking the minimum viable element size, since that's usually more efficient, so it would pick 16-bit elements for Simd<u8, 256> and 32-bit elements for Simd<u8, 65536> (yes, afaict that is theoretically possible in a single vector on RISC-V V).

Would it be sufficient to consider only value of N when determine the element size? i.e. use 8-bit elements for Simd<u64,16>

e.g.

    macro_rules! case {
        ($ty:ty) => {
            if N < <$ty>::MAX as usize {
                return index.cast().simd_lt(Simd::splat(len.min(N) as $ty)).cast();
            }
        };
    }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be sufficient to consider only value of N when determine the element size? i.e. use 8-bit elements for Simd<u64,16>

no, because iirc LLVM doesn't optimize-out the element-size conversions then.

@calebzulawski calebzulawski merged commit 50e8ae8 into rust-lang:master Mar 14, 2024
66 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants