Skip to content

Commit

Permalink
add v3/types.py, bring v2 into closer alignment to v3 api
Browse files Browse the repository at this point in the history
  • Loading branch information
d-v-b committed Jan 7, 2024
1 parent a7ce7a1 commit 27740b0
Show file tree
Hide file tree
Showing 31 changed files with 1,192 additions and 1,047 deletions.
31 changes: 18 additions & 13 deletions zarr/tests/test_array_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
from typing import Any, Dict, Literal, Tuple, Union
import numpy as np

from zarr.v3.common import ChunkCoords
from zarr.v3.metadata import DefaultChunkKeyEncodingMetadata, RegularChunkGridMetadata
from zarr.v3.types import Attributes, ChunkCoords
from zarr.v3.metadata.v3 import DefaultChunkKeyEncoding, RegularChunkGrid, RegularChunkGridConfig

# todo: parametrize by chunks
@pytest.mark.asyncio
@pytest.mark.parametrize("zarr_version", ("2", "3"))
@pytest.mark.parametrize(
"shape",
Expand All @@ -30,19 +31,19 @@
@pytest.mark.parametrize("attributes", ({}, dict(a=10, b=10)))
@pytest.mark.parametrize("fill_value", (0, 1, 2))
@pytest.mark.parametrize("dimension_separator", (".", "/"))
def test_array(
async def test_array(
tmpdir,
zarr_version: Literal["2", "3"],
shape: Tuple[int, ...],
dtype: Union[str, np.dtype],
attributes: Dict[str, Any],
attributes: Attributes,
fill_value: float,
dimension_separator: Literal[".", "/"],
):
store_path = str(tmpdir)
arr: Union[v2.Array, v3.Array]
arr: Union[v2.AsyncArray, v3.Array]
if zarr_version == "2":
arr = v2.Array.create(
arr = await v2.Array.create(
store=store_path,
shape=shape,
dtype=dtype,
Expand All @@ -53,7 +54,7 @@ def test_array(
exists_ok=True,
)
else:
arr = v3.Array.create(
arr = await v3.Array.create(
store=store_path,
shape=shape,
dtype=dtype,
Expand Down Expand Up @@ -82,16 +83,18 @@ def test_init_format(zarr_format: Literal[2, 3]):
shape = (10,)
if zarr_format == 2:
with pytest.raises(ValueError):
arr = v2.ArrayMetadata(shape=shape, dtype=dtype, chunks=shape, zarr_format=3)
arr1 = v2.ArrayMetadata(shape=shape, dtype=dtype, chunks=shape, zarr_format=3)
else:
with pytest.raises(ValueError):
arr = v3.ArrayMetadata(
arr2 = v3.ArrayMetadata(
shape=shape,
data_type=dtype,
codecs=[],
chunk_grid=RegularChunkGridMetadata(configuration={"chunk_shape": shape}),
chunk_grid=RegularChunkGrid(
configuration=RegularChunkGridConfig(chunk_shape=shape)
),
fill_value=0,
chunk_key_encoding=DefaultChunkKeyEncodingMetadata(),
chunk_key_encoding=DefaultChunkKeyEncoding(),
zarr_format=2,
)

Expand All @@ -109,8 +112,10 @@ def test_init_node_type(zarr_format: Literal["2", "3"]):
shape=shape,
data_type=dtype,
codecs=[],
chunk_grid=RegularChunkGridMetadata(configuration={"chunk_shape": shape}),
chunk_grid=RegularChunkGrid(
configuration=RegularChunkGridConfig(chunk_shape=shape)
),
fill_value=0,
chunk_key_encoding=DefaultChunkKeyEncodingMetadata(),
chunk_key_encoding=DefaultChunkKeyEncoding(),
node_type="group",
)
57 changes: 24 additions & 33 deletions zarr/tests/test_codecs_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import pytest
import zarr
from zarr.v3 import codecs
from zarr.v3.array.base import runtime_configuration
from zarr.v3.array.v3 import Array, AsyncArray
from zarr.v3.common import Selection
from zarr.v3.types import Selection
from zarr.v3.array.indexing import morton_order_iter
from zarr.v3.metadata import CodecMetadata, ShardingCodecIndexLocation, runtime_configuration
from zarr.v3.metadata.v3 import CodecMetadata

from zarr.v3.store import MemoryStore, Store

Expand Down Expand Up @@ -46,12 +47,8 @@ def sample_data() -> np.ndarray:
return np.arange(0, 128 * 128 * 128, dtype="uint16").reshape((128, 128, 128), order="F")


@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
)
def test_sharding(
store: Store, sample_data: np.ndarray, index_location: ShardingCodecIndexLocation
):
@pytest.mark.parametrize("index_location", ["start", "end"])
def test_sharding(store: Store, sample_data: np.ndarray, index_location: Literal["start", "end"]):
a = Array.create(
store / "sample",
shape=sample_data.shape,
Expand All @@ -78,11 +75,9 @@ def test_sharding(
assert np.array_equal(sample_data, read_data)


@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
)
@pytest.mark.parametrize("index_location", ["start", "end"])
def test_sharding_partial(
store: Store, sample_data: np.ndarray, index_location: ShardingCodecIndexLocation
store: Store, sample_data: np.ndarray, index_location: Literal["start", "end"]
):
a = Array.create(
store / "sample",
Expand Down Expand Up @@ -113,11 +108,9 @@ def test_sharding_partial(
assert np.array_equal(sample_data, read_data)


@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
)
@pytest.mark.parametrize("index_location", ["start", "end"])
def test_sharding_partial_read(
store: Store, sample_data: np.ndarray, index_location: ShardingCodecIndexLocation
store: Store, sample_data: np.ndarray, index_location: Literal["start", "end"]
):
a = Array.create(
store / "sample",
Expand All @@ -142,11 +135,9 @@ def test_sharding_partial_read(
assert np.all(read_data == 1)


@pytest.mark.parametrize(
"index_location", [ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end]
)
@pytest.mark.parametrize("index_location", ["start", "end"])
def test_sharding_partial_overwrite(
store: Store, sample_data: np.ndarray, index_location: ShardingCodecIndexLocation
store: Store, sample_data: np.ndarray, index_location: Literal["start", "end"]
):
data = sample_data[:10, :10, :10]

Expand Down Expand Up @@ -182,17 +173,17 @@ def test_sharding_partial_overwrite(

@pytest.mark.parametrize(
"outer_index_location",
[ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end],
["start", "end"],
)
@pytest.mark.parametrize(
"inner_index_location",
[ShardingCodecIndexLocation.start, ShardingCodecIndexLocation.end],
["start", "end"],
)
def test_nested_sharding(
store: Store,
sample_data: np.ndarray,
outer_index_location: ShardingCodecIndexLocation,
inner_index_location: ShardingCodecIndexLocation,
outer_index_location: Literal["start", "end"],
inner_index_location: Literal["start", "end"],
):
a = Array.create(
store / "l4_sample" / "color" / "1",
Expand Down Expand Up @@ -243,7 +234,7 @@ async def test_order(
else [codecs.transpose_codec(store_order, data.ndim), codecs.bytes_codec()]
)

a = await AsyncArray.create(
a_create = await AsyncArray.create(
store / "order",
shape=data.shape,
chunk_shape=(32, 8),
Expand All @@ -254,15 +245,15 @@ async def test_order(
runtime_configuration=runtime_configuration(runtime_write_order),
)

await _AsyncArrayProxy(a)[:, :].set(data)
read_data = await _AsyncArrayProxy(a)[:, :].get()
await _AsyncArrayProxy(a_create)[:, :].set(data)
read_data = await _AsyncArrayProxy(a_create)[:, :].get()
assert np.array_equal(data, read_data)

a = await AsyncArray.open(
a_open = await AsyncArray.open(
store / "order",
runtime_configuration=runtime_configuration(order=runtime_read_order),
)
read_data = await _AsyncArrayProxy(a)[:, :].get()
read_data = await _AsyncArrayProxy(a_open)[:, :].get()
assert np.array_equal(data, read_data)

if runtime_read_order == "F":
Expand Down Expand Up @@ -922,7 +913,7 @@ def test_invalid_metadata(store: Store):
async def test_resize(store: Store):
data = np.zeros((16, 18), dtype="uint16")

a = await AsyncArray.create(
a_create = await AsyncArray.create(
store / "resize",
shape=data.shape,
chunk_shape=(10, 10),
Expand All @@ -931,14 +922,14 @@ async def test_resize(store: Store):
fill_value=1,
)

await _AsyncArrayProxy(a)[:16, :18].set(data)
await _AsyncArrayProxy(a_create)[:16, :18].set(data)
assert await store.get_async("resize/0.0") is not None
assert await store.get_async("resize/0.1") is not None
assert await store.get_async("resize/1.0") is not None
assert await store.get_async("resize/1.1") is not None

a = await a.resize((10, 12))
assert a.metadata.shape == (10, 12)
a_resize = await a_create.resize((10, 12))
assert a_resize.metadata.shape == (10, 12)
assert await store.get_async("resize/0.0") is not None
assert await store.get_async("resize/0.1") is not None
assert await store.get_async("resize/1.0") is None
Expand Down
73 changes: 73 additions & 0 deletions zarr/tests/test_v3x.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest

from zarr.v3.metadata.v3 import (
DefaultChunkKeyConfig,
DefaultChunkKeyEncoding,
RegularChunkGrid,
RegularChunkGridConfig,
)

from .test_codecs_v3 import store
from zarr.v3.array.v3 import ArrayMetadata
from zarr.v3.array.v3x import Array, lower_index


def test_v3x(store):
zmeta = ArrayMetadata(
shape=(10, 10),
data_type="uint16",
chunk_grid=RegularChunkGrid(configuration=RegularChunkGridConfig),
chunk_key_encoding=DefaultChunkKeyEncoding(configuration=DefaultChunkKeyConfig),
fill_value=0,
codecs=[],
)

arr = Array(
zmeta, store, shape=zmeta.shape, index=tuple(map(slice, zmeta.shape)), attributes={}
)
arr[slice(None), slice(None)]


# as fun as it looks
@pytest.mark.parametrize(
"args, expected",
(
(((slice(0, 10, 1),), (10,)), ((((0,), (slice(0, 10, 1),))),)),
(
((slice(0, 10, 1), slice(0, 10, 1)), (10, 10)),
(((0, 0), (slice(0, 10, 1), slice(0, 10, 1))),),
),
(((slice(0, 1, 1),), (10,)), (((0,), (slice(0, 1, 1),)),)),
(
((slice(0, 1, 1), slice(0, 1, 1)), (10, 4)),
(((0, 0), (slice(0, 1, 1), slice(0, 1, 1))),),
),
(
((slice(3, 11, 1), slice(0, 1, 1)), (10, 4)),
(
((0, 0), (slice(3, 10, 1), slice(0, 1, 1))),
((1, 0), (slice(0, 1, 1), slice(0, 1, 1))),
),
),
(
((slice(3, 22, 1), slice(0, 1, 1)), (10, 4)),
(
((0, 0), (slice(3, 10, 1), slice(0, 1, 1))),
((1, 0), (slice(0, 10, 1), slice(0, 1, 1))),
((2, 0), (slice(0, 2, 1), slice(0, 1, 1))),
),
),
(
((slice(3, 22, 1), slice(0, 5, 1)), (20, 4)),
(
((0, 0), (slice(3, 20, 1), slice(0, 4, 1))),
((0, 1), (slice(3, 20, 1), slice(0, 1, 1))),
((1, 0), (slice(0, 2, 1), slice(0, 4, 1))),
((1, 1), (slice(0, 2, 1), slice(0, 1, 1))),
),
),
),
)
def test_lower_index(args, expected):
observed = tuple(lower_index(*args))
assert observed == expected
8 changes: 4 additions & 4 deletions zarr/v3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from typing import Union

from zarr.v3.array.v3 import Array as ZArrayV3 # noqa: F401
from zarr.v3.array.v2 import Array as ZArrayV2 # noqa: F401
from zarr.v3.array.v3 import Array as ArrayV3 # noqa: F401
from zarr.v3.array.v2 import AsyncArray as ArrayV2 # noqa: F401
from zarr.v3.group import Group # noqa: F401
from zarr.v3.group_v2 import GroupV2 # noqa: F401
from zarr.v3.array.base import RuntimeConfiguration, runtime_configuration # noqa: F401
Expand All @@ -21,7 +21,7 @@
async def open_auto_async(
store: StoreLike,
runtime_configuration_: RuntimeConfiguration = RuntimeConfiguration(),
) -> Union[ZArrayV2, ZArrayV3, Group, GroupV2]:
) -> Union[ArrayV2, ArrayV3, Group, GroupV2]:
store_path = make_store_path(store)
try:
return await Group.open_or_array(store_path, runtime_configuration=runtime_configuration_)
Expand All @@ -32,7 +32,7 @@ async def open_auto_async(
def open_auto(
store: StoreLike,
runtime_configuration_: RuntimeConfiguration = RuntimeConfiguration(),
) -> Union[ZArrayV2, ZArrayV3, Group, GroupV2]:
) -> Union[ArrayV2, ArrayV3, Group, GroupV2]:
return _sync(
open_auto_async(store, runtime_configuration_),
runtime_configuration_.asyncio_loop,
Expand Down
Loading

0 comments on commit 27740b0

Please sign in to comment.