From 574668063a11eb97aa21ca29f7cba7389f7e3e30 Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Wed, 12 Feb 2025 08:09:25 +1100 Subject: [PATCH 1/7] fix: sharding codec with fancy indexing --- src/zarr/codecs/sharding.py | 6 +++++- tests/test_array.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index e8730c86dd..c1087b3ad5 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -481,8 +481,12 @@ async def _decode_partial_single( ) # setup output array + if hasattr(indexer, "sel_shape"): + out_shape = indexer.sel_shape + else: + out_shape = indexer.shape out = shard_spec.prototype.nd_buffer.create( - shape=indexer.shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 + shape=out_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 ) indexed_chunks = list(indexer) diff --git a/tests/test_array.py b/tests/test_array.py index 1b84d1d061..8eb3cc6a23 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1420,3 +1420,18 @@ def test_multiprocessing(store: Store, method: Literal["fork", "spawn", "forkser results = pool.starmap(_index_array, [(arr, slice(len(data)))]) assert all(np.array_equal(r, data) for r in results) + + +async def test_sharding_coordinate_selection() -> None: + store = MemoryStore() + g = zarr.open_group(store, mode="w") + arr = g.create_array( + name="a", + shape=(10, 20, 30), + chunks=(5, 1, 30), + overwrite=True, + dtype=np.float32, + shards=(5, 20, 30), + ) + arr[:] = 1 + assert (arr[5, [0, 1]] == 1).all() From dfc66669df6a3156c16b420804b51a8250bbd2c0 Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Wed, 12 Feb 2025 08:16:01 +1100 Subject: [PATCH 2/7] changelog --- changes/2817.bugfix.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changes/2817.bugfix.rst diff --git a/changes/2817.bugfix.rst b/changes/2817.bugfix.rst new file mode 100644 index 0000000000..b1c0fa9220 --- /dev/null +++ b/changes/2817.bugfix.rst @@ -0,0 +1 @@ +Fix fancy indexing (e.g. arr[5, [0, 1]]) with the sharding codec \ No newline at end of file From 55ebb3104e7e835242a499d28d9fcdfb1188cc66 Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Wed, 12 Feb 2025 08:36:40 +1100 Subject: [PATCH 3/7] add a better test --- tests/test_array.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_array.py b/tests/test_array.py index 8eb3cc6a23..ff88359b55 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1433,5 +1433,5 @@ async def test_sharding_coordinate_selection() -> None: dtype=np.float32, shards=(5, 20, 30), ) - arr[:] = 1 - assert (arr[5, [0, 1]] == 1).all() + arr[:] = np.arange(10*20*30).reshape((10, 20, 30)) + assert (arr[5, [0, 1]] == np.vstack([np.arange(3000, 3030), np.arange(3030, 3060)])).all() From 3e43b1f7604f8f6c6b3852925f391cd2d6dd02b6 Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Wed, 12 Feb 2025 08:51:41 +1100 Subject: [PATCH 4/7] proper fix --- src/zarr/codecs/sharding.py | 12 ++++++------ tests/test_array.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/zarr/codecs/sharding.py b/src/zarr/codecs/sharding.py index c1087b3ad5..c287159216 100644 --- a/src/zarr/codecs/sharding.py +++ b/src/zarr/codecs/sharding.py @@ -481,12 +481,8 @@ async def _decode_partial_single( ) # setup output array - if hasattr(indexer, "sel_shape"): - out_shape = indexer.sel_shape - else: - out_shape = indexer.shape out = shard_spec.prototype.nd_buffer.create( - shape=out_shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 + shape=indexer.shape, dtype=shard_spec.dtype, order=shard_spec.order, fill_value=0 ) indexed_chunks = list(indexer) @@ -533,7 +529,11 @@ async def _decode_partial_single( ], out, ) - return out + + if hasattr(indexer, "sel_shape"): + return out.reshape(indexer.sel_shape) + else: + return out async def _encode_single( self, diff --git a/tests/test_array.py b/tests/test_array.py index ff88359b55..842e071612 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1427,11 +1427,11 @@ async def test_sharding_coordinate_selection() -> None: g = zarr.open_group(store, mode="w") arr = g.create_array( name="a", - shape=(10, 20, 30), - chunks=(5, 1, 30), + shape=(2, 3, 4), + chunks=(1, 2, 2), overwrite=True, dtype=np.float32, - shards=(5, 20, 30), + shards=(2, 4, 4), ) - arr[:] = np.arange(10*20*30).reshape((10, 20, 30)) - assert (arr[5, [0, 1]] == np.vstack([np.arange(3000, 3030), np.arange(3030, 3060)])).all() + arr[:] = np.arange(2 * 3 * 4).reshape((2, 3, 4)) + assert (arr[1, [0, 1]] == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all() From 1a30563a2b6d67c74357e14b091c144ab6befe46 Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Wed, 12 Feb 2025 09:49:03 +1100 Subject: [PATCH 5/7] fix: ArrayOfIntOrBool typing --- src/zarr/core/indexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index 733b2464ac..c13a5fd716 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -34,7 +34,7 @@ from zarr.core.common import ChunkCoords IntSequence = list[int] | npt.NDArray[np.intp] -ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_] +ArrayOfIntOrBool = IntSequence | npt.NDArray[np.bool_] BasicSelector = int | slice | EllipsisType Selector = BasicSelector | ArrayOfIntOrBool BasicSelection = BasicSelector | tuple[BasicSelector, ...] # also used for BlockIndex From ebb6c6dbd6d404333eb9615bb1cf1a35c96d58fc Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Wed, 12 Feb 2025 10:00:07 +1100 Subject: [PATCH 6/7] Revert "fix: ArrayOfIntOrBool typing" This reverts commit 1a30563a2b6d67c74357e14b091c144ab6befe46. --- src/zarr/core/indexing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/zarr/core/indexing.py b/src/zarr/core/indexing.py index c13a5fd716..733b2464ac 100644 --- a/src/zarr/core/indexing.py +++ b/src/zarr/core/indexing.py @@ -34,7 +34,7 @@ from zarr.core.common import ChunkCoords IntSequence = list[int] | npt.NDArray[np.intp] -ArrayOfIntOrBool = IntSequence | npt.NDArray[np.bool_] +ArrayOfIntOrBool = npt.NDArray[np.intp] | npt.NDArray[np.bool_] BasicSelector = int | slice | EllipsisType Selector = BasicSelector | ArrayOfIntOrBool BasicSelection = BasicSelector | tuple[BasicSelector, ...] # also used for BlockIndex From 40f337f12f8673f1eef951df08410ca05084f117 Mon Sep 17 00:00:00 2001 From: Lachlan Deakin Date: Wed, 12 Feb 2025 10:01:30 +1100 Subject: [PATCH 7/7] ignore typing error in test --- tests/test_array.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_array.py b/tests/test_array.py index 842e071612..a9e9d1232d 100644 --- a/tests/test_array.py +++ b/tests/test_array.py @@ -1434,4 +1434,4 @@ async def test_sharding_coordinate_selection() -> None: shards=(2, 4, 4), ) arr[:] = np.arange(2 * 3 * 4).reshape((2, 3, 4)) - assert (arr[1, [0, 1]] == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all() + assert (arr[1, [0, 1]] == np.array([[12, 13, 14, 15], [16, 17, 18, 19]])).all() # type: ignore[index]