Skip to content

Commit

Permalink
Merge pull request #315 from tovrstra/ruff
Browse files Browse the repository at this point in the history
Ruff
  • Loading branch information
tovrstra authored May 29, 2024
2 parents 3235a15 + 47c165e commit 486e413
Show file tree
Hide file tree
Showing 79 changed files with 949 additions and 1,004 deletions.
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ repos:
hooks:
- id: remove-crlf
- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.28.3
rev: 0.28.4
hooks:
- id: check-github-workflows
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.4.5
hooks:
- id: ruff-format
- id: ruff
args: ["--fix", "--show-fixes"]
3 changes: 1 addition & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import os
import subprocess


# -- Fragile tricks for RTD -----------------------------------------------

# Normally sphinx-build should be called after iodata is installed somehow.
Expand Down Expand Up @@ -86,7 +85,7 @@

# General information about the project.
project = "IOData"
copyright = "2019, The IODATA Development Team"
copyright = "2019, The IODATA Development Team" # noqa: A001
author = "The IODATA Development Team"

# The version info for the project yo're documenting, acts as replacement for
Expand Down
7 changes: 3 additions & 4 deletions doc/gen_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from iodata.api import FORMAT_MODULES


__all__ = []


Expand All @@ -40,7 +39,7 @@


def _format_words(words):
return ", ".join("``{}``".format(word) for word in words)
return ", ".join(f"``{word}``" for word in words)


def _print_section(title, linechar):
Expand All @@ -65,7 +64,7 @@ def main():
# add labels for cross-referencing format (e.g. in formats table)
print(f".. _format_{modname}:")
print()
_print_section("{} (``{}``)".format(lines[0][:-1], modname), "=")
_print_section(f"{lines[0][:-1]} (``{modname}``)", "=")
print()
for line in lines[2:]:
print(line)
Expand All @@ -76,7 +75,7 @@ def main():
for fnname in FNNAMES:
fn = getattr(module, fnname, None)
if fn is not None:
_print_section(":py:func:`iodata.formats.{}.{}`".format(modname, fnname), "-")
_print_section(f":py:func:`iodata.formats.{modname}.{fnname}`", "-")
if fnname.startswith("load"):
print("- Always loads", _format_words(fn.guaranteed))
if fn.ifpresent:
Expand Down
3 changes: 1 addition & 2 deletions doc/gen_formats_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
# pylint: disable=unused-argument,redefined-builtin
"""Generate formats.rst."""

from collections import defaultdict
import inspect
from collections import defaultdict

import iodata


__all__ = []


Expand Down
8 changes: 4 additions & 4 deletions doc/gen_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"""Generate formats.rst."""

from gen_formats import _format_words, _print_section
from iodata.api import INPUT_MODULES

from iodata.api import INPUT_MODULES

__all__ = []

Expand Down Expand Up @@ -57,12 +57,12 @@ def main():
# add labels for cross-referencing format (e.g. in formats table)
print(f".. _input_{modname}:")
print()
_print_section("{} (``{}``)".format(lines[0][:-1], modname), "=")
_print_section(f"{lines[0][:-1]} (``{modname}``)", "=")
print()
for line in lines[2:]:
print(line)

_print_section(":py:func:`iodata.formats.{}.write_input`".format(modname), "-")
_print_section(f":py:func:`iodata.formats.{modname}.write_input`", "-")
fn = getattr(module, "write_input", None)
print("- Requires", _format_words(fn.required))
if fn.optional:
Expand All @@ -75,7 +75,7 @@ def main():
print()
template = getattr(module, "default_template", None)
if template:
code_block_lines = [" " + l for l in template.split("\n")]
code_block_lines = [" " + ell for ell in template.split("\n")]
print(TEMPLATE.format(code_block_lines="\n".join(code_block_lines)))
print()
print()
Expand Down
7 changes: 5 additions & 2 deletions iodata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
"""Input and Output Module."""

try:
from ._version import __version__
from ._version import __version__, __version_tuple__
except ImportError:
__version__ = "0.0.0.post0"
__version_tuple__ = (0, 0, 0, "a-dev")


from .api import dump_many, dump_one, load_many, load_one, write_input
from .iodata import IOData
from .api import *

__all__ = ("IOData", "load_one", "load_many", "dump_one", "dump_many", "write_input")
5 changes: 3 additions & 2 deletions iodata/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
"""CLI for file conversion."""

import argparse

import numpy as np

from .api import load_one, dump_one, load_many, dump_many, FORMAT_MODULES
from .api import FORMAT_MODULES, dump_many, dump_one, load_many, load_one

try:
from iodata.version import __version__
Expand Down Expand Up @@ -74,7 +75,7 @@ def parse_args():
"-V",
"--version",
action="version",
version="%(prog)s (IOData version {})".format(__version__),
version=f"%(prog)s (IOData version {__version__})",
)
parser.add_argument(
"-i", "--infmt", help="Select the input format, overrides automatic detection."
Expand Down
39 changes: 19 additions & 20 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
"""Functions to be used by end users."""

import os
from typing import Iterator
from types import ModuleType
from collections.abc import Iterator
from fnmatch import fnmatch
from pkgutil import iter_modules
from importlib import import_module
from pkgutil import iter_modules
from types import ModuleType
from typing import Optional

from .iodata import IOData
from .utils import LineIterator


__all__ = ["load_one", "load_many", "dump_one", "dump_many", "write_input"]


Expand All @@ -46,7 +46,7 @@ def _find_format_modules():
FORMAT_MODULES = _find_format_modules()


def _select_format_module(filename: str, attrname: str, fmt: str = None) -> ModuleType:
def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = None) -> ModuleType:
"""Find a file format module with the requested attribute name.
Parameters
Expand All @@ -68,14 +68,13 @@ def _select_format_module(filename: str, attrname: str, fmt: str = None) -> Modu
basename = os.path.basename(filename)
if fmt is None:
for format_module in FORMAT_MODULES.values():
if any(fnmatch(basename, pattern) for pattern in format_module.PATTERNS):
if hasattr(format_module, attrname):
return format_module
if any(fnmatch(basename, pattern) for pattern in format_module.PATTERNS) and hasattr(
format_module, attrname
):
return format_module
else:
return FORMAT_MODULES[fmt]
raise ValueError(
"Could not find file format with feature {} for file {}".format(attrname, filename)
)
raise ValueError(f"Could not find file format with feature {attrname} for file {filename}")


def _find_input_modules():
Expand Down Expand Up @@ -113,7 +112,7 @@ def _select_input_module(fmt: str) -> ModuleType:
raise ValueError(f"Could not find input format {fmt}!")


def load_one(filename: str, fmt: str = None, **kwargs) -> IOData:
def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
"""Load data from a file.
This function uses the extension or prefix of the filename to determine the
Expand Down Expand Up @@ -145,7 +144,7 @@ def load_one(filename: str, fmt: str = None, **kwargs) -> IOData:
return iodata


def load_many(filename: str, fmt: str = None, **kwargs) -> Iterator[IOData]:
def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IOData]:
"""Load multiple IOData instances from a file.
This function uses the extension or prefix of the filename to determine the
Expand All @@ -170,14 +169,14 @@ def load_many(filename: str, fmt: str = None, **kwargs) -> Iterator[IOData]:
"""
format_module = _select_format_module(filename, "load_many", fmt)
lit = LineIterator(filename)
for data in format_module.load_many(lit, **kwargs):
try:
try:
for data in format_module.load_many(lit, **kwargs):
yield IOData(**data)
except StopIteration:
return
except StopIteration:
return


def dump_one(iodata: IOData, filename: str, fmt: str = None, **kwargs):
def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs):
"""Write data to a file.
This routine uses the extension or prefix of the filename to determine
Expand All @@ -202,7 +201,7 @@ def dump_one(iodata: IOData, filename: str, fmt: str = None, **kwargs):
format_module.dump_one(f, iodata, **kwargs)


def dump_many(iodatas: Iterator[IOData], filename: str, fmt: str = None, **kwargs):
def dump_many(iodatas: Iterator[IOData], filename: str, fmt: Optional[str] = None, **kwargs):
"""Write multiple IOData instances to a file.
This routine uses the extension or prefix of the filename to determine
Expand All @@ -226,7 +225,7 @@ def dump_many(iodatas: Iterator[IOData], filename: str, fmt: str = None, **kwarg
format_module.dump_many(f, iodatas, **kwargs)


def write_input(iodata: IOData, filename: str, fmt: str, template: str = None, **kwargs):
def write_input(iodata: IOData, filename: str, fmt: str, template: Optional[str] = None, **kwargs):
"""Write input file using an instance of IOData for the specified software format.
Parameters
Expand Down
18 changes: 5 additions & 13 deletions iodata/attrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import numpy as np


__all__ = ["convert_array_to", "validate_shape"]


Expand All @@ -35,7 +34,6 @@ def converter(array):
return converter


# pylint: disable=too-many-branches
def validate_shape(*shape_requirements: tuple):
"""Return a validator for the shape of an array or the length of an iterable.
Expand Down Expand Up @@ -87,26 +85,21 @@ def validator(obj, attribute, value):
other_name, other_axis = item
other = getattr(obj, other_name)
if other is None:
raise TypeError("Other attribute '{}' is not set.".format(other_name))
raise TypeError(f"Other attribute '{other_name}' is not set.")
if other_axis == 0:
expected_shape.append(len(other))
else:
if other_axis >= other.ndim or other_axis < 0:
raise TypeError(
"Cannot get length along axis "
"{} of attribute {} with ndim {}.".format(
other_axis, other_name, other.ndim
)
f"{other_axis} of attribute {other_name} with ndim {other.ndim}."
)
expected_shape.append(other.shape[other_axis])
else:
raise ValueError(f"Cannot interpret item in shape_requirements: {item}")
expected_shape = tuple(expected_shape)
# Get the actual shape
if isinstance(value, np.ndarray):
observed_shape = value.shape
else:
observed_shape = (len(value),)
observed_shape = value.shape if isinstance(value, np.ndarray) else (len(value),)
# Compare
match = True
if len(expected_shape) != len(observed_shape):
Expand All @@ -121,9 +114,8 @@ def validator(obj, attribute, value):
# Raise TypeError if needed.
if not match:
raise TypeError(
"Expecting shape {} for attribute {}, got {}".format(
expected_shape, attribute.name, observed_shape
)
f"Expecting shape {expected_shape} for attribute {attribute.name}, "
f"got {observed_shape}"
)

return validator
Loading

0 comments on commit 486e413

Please sign in to comment.