diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8abb9fddd3..67137be96c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,7 +35,6 @@ repos: - numcodecs - numpy - typing_extensions - - zstandard # Tests - pytest # Zarr v2 diff --git a/pyproject.toml b/pyproject.toml index 88542b2cf5..53b4cb3244 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,6 @@ dependencies = [ 'numcodecs>=0.10.0', 'fsspec>2024', 'crc32c', - 'zstandard', 'typing_extensions', 'donfig', ] @@ -85,8 +84,8 @@ docs = [ 'pydata-sphinx-theme', 'numpydoc', 'numcodecs[msgpack]', - "msgpack", - "lmdb", + 'msgpack', + 'lmdb', ] extra = [ 'msgpack', diff --git a/src/zarr/codecs/zstd.py b/src/zarr/codecs/zstd.py index b244ee703a..3b3d3f33dd 100644 --- a/src/zarr/codecs/zstd.py +++ b/src/zarr/codecs/zstd.py @@ -1,10 +1,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from functools import cached_property +from importlib.metadata import version +from typing import TYPE_CHECKING -import numpy.typing as npt -from zstandard import ZstdCompressor, ZstdDecompressor +from numcodecs.zstd import Zstd from zarr.abc.codec import BytesBytesCodec from zarr.array_spec import ArraySpec @@ -38,6 +39,14 @@ class ZstdCodec(BytesBytesCodec): checksum: bool = False def __init__(self, *, level: int = 0, checksum: bool = False) -> None: + # numcodecs 0.13.0 introduces the checksum attribute for the zstd codec + _numcodecs_version = tuple(map(int, version("numcodecs").split("."))) + if _numcodecs_version < (0, 13, 0): # pragma: no cover + raise RuntimeError( + "numcodecs version >= 0.13.0 is required to use the zstd codec. " + f"Version {_numcodecs_version} is currently installed." + ) + level_parsed = parse_zstd_level(level) checksum_parsed = parse_checksum(checksum) @@ -52,13 +61,10 @@ def from_dict(cls, data: dict[str, JSON]) -> Self: def to_dict(self) -> dict[str, JSON]: return {"name": "zstd", "configuration": {"level": self.level, "checksum": self.checksum}} - def _compress(self, data: npt.NDArray[Any]) -> bytes: - ctx = ZstdCompressor(level=self.level, write_checksum=self.checksum) - return ctx.compress(data.tobytes()) - - def _decompress(self, data: npt.NDArray[Any]) -> bytes: - ctx = ZstdDecompressor() - return ctx.decompress(data.tobytes()) + @cached_property + def _zstd_codec(self) -> Zstd: + config_dict = {"level": self.level, "checksum": self.checksum} + return Zstd.from_config(config_dict) async def _decode_single( self, @@ -66,7 +72,7 @@ async def _decode_single( chunk_spec: ArraySpec, ) -> Buffer: return await to_thread( - as_numpy_array_wrapper, self._decompress, chunk_bytes, chunk_spec.prototype + as_numpy_array_wrapper, self._zstd_codec.decode, chunk_bytes, chunk_spec.prototype ) async def _encode_single( @@ -75,7 +81,7 @@ async def _encode_single( chunk_spec: ArraySpec, ) -> Buffer | None: return await to_thread( - as_numpy_array_wrapper, self._compress, chunk_bytes, chunk_spec.prototype + as_numpy_array_wrapper, self._zstd_codec.encode, chunk_bytes, chunk_spec.prototype ) def compute_encoded_size(self, _input_byte_length: int, _chunk_spec: ArraySpec) -> int: