From 9a131c563ec2b1d8ebb4cb93513b1dafc2bb56e6 Mon Sep 17 00:00:00 2001 From: Kornel Date: Tue, 30 Jul 2024 18:43:12 +0100 Subject: [PATCH] Make BlobReader a slice newtype --- recapn/src/data.rs | 19 +++++------------- recapn/src/ptr.rs | 48 ++++++++++++++++++++++++++++++---------------- recapn/src/text.rs | 25 +++++++++++++++--------- 3 files changed, 52 insertions(+), 40 deletions(-) diff --git a/recapn/src/data.rs b/recapn/src/data.rs index 076d03c..4476cda 100644 --- a/recapn/src/data.rs +++ b/recapn/src/data.rs @@ -59,16 +59,7 @@ impl<'a> Reader<'a> { /// If the slice is too large to be in a message, this function panics. #[inline] pub const fn from_slice(slice: &'a [u8]) -> Self { - let len = slice.len(); - if len > ElementCount::MAX_VALUE as usize { - panic!("slice is too large to be contained within a cap'n proto message") - } - - let count = ElementCount::new(len as u32).unwrap(); - unsafe { - let ptr = NonNull::new_unchecked(slice.as_ptr().cast_mut()); - Self(ptr::Reader::new(ptr, count)) - } + Self(ptr::Reader::new(slice).expect("slice is too large to be contained within a cap'n proto message")) } #[inline] @@ -83,9 +74,7 @@ impl<'a> Reader<'a> { #[inline] pub const fn as_slice(&self) -> &'a [u8] { - let data = self.0.data().as_ptr().cast_const(); - let len = self.len() as usize; - unsafe { core::slice::from_raw_parts(data, len) } + self.0.as_slice() } } @@ -164,7 +153,9 @@ impl<'a> Builder<'a> { #[inline] pub fn as_reader<'b>(&'b self) -> Reader<'b> { - Data(ptr::Reader::new(self.0.data(), self.0.len())) + Data(unsafe { + ptr::Reader::new_unchecked(self.0.data(), self.0.len()) + }) } #[inline] diff --git a/recapn/src/ptr.rs b/recapn/src/ptr.rs index 595c830..66480ea 100644 --- a/recapn/src/ptr.rs +++ b/recapn/src/ptr.rs @@ -2256,7 +2256,9 @@ impl<'a, T: Table> PtrReader<'a, T> { } }?; - Ok(BlobReader::new(ptr.as_inner().cast(), element_count)) + Ok(unsafe { + BlobReader::new_unchecked(ptr.as_inner().cast(), element_count) + }) } #[inline] @@ -2915,32 +2917,43 @@ impl<'a, T: Table> Capable for ListReader<'a, T> { #[derive(Clone, Copy)] pub struct BlobReader<'a> { - a: PhantomData<&'a [u8]>, - ptr: NonNull, - len: ElementCount, + slice: &'a [u8], } impl<'a> BlobReader<'a> { - pub(crate) const fn new(ptr: NonNull, len: ElementCount) -> Self { - Self { a: PhantomData, ptr, len } + #[inline] + pub(crate) const fn new(slice: &'a [u8]) -> Option { + if slice.len() < ElementCount::MAX_VALUE as usize { + Some(Self { slice }) + } else { + None + } + } + + pub(crate) const unsafe fn new_unchecked(ptr: NonNull, len: ElementCount) -> Self { + Self { + slice: std::slice::from_raw_parts(ptr.as_ptr(), len.get() as usize), + } } pub const fn empty() -> Self { - Self::new(NonNull::dangling(), ElementCount::ZERO) + Self { slice: &[] } } pub const fn data(&self) -> NonNull { - self.ptr + unsafe { + NonNull::new_unchecked(self.slice.as_ptr().cast_mut()) + } } pub const fn len(&self) -> ElementCount { - self.len + unsafe { + ElementCount::new_unchecked(self.slice.len() as _) + } } pub const fn as_slice(&self) -> &'a [u8] { - unsafe { - core::slice::from_raw_parts(self.ptr.as_ptr().cast_const(), self.len.get() as usize) - } + self.slice } } @@ -5268,18 +5281,19 @@ impl BlobBuilder<'_> { #[inline] pub const fn as_reader(&self) -> BlobReader { - BlobReader::new(self.ptr, self.len) + unsafe { + BlobReader::new_unchecked(self.ptr, self.len) + } } #[inline] pub(crate) fn copy_from(&mut self, other: BlobReader) { - assert_eq!(self.len, other.len); + assert_eq!(self.len, other.len()); let dst = self.ptr.as_ptr(); - let src = other.ptr.as_ptr().cast_const(); - let len = other.len.get() as usize; + let src = other.as_slice(); unsafe { - ptr::copy_nonoverlapping(src, dst, len) + ptr::copy_nonoverlapping(src.as_ptr(), dst, src.len()) } } diff --git a/recapn/src/text.rs b/recapn/src/text.rs index 0419c66..057dd38 100644 --- a/recapn/src/text.rs +++ b/recapn/src/text.rs @@ -72,13 +72,15 @@ impl<'a> Reader<'a> { #[inline] pub const fn from_slice(s: &'a [u8]) -> Self { match s { - [.., 0] if s.len() <= ByteCount::MAX_VALUE as usize => { - let ptr = unsafe { NonNull::new_unchecked(s.as_ptr().cast_mut()) }; - let len = ElementCount::new(s.len() as u32).unwrap(); - Self(ptr::Reader::new(ptr, len)) + [.., 0] => { + match ptr::Reader::new(s) { + Some(r) => Some(Self(r)), + None => None, + } }, - _ => panic!("attempted to make invalid text blob from slice"), + _ => None, } + .expect("attempted to make invalid text blob from slice") } pub const fn byte_count(&self) -> ByteCount { @@ -99,14 +101,19 @@ impl<'a> Reader<'a> { /// Returns the bytes of the text field without the null terminator #[inline] pub const fn as_bytes(&self) -> &'a [u8] { - let (_, remainder) = self.as_bytes_with_nul().split_last().unwrap(); - remainder + match self.as_bytes_with_nul().split_last() { + Some((_, remainder)) => remainder, + _ => { + debug_assert!(false, "this shouldn't happen, it's to avoid panic code in release"); + EMPTY_SLICE + }, + } } /// Returns the bytes of the text field with the null terminator #[inline] pub const fn as_bytes_with_nul(&self) -> &'a [u8] { - unsafe { slice::from_raw_parts(self.0.data().as_ptr().cast_const(), self.len() as usize) } + self.0.as_slice() } #[inline] @@ -254,4 +261,4 @@ impl PartialEq> for str { fn eq(&self, other: &Builder<'_>) -> bool { self.as_bytes() == other.as_bytes() } -} \ No newline at end of file +}