Skip to content

Commit

Permalink
Merge branch 'feat/batch-creation' of github.com:d-v-b/zarr-python in…
Browse files Browse the repository at this point in the history
…to feat/batch-creation
  • Loading branch information
d-v-b committed Jan 22, 2025
2 parents d07435b + 787d6bf commit 29ecce7
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 30 deletions.
81 changes: 51 additions & 30 deletions src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from zarr.errors import MetadataValidationError
from zarr.storage import StoreLike, StorePath
from zarr.storage._common import ensure_no_existing_node, make_store_path
from zarr.storage._utils import normalize_path

if TYPE_CHECKING:
from collections.abc import (
Expand Down Expand Up @@ -2984,7 +2985,8 @@ def _get_roots(
data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata],
) -> tuple[str, ...]:
"""
Return the keys of the root(s) of the hierarchy
Return the keys of the root(s) of the hierarchy. A root is a key with the fewest number of
path segments.
"""
if "" in data:
return ("",)
Expand Down Expand Up @@ -3012,8 +3014,8 @@ def _parse_hierarchy_dict(
then return an identical copy of that dict. Otherwise, return a version of the input dict
with groups added where they are needed to make the hierarchy explicit.
For example, an input of {'a/b/c': ...} will result in a return value of
{'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}.
For example, an input of {'a/b/c': ArrayMetadata} will result in a return value of
{'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}.
The input is also checked for the following conditions, and an error is raised if any
of them are violated:
Expand All @@ -3024,8 +3026,6 @@ def _parse_hierarchy_dict(
This function ensures that the input is transformed into a specification of a complete and valid
Zarr hierarchy.
"""
# Create a copy of the input dict
out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data}

observed_zarr_formats: dict[ZarrFormat, list[str | None]] = {2: [], 3: []}

Expand All @@ -3041,38 +3041,59 @@ def _parse_hierarchy_dict(
f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}."
"Ensure that all nodes have the same Zarr format."
)

raise ValueError(msg)

out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {}

for k, v in data.items():
if k is None:
# root node
pass
else:
if k.startswith("/"):
msg = f"Keys of hierarchy dicts must be relative paths, i.e. they cannot start with '/'. Got {k}, which violates this rule."
raise ValueError(k)
# TODO: ensure that the key is a valid path
# Split the key into its path components
key_split = k.split("/")

# Iterate over the intermediate path components
*subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}")
for subpath in subpaths:
# If a component is not already in the output dict, add a group
if subpath not in out:
out[subpath] = GroupMetadata(zarr_format=v.zarr_format)
else:
if not isinstance(out[subpath], GroupMetadata):
msg = (
f"The node at {subpath} contains other nodes, but it is not a Zarr group. "
"This is invalid. Only Zarr groups can contain other nodes."
)
raise ValueError(msg)
# TODO: ensure that the key is a valid path
key_split = k.split("/")
*subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b]))

for subpath in subpaths:
# If a component is not already in the output dict, add a group
if subpath not in out:
out[subpath] = GroupMetadata(zarr_format=v.zarr_format)
else:
if not isinstance(out[subpath], GroupMetadata):
msg = (
f"The node at {subpath} contains other nodes, but it is not a Zarr group. "
"This is invalid. Only Zarr groups can contain other nodes."
)
raise ValueError(msg)

return out


def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]:
"""
Normalize the input paths according to the normalization scheme used for zarr node paths.
If any two paths normalize to the same value, raise a ValueError.
"""
path_map: dict[str, str] = {}
for path in paths:
parsed = normalize_path(path)
if parsed in path_map:
msg = (
f"After normalization, the value '{path}' collides with '{path_map[parsed]}'. "
f"Both '{path}' and '{path_map[parsed]}' normalize to the same value: '{parsed}'. "
f"You should use either '{path}' or '{path_map[parsed]}', but not both."
)
raise ValueError(msg)
path_map[parsed] = path
return tuple(path_map.keys())


def _normalize_path_keys(data: dict[str, T]) -> dict[str, T]:
"""
Normalize the keys of the input dict according to the normalization scheme used for zarr node
paths. If any two keys in the input normalize to the value, raise a ValueError. Return the
values of data with the normalized keys.
"""
parsed_keys = _normalize_paths(data.keys())
return dict(zip(parsed_keys, data.values(), strict=False))


async def _getitem_semaphore(
node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None
) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup:
Expand Down
34 changes: 34 additions & 0 deletions tests/test_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
GroupMetadata,
_from_flat,
_join_paths,
_normalize_path_keys,
_normalize_paths,
create_hierarchy,
create_nodes,
)
from zarr.core.sync import _collect_aiterator, sync
from zarr.errors import ContainsArrayError, ContainsGroupError
from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore
from zarr.storage._common import make_store_path
from zarr.storage._utils import normalize_path
from zarr.testing.store import LatencyStore

from .conftest import meta_from_array, parse_store
Expand Down Expand Up @@ -1631,6 +1634,37 @@ async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: s
assert members_observed_meta == members_expected_meta_relative


@pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")])
def test_normalize_paths_invalid(paths: tuple[str, str]):
"""
Ensure that calling _normalize_paths on values that will normalize to the same value
will generate a ValueError.
"""
a, b = paths
msg = f"After normalization, the value '{b}' collides with '{a}'. "
with pytest.raises(ValueError, match=msg):
_normalize_paths(paths)


@pytest.mark.parametrize(
"paths", [("/a", "a/b"), ("a", "a/b"), ("a/", "a///b"), ("/a/", "//a/b///")]
)
def test_normalize_paths_valid(paths: tuple[str, str]):
"""
Ensure that calling _normalize_paths on values that normalize to distinct values
returns a tuple of those normalized values.
"""
expected = tuple(map(normalize_path, paths))
assert _normalize_paths(paths) == expected


def test_normalize_path_keys():
data = {"": 10, "a": "hello", "a/b": None, "/a/b/c/d": None}
observed = _normalize_path_keys(data)
expected = {normalize_path(k): v for k, v in data.items()}
assert observed == expected


@pytest.mark.parametrize("store", ["memory"], indirect=True)
def test_group_members_performance(store: Store) -> None:
"""
Expand Down

0 comments on commit 29ecce7

Please sign in to comment.