From 19e90e323284ecccbe795f5711f8a35b712db4fd Mon Sep 17 00:00:00 2001 From: ilan-gold Date: Sun, 2 Feb 2025 14:05:10 +0100 Subject: [PATCH] (fix): `object` dtypes in rust --- python/zarrs/pipeline.py | 8 ++++---- src/chunk_item.rs | 18 ++++++++++++++++-- tests/test_v2.py | 4 ++-- 3 files changed, 22 insertions(+), 8 deletions(-) diff --git a/python/zarrs/pipeline.py b/python/zarrs/pipeline.py index c0c9e04..946d913 100644 --- a/python/zarrs/pipeline.py +++ b/python/zarrs/pipeline.py @@ -71,6 +71,8 @@ def codecs_to_dict(codecs: Iterable[Codec]) -> Generator[dict[str, Any], None, N }, } # TODO: get the endianness added to V2Codec API + # TODO: how to handle this with strings, which don't need this but zarrs + # complains about its absence if its not there yield BytesCodec().to_dict() else: yield codec.to_dict() @@ -220,10 +222,8 @@ def _raise_error_on_batch_info_error( tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple] ], ): - if any( - info.dtype in ["object"] or info.dtype.kind in {"V", "S"} - for (_, info, _, _) in batch_info - ): + # https://github.com/LDeakin/zarrs/blob/0532fe983b7b42b59dbf84e50a2fe5e6f7bad4ce/zarrs_metadata/src/v2_to_v3.rs#L289-L293 + if any(info.dtype.kind in {"V", "S"} for (_, info, _, _) in batch_info): raise UnsupportedDataTypeError() if any(info.fill_value is None for (_, info, _, _) in batch_info): raise FillValueNoneError() diff --git a/src/chunk_item.rs b/src/chunk_item.rs index 8ad36ad..8ead5fd 100644 --- a/src/chunk_item.rs +++ b/src/chunk_item.rs @@ -3,7 +3,7 @@ use std::num::NonZeroU64; use pyo3::{ exceptions::{PyRuntimeError, PyValueError}, pyclass, pymethods, - types::{PyAnyMethods as _, PyBytes, PyBytesMethods, PySlice, PySliceMethods as _}, + types::{PyAnyMethods, PyBytes, PyBytesMethods, PyInt, PySlice, PySliceMethods as _}, Bound, PyAny, PyErr, PyResult, }; use pyo3_stub_gen::derive::{gen_stub_pyclass, gen_stub_pymethods}; @@ -40,7 +40,7 @@ impl Basic { let path: String = byte_interface.getattr("path")?.extract()?; let chunk_shape = chunk_spec.getattr("shape")?.extract()?; - let dtype: String = chunk_spec + let mut dtype: String = chunk_spec .getattr("dtype")? .call_method0("__str__")? .extract()?; @@ -48,6 +48,20 @@ impl Basic { let fill_value_bytes: Vec; if let Ok(fill_value_downcast) = fill_value.downcast::() { fill_value_bytes = fill_value_downcast.as_bytes().to_vec(); + } else if let Ok(fill_value_downcast) = fill_value.downcast::() { + let fill_value_usize: usize = fill_value_downcast.extract()?; + if fill_value_usize == (0 as usize) && dtype == "object" { + // https://github.com/LDeakin/zarrs/pull/140 + fill_value_bytes = "".as_bytes().to_vec(); + // zarrs doesn't understand `object` which is the output of `np.dtype("|O").__str__()` + // but maps it to "string" internally https://github.com/LDeakin/zarrs/blob/0532fe983b7b42b59dbf84e50a2fe5e6f7bad4ce/zarrs_metadata/src/v2_to_v3.rs#L288 + dtype = String::from("string"); + } else { + return Err(PyErr::new::(format!( + "Cannot understand non-zero integer {:?} fill value for dtype {:?}", + fill_value_usize, dtype + ))); + } } else if fill_value.hasattr("tobytes")? { fill_value_bytes = fill_value.call_method0("tobytes")?.extract()?; } else { diff --git a/tests/test_v2.py b/tests/test_v2.py index b2f33ae..31748cc 100644 --- a/tests/test_v2.py +++ b/tests/test_v2.py @@ -157,8 +157,8 @@ def test_v2_encode_decode_with_data(dtype_value, tmp_path): @pytest.mark.parametrize("dtype", [str, "str"]) -async def test_create_dtype_str(dtype: Any) -> None: - arr = zarr.create(shape=3, dtype=dtype, zarr_format=2) +async def test_create_dtype_str(dtype: Any, tmp_path: Path) -> None: + arr = zarr.create(store=tmp_path, shape=3, dtype=dtype, zarr_format=2) assert arr.dtype.kind == "O" assert arr.metadata.to_dict()["dtype"] == "|O" assert arr.metadata.filters == (numcodecs.vlen.VLenBytes(),)