Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade ruff and enable a few more fixes #54

Merged
merged 1 commit into from
Jan 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
- id: trailing-whitespace
exclude: \.dis\.grb$
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.8.6
rev: v0.9.3
hooks:
# Run the linter.
- id: ruff
Expand Down
14 changes: 10 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,17 @@ filterwarnings = [

[tool.ruff.lint]
select = [
"E", # pycodestyle
"F", # Pyflakes
"UP", # pyupgrade
"I", # isort
"E", # pycodestyle
"F", # Pyflakes
"I", # isort
"NPY", # NumPy-specific
"RET", # flake8-return
"RUF", # Ruff-specific rules
"UP", # pyupgrade
]
ignore = [
"RUF005", # Consider iterable unpacking instead of concatenation
"RUF012", # Mutable class attributes should be annotated with ...
]

[tool.ruff.lint.per-file-ignores]
Expand Down
4 changes: 2 additions & 2 deletions src/gridit/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def cli_main():
--array-from-raster {Path("tests/data/Mana.tif")} {cl}
--write-raster {tmpdir / "Mana_100m.tif"} {cl}
--write-creation-option COMPRESS=deflate
""" # noqa
"""
waitaku2 = Path("tests/data/waitaku2")
if has_netcdf4:
examples += f"""\
Expand All @@ -76,7 +76,7 @@ def cli_main():
--array-from-netcdf {waitaku2}.nc:rid:myvar:0 {cl}
--time-stats "quantile(0.75),max" {cl}
--write-text {tmpdir / "waitaku2_cat.txt"}
""" # noqa
"""
if has_flopy:
examples += f"""\

Expand Down
7 changes: 3 additions & 4 deletions src/gridit/array_from.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ def array_from_array(self, grid, array, resampling=None):

if not isinstance(grid, Grid):
raise TypeError(f"expected grid to be a Grid; found {type(grid)!r}")
elif not (hasattr(array, "ndim") and hasattr(array, "shape")):
if not (hasattr(array, "ndim") and hasattr(array, "shape")):
raise TypeError(f"expected array to be array_like; found {type(array)!r}")
elif not (array.ndim in (2, 3) and array.shape[-2:] == grid.shape):
if not (array.ndim in (2, 3) and array.shape[-2:] == grid.shape):
raise ValueError("array has different shape than grid")

rel_res_diff = abs((grid.resolution - self.resolution) / self.resolution) * 100
Expand Down Expand Up @@ -267,8 +267,7 @@ def mask_from_raster(self, fname: str, bidx: int = 1):
# return ar.mask
if ar.mask.shape:
return ar.mask
else:
return np.full(ar.shape, ar.mask)
return np.full(ar.shape, ar.mask)


def array_from_vector(
Expand Down
4 changes: 2 additions & 2 deletions src/gridit/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,9 @@ def cell_geodataframe(self, *, values=None, mask=None, order="C"):
for name, array in values.items():
if not isinstance(name, str):
raise ValueError("key for values must be str")
elif getattr(array, "shape", None) != self.shape:
if getattr(array, "shape", None) != self.shape:
raise ValueError(
f"array {name!r} in values must have the same shape " "as the grid"
f"array {name!r} in values must have the same shape as the grid"
)
gdf[name] = array.ravel(order=order)[sel]
return gdf
4 changes: 2 additions & 2 deletions src/gridit/classmethods.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def get_shape_top_left(
buffer_items = [buffer] * 4
if not (minx <= maxx):
raise ValueError("'minx' must be less than 'maxx'")
elif not (miny <= maxy):
if not (miny <= maxy):
raise ValueError("'miny' must be less than 'maxy'")
elif resolution <= 0:
if resolution <= 0:
raise ValueError("'resolution' must be greater than zero")
if buffer:
minx -= buffer_items[0]
Expand Down
5 changes: 2 additions & 3 deletions src/gridit/display.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Display utilities module."""

__all__ = ["shorten", "print_array"]
__all__ = ["print_array", "shorten"]

import numpy as np

Expand All @@ -12,8 +12,7 @@ def shorten(text, width):
text = text.strip()
if len(text) < width:
return text
else:
return text[: (width - 5)] + "[...]"
return text[: (width - 5)] + "[...]"


def print_array(ar, logger=None):
Expand Down
23 changes: 10 additions & 13 deletions src/gridit/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def write_raster(grid, array, fname, driver=None, **kwargs):
raise ModuleNotFoundError("array_from_vector requires rasterio")
if array.ndim != 2:
raise ValueError("array must have two-dimensions")
elif array.shape != grid.shape:
if array.shape != grid.shape:
raise ValueError("array must have same shape " + str(grid.shape))
grid.logger.info("writing raster file: %s", fname)
if driver is None:
Expand Down Expand Up @@ -114,22 +114,20 @@ def fiona_property_type(ar):
frac = np.modf(ar)[0]
if (frac == 0.0).all():
return f"float:{precision}"
elif (has_dec := np.char.count(ar_str, ".") > 0).any():
if (has_dec := np.char.count(ar_str, ".") > 0).any():
ndc = ar_len[has_dec] - np.char.index(ar_str[has_dec], ".")
scale = ndc.max() - 1
return f"float:{precision}.{scale}"
else:
return "float"
elif np.issubdtype(ar.dtype, np.integer):
return "float"
if np.issubdtype(ar.dtype, np.integer):
scale = max(len(str(ar.min())), len(str(ar.max())))
return f"int:{scale}"
elif np.issubdtype(ar.dtype, np.bool_):
if np.issubdtype(ar.dtype, np.bool_):
return "int:1"
elif np.issubdtype(ar.dtype, np.str_) or np.issubdtype(ar.dtype, np.bytes_):
if np.issubdtype(ar.dtype, np.str_) or np.issubdtype(ar.dtype, np.bytes_):
scale = np.char.str_len(ar).max()
return f"str:{scale}"
else:
return "str"
return "str"


def write_vector(grid, array, fname, attribute, layer=None, driver=None, **kwargs):
Expand Down Expand Up @@ -179,9 +177,8 @@ def write_vector(grid, array, fname, attribute, layer=None, driver=None, **kwarg
if not isinstance(attribute, list) or len(attribute) != array.shape[0]:
if array.shape[0] == 1:
raise ValueError("attribute must be a str or a 1 item str list")
else:
raise ValueError(f"attribute must list of str with length {array.shape[0]}")
elif array.shape[-2:] != grid.shape:
raise ValueError(f"attribute must list of str with length {array.shape[0]}")
if array.shape[-2:] != grid.shape:
raise ValueError(f"last two dimensions of array shape must be {grid.shape}")
grid.logger.info("writing vector file: %s with layer: %s", fname, layer)
if driver is None:
Expand Down Expand Up @@ -285,7 +282,7 @@ def fiona_filter_collection(ds, filter):
raise ModuleNotFoundError("fiona_filter_collection requires fiona")
if not isinstance(ds, fiona.Collection):
raise ValueError(f"ds must be fiona.Collection; found {type(ds)}")
elif ds.closed:
if ds.closed:
raise ValueError("ds is closed")
flt = fiona.io.MemoryFile().open(driver=ds.driver, schema=ds.schema, crs=ds.crs)
if isinstance(filter, dict):
Expand Down
42 changes: 21 additions & 21 deletions src/gridit/gridpolyconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,15 @@
poly_idx = tuple(poly_idx)
except Exception:
raise ValueError(
"poly_idx must be a tuple or list-like; "
f"found {type(poly_idx)!r}"
f"poly_idx must be a tuple or list-like; found {type(poly_idx)!r}"
)
if len(poly_idx) != len(set(poly_idx)):
raise ValueError("poly_idx values are not unique")
self.poly_idx = poly_idx
# self.idx_d = dict(enumerate(poly_idx, 1))
if not isinstance(idx_ar, np.ndarray):
raise ValueError(f"idx_ar must be a numpy array; found {type(idx_ar)!r}")
elif not np.issubdtype(idx_ar.dtype, np.integer):
if not np.issubdtype(idx_ar.dtype, np.integer):
raise ValueError(f"idx_ar dtype must integer-based; found {idx_ar.dtype!r}")
self.idx_ar = idx_ar.copy()
self.idx_ar.flags.writeable = False
Expand All @@ -101,17 +100,17 @@
elif idx_ar.ndim == 3:
if ar_count is None:
raise ValueError("ar_count must be specified if idx_ar is 3D")
elif not isinstance(ar_count, np.ndarray):
if not isinstance(ar_count, np.ndarray):
raise ValueError(
"ar_count must be a numpy array; " f"found {type(ar_count)!r}"
f"ar_count must be a numpy array; found {type(ar_count)!r}"
)
elif not np.issubdtype(ar_count.dtype, np.integer):
if not np.issubdtype(ar_count.dtype, np.integer):
raise ValueError(
"ar_count dtype must integer-based; " f"found {ar_count.dtype!r}"
f"ar_count dtype must integer-based; found {ar_count.dtype!r}"
)
elif ar_count.shape != idx_ar.shape:
if ar_count.shape != idx_ar.shape:
raise ValueError(
"ar_count shape must match idx_ar; " f"found {ar_count.shape}"
f"ar_count shape must match idx_ar; found {ar_count.shape}"
)
self.ar_count = ar_count.copy()
self.ar_count.flags.writeable = False
Expand Down Expand Up @@ -224,13 +223,13 @@
raise ValueError("grid must be an instance of Grid")
if not isinstance(refine, int):
raise ValueError("refine must be int")
elif refine < 1:
if refine < 1:
raise ValueError("refine must be >= 1")
use_refine = refine > 1
if use_refine:
if not isinstance(max_levels, int):
raise ValueError("max_levels must be int")
elif max_levels < 1:
if max_levels < 1:
raise ValueError("max_levels must be >= 1")
if logger is None:
logger = get_logger(__package__)
Expand Down Expand Up @@ -263,14 +262,16 @@
]
if len(list_dir) == 0:
return None
elif fname in list_dir:
if fname in list_dir:
return dirname / fname
elif caching == 1:
if caching == 1:
prefix = fname[:9]
part_list_dir = [f for f in list_dir if f[:9] == prefix]
if part_list_dir:
# there might be more than one!
return dirname / part_list_dir[0]
return None
return None

if caching == 1:
args = (grid, poly_idx, refine)
Expand Down Expand Up @@ -491,8 +492,7 @@
def from_pickle(fname: str):
"""Unpickle object from a file."""
with open(fname, "rb") as f:
obj = pickle.load(f)
return obj
return pickle.load(f)

Check warning on line 495 in src/gridit/gridpolyconv.py

View check run for this annotation

Codecov / codecov/patch

src/gridit/gridpolyconv.py#L495

Added line #L495 was not covered by tests

def array_from_values(self, index, values, fill=0, enforce1d=False):
"""Generate 2D or 3D array from 1D or 2D values.
Expand Down Expand Up @@ -526,11 +526,11 @@
values = np.array(values)
if not hasattr(values, "ndim"):
raise ValueError("expected values be array-like")
elif values.ndim not in (1, 2):
if values.ndim not in (1, 2):
raise ValueError("expected values to have 1 or 2 dimensions")
elif len(index) != values.shape[-1]:
if len(index) != values.shape[-1]:
raise ValueError(
"length of last dimension of values " "does not match index length"
"length of last dimension of values does not match index length"
)
self.logger.info("reading array from values with shape %s", values.shape)
if enforce1d and values.ndim != 1:
Expand All @@ -540,7 +540,7 @@
poly_idx_l = list(self.poly_idx)
if index_s.isdisjoint(poly_idx_s):
raise ValueError("index is disjoint from poly_idx")
elif not index_s.issuperset(poly_idx_s):
if not index_s.issuperset(poly_idx_s):
raise ValueError("index is not a superset of poly_idx")
if index != poly_idx_l:
# subset and/or re-order values to match poly_idx
Expand Down Expand Up @@ -686,7 +686,7 @@
avail = list(ds.variables.keys())
if var_name not in avail:
raise AttributeError(f"cannot find '{var_name}' in variables: {avail}")
elif idx_name not in avail:
if idx_name not in avail:
raise AttributeError(f"cannot find '{idx_name}' in variables: {avail}")
if idx_name not in ds.coords:
new_coords = []
Expand Down Expand Up @@ -763,7 +763,7 @@
else:
month_sel = (month >= start_month) & (month <= end_month)
self.logger.info(
"performing statistics with a %d-month window, " "starting in %s",
"performing statistics with a %d-month window, starting in %s",
num_months,
calendar.month_name[start_month],
)
Expand Down
12 changes: 6 additions & 6 deletions src/gridit/modflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __post_init__(self):
delc = self.delc[0]
if not (self.delc == delc).all():
raise ValueError("model delc is not constant")
elif delr != delc:
if delr != delc:
raise ValueError("model delr and delc are different")

if self.rotation != 0:
Expand Down Expand Up @@ -165,21 +165,21 @@ def get_modflow_model(
stacklevel=2,
)
return SimpleNamespace(modelgrid=file_or_dir)
elif hasattr(file_or_dir, "modelgrid"): # this is a flopy-like object
if hasattr(file_or_dir, "modelgrid"): # this is a flopy-like object
warn(
"getting modflow model from a flopy-like object is deprecated; "
"use ModelGrid.from_modelgrid(obj.modelgrid) instead",
DeprecationWarning,
stacklevel=2,
)
return file_or_dir
elif not isinstance(file_or_dir, (str, PathLike)):
if not isinstance(file_or_dir, (str, PathLike)):
raise TypeError(f"expected str or PathLike object; found {type(file_or_dir)}")

pth = Path(file_or_dir).resolve()
if not pth.exists():
raise FileNotFoundError(f"cannot read path '{pth}'")
elif pth.suffixes[-2:] == [".dis", ".grb"]:
if pth.suffixes[-2:] == [".dis", ".grb"]:
# Binary grid file
if logger is not None:
logger.info("reading grid from a binary grid file: %s", pth)
Expand All @@ -192,7 +192,7 @@ def get_modflow_model(
logger.warning(msg, *args)
grb = flopy.mf6.utils.MfGrdFile(pth)
return SimpleNamespace(modelgrid=grb.modelgrid)
elif (pth.is_dir() and (pth / "mfsim.nam").is_file()) or pth.name == "mfsim.nam":
if (pth.is_dir() and (pth / "mfsim.nam").is_file()) or pth.name == "mfsim.nam":
# MODFLOW 6
sim_ws = str(pth) if pth.is_dir() else str(pth.parent)
if logger is not None:
Expand All @@ -219,7 +219,7 @@ def get_modflow_model(
model = sim.get_model(model_name)
model.tdis = sim.tdis # this is a bit of a hack
return model
elif pth.is_file(): # assume 'classic' MOFLOW file
if pth.is_file(): # assume 'classic' MOFLOW file
with catch_warnings():
filterwarnings("ignore", category=UserWarning)
try:
Expand Down
3 changes: 2 additions & 1 deletion src/gridit/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from gridit.logger import get_logger

__all__ = [
"is_same_crs",
"flat_grid_intersect",
"is_same_crs",
]


Expand All @@ -24,6 +24,7 @@ def is_same_crs(wkt1: str, wkt2: str) -> bool:
def epsg_code(wkt):
if match := re.fullmatch(r"EPSG:(\d+)", wkt, re.IGNORECASE):
return match.groups()[0]
return None

code1 = epsg_code(wkt1)
code2 = epsg_code(wkt2)
Expand Down
Loading