diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index 3592b5ce1a..7b1bfe5f77 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -58,7 +58,12 @@ from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.metadata.v3 import V3JsonEncoder from zarr.core.sync import SyncMixin, sync -from zarr.errors import MetadataValidationError +from zarr.errors import ( + ContainsArrayError, + ContainsGroupError, + MetadataValidationError, + RootedHierarchyError, +) from zarr.storage import StoreLike, StorePath from zarr.storage._common import ensure_no_existing_node, make_store_path from zarr.storage._utils import normalize_path @@ -683,53 +688,14 @@ async def getitem( """ store_path = self.store_path / key logger.debug("key=%s, store_path=%s", key, store_path) - metadata: ArrayV2Metadata | ArrayV3Metadata | GroupMetadata # Consolidated metadata lets us avoid some I/O operations so try that first. if self.metadata.consolidated_metadata is not None: return self._getitem_consolidated(store_path, key, prefix=self.name) - - # Note: - # in zarr-python v2, we first check if `key` references an Array, else if `key` references - # a group,using standalone `contains_array` and `contains_group` functions. These functions - # are reusable, but for v3 they would perform redundant I/O operations. - # Not clear how much of that strategy we want to keep here. elif self.metadata.zarr_format == 3: - zarr_json_bytes = await (store_path / ZARR_JSON).get() - if zarr_json_bytes is None: - raise KeyError(key) - else: - zarr_json = json.loads(zarr_json_bytes.to_bytes()) - metadata = _build_metadata_v3(zarr_json) - return _build_node_v3(metadata, store_path) - + return await _read_node_v3(store_path=store_path) elif self.metadata.zarr_format == 2: - # Q: how do we like optimistically fetching .zgroup, .zarray, and .zattrs? - # This guarantees that we will always make at least one extra request to the store - zgroup_bytes, zarray_bytes, zattrs_bytes = await asyncio.gather( - (store_path / ZGROUP_JSON).get(), - (store_path / ZARRAY_JSON).get(), - (store_path / ZATTRS_JSON).get(), - ) - - if zgroup_bytes is None and zarray_bytes is None: - raise KeyError(key) - - # unpack the zarray, if this is None then we must be opening a group - zarray = json.loads(zarray_bytes.to_bytes()) if zarray_bytes else None - zgroup = json.loads(zgroup_bytes.to_bytes()) if zgroup_bytes else None - # unpack the zattrs, this can be None if no attrs were written - zattrs = json.loads(zattrs_bytes.to_bytes()) if zattrs_bytes is not None else {} - - if zarray is not None: - metadata = _build_metadata_v2(zarray, zattrs) - return _build_node_v2(metadata=metadata, store_path=store_path) - else: - # this is just for mypy - if TYPE_CHECKING: - assert zgroup is not None - metadata = _build_metadata_v2(zgroup, zattrs) - return _build_node_v2(metadata=metadata, store_path=store_path) + return await _read_node_v2(store_path=store_path) else: raise ValueError(f"unexpected zarr_format: {self.metadata.zarr_format}") @@ -1431,7 +1397,9 @@ async def _members( # TODO: find a better name for this. create_tree could work. # TODO: include an example in the docstring async def create_hierarchy( - self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata] + self, + nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], + overwrite: bool, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ Create a hierarchy of arrays or groups rooted at this group. @@ -1445,13 +1413,20 @@ async def create_hierarchy( ---------- nodes : A dictionary representing the hierarchy to create + overwrite : bool + Whether or not existing arrays / groups should be replaced. + Returns ------- - An asynchronous iterator over the created nodes. + An asynchronous iterator over the created arrays and / or groups. """ semaphore = asyncio.Semaphore(config.get("async.concurrency")) async for node in create_hierarchy( - store_path=self.store_path, nodes=nodes, semaphore=semaphore + store_path=self.store_path, + nodes=nodes, + semaphore=semaphore, + overwrite=overwrite, + allow_root=False, ): yield node @@ -2078,7 +2053,9 @@ def members(self, max_depth: int | None = 0) -> tuple[tuple[str, Array | Group], return tuple((kv[0], _parse_async_node(kv[1])) for kv in _members) def create_hierarchy( - self, nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata] + self, + nodes: dict[str, ArrayV2Metadata | ArrayV3Metadata | GroupMetadata], + overwrite: bool = False, ) -> Iterator[ tuple[str, AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]] ]: @@ -2099,14 +2076,6 @@ def create_hierarchy( ------- A dict containing the created nodes, with the same keys as the input """ - if "" in nodes: - msg = ( - "Found the key '' in nodes, which denotes the root group. Creating the root group " - "from an existing group is not supported. If you want to create an entire Zarr group, " - "including the root group, from a dict then use the _from_flat method." - ) - raise ValueError(msg) - # check that all the nodes have the same zarr_format as Self for key, value in nodes.items(): if value.zarr_format != self.metadata.zarr_format: @@ -2116,9 +2085,19 @@ def create_hierarchy( f" has zarr_format {self.metadata.zarr_format}." ) raise ValueError(msg) - nodes_created = self._sync_iter(self._async_group.create_hierarchy(nodes)) - for n in nodes_created: - yield (_join_paths([self.path, n.name]), n) + try: + nodes_created = self._sync_iter( + self._async_group.create_hierarchy(nodes, overwrite=overwrite) + ) + for n in nodes_created: + yield (_join_paths([self.path, n.name]), n) + except RootedHierarchyError as e: + msg = ( + "The input defines a root node, but a root node already exists, namely this Group instance." + "It is an error to use this method to create a root node. " + "Remove the root node from the input dict, or use a function like _from_flat to create a rooted hierarchy." + ) + raise ValueError(msg) from e def keys(self) -> Generator[str, None]: """Return an iterator over group member names. @@ -2862,6 +2841,7 @@ async def create_hierarchy( nodes: dict[str, GroupMetadata | ArrayV3Metadata | ArrayV2Metadata], semaphore: asyncio.Semaphore | None = None, overwrite: bool = False, + allow_root: bool = True, ) -> AsyncIterator[AsyncGroup | AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata]]: """ Create a complete zarr hierarchy concurrently. Groups that are implicitly defined by the input @@ -2883,6 +2863,11 @@ async def create_hierarchy( semaphore : asyncio.Semaphore | None An optional semaphore to limit the number of concurrent tasks. If not provided, the number of concurrent tasks is not limited. + allow_root : bool + Whether to allow a root node to be created. If ``False``, attempting to create a root node + will result in an error. Use this option when calling this function as part of a method + defined on ``AsyncGroup`` instances, because in this case the root node has already been + created. Yields ------ @@ -2891,11 +2876,75 @@ async def create_hierarchy( """ nodes_parsed = _parse_hierarchy_dict(nodes) - if overwrite: - await store_path.delete_dir() - else: - # TODO: check if any of the nodes already exist, and error if so - raise NotImplementedError + if not allow_root and "" in nodes_parsed: + msg = ( + "Found the key '' in nodes (after key name normalization). That key denotes the root of a hierarchy, but ``allow_root`` is False, and so creating this node " + "is not allowed. Either remove this key from ``nodes``, or set ``allow_root`` to True." + ) + raise RootedHierarchyError(msg) + + # we allow creating empty hierarchies -- it's a no-op + if len(nodes_parsed) > 0: + if overwrite: + await store_path.delete_dir() + else: + # attempt to fetch all of the metadata described in hierarchy + # first figure out which zarr format we are dealing with + sample, *_ = nodes_parsed.values() + redundant_implicit_groups = [] + # TODO: decide if this set difference is sufficient for detecting implicit groups. + # an alternative would be to use an explicit implicit group class. + + implicit_group_names = set(nodes_parsed.keys()) - set(nodes.keys()) + + zarr_format = sample.zarr_format + if zarr_format == 3 or zarr_format == 2: + func = _read_metadata_v3 + else: + raise ValueError(f"Invalid zarr_format: {zarr_format}") + + coros = (func(store_path=store_path / key) for key in nodes_parsed) + extant_node_query = dict( + zip( + nodes_parsed.keys(), + await asyncio.gather(*coros, return_exceptions=True), + strict=False, + ) + ) + + for key, value in extant_node_query.items(): + if isinstance(value, BaseException): + if isinstance(value, KeyError): + # ignore KeyErrors, because they represent nodes we can safely create + pass + else: + # Any other exception is a real error + raise value + else: + # this is a node that already exists, but a node with this name was specified in + # nodes_parsed. + # Two cases produce exceptions: + # 1. The node is a group, and a node with this name was explicitly defined in + # nodes + # 2. The node is an array. + # The third case is when this extant node is a group, but its name was not + # explicitly defined in nodes. This means it was added as an implicit group by + # _parse_hierarchy_dict, and we can remove the reference to this node from + # nodes_parsed. We don't need to create this node. + + if isinstance(value, GroupMetadata): + if key not in implicit_group_names: + raise ContainsGroupError(store_path.store, key) + else: + # as there is already a group with this name, we should not create a new one + redundant_implicit_groups.append(key) + elif isinstance(value, ArrayV2Metadata | ArrayV3Metadata): + raise ContainsArrayError(store_path.store, key) + + nodes_parsed = { + k: v for k, v in nodes_parsed.items() if k not in redundant_implicit_groups + } + async for node in create_nodes(store_path=store_path, nodes=nodes_parsed, semaphore=semaphore): yield node @@ -3043,12 +3092,19 @@ def _parse_hierarchy_dict( ) raise ValueError(msg) - out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {} + # normalize the keys of the dict + + data_normed: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = ( + _normalize_path_keys(data) + ) + + out: dict[str, GroupMetadata | ArrayV2Metadata | ArrayV3Metadata] = {**data_normed} for k, v in data.items(): # TODO: ensure that the key is a valid path key_split = k.split("/") - *subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b])) + # we use /.join here because it checks the types of its inputs, unlike an f string + *subpaths, _ = accumulate(key_split, lambda a, b: "/".join([a, b])) # noqa: FLY002 for subpath in subpaths: # If a component is not already in the output dict, add a group @@ -3061,7 +3117,6 @@ def _parse_hierarchy_dict( "This is invalid. Only Zarr groups can contain other nodes." ) raise ValueError(msg) - return out @@ -3084,7 +3139,7 @@ def _normalize_paths(paths: Iterable[str]) -> tuple[str, ...]: return tuple(path_map.keys()) -def _normalize_path_keys(data: dict[str, T]) -> dict[str, T]: +def _normalize_path_keys(data: Mapping[str, T]) -> dict[str, T]: """ Normalize the keys of the input dict according to the normalization scheme used for zarr node paths. If any two keys in the input normalize to the value, raise a ValueError. Return the @@ -3212,20 +3267,56 @@ async def _iter_members_deep( yield key, node -def _resolve_metadata_v2( - blobs: tuple[str | bytes | bytearray, str | bytes | bytearray], -) -> ArrayV2Metadata | GroupMetadata: - zarr_metadata = json.loads(blobs[0]) - attrs = json.loads(blobs[1]) - if "shape" in zarr_metadata: - return ArrayV2Metadata.from_dict(zarr_metadata | {"attrs": attrs}) +async def _read_metadata_v3(store_path: StorePath) -> ArrayV3Metadata | GroupMetadata: + """ + Given a store_path, return ArrayV3Metadata or GroupMetadata defined by the metadata + document stored at store_path.path / zarr.json. If no such document is found, raise a KeyError. + """ + zarr_json_bytes = await (store_path / ZARR_JSON).get() + if zarr_json_bytes is None: + raise KeyError(store_path.path) + else: + zarr_json = json.loads(zarr_json_bytes.to_bytes()) + return _build_metadata_v3(zarr_json) + + +async def _read_metadata_v2(store_path: StorePath) -> ArrayV2Metadata | GroupMetadata: + """ + Given a store_path, return ArrayV2Metadata or GroupMetadata defined by the metadata + document stored at store_path.path / (.zgroup | .zarray). If no such document is found, + raise a KeyError. + """ + # TODO: consider first fetching array metadata, and only fetching group metadata when we don't + # find an array + zarray_bytes, zgroup_bytes, zattrs_bytes = await asyncio.gather( + (store_path / ZARRAY_JSON).get(), + (store_path / ZGROUP_JSON).get(), + (store_path / ZATTRS_JSON).get(), + ) + + if zattrs_bytes is None: + zattrs = {} else: - return GroupMetadata.from_dict(zarr_metadata | {"attrs": attrs}) + zattrs = json.loads(zattrs_bytes.to_bytes()) + + # TODO: decide how to handle finding both array and group metadata. The spec does not seem to + # consider this situation. A practical approach would be to ignore that combination, and only + # return the array metadata. + if zarray_bytes is not None: + zmeta = json.loads(zarray_bytes.to_bytes()) + else: + if zgroup_bytes is None: + # neither .zarray or .zgroup were found results in KeyError + raise KeyError(store_path.path) + else: + zmeta = json.loads(zgroup_bytes.to_bytes()) + return _build_metadata_v2(zmeta, zattrs) -def _build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMetadata: + +def _build_metadata_v3(zarr_json: dict[str, JSON]) -> ArrayV3Metadata | GroupMetadata: """ - Take a dict and convert it into the correct metadata type. + Convert a dict representation of Zarr V3 metadata into the corresponding metadata class. """ if "node_type" not in zarr_json: raise KeyError("missing `node_type` key in metadata document.") @@ -3239,10 +3330,10 @@ def _build_metadata_v3(zarr_json: dict[str, Any]) -> ArrayV3Metadata | GroupMeta def _build_metadata_v2( - zarr_json: dict[str, Any], attrs_json: dict[str, Any] + zarr_json: dict[str, object], attrs_json: dict[str, JSON] ) -> ArrayV2Metadata | GroupMetadata: """ - Take a dict and convert it into the correct metadata type. + Convert a dict representation of Zarr V2 metadata into the corresponding metadata class. """ match zarr_json: case {"shape": _}: @@ -3282,6 +3373,37 @@ def _build_node_v2( raise ValueError(f"Unexpected metadata type: {type(metadata)}") +async def _read_node_v2(store_path: StorePath) -> AsyncArray[ArrayV2Metadata] | AsyncGroup: + """ + Read a Zarr v2 AsyncArray or AsyncGroup from a location defined by a StorePath. + """ + metadata = await _read_metadata_v2(store_path=store_path) + return _build_node_v2(metadata=metadata, store_path=store_path) + + +async def _read_node_v3(store_path: StorePath) -> AsyncArray[ArrayV3Metadata] | AsyncGroup: + """ + Read a Zarr v3 AsyncArray or AsyncGroup from a location defined by a StorePath. + """ + metadata = await _read_metadata_v3(store_path=store_path) + return _build_node_v3(metadata=metadata, store_path=store_path) + + +async def _read_node( + store_path: StorePath, zarr_format: ZarrFormat +) -> AsyncArray[ArrayV2Metadata] | AsyncArray[ArrayV3Metadata] | AsyncGroup: + """ + Read and AsyncArray or AsyncGroup from a location defined by a StorePath. + """ + match zarr_format: + case 2: + return await _read_node_v2(store_path=store_path) + case 3: + return await _read_node_v3(store_path=store_path) + case _: + raise ValueError(f"Unexpected zarr format: {zarr_format}") + + async def _set_return_key(*, store: Store, key: str, value: Buffer, replace: bool) -> str: """ Either write a value to storage at the given key, or ensure that there is already a value in @@ -3314,8 +3436,6 @@ def _persist_metadata( ) -> tuple[Coroutine[None, None, str], ...]: """ Prepare to save a metadata document to storage, returning a tuple of coroutines that must be awaited. - If ``metadata`` is an instance of ``_ImplicitGroupMetadata``, then _set_return_key will be invoked with - ``replace=False``, which defers to a pre-existing metadata document in storage if one exists. Otherwise, existing values will be overwritten. """ to_save = metadata.to_buffer_dict(default_buffer_prototype()) diff --git a/src/zarr/errors.py b/src/zarr/errors.py index 441cdab9a3..855ea51b9d 100644 --- a/src/zarr/errors.py +++ b/src/zarr/errors.py @@ -57,3 +57,10 @@ class NodeTypeValidationError(MetadataValidationError): This can be raised when the value is invalid or unexpected given the context, for example an 'array' node when we expected a 'group'. """ + + +class RootedHierarchyError(BaseZarrError): + """ + Exception raised when attempting to create a rooted hierarchy in a context where that is not + permitted. + """ diff --git a/tests/test_group.py b/tests/test_group.py index def4fc554a..3d599141b7 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -28,6 +28,7 @@ _join_paths, _normalize_path_keys, _normalize_paths, + _read_node, create_hierarchy, create_nodes, ) @@ -1487,20 +1488,29 @@ async def test_create_hierarchy(store: Store, overwrite: bool, zarr_format: Zarr hierarchy_spec = { "group": GroupMetadata(attributes={"foo": 10}, zarr_format=zarr_format), "group/array_0": meta_from_array(np.arange(3), zarr_format=zarr_format), - "group/array_1": meta_from_array(np.arange(4), zarr_format=zarr_format), "group/subgroup/array_0": meta_from_array(np.arange(4), zarr_format=zarr_format), - "group/subgroup/array_1": meta_from_array(np.arange(5), zarr_format=zarr_format), } - pre_existing_nodes = {"extra": GroupMetadata(zarr_format=zarr_format)} + pre_existing_nodes = { + "group/extra": GroupMetadata(zarr_format=zarr_format, attributes={"name": "extra"}), + "": GroupMetadata(zarr_format=zarr_format, attributes={"name": "root"}), + } # we expect create_hierarchy to insert a group that was missing from the hierarchy spec expected_meta = hierarchy_spec | {"group/subgroup": GroupMetadata(zarr_format=zarr_format)} spath = await make_store_path(store, path=path) + # initialize the group with some nodes - await _from_flat(store_path=spath, nodes=pre_existing_nodes) + sync(_collect_aiterator(create_nodes(store_path=spath, nodes=pre_existing_nodes))) + observed_nodes = { str(PurePosixPath(a.name).relative_to("/" + path)): a async for a in create_hierarchy(store_path=spath, nodes=expected_meta, overwrite=overwrite) } + if not overwrite: + extra_group = await _read_node(spath / "group/extra", zarr_format=zarr_format) + assert extra_group.metadata.attributes == {"name": "extra"} + else: + with pytest.raises(KeyError): + await _read_node(spath / "group/extra", zarr_format=zarr_format) assert expected_meta == {k: v.metadata for k, v in observed_nodes.items()} @@ -1518,15 +1528,31 @@ def test_group_create_hierarchy(store: Store, zarr_format: ZarrFormat, overwrite np.zeros(5), zarr_format=zarr_format, attributes={"name": "a/b/c"} ), } - nodes = g.create_hierarchy(tree) + nodes = g.create_hierarchy(tree, overwrite=overwrite) for k, v in g.members(max_depth=None): assert v.metadata == tree[k] == nodes[k].metadata +@pytest.mark.parametrize("store", ["memory"], indirect=True) +@pytest.mark.parametrize("overwrite", [True, False]) +def test_group_create_hierarchy_no_root(store: Store, zarr_format: ZarrFormat, overwrite: bool): + """ + Test that the Group.create_hierarchy method will error if the dict provided contains a root. + """ + g = Group.from_store(store, zarr_format=zarr_format) + tree = { + "": GroupMetadata(zarr_format=zarr_format, attributes={"name": "a"}), + } + with pytest.raises( + ValueError, match="It is an error to use this method to create a root node. " + ): + _ = tuple(g.create_hierarchy(tree, overwrite=overwrite)) + + @pytest.mark.parametrize("store", ["memory"], indirect=True) def test_group_create_hierarchy_invalid_mixed_zarr_format(store: Store, zarr_format: ZarrFormat): """ - Test that ```Group.create_hierarchy``` will raise an error if the zarr_format of the nodes is + Test that ``Group.create_hierarchy`` will raise an error if the zarr_format of the nodes is different from the parent group. """ other_format = 2 if zarr_format == 3 else 3