From a829b3d7d7b1ba63a6dd60cca239fd249996afa6 Mon Sep 17 00:00:00 2001 From: David Stansby Date: Sun, 24 Dec 2023 16:25:34 +0000 Subject: [PATCH] Add some misc typing to zarr.util --- zarr/util.py | 50 ++++++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 24 deletions(-) diff --git a/zarr/util.py b/zarr/util.py index 54c389db69..fbdc600dc2 100644 --- a/zarr/util.py +++ b/zarr/util.py @@ -17,7 +17,9 @@ Union, Iterable, cast, + List, ) +from types import EllipsisType import numpy as np from asciitree import BoxStyle, LeftAligned @@ -54,7 +56,7 @@ def flatten(arg: Iterable) -> Iterable: class NumberEncoder(json.JSONEncoder): - def default(self, o): + def default(self, o: numbers.Number) -> float: # See json.JSONEncoder.default docstring for explanation # This is necessary to encode numpy dtype if isinstance(o, numbers.Integral): @@ -219,7 +221,7 @@ def normalize_dtype(dtype: Union[str, np.dtype], object_codec) -> Tuple[np.dtype # noinspection PyTypeChecker -def is_total_slice(item, shape: Tuple[int]) -> bool: +def is_total_slice(item: Union[EllipsisType, slice, tuple], shape: Tuple[int]) -> bool: """Determine whether `item` specifies a complete slice of array with the given `shape`. Used to optimize __setitem__ operations on the Chunk class.""" @@ -263,7 +265,7 @@ def normalize_resize_args(old_shape, *args): return new_shape -def human_readable_size(size) -> str: +def human_readable_size(size: float) -> str: if size < 2**10: return "%s" % size elif size < 2**20: @@ -391,7 +393,7 @@ def info_text_report(items: Dict[Any, Any]) -> str: return report -def info_html_report(items) -> str: +def info_html_report(items: dict) -> str: report = '' report += "" for k, v in items: @@ -420,25 +422,25 @@ def _repr_html_(self): class TreeNode: - def __init__(self, obj, depth=0, level=None): + def __init__(self, obj, depth: int = 0, level: Optional[int] = None): self.obj = obj self.depth = depth self.level = level - def get_children(self): + def get_children(self) -> List["TreeNode"]: if hasattr(self.obj, "values"): if self.level is None or self.depth < self.level: depth = self.depth + 1 return [TreeNode(o, depth=depth, level=self.level) for o in self.obj.values()] return [] - def get_text(self): + def get_text(self) -> str: name = self.obj.name.split("/")[-1] or "/" if hasattr(self.obj, "shape"): name += " {} {}".format(self.obj.shape, self.obj.dtype) return name - def get_type(self): + def get_type(self) -> str: return type(self.obj).__name__ @@ -466,7 +468,7 @@ def tree_get_icon(stype: str) -> str: raise ValueError("Unknown type: %s" % stype) -def tree_widget_sublist(node, root=False, expand=False): +def tree_widget_sublist(node, root: bool = False, expand: Union[bool, int] = False): import ipytree result = ipytree.Node() @@ -482,7 +484,7 @@ def tree_widget_sublist(node, root=False, expand=False): return result -def tree_widget(group, expand, level): +def tree_widget(group, expand: Union[bool, int], level: int): try: import ipytree except ImportError as error: @@ -501,7 +503,7 @@ def tree_widget(group, expand, level): class TreeViewer: - def __init__(self, group, expand=False, level=None): + def __init__(self, group, expand: Union[bool, int] = False, level: Optional[int] = None): self.group = group self.expand = expand self.level = level @@ -519,7 +521,7 @@ def __init__(self, group, expand=False, level=None): VERTICAL_AND_RIGHT="\u251C", ) - def __bytes__(self): + def __bytes__(self) -> bytes: drawer = LeftAligned( traverse=TreeTraversal(), draw=BoxStyle(gfx=self.bytes_kwargs, **self.text_kwargs) ) @@ -532,14 +534,14 @@ def __bytes__(self): return result - def __unicode__(self): + def __unicode__(self) -> str: drawer = LeftAligned( traverse=TreeTraversal(), draw=BoxStyle(gfx=self.unicode_kwargs, **self.text_kwargs) ) root = TreeNode(self.group, level=self.level) return drawer(root) - def __repr__(self): + def __repr__(self) -> str: return self.__unicode__() def _repr_mimebundle_(self, **kwargs): @@ -547,7 +549,7 @@ def _repr_mimebundle_(self, **kwargs): return tree._repr_mimebundle_(**kwargs) -def check_array_shape(param, array, shape): +def check_array_shape(param, array, shape) -> None: if not hasattr(array, "shape"): raise TypeError( "parameter {!r}: expected an array-like object, got {!r}".format(param, type(array)) @@ -560,7 +562,7 @@ def check_array_shape(param, array, shape): ) -def is_valid_python_name(name): +def is_valid_python_name(name: str) -> bool: from keyword import iskeyword return name.isidentifier() and not iskeyword(name) @@ -569,10 +571,10 @@ def is_valid_python_name(name): class NoLock: """A lock that doesn't lock.""" - def __enter__(self): + def __enter__(self) -> None: pass - def __exit__(self, *args): + def __exit__(self, *args: Any) -> None: pass @@ -600,7 +602,7 @@ def __init__(self, store_key, chunk_store): _key_path = "/".join(_key_path[:-1] + _chunk_path) self.key_path = _key_path - def prepare_chunk(self): + def prepare_chunk(self) -> None: assert self.buff is None header = self.fs.read_block(self.key_path, 0, 16) nbytes, self.cbytes, blocksize = cbuffer_sizes(header) @@ -620,7 +622,7 @@ def prepare_chunk(self): self.buff[16 : (16 + (self.nblocks * 4))] = start_points_buffer self.n_per_block = blocksize / typesize - def read_part(self, start, nitems): + def read_part(self, start, nitems) -> None: assert self.buff is not None if self.nblocks == 1: return @@ -654,7 +656,7 @@ def __init__(self, store_key, chunk_store, itemsize): self.store_key = store_key self.itemsize = itemsize - def prepare_chunk(self): + def prepare_chunk(self) -> None: pass def read_part(self, start, nitems): @@ -695,7 +697,7 @@ def retry_call( raise -def all_equal(value: Any, array: Any): +def all_equal(value: Any, array: Any) -> bool: """ Test if all the elements of an array are equivalent to a value. If `value` is None, then this function does not do any comparison and @@ -720,11 +722,11 @@ def all_equal(value: Any, array: Any): # Numpy errors if you call np.isnan on custom dtypes, so ensure # we are working with floats before calling isnan if np.issubdtype(array.dtype, np.floating) and np.isnan(value): - return np.all(np.isnan(array)) + return bool(np.all(np.isnan(array))) else: # using == raises warnings from numpy deprecated pattern, but # using np.equal() raises type errors for structured dtypes... - return np.all(value == array) + return bool(np.all(value == array)) def ensure_contiguous_ndarray_or_bytes(buf) -> Union[NDArrayLike, bytes]: