diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 65dfd5442e..3592b5ce1a 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -61,6 +61,7 @@ from zarr.errors import MetadataValidationError from zarr.storage import StoreLike, StorePath from zarr.storage._common import ensure_no_existing_node, make_store_path +from zarr.storage._utils import normalize_path if TYPE_CHECKING: from collections.abc import ( @@ -2984,7 +2985,8 @@ def _get_roots( data: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata], ) -> tuple[str, ...]: """ - Return the keys of the root(s) of the hierarchy + Return the keys of the root(s) of the hierarchy. A root is a key with the fewest number of + path segments. """ if "" in data: return ("",) @@ -3012,8 +3014,8 @@ def _parse_hierarchy_dict( then return an identical copy of that dict. Otherwise, return a version of the input dict with groups added where they are needed to make the hierarchy explicit. - For example, an input of {'a/b/c': ...} will result in a return value of - {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ...}. + For example, an input of {'a/b/c': ArrayMetadata} will result in a return value of + {'a': GroupMetadata, 'a/b': GroupMetadata, 'a/b/c': ArrayMetadata}. The input is also checked for the following conditions, and an error is raised if any of them are violated: @@ -3024,8 +3026,6 @@ def _parse_hierarchy_dict( This function ensures that the input is transformed into a specification of a complete and valid Zarr hierarchy. """ - # Create a copy of the input dict - out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data} observed_zarr_formats: dict[ZarrFormat, list[str | None]] = {2: [], 3: []} @@ -3041,38 +3041,59 @@ def _parse_hierarchy_dict( f"The following keys map to Zarr v3 nodes: {observed_zarr_formats.get(3)}." "Ensure that all nodes have the same Zarr format." ) - raise ValueError(msg) + out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {} + for k, v in data.items(): - if k is None: - # root node - pass - else: - if k.startswith("/"): - msg = f"Keys of hierarchy dicts must be relative paths, i.e. they cannot start with '/'. Got {k}, which violates this rule." - raise ValueError(k) - # TODO: ensure that the key is a valid path - # Split the key into its path components - key_split = k.split("/") - - # Iterate over the intermediate path components - *subpaths, _ = accumulate(key_split, lambda a, b: f"{a}/{b}") - for subpath in subpaths: - # If a component is not already in the output dict, add a group - if subpath not in out: - out[subpath] = GroupMetadata(zarr_format=v.zarr_format) - else: - if not isinstance(out[subpath], GroupMetadata): - msg = ( - f"The node at {subpath} contains other nodes, but it is not a Zarr group. " - "This is invalid. Only Zarr groups can contain other nodes." - ) - raise ValueError(msg) + # TODO: ensure that the key is a valid path + key_split = k.split("/") + *subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b])) + + for subpath in subpaths: + # If a component is not already in the output dict, add a group + if subpath not in out: + out[subpath] = GroupMetadata(zarr_format=v.zarr_format) + else: + if not isinstance(out[subpath], GroupMetadata): + msg = ( + f"The node at {subpath} contains other nodes, but it is not a Zarr group. " + "This is invalid. Only Zarr groups can contain other nodes." + ) + raise ValueError(msg) return out +def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: + """ + Normalize the input paths according to the normalization scheme used for zarr node paths. + If any two paths normalize to the same value, raise a ValueError. + """ + path_map: dict[str, str] = {} + for path in paths: + parsed = normalize_path(path) + if parsed in path_map: + msg = ( + f"After normalization, the value '{path}' collides with '{path_map[parsed]}'. " + f"Both '{path}' and '{path_map[parsed]}' normalize to the same value: '{parsed}'. " + f"You should use either '{path}' or '{path_map[parsed]}', but not both." + ) + raise ValueError(msg) + path_map[parsed] = path + return tuple(path_map.keys()) + + +def _normalize_path_keys(data: dict[str, T]) -> dict[str, T]: + """ + Normalize the keys of the input dict according to the normalization scheme used for zarr node + paths. If any two keys in the input normalize to the value, raise a ValueError. Return the + values of data with the normalized keys. + """ + parsed_keys = _normalize_paths(data.keys()) + return dict(zip(parsed_keys, data.values(), strict=False)) + + async def _getitem_semaphore( node: AsyncGroup, key: str, semaphore: asyncio.Semaphore | None ) -> AsyncArray[ArrayV3Metadata] | AsyncArray[ArrayV2Metadata] | AsyncGroup: diff --git a/tests/test_group.py b/tests/test_group.py index fe9eadabf6..def4fc554a 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -26,6 +26,8 @@ GroupMetadata, _from_flat, _join_paths, + _normalize_path_keys, + _normalize_paths, create_hierarchy, create_nodes, ) @@ -33,6 +35,7 @@ from zarr.errors import ContainsArrayError, ContainsGroupError from zarr.storage import LocalStore, MemoryStore, StorePath, ZipStore from zarr.storage._common import make_store_path +from zarr.storage._utils import normalize_path from zarr.testing.store import LatencyStore from .conftest import meta_from_array, parse_store @@ -1631,6 +1634,37 @@ async def test_group_from_flat(store: Store, zarr_format, path: str, root_key: s assert members_observed_meta == members_expected_meta_relative +@pytest.mark.parametrize("paths", [("a", "/a"), ("", "/"), ("b/", "b")]) +def test_normalize_paths_invalid(paths: tuple[str, str]): + """ + Ensure that calling _normalize_paths on values that will normalize to the same value + will generate a ValueError. + """ + a, b = paths + msg = f"After normalization, the value '{b}' collides with '{a}'. " + with pytest.raises(ValueError, match=msg): + _normalize_paths(paths) + + +@pytest.mark.parametrize( + "paths", [("/a", "a/b"), ("a", "a/b"), ("a/", "a///b"), ("/a/", "//a/b///")] +) +def test_normalize_paths_valid(paths: tuple[str, str]): + """ + Ensure that calling _normalize_paths on values that normalize to distinct values + returns a tuple of those normalized values. + """ + expected = tuple(map(normalize_path, paths)) + assert _normalize_paths(paths) == expected + + +def test_normalize_path_keys(): + data = {"": 10, "a": "hello", "a/b": None, "/a/b/c/d": None} + observed = _normalize_path_keys(data) + expected = {normalize_path(k): v for k, v in data.items()} + assert observed == expected + + @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_members_performance(store: Store) -> None: """