diff --git a/changes/2813.feature.rst b/changes/2813.feature.rst new file mode 100644 index 0000000000..8a28f75082 --- /dev/null +++ b/changes/2813.feature.rst @@ -0,0 +1 @@ +Add `zarr.testing.strategies.array_metadata` to generate ArrayV2Metadata and ArrayV3Metadata instances. diff --git a/changes/2817.bugfix.rst b/changes/2817.bugfix.rst new file mode 100644 index 0000000000..b1c0fa9220 --- /dev/null +++ b/changes/2817.bugfix.rst @@ -0,0 +1 @@ +Fix fancy indexing (e.g. arr[5, [0, 1]]) with the sharding codec \ No newline at end of file diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index 459805d808..42b1313fac 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -531,7 +531,11 @@ async def _decode_partial_single( ], out, ) - return out + + if hasattr(indexer, "sel_shape"): + return out.reshape(indexer.sel_shape) + else: + return out async def _encode_single( self, diff --git a/src/zarr/testing/strategies.py b/src/zarr/testing/strategies.py index 0883d79bf0..84edb04c83 100644 --- a/src/zarr/testing/strategies.py +++ b/src/zarr/testing/strategies.py @@ -1,5 +1,5 @@ import sys -from typing import Any +from typing import Any, Literal import hypothesis.extra.numpy as npst import hypothesis.strategies as st @@ -8,9 +8,13 @@ from hypothesis.strategies import SearchStrategy import zarr -from zarr.abc.store import RangeByteRequest +from zarr.abc.store import RangeByteRequest, Store +from zarr.codecs.bytes import BytesCodec from zarr.core.array import Array +from zarr.core.chunk_grids import RegularChunkGrid +from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding from zarr.core.common import ZarrFormat +from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.sync import sync from zarr.storage import MemoryStore, StoreLike from zarr.storage._common import _dereference_path @@ -67,6 +71,11 @@ def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]: ) +def clear_store(x: Store) -> Store: + sync(x.clear()) + return x + + # From https://zarr-specs.readthedocs.io/en/latest/v3/core/v3.0.html#node-names # 1. must not be the empty string ("") # 2. must not include the character "/" @@ -85,12 +94,59 @@ def safe_unicode_for_dtype(dtype: np.dtype[np.str_]) -> st.SearchStrategy[str]: # st.builds will only call a new store constructor for different keyword arguments # i.e. stores.examples() will always return the same object per Store class. # So we map a clear to reset the store. -stores = st.builds(MemoryStore, st.just({})).map(lambda x: sync(x.clear())) +stores = st.builds(MemoryStore, st.just({})).map(clear_store) compressors = st.sampled_from([None, "default"]) zarr_formats: st.SearchStrategy[ZarrFormat] = st.sampled_from([2, 3]) array_shapes = npst.array_shapes(max_dims=4, min_side=0) +@st.composite # type: ignore[misc] +def dimension_names(draw: st.DrawFn, *, ndim: int | None = None) -> list[None | str] | None: + simple_text = st.text(zarr_key_chars, min_size=0) + return draw(st.none() | st.lists(st.none() | simple_text, min_size=ndim, max_size=ndim)) # type: ignore[no-any-return] + + +@st.composite # type: ignore[misc] +def array_metadata( + draw: st.DrawFn, + *, + array_shapes: st.SearchStrategy[tuple[int, ...]] = npst.array_shapes, + zarr_formats: st.SearchStrategy[Literal[2, 3]] = zarr_formats, + attributes: st.SearchStrategy[dict[str, Any]] = attrs, +) -> ArrayV2Metadata | ArrayV3Metadata: + zarr_format = draw(zarr_formats) + # separator = draw(st.sampled_from(['/', '\\'])) + shape = draw(array_shapes()) + ndim = len(shape) + chunk_shape = draw(array_shapes(min_dims=ndim, max_dims=ndim)) + dtype = draw(v3_dtypes()) + fill_value = draw(npst.from_dtype(dtype)) + if zarr_format == 2: + return ArrayV2Metadata( + shape=shape, + chunks=chunk_shape, + dtype=dtype, + fill_value=fill_value, + order=draw(st.sampled_from(["C", "F"])), + attributes=draw(attributes), + dimension_separator=draw(st.sampled_from([".", "/"])), + filters=None, + compressor=None, + ) + else: + return ArrayV3Metadata( + shape=shape, + data_type=dtype, + chunk_grid=RegularChunkGrid(chunk_shape=chunk_shape), + fill_value=fill_value, + attributes=draw(attributes), + dimension_names=draw(dimension_names(ndim=ndim)), + chunk_key_encoding=DefaultChunkKeyEncoding(separator="/"), # FIXME + codecs=[BytesCodec()], + storage_transformers=(), + ) + + @st.composite # type: ignore[misc] def numpy_arrays( draw: st.DrawFn, diff --git a/tests/test_array.py b/tests/test_array.py index 6aaf1072ba..4838129561 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1429,3 +1429,18 @@ def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkser results = pool.starmap(_index_array, [(arr, slice(len(data)))]) assert all(np.array_equal(r, data) for r in results) + + +async def test_sharding_coordinate_selection() -> None: + store = MemoryStore() + g = zarr.open_group(store, mode="w") + arr = g.create_array( + name="a", + shape=(2, 3, 4), + chunks=(1, 2, 2), + overwrite=True, + dtype=np.float32, + shards=(2, 4, 4), + ) + arr[:] = np.arange(2 * 3 * 4).reshape((2, 3, 4)) + assert (arr[1, [0, 1]] == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all() # type: ignore[index] diff --git a/tests/test_properties.py b/tests/test_properties.py index cfa6a706d8..bf98f9d162 100644 --- a/tests/test_properties.py +++ b/tests/test_properties.py @@ -2,17 +2,23 @@ import pytest from numpy.testing import assert_array_equal +from zarr.core.buffer import default_buffer_prototype + pytest.importorskip("hypothesis") import hypothesis.extra.numpy as npst import hypothesis.strategies as st from hypothesis import given +from zarr.abc.store import Store +from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.testing.strategies import ( + array_metadata, arrays, basic_indices, numpy_arrays, orthogonal_indices, + stores, zarr_formats, ) @@ -64,6 +70,17 @@ def test_vindex(data: st.DataObject) -> None: assert_array_equal(nparray[indexer], actual) +@given(store=stores, meta=array_metadata()) # type: ignore[misc] +async def test_roundtrip_array_metadata( + store: Store, meta: ArrayV2Metadata | ArrayV3Metadata +) -> None: + asdict = meta.to_buffer_dict(prototype=default_buffer_prototype()) + for key, expected in asdict.items(): + await store.set(f"0/{key}", expected) + actual = await store.get(f"0/{key}", prototype=default_buffer_prototype()) + assert actual == expected + + # @st.composite # def advanced_indices(draw, *, shape): # basic_idxr = draw(