Skip to content

Commit

Permalink
(refactor): handle None fill-value more gracefully
Browse files Browse the repository at this point in the history
  • Loading branch information
ilan-gold committed Feb 2, 2025
1 parent 4a59ec1 commit 45efee1
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 16 deletions.
19 changes: 8 additions & 11 deletions python/zarrs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,11 @@
from .utils import (
CollapsedDimensionError,
DiscontiguousArrayError,
FillValueNoneError,
make_chunk_info_for_rust_with_indices,
)


class FillValueNoneError(Exception):
pass


class UnsupportedDataTypeError(Exception):
pass

Expand Down Expand Up @@ -184,7 +181,7 @@ async def read(
if not out.dtype.isnative:
raise RuntimeError("Non-native byte order not supported")
try:
self._raise_error_on_batch_info_error(batch_info)
self._raise_error_on_unsupported_batch_dtype(batch_info)
chunks_desc = make_chunk_info_for_rust_with_indices(
batch_info, drop_axes, out.shape
)
Expand Down Expand Up @@ -214,7 +211,7 @@ async def write(
drop_axes: tuple[int, ...] = (),
) -> None:
try:
self._raise_error_on_batch_info_error(batch_info)
self._raise_error_on_unsupported_batch_dtype(batch_info)
chunks_desc = make_chunk_info_for_rust_with_indices(
batch_info, drop_axes, value.shape
)
Expand All @@ -240,15 +237,15 @@ async def write(
)
return None

def _raise_error_on_batch_info_error(
def _raise_error_on_unsupported_batch_dtype(
self,
batch_info: Iterable[
tuple[ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]
],
):
# https://github.com/LDeakin/zarrs/blob/0532fe983b7b42b59dbf84e50a2fe5e6f7bad4ce/zarrs_metadata/src/v2_to_v3.rs#L289-L293
if any(info.dtype.kind in {"V", "S"} for (_, info, _, _) in batch_info):
if any(
info.dtype.kind in {"V", "S", "U", "M", "m"}
for (_, info, _, _) in batch_info
):
raise UnsupportedDataTypeError()
# TODO: is there some sort of default documented somewhere?
if any(info.fill_value is None for (_, info, _, _) in batch_info):
raise FillValueNoneError()
37 changes: 34 additions & 3 deletions python/zarrs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
import operator
import os
from functools import reduce
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import numpy as np
from zarr.core.array_spec import ArraySpec
from zarr.core.indexing import SelectorTuple, is_integer
from zarr.core.metadata.v3 import DataType, parse_fill_value

from zarrs._internal import Basic, WithSubset

Expand All @@ -15,7 +17,6 @@
from types import EllipsisType

from zarr.abc.store import ByteGetter, ByteSetter
from zarr.core.array_spec import ArraySpec


# adapted from https://docs.python.org/3/library/concurrent.futures.html#concurrent.futures.ThreadPoolExecutor
Expand All @@ -31,6 +32,10 @@ class CollapsedDimensionError(Exception):
pass


class FillValueNoneError(Exception):
pass


# This is a (mostly) copy of the function from zarr.core.indexing that fixes:
# DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated
# TODO: Upstream this fix
Expand Down Expand Up @@ -134,6 +139,24 @@ def get_shape_for_selector(
return resulting_shape_from_index(shape, selector_tuple, drop_axes, pad=pad)


def get_implicit_fill_value(dtype: np.dtype, fill_value: Any):
if fill_value is not None:
return fill_value
dtype_str = str(dtype)
if dtype_str == "bool":
fill_value = False
elif np.issubdtype(dtype, np.integer):
fill_value = 0
elif np.issubdtype(dtype, np.floating):
fill_value = 0.0
elif dtype_str == "object":
# v2 object dtype used 0 as a sentinel value for fill values to actually indicate ""
return 0
else:
raise FillValueNoneError()
return parse_fill_value(fill_value, DataType.parse(dtype))


def make_chunk_info_for_rust_with_indices(
batch_info: Iterable[
tuple[ByteGetter | ByteSetter, ArraySpec, SelectorTuple, SelectorTuple]
Expand All @@ -144,6 +167,14 @@ def make_chunk_info_for_rust_with_indices(
shape = shape if shape else (1,) # constant array
chunk_info_with_indices: list[WithSubset] = []
for byte_getter, chunk_spec, chunk_selection, out_selection in batch_info:
if chunk_spec.fill_value is None:
chunk_spec = ArraySpec(
chunk_spec.shape,
chunk_spec.dtype,
get_implicit_fill_value(chunk_spec.dtype, chunk_spec.fill_value),
chunk_spec.config,
chunk_spec.prototype,
)
chunk_info = Basic(byte_getter, chunk_spec)
out_selection_as_slices = selector_tuple_to_slice_selection(out_selection)
chunk_selection_as_slices = selector_tuple_to_slice_selection(chunk_selection)
Expand All @@ -168,4 +199,4 @@ def make_chunk_info_for_rust_with_indices(
shape=shape,
)
)
return chunk_info_with_indices
return chunk_info_with_indices
4 changes: 2 additions & 2 deletions src/chunk_item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ impl Basic {
let fill_value_bytes: Vec<u8>;
if let Ok(fill_value_downcast) = fill_value.downcast::<PyBytes>() {
fill_value_bytes = fill_value_downcast.as_bytes().to_vec();
} else if fill_value.hasattr("tobytes")? {
fill_value_bytes = fill_value.call_method0("tobytes")?.extract()?;
} else if let Ok(fill_value_downcast) = fill_value.downcast::<PyInt>() {
let fill_value_usize: usize = fill_value_downcast.extract()?;
if fill_value_usize == (0 as usize) && dtype == "object" {
Expand All @@ -62,8 +64,6 @@ impl Basic {
fill_value_usize, dtype
)));
}
} else if fill_value.hasattr("tobytes")? {
fill_value_bytes = fill_value.call_method0("tobytes")?.extract()?;
} else {
return Err(PyErr::new::<PyValueError, _>(format!(
"Unsupported fill value {:?}",
Expand Down

0 comments on commit 45efee1

Please sign in to comment.