diff --git a/src/array.rs b/src/array.rs index 599f40b..8abe90d 100644 --- a/src/array.rs +++ b/src/array.rs @@ -1,8 +1,10 @@ use pyo3::exceptions::{PyIndexError, PyTypeError, PyValueError}; use pyo3::prelude::*; use zarrs::array::{Array as RustArray}; +use zarrs::array_subset::ArraySubset; use zarrs::storage::ReadableStorageTraits; -use pyo3::types::PySlice; +use pyo3::types::{PyInt, PyList, PySlice}; +use std::ops::Range; #[pyclass] pub struct Array { @@ -11,38 +13,54 @@ pub struct Array { impl Array { - fn bound_slice(&self, slice: &Bound) -> PyResult> { - let start: i32 = slice.getattr("start")?.extract().map_or(0, |x| x); - let mut start_u64: u64 = start as u64; - if start < 0 { - if self.arr.shape()[0] as i32 + start < 0 { - return Err(PyIndexError::new_err(format!("{0} out of bounds", start))) - } - start_u64 = u64::try_from(start).map_err(|_| PyErr::new::("Failed to extract start"))?; - } - let stop: i32 = slice.getattr("stop")?.extract().map_or(self.arr.shape()[0] as i32, |x| x); - let mut stop_u64: u64 = stop as u64; - if stop < 0 { - if self.arr.shape()[0] as i32 + stop < 0 { - return Err(PyIndexError::new_err(format!("{0} out of bounds", stop))) + fn maybe_convert_u64(&self, ind: i32, axis: usize) -> PyResult { + let mut ind_u64: u64 = ind as u64; + if ind < 0 { + if self.arr.shape()[axis] as i32 + ind < 0 { + return Err(PyIndexError::new_err(format!("{0} out of bounds", ind))) } - stop_u64 = u64::try_from(stop).map_err(|_| PyErr::new::("Failed to extract stop"))?; + ind_u64 = u64::try_from(ind).map_err(|_| PyIndexError::new_err("Failed to extract start"))?; } - let _step: u64 = slice.getattr("step")?.extract().map_or(1, |x| x); - let selection: Vec = (start_u64..stop_u64).step_by(_step.try_into().unwrap()).collect(); + return Ok(ind_u64); + } + + fn bound_slice(&self, slice: &Bound) -> PyResult> { + let start: i32 = slice.getattr("start")?.extract().map_or(0, |x| x); + let stop: i32 = slice.getattr("stop")?.extract().map_or(0, |x| x); + let start_u64 = self.maybe_convert_u64(start, 0)?; + let stop_u64 = self.maybe_convert_u64(stop, 0)?; + // let _step: u64 = slice.getattr("step")?.extract().map_or(1, |x| x); // there is no way to use step it seems with zarrs? + let selection = start_u64..stop_u64; return Ok(selection) } + + fn fill_from_slices(&self, slices: Vec>) -> PyResult>> { + Ok(self.arr.shape().iter().enumerate().map(|(index, &value)| { if index < slices.len() { slices[index].clone() } else { 0..value } }).collect()) + } } #[pymethods] impl Array { pub fn __getitem__(&self, key: &Bound<'_, PyAny>) -> PyResult> { + let selection: ArraySubset; if let Ok(slice) = key.downcast::() { - let selection = self.bound_slice(slice)?; - return self.arr.retrieve_chunk(&selection[..]).map_err(|x| PyErr::new::(x.to_string())); + selection = ArraySubset::new_with_ranges(&self.fill_from_slices(vec![self.bound_slice(slice)?])?); + } else if let Ok(list) = key.downcast::(){ + let ranges: Vec> = list.into_iter().enumerate().map(|(index, val)| { + if let Ok(int) = val.downcast::() { + let end = self.maybe_convert_u64(int.extract()?, index)?; + Ok(end..(end + 1)) + } else if let Ok(slice) = val.downcast::() { + Ok(self.bound_slice(slice)?) + } else { + return Err(PyValueError::new_err(format!("Cannot take {0}, must be int or slice", val.to_string()))); + } + }).collect::>, _>>()?; + selection = ArraySubset::new_with_ranges(&self.fill_from_slices(ranges)?); } else { return Err(PyTypeError::new_err("Unsupported type")); } + return self.arr.retrieve_chunks(&selection).map_err(|x| PyErr::new::(x.to_string())); } }