Skip to content

Commit

Permalink
feat(binary): implement fixed size binary operations
Browse files Browse the repository at this point in the history
  • Loading branch information
f4t4nt committed Jan 14, 2025
1 parent 4a695b8 commit 82660c8
Show file tree
Hide file tree
Showing 9 changed files with 959 additions and 43 deletions.
179 changes: 176 additions & 3 deletions src/daft-core/src/array/ops/binary.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
use std::iter;

use arrow2::bitmap::utils::{BitmapIter, ZipValidity};
use common_error::{DaftError, DaftResult};
use num_traits::Zero;

use crate::{
array::ops::as_arrow::AsArrow,
datatypes::{BinaryArray, DaftIntegerType, DaftNumericType, DataArray, UInt64Array},
datatypes::{
BinaryArray, DaftIntegerType, DaftNumericType, DataArray, FixedSizeBinaryArray, UInt64Array,
},
};

enum BroadcastedBinaryIter<'a> {
Expand All @@ -18,6 +22,22 @@ enum BroadcastedBinaryIter<'a> {
),
}

enum BroadcastedFixedSizeBinaryIter<'a> {
Repeat(std::iter::Take<std::iter::Repeat<Option<&'a [u8]>>>),
NonRepeat(ZipValidity<&'a [u8], std::slice::ChunksExact<'a, u8>, BitmapIter<'a>>),
}

impl<'a> Iterator for BroadcastedFixedSizeBinaryIter<'a> {
type Item = Option<&'a [u8]>;

fn next(&mut self) -> Option<Self::Item> {
match self {
BroadcastedFixedSizeBinaryIter::Repeat(iter) => iter.next(),
BroadcastedFixedSizeBinaryIter::NonRepeat(iter) => iter.next(),
}
}
}

impl<'a> Iterator for BroadcastedBinaryIter<'a> {
type Item = Option<&'a [u8]>;

Expand All @@ -42,8 +62,8 @@ fn create_broadcasted_numeric_iter<'a, T, I>(
len: usize,
) -> Box<dyn Iterator<Item = DaftResult<Option<I>>> + 'a>
where
T: DaftNumericType,
T::Native: TryInto<I>,
T: DaftIntegerType,
T::Native: TryInto<I> + Ord + Zero,
{
if arr.len() == 1 {
let val = arr.as_arrow().iter().next().unwrap();
Expand All @@ -52,6 +72,11 @@ where
.take(len)
.map(|x| -> DaftResult<Option<I>> {
x.map(|x| {
if *x < T::Native::zero() {
return Err(DaftError::ComputeError(
"Start index must be non-negative".to_string(),
));
}
(*x).try_into().map_err(|_| {
DaftError::ComputeError(
"Error in slice: failed to cast value".to_string(),
Expand All @@ -64,6 +89,11 @@ where
} else {
Box::new(arr.as_arrow().iter().map(|x| -> DaftResult<Option<I>> {
x.map(|x| {
if *x < T::Native::zero() {
return Err(DaftError::ComputeError(
"Start index must be non-negative".to_string(),
));
}
(*x).try_into().map_err(|_| {
DaftError::ComputeError("Error in slice: failed to cast value".to_string())
})
Expand Down Expand Up @@ -178,3 +208,146 @@ impl BinaryArray {
Ok(Self::from((self.name(), Box::new(builder.into()))))
}
}

impl FixedSizeBinaryArray {
pub fn length(&self) -> DaftResult<UInt64Array> {
let self_arrow = self.as_arrow();
let size = self_arrow.size();
let arrow_result = arrow2::array::UInt64Array::from_iter(
iter::repeat(Some(size as u64)).take(self_arrow.len()),
)
.with_validity(self_arrow.validity().cloned());
Ok(UInt64Array::from((self.name(), Box::new(arrow_result))))
}

fn create_broadcasted_iter(&self, len: usize) -> BroadcastedFixedSizeBinaryIter<'_> {
let self_arrow = self.as_arrow();
if self_arrow.len() == 1 {
BroadcastedFixedSizeBinaryIter::Repeat(iter::repeat(self_arrow.get(0)).take(len))
} else {
BroadcastedFixedSizeBinaryIter::NonRepeat(self_arrow.iter())
}
}

pub fn binary_slice<I, J>(
&self,
start: &DataArray<I>,
length: Option<&DataArray<J>>,
) -> DaftResult<BinaryArray>
where
I: DaftIntegerType,
<I as DaftNumericType>::Native: Ord + TryInto<usize>,
J: DaftIntegerType,
<J as DaftNumericType>::Native: Ord + TryInto<usize>,
{
let self_arrow = self.as_arrow();
let output_len = if self_arrow.len() == 1 {
std::cmp::max(start.len(), length.map_or(1, |l| l.len()))
} else {
self_arrow.len()
};

let self_iter = self.create_broadcasted_iter(output_len);
let start_iter = create_broadcasted_numeric_iter::<I, usize>(start, output_len);
let length_iter = match length {
Some(length) => create_broadcasted_numeric_iter::<J, usize>(length, output_len),
None => Box::new(iter::repeat_with(|| Ok(None))),
};

let mut builder = arrow2::array::MutableBinaryArray::<i64>::new();
let mut validity = arrow2::bitmap::MutableBitmap::new();

for ((val, start), length) in self_iter.zip(start_iter).zip(length_iter) {
match (val, start?, length?) {
(Some(val), Some(start), Some(length)) => {
if start >= val.len() || length == 0 {
builder.push::<&[u8]>(None);
validity.push(false);
} else {
let end = (start + length).min(val.len());
let slice = &val[start..end];
builder.push(Some(slice));
validity.push(true);
}
}
(Some(val), Some(start), None) => {
if start >= val.len() {
builder.push::<&[u8]>(None);
validity.push(false);
} else {
let slice = &val[start..];
builder.push(Some(slice));
validity.push(true);
}
}
_ => {
builder.push::<&[u8]>(None);
validity.push(false);
}
}
}

Ok(BinaryArray::from((self.name(), Box::new(builder.into()))))
}

pub fn binary_concat(
&self,
other: &FixedSizeBinaryArray,
) -> std::result::Result<FixedSizeBinaryArray, DaftError> {
let self_arrow = self.as_arrow();
let other_arrow = other.as_arrow();
let self_size = self_arrow.size();
let other_size = other_arrow.size();
let combined_size = self_size + other_size;

// Create a new FixedSizeBinaryArray with the combined size
let mut values = Vec::with_capacity(self_arrow.len() * combined_size);
let mut validity = arrow2::bitmap::MutableBitmap::new();

let self_iter = self_arrow.iter();
let other_iter = other_arrow.iter();

for (val1, val2) in self_iter.zip(other_iter) {
match (val1, val2) {
(Some(val1), Some(val2)) => {
values.extend_from_slice(val1);
values.extend_from_slice(val2);
validity.push(true);
}
_ => {
values.extend(std::iter::repeat(0u8).take(combined_size));
validity.push(false);
}
}
}

// Create a new FixedSizeBinaryArray with the combined size
let result = arrow2::array::FixedSizeBinaryArray::try_new(
arrow2::datatypes::DataType::FixedSizeBinary(combined_size),
values.into(),
Some(validity.into()),
)?;

Ok(FixedSizeBinaryArray::from((self.name(), Box::new(result))))
}

pub fn into_binary(&self) -> std::result::Result<BinaryArray, DaftError> {
let mut builder = arrow2::array::MutableBinaryArray::<i64>::new();
let mut validity = arrow2::bitmap::MutableBitmap::new();

for val in self.as_arrow() {
match val {
Some(val) => {
builder.push(Some(val));
validity.push(true);
}
None => {
builder.push::<&[u8]>(None);
validity.push(false);
}
}
}

Ok(BinaryArray::from((self.name(), Box::new(builder.into()))))
}
}
68 changes: 51 additions & 17 deletions src/daft-core/src/series/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ impl Series {
) -> DaftResult<Self> {
match self.data_type() {
DataType::Binary => f(self.binary()?),
DataType::FixedSizeBinary(_) => Err(DaftError::TypeError(format!(
"Operation not implemented for type {}",
self.data_type()
))),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Operation not implemented for type {dt}"
Expand All @@ -21,29 +25,59 @@ impl Series {
}

pub fn binary_length(&self) -> DaftResult<Self> {
self.with_binary_array(|arr| Ok(arr.length()?.into_series()))
match self.data_type() {
DataType::Binary => self.with_binary_array(|arr| Ok(arr.length()?.into_series())),
DataType::FixedSizeBinary(_) => Ok(self.fixed_size_binary()?.length()?.into_series()),
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Operation not implemented for type {dt}"
))),
}
}

pub fn binary_concat(&self, other: &Self) -> DaftResult<Self> {
self.with_binary_array(|arr| Ok(arr.binary_concat(other.binary()?)?.into_series()))
}

pub fn binary_slice(&self, start: &Self, length: &Self) -> DaftResult<Self> {
self.with_binary_array(|arr| {
with_match_integer_daft_types!(start.data_type(), |$T| {
if length.data_type().is_integer() {
with_match_integer_daft_types!(length.data_type(), |$U| {
Ok(arr.binary_slice(start.downcast::<<$T as DaftDataType>::ArrayType>()?, Some(length.downcast::<<$U as DaftDataType>::ArrayType>()?))?.into_series())
})
} else if length.data_type().is_null() {
Ok(arr.binary_slice(start.downcast::<<$T as DaftDataType>::ArrayType>()?, None::<&DataArray<Int8Type>>)?.into_series())
} else {
Err(DaftError::TypeError(format!(
"slice not implemented for length type {}",
length.data_type()
)))
}
})
})
match self.data_type() {
DataType::Binary => self.with_binary_array(|arr| {
with_match_integer_daft_types!(start.data_type(), |$T| {
if length.data_type().is_integer() {
with_match_integer_daft_types!(length.data_type(), |$U| {
Ok(arr.binary_slice(start.downcast::<<$T as DaftDataType>::ArrayType>()?, Some(length.downcast::<<$U as DaftDataType>::ArrayType>()?))?.into_series())
})
} else if length.data_type().is_null() {
Ok(arr.binary_slice(start.downcast::<<$T as DaftDataType>::ArrayType>()?, None::<&DataArray<Int8Type>>)?.into_series())
} else {
Err(DaftError::TypeError(format!(
"slice not implemented for length type {}",
length.data_type()
)))
}
})
}),
DataType::FixedSizeBinary(_) => {
let fixed_arr = self.fixed_size_binary()?;
with_match_integer_daft_types!(start.data_type(), |$T| {
if length.data_type().is_integer() {
with_match_integer_daft_types!(length.data_type(), |$U| {
Ok(fixed_arr.binary_slice(start.downcast::<<$T as DaftDataType>::ArrayType>()?, Some(length.downcast::<<$U as DaftDataType>::ArrayType>()?))?.into_series())
})
} else if length.data_type().is_null() {
Ok(fixed_arr.binary_slice(start.downcast::<<$T as DaftDataType>::ArrayType>()?, None::<&DataArray<Int8Type>>)?.into_series())
} else {
Err(DaftError::TypeError(format!(
"slice not implemented for length type {}",
length.data_type()
)))
}
})
}
DataType::Null => Ok(self.clone()),
dt => Err(DaftError::TypeError(format!(
"Operation not implemented for type {dt}"
))),
}
}
}
Loading

0 comments on commit 82660c8

Please sign in to comment.