Skip to content

Commit

Permalink
ENH: Support Zarr v3
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Jan 9, 2025
1 parent 46b1acc commit 5ff810e
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 29 deletions.
43 changes: 28 additions & 15 deletions pipefunc/map/_storage_array/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
if TYPE_CHECKING:
from pipefunc.map._types import ShapeTuple

# get zarr version
_ZARR_VERSION = zarr.__version__
_ZARR_MAJOR = int(_ZARR_VERSION.split(".", 1)[0])
if _ZARR_MAJOR < 3: # noqa: PLR2004
from zarr import open as create_array
from zarr.storage import DirectoryStore as LocalStore
from zarr.storage import Store
else:
from zarr import create_array
from zarr.abc.store import Store
from zarr.storage import LocalStore


class ZarrFileArray(StorageBase):
"""Array interface to a Zarr store.
Expand All @@ -38,7 +50,7 @@ def __init__(
internal_shape: ShapeTuple | None = None,
shape_mask: tuple[bool, ...] | None = None,
*,
store: zarr.storage.Store | str | Path | None = None,
store: Store | str | Path | None = None,
object_codec: Any = None,
) -> None:
"""Initialize the ZarrFileArray."""
Expand All @@ -49,8 +61,8 @@ def __init__(
msg = "shape_mask must have the same length as shape + internal_shape"
raise ValueError(msg)
self.folder = Path(folder) if folder is not None else folder
if not isinstance(store, zarr.storage.Store):
store = zarr.DirectoryStore(str(self.folder))
if not isinstance(store, Store):
store = LocalStore(str(self.folder))
self.store = store
self.shape = tuple(shape)
self.shape_mask = tuple(shape_mask) if shape_mask is not None else (True,) * len(shape)
Expand All @@ -60,24 +72,25 @@ def __init__(
object_codec = CloudPickleCodec()

chunks = select_by_mask(self.shape_mask, (1,) * len(self.shape), self.internal_shape)
self.array = zarr.open(
zarr_kwargs = {"mode": "a"} if _ZARR_MAJOR < 3 else {} # noqa: PLR2004
self.array = create_array(
self.store,
mode="a",
path="/array",
name="/array",
shape=self.full_shape,
dtype=object,
object_codec=object_codec,
chunks=chunks,
**zarr_kwargs,
)
self._mask = zarr.open(
self._mask = create_array(
self.store,
mode="a",
path="/mask",
name="/mask",
shape=self.shape,
dtype=bool,
fill_value=True,
object_codec=object_codec,
chunks=1,
**zarr_kwargs,
)

@property
Expand Down Expand Up @@ -222,7 +235,7 @@ def dump_in_subprocess(self) -> bool:
return True


class _SharedDictStore(zarr.storage.KVStore):
class _SharedDictStore(Store):
"""Custom Store subclass using a shared dictionary."""

def __init__(self, shared_dict: multiprocessing.managers.DictProxy | None = None) -> None:
Expand Down Expand Up @@ -256,12 +269,12 @@ def __init__(
internal_shape: tuple[int, ...] | None = None,
shape_mask: tuple[bool, ...] | None = None,
*,
store: zarr.storage.Store | None = None,
store: Store | None = None,
object_codec: Any = None,
) -> None:
"""Initialize the ZarrMemoryArray."""
if store is None:
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
super().__init__(
folder=folder,
shape=shape,
Expand All @@ -273,11 +286,11 @@ def __init__(
self.load()

@property
def persistent_store(self) -> zarr.storage.Store | None:
def persistent_store(self) -> Store | None:
"""Return the persistent store."""
if self.folder is None: # pragma: no cover
return None
return zarr.DirectoryStore(self.folder)
return LocalStore(self.folder)

def persist(self) -> None:
"""Persist the memory storage to disk."""
Expand Down Expand Up @@ -315,7 +328,7 @@ def __init__(
internal_shape: tuple[int, ...] | None = None,
shape_mask: tuple[bool, ...] | None = None,
*,
store: zarr.storage.Store | None = None,
store: Store | None = None,
object_codec: Any = None,
) -> None:
"""Initialize the ZarrSharedMemoryArray."""
Expand Down
2 changes: 1 addition & 1 deletion tests/map/storage/test_all_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _array_type(shape, internal_shape=None, shape_mask=None):

from pipefunc.map import ZarrFileArray

store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
return ZarrFileArray(None, shape, internal_shape, shape_mask, store=store)
elif request.param == "dict":

Expand Down
26 changes: 13 additions & 13 deletions tests/map/storage/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

def test_zarr_array_init():
shape = (2, 3, 4)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
assert arr.shape == shape
assert arr.strides == (12, 4, 1)


def test_zarr_array_properties():
shape = (2, 3, 4)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
assert arr.size == 24
assert arr.rank == 3
Expand All @@ -25,7 +25,7 @@ def test_zarr_array_properties():

def test_zarr_array_getitem():
shape = (2, 3)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
arr.dump((0, 0), {"a": 1})
arr.dump((1, 2), {"b": 2})
Expand All @@ -40,7 +40,7 @@ def test_zarr_array_getitem():

def test_zarr_array_to_array():
shape = (2, 3)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
arr.dump((0, 0), {"a": 1})
arr.dump((1, 2), {"b": 2})
Expand All @@ -55,7 +55,7 @@ def test_zarr_array_to_array():

def test_zarr_array_dump():
shape = (2, 3)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
arr.dump((0, 0), {"a": 1})
arr.dump((1, 2), {"b": 2})
Expand All @@ -67,7 +67,7 @@ def test_zarr_array_dump():

def test_zarr_array_getitem_with_slicing():
shape = (2, 3, 4)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
arr.dump((0, 0, 0), {"a": 1})
arr.dump((0, 1, 0), {"b": 2})
Expand All @@ -93,7 +93,7 @@ def test_zarr_array_getitem_with_slicing():


def test_zarr_array_with_internal_arrays():
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
shape = (2, 2)
internal_shape = (3, 3, 4)
shape_mask = (True, True, False, False, False)
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_zarr_array_with_internal_arrays():


def test_zarr_array_with_internal_arrays_slicing():
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
shape = (2, 2)
internal_shape = (3, 3, 4)
shape_mask = (True, True, False, False, False)
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_zarr_array_with_internal_arrays_slicing():

def test_zarr_array_set_and_get_single_item():
shape = (2, 3)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)

arr.dump((0, 0), {"a": 1})
Expand All @@ -194,7 +194,7 @@ def test_zarr_array_set_and_get_single_item_with_internal_shape():
shape = (2, 2)
internal_shape = (3, 3)
shape_mask = (True, True, False, False)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(
folder=None,
store=store,
Expand All @@ -216,7 +216,7 @@ def test_zarr_array_set_and_get_single_item_with_internal_shape_and_indexing():
shape = (2, 2)
internal_shape = (3, 3)
shape_mask = (True, True, False, False)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(
folder=None,
store=store,
Expand All @@ -236,7 +236,7 @@ def test_zarr_array_set_and_get_single_item_with_internal_shape_and_indexing():

def test_zarr_array_set_and_get_slice():
shape = (2, 3, 4)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)

data1 = 42
Expand All @@ -252,7 +252,7 @@ def test_zarr_array_set_and_get_slice_with_internal_shape():
shape = (2, 2)
internal_shape = (3, 3, 4)
shape_mask = (True, True, False, False, False)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(
folder=None,
store=store,
Expand Down

0 comments on commit 5ff810e

Please sign in to comment.