Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Support Zarr v3 #523

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions pipefunc/map/_storage_array/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
if TYPE_CHECKING:
from pipefunc.map._types import ShapeTuple

# get zarr version
_ZARR_VERSION = zarr.__version__
_ZARR_MAJOR = int(_ZARR_VERSION.split(".", 1)[0])
if _ZARR_MAJOR < 3: # noqa: PLR2004
from zarr import open as create_array
from zarr.storage import DirectoryStore as LocalStore
from zarr.storage import Store
else:
from zarr import create_array
from zarr.abc.store import Store
from zarr.storage import LocalStore


class ZarrFileArray(StorageBase):
"""Array interface to a Zarr store.
Expand All @@ -38,7 +50,7 @@ def __init__(
internal_shape: ShapeTuple | None = None,
shape_mask: tuple[bool, ...] | None = None,
*,
store: zarr.storage.Store | str | Path | None = None,
store: Store | str | Path | None = None,
object_codec: Any = None,
) -> None:
"""Initialize the ZarrFileArray."""
Expand All @@ -49,8 +61,8 @@ def __init__(
msg = "shape_mask must have the same length as shape + internal_shape"
raise ValueError(msg)
self.folder = Path(folder) if folder is not None else folder
if not isinstance(store, zarr.storage.Store):
store = zarr.DirectoryStore(str(self.folder))
if not isinstance(store, Store):
store = LocalStore(str(self.folder))
self.store = store
self.shape = tuple(shape)
self.shape_mask = tuple(shape_mask) if shape_mask is not None else (True,) * len(shape)
Expand All @@ -60,24 +72,25 @@ def __init__(
object_codec = CloudPickleCodec()

chunks = select_by_mask(self.shape_mask, (1,) * len(self.shape), self.internal_shape)
self.array = zarr.open(
zarr_kwargs = {"mode": "a"} if _ZARR_MAJOR < 3 else {} # noqa: PLR2004
self.array = create_array(
self.store,
mode="a",
path="/array",
name="/array",
shape=self.full_shape,
dtype=object,
object_codec=object_codec,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunks=chunks,
**zarr_kwargs,
)
self._mask = zarr.open(
self._mask = create_array(
self.store,
mode="a",
path="/mask",
name="/mask",
shape=self.shape,
dtype=bool,
fill_value=True,
object_codec=object_codec,
chunks=1,
**zarr_kwargs,
)

@property
Expand Down Expand Up @@ -222,7 +235,7 @@ def dump_in_subprocess(self) -> bool:
return True


class _SharedDictStore(zarr.storage.KVStore):
class _SharedDictStore(Store):
"""Custom Store subclass using a shared dictionary."""

def __init__(self, shared_dict: multiprocessing.managers.DictProxy | None = None) -> None:
Expand Down Expand Up @@ -256,12 +269,12 @@ def __init__(
internal_shape: tuple[int, ...] | None = None,
shape_mask: tuple[bool, ...] | None = None,
*,
store: zarr.storage.Store | None = None,
store: Store | None = None,
object_codec: Any = None,
) -> None:
"""Initialize the ZarrMemoryArray."""
if store is None:
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
super().__init__(
folder=folder,
shape=shape,
Expand All @@ -273,11 +286,11 @@ def __init__(
self.load()

@property
def persistent_store(self) -> zarr.storage.Store | None:
def persistent_store(self) -> Store | None:
"""Return the persistent store."""
if self.folder is None: # pragma: no cover
return None
return zarr.DirectoryStore(self.folder)
return LocalStore(self.folder)

def persist(self) -> None:
"""Persist the memory storage to disk."""
Expand Down Expand Up @@ -315,7 +328,7 @@ def __init__(
internal_shape: tuple[int, ...] | None = None,
shape_mask: tuple[bool, ...] | None = None,
*,
store: zarr.storage.Store | None = None,
store: Store | None = None,
object_codec: Any = None,
) -> None:
"""Initialize the ZarrSharedMemoryArray."""
Expand Down
2 changes: 1 addition & 1 deletion tests/map/storage/test_all_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _array_type(shape, internal_shape=None, shape_mask=None):

from pipefunc.map import ZarrFileArray

store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
return ZarrFileArray(None, shape, internal_shape, shape_mask, store=store)
elif request.param == "dict":

Expand Down
26 changes: 13 additions & 13 deletions tests/map/storage/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

def test_zarr_array_init():
shape = (2, 3, 4)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
assert arr.shape == shape
assert arr.strides == (12, 4, 1)


def test_zarr_array_properties():
shape = (2, 3, 4)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
assert arr.size == 24
assert arr.rank == 3
Expand All @@ -25,7 +25,7 @@ def test_zarr_array_properties():

def test_zarr_array_getitem():
shape = (2, 3)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
arr.dump((0, 0), {"a": 1})
arr.dump((1, 2), {"b": 2})
Expand All @@ -40,7 +40,7 @@ def test_zarr_array_getitem():

def test_zarr_array_to_array():
shape = (2, 3)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
arr.dump((0, 0), {"a": 1})
arr.dump((1, 2), {"b": 2})
Expand All @@ -55,7 +55,7 @@ def test_zarr_array_to_array():

def test_zarr_array_dump():
shape = (2, 3)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
arr.dump((0, 0), {"a": 1})
arr.dump((1, 2), {"b": 2})
Expand All @@ -67,7 +67,7 @@ def test_zarr_array_dump():

def test_zarr_array_getitem_with_slicing():
shape = (2, 3, 4)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)
arr.dump((0, 0, 0), {"a": 1})
arr.dump((0, 1, 0), {"b": 2})
Expand All @@ -93,7 +93,7 @@ def test_zarr_array_getitem_with_slicing():


def test_zarr_array_with_internal_arrays():
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
shape = (2, 2)
internal_shape = (3, 3, 4)
shape_mask = (True, True, False, False, False)
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_zarr_array_with_internal_arrays():


def test_zarr_array_with_internal_arrays_slicing():
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
shape = (2, 2)
internal_shape = (3, 3, 4)
shape_mask = (True, True, False, False, False)
Expand Down Expand Up @@ -180,7 +180,7 @@ def test_zarr_array_with_internal_arrays_slicing():

def test_zarr_array_set_and_get_single_item():
shape = (2, 3)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)

arr.dump((0, 0), {"a": 1})
Expand All @@ -194,7 +194,7 @@ def test_zarr_array_set_and_get_single_item_with_internal_shape():
shape = (2, 2)
internal_shape = (3, 3)
shape_mask = (True, True, False, False)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(
folder=None,
store=store,
Expand All @@ -216,7 +216,7 @@ def test_zarr_array_set_and_get_single_item_with_internal_shape_and_indexing():
shape = (2, 2)
internal_shape = (3, 3)
shape_mask = (True, True, False, False)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(
folder=None,
store=store,
Expand All @@ -236,7 +236,7 @@ def test_zarr_array_set_and_get_single_item_with_internal_shape_and_indexing():

def test_zarr_array_set_and_get_slice():
shape = (2, 3, 4)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(folder=None, store=store, shape=shape)

data1 = 42
Expand All @@ -252,7 +252,7 @@ def test_zarr_array_set_and_get_slice_with_internal_shape():
shape = (2, 2)
internal_shape = (3, 3, 4)
shape_mask = (True, True, False, False, False)
store = zarr.MemoryStore()
store = zarr.storage.MemoryStore()
arr = ZarrFileArray(
folder=None,
store=store,
Expand Down
Loading