Skip to content

Commit

Permalink
(fix): make pipeline pickleable (#67)
Browse files Browse the repository at this point in the history
* (fix): make pipeline pickleable

* (fix): add type to helper

* (fix): proper type

* (fix): dataclasses are not frozen by default

* (fix): format

* (chore): update docs
  • Loading branch information
ilan-gold authored Dec 10, 2024
1 parent 4ef3899 commit c9b1d85
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 18 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ zarr.config.set({
})
```

If the `ZarrsCodecPipeline` is pickled, and then un-pickled, and during that time one of `store_empty_chunks`, `chunk_concurrent_minimum`, `chunk_concurrent_maximum`, or `num_threads` has changed, the newly un-pickled version will pick up the new value. However, one a `ZarrsCodecPipeline` object has been instantiated, these values are then fixed. This may change in the future as guidance from the `zarr` community becomes clear.

## Concurrency

Concurrency can be classified into two types:
Expand Down
53 changes: 35 additions & 18 deletions python/zarrs/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, TypedDict

import numpy as np
from zarr.abc.codec import (
Expand All @@ -14,7 +14,7 @@

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from typing import Self
from typing import Any, Self

from zarr.abc.store import ByteGetter, ByteSetter
from zarr.core.array_spec import ArraySpec
Expand All @@ -32,10 +32,40 @@
)


@dataclass(frozen=True)
def get_codec_pipeline_impl(codec_metadata_json: str) -> CodecPipelineImpl:
return CodecPipelineImpl(
codec_metadata_json,
validate_checksums=config.get("codec_pipeline.validate_checksums", None),
# TODO: upstream zarr-python array.write_empty_chunks is not merged yet #2429
store_empty_chunks=config.get("array.write_empty_chunks", None),
chunk_concurrent_minimum=config.get(
"codec_pipeline.chunk_concurrent_minimum", None
),
chunk_concurrent_maximum=config.get(
"codec_pipeline.chunk_concurrent_maximum", None
),
num_threads=config.get("threading.max_workers", None),
)


class ZarrsCodecPipelineState(TypedDict):
codec_metadata_json: str
codecs: tuple[Codec, ...]


@dataclass
class ZarrsCodecPipeline(CodecPipeline):
codecs: tuple[Codec, ...]
impl: CodecPipelineImpl
codec_metadata_json: str

def __getstate__(self) -> ZarrsCodecPipelineState:
return {"codec_metadata_json": self.codec_metadata_json, "codecs": self.codecs}

def __setstate__(self, state: ZarrsCodecPipelineState):
self.codecs = state["codecs"]
self.codec_metadata_json = state["codec_metadata_json"]
self.impl = get_codec_pipeline_impl(self.codec_metadata_json)

def evolve_from_array_spec(self, array_spec: ArraySpec) -> Self:
raise NotImplementedError("evolve_from_array_spec")
Expand All @@ -49,22 +79,9 @@ def from_codecs(cls, codecs: Iterable[Codec]) -> Self:
# https://github.com/zarr-developers/zarr-python/issues/2409
# https://github.com/zarr-developers/zarr-python/pull/2429
return cls(
codec_metadata_json=codec_metadata_json,
codecs=tuple(codecs),
impl=CodecPipelineImpl(
codec_metadata_json,
validate_checksums=config.get(
"codec_pipeline.validate_checksums", None
),
# TODO: upstream zarr-python array.write_empty_chunks is not merged yet #2429
store_empty_chunks=config.get("array.write_empty_chunks", None),
chunk_concurrent_minimum=config.get(
"codec_pipeline.chunk_concurrent_minimum", None
),
chunk_concurrent_maximum=config.get(
"codec_pipeline.chunk_concurrent_maximum", None
),
num_threads=config.get("threading.max_workers", None),
),
impl=get_codec_pipeline_impl(codec_metadata_json),
)

@property
Expand Down
11 changes: 11 additions & 0 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import operator
import pickle
import tempfile
from collections.abc import Callable
from contextlib import contextmanager
Expand Down Expand Up @@ -229,3 +230,13 @@ def test_ellipsis_indexing_invalid(arr: zarr.Array):
# zarrs-python error: ValueError: operands could not be broadcast together with shapes (4,) (3,)
# numpy error: ValueError: could not broadcast input array from shape (3,) into shape (4,)
arr[2, ...] = stored_value


def test_pickle(arr: zarr.Array, tmp_path: Path):
arr[:] = np.arange(reduce(operator.mul, arr.shape, 1)).reshape(arr.shape)
expected = arr[:]
with Path.open(tmp_path / "arr.pickle", "wb") as f:
pickle.dump(arr._async_array.codec_pipeline, f)
with Path.open(tmp_path / "arr.pickle", "rb") as f:
object.__setattr__(arr._async_array, "codec_pipeline", pickle.load(f))
assert (arr[:] == expected).all()

0 comments on commit c9b1d85

Please sign in to comment.