Skip to content

Commit

Permalink
Warn the user when shape or chunks contains float values (#2579)
Browse files Browse the repository at this point in the history
* Warn user when shape or chunks contains non-integer values like floats

* Test for non-integer warnings
  • Loading branch information
faymanns authored Dec 28, 2024
1 parent 2ab280a commit c969f5c
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
9 changes: 8 additions & 1 deletion zarr/tests/test_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os.path
import shutil
import warnings
import numbers

import numpy as np
import pytest
Expand Down Expand Up @@ -762,7 +763,13 @@ def test_create_with_storage_transformers(at_root):
)
def test_shape_chunk_ints(init_shape, init_chunks, shape, chunks):
g = open_group()
array = g.create_dataset("ds", shape=init_shape, chunks=init_chunks, dtype=np.uint8)
if not isinstance(init_shape[0], numbers.Integral) or not isinstance(
init_chunks[0], numbers.Integral
):
with pytest.warns(UserWarning):
array = g.create_dataset("ds", shape=init_shape, chunks=init_chunks, dtype=np.uint8)
else:
array = g.create_dataset("ds", shape=init_shape, chunks=init_chunks, dtype=np.uint8)

assert all(
isinstance(s, int) for s in array.shape
Expand Down
3 changes: 2 additions & 1 deletion zarr/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def test_normalize_shape():
with pytest.raises(TypeError):
normalize_shape(None)
with pytest.raises(ValueError):
normalize_shape("foo")
with pytest.warns(UserWarning):
normalize_shape("foo")


def test_normalize_chunks():
Expand Down
6 changes: 6 additions & 0 deletions zarr/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Iterable,
cast,
)
import warnings

import numpy as np
from asciitree import BoxStyle, LeftAligned
Expand Down Expand Up @@ -88,6 +89,8 @@ def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> Tuple[int, ...]

# normalize
shape = cast(Tuple[int, ...], shape)
if not all(isinstance(s, numbers.Integral) for s in shape):
warnings.warn("shape contains non-integer value(s)", UserWarning, stacklevel=2)
shape = tuple(int(s) for s in shape)
return shape

Expand Down Expand Up @@ -176,6 +179,9 @@ def normalize_chunks(chunks: Any, shape: Tuple[int, ...], typesize: int) -> Tupl
if -1 in chunks or None in chunks:
chunks = tuple(s if c == -1 or c is None else int(c) for s, c in zip(shape, chunks))

if not all(isinstance(c, numbers.Integral) for c in chunks):
warnings.warn("chunks contains non-integer value(s)", UserWarning, stacklevel=2)

chunks = tuple(int(c) for c in chunks)
return chunks

Expand Down

0 comments on commit c969f5c

Please sign in to comment.