Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ruff #315

Merged
merged 18 commits into from
May 29, 2024
Merged

Ruff #315

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ repos:
hooks:
- id: remove-crlf
- repo: https://github.com/python-jsonschema/check-jsonschema
rev: 0.28.3
rev: 0.28.4
hooks:
- id: check-github-workflows
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.4
rev: v0.4.5
hooks:
- id: ruff-format
- id: ruff
args: ["--fix", "--show-fixes"]
3 changes: 1 addition & 2 deletions doc/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import os
import subprocess


# -- Fragile tricks for RTD -----------------------------------------------

# Normally sphinx-build should be called after iodata is installed somehow.
Expand Down Expand Up @@ -86,7 +85,7 @@

# General information about the project.
project = "IOData"
copyright = "2019, The IODATA Development Team"
copyright = "2019, The IODATA Development Team" # noqa: A001
author = "The IODATA Development Team"

# The version info for the project yo're documenting, acts as replacement for
Expand Down
7 changes: 3 additions & 4 deletions doc/gen_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

from iodata.api import FORMAT_MODULES


__all__ = []


Expand All @@ -40,7 +39,7 @@


def _format_words(words):
return ", ".join("``{}``".format(word) for word in words)
return ", ".join(f"``{word}``" for word in words)


def _print_section(title, linechar):
Expand All @@ -65,7 +64,7 @@ def main():
# add labels for cross-referencing format (e.g. in formats table)
print(f".. _format_{modname}:")
print()
_print_section("{} (``{}``)".format(lines[0][:-1], modname), "=")
_print_section(f"{lines[0][:-1]} (``{modname}``)", "=")
print()
for line in lines[2:]:
print(line)
Expand All @@ -76,7 +75,7 @@ def main():
for fnname in FNNAMES:
fn = getattr(module, fnname, None)
if fn is not None:
_print_section(":py:func:`iodata.formats.{}.{}`".format(modname, fnname), "-")
_print_section(f":py:func:`iodata.formats.{modname}.{fnname}`", "-")
if fnname.startswith("load"):
print("- Always loads", _format_words(fn.guaranteed))
if fn.ifpresent:
Expand Down
3 changes: 1 addition & 2 deletions doc/gen_formats_tab.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@
# pylint: disable=unused-argument,redefined-builtin
"""Generate formats.rst."""

from collections import defaultdict
import inspect
from collections import defaultdict

import iodata


__all__ = []


Expand Down
8 changes: 4 additions & 4 deletions doc/gen_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"""Generate formats.rst."""

from gen_formats import _format_words, _print_section
from iodata.api import INPUT_MODULES

from iodata.api import INPUT_MODULES

__all__ = []

Expand Down Expand Up @@ -57,12 +57,12 @@ def main():
# add labels for cross-referencing format (e.g. in formats table)
print(f".. _input_{modname}:")
print()
_print_section("{} (``{}``)".format(lines[0][:-1], modname), "=")
_print_section(f"{lines[0][:-1]} (``{modname}``)", "=")
print()
for line in lines[2:]:
print(line)

_print_section(":py:func:`iodata.formats.{}.write_input`".format(modname), "-")
_print_section(f":py:func:`iodata.formats.{modname}.write_input`", "-")
fn = getattr(module, "write_input", None)
print("- Requires", _format_words(fn.required))
if fn.optional:
Expand All @@ -75,7 +75,7 @@ def main():
print()
template = getattr(module, "default_template", None)
if template:
code_block_lines = [" " + l for l in template.split("\n")]
code_block_lines = [" " + ell for ell in template.split("\n")]
print(TEMPLATE.format(code_block_lines="\n".join(code_block_lines)))
print()
print()
Expand Down
7 changes: 5 additions & 2 deletions iodata/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
"""Input and Output Module."""

try:
from ._version import __version__
from ._version import __version__, __version_tuple__
except ImportError:
__version__ = "0.0.0.post0"
__version_tuple__ = (0, 0, 0, "a-dev")


from .api import dump_many, dump_one, load_many, load_one, write_input
from .iodata import IOData
from .api import *

__all__ = ("IOData", "load_one", "load_many", "dump_one", "dump_many", "write_input")
5 changes: 3 additions & 2 deletions iodata/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
"""CLI for file conversion."""

import argparse

import numpy as np

from .api import load_one, dump_one, load_many, dump_many, FORMAT_MODULES
from .api import FORMAT_MODULES, dump_many, dump_one, load_many, load_one

try:
from iodata.version import __version__
Expand Down Expand Up @@ -74,7 +75,7 @@ def parse_args():
"-V",
"--version",
action="version",
version="%(prog)s (IOData version {})".format(__version__),
version=f"%(prog)s (IOData version {__version__})",
)
parser.add_argument(
"-i", "--infmt", help="Select the input format, overrides automatic detection."
Expand Down
39 changes: 19 additions & 20 deletions iodata/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
"""Functions to be used by end users."""

import os
from typing import Iterator
from types import ModuleType
from collections.abc import Iterator
from fnmatch import fnmatch
from pkgutil import iter_modules
from importlib import import_module
from pkgutil import iter_modules
from types import ModuleType
from typing import Optional

from .iodata import IOData
from .utils import LineIterator


__all__ = ["load_one", "load_many", "dump_one", "dump_many", "write_input"]


Expand All @@ -46,7 +46,7 @@ def _find_format_modules():
FORMAT_MODULES = _find_format_modules()


def _select_format_module(filename: str, attrname: str, fmt: str = None) -> ModuleType:
def _select_format_module(filename: str, attrname: str, fmt: Optional[str] = None) -> ModuleType:
"""Find a file format module with the requested attribute name.

Parameters
Expand All @@ -68,14 +68,13 @@ def _select_format_module(filename: str, attrname: str, fmt: str = None) -> Modu
basename = os.path.basename(filename)
if fmt is None:
for format_module in FORMAT_MODULES.values():
if any(fnmatch(basename, pattern) for pattern in format_module.PATTERNS):
if hasattr(format_module, attrname):
return format_module
if any(fnmatch(basename, pattern) for pattern in format_module.PATTERNS) and hasattr(
format_module, attrname
):
return format_module
else:
return FORMAT_MODULES[fmt]
raise ValueError(
"Could not find file format with feature {} for file {}".format(attrname, filename)
)
raise ValueError(f"Could not find file format with feature {attrname} for file {filename}")


def _find_input_modules():
Expand Down Expand Up @@ -113,7 +112,7 @@ def _select_input_module(fmt: str) -> ModuleType:
raise ValueError(f"Could not find input format {fmt}!")


def load_one(filename: str, fmt: str = None, **kwargs) -> IOData:
def load_one(filename: str, fmt: Optional[str] = None, **kwargs) -> IOData:
"""Load data from a file.

This function uses the extension or prefix of the filename to determine the
Expand Down Expand Up @@ -145,7 +144,7 @@ def load_one(filename: str, fmt: str = None, **kwargs) -> IOData:
return iodata


def load_many(filename: str, fmt: str = None, **kwargs) -> Iterator[IOData]:
def load_many(filename: str, fmt: Optional[str] = None, **kwargs) -> Iterator[IOData]:
"""Load multiple IOData instances from a file.

This function uses the extension or prefix of the filename to determine the
Expand All @@ -170,14 +169,14 @@ def load_many(filename: str, fmt: str = None, **kwargs) -> Iterator[IOData]:
"""
format_module = _select_format_module(filename, "load_many", fmt)
lit = LineIterator(filename)
for data in format_module.load_many(lit, **kwargs):
try:
try:
for data in format_module.load_many(lit, **kwargs):
yield IOData(**data)
except StopIteration:
return
except StopIteration:
return


def dump_one(iodata: IOData, filename: str, fmt: str = None, **kwargs):
def dump_one(iodata: IOData, filename: str, fmt: Optional[str] = None, **kwargs):
"""Write data to a file.

This routine uses the extension or prefix of the filename to determine
Expand All @@ -202,7 +201,7 @@ def dump_one(iodata: IOData, filename: str, fmt: str = None, **kwargs):
format_module.dump_one(f, iodata, **kwargs)


def dump_many(iodatas: Iterator[IOData], filename: str, fmt: str = None, **kwargs):
def dump_many(iodatas: Iterator[IOData], filename: str, fmt: Optional[str] = None, **kwargs):
"""Write multiple IOData instances to a file.

This routine uses the extension or prefix of the filename to determine
Expand All @@ -226,7 +225,7 @@ def dump_many(iodatas: Iterator[IOData], filename: str, fmt: str = None, **kwarg
format_module.dump_many(f, iodatas, **kwargs)


def write_input(iodata: IOData, filename: str, fmt: str, template: str = None, **kwargs):
def write_input(iodata: IOData, filename: str, fmt: str, template: Optional[str] = None, **kwargs):
"""Write input file using an instance of IOData for the specified software format.

Parameters
Expand Down
18 changes: 5 additions & 13 deletions iodata/attrutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

import numpy as np


__all__ = ["convert_array_to", "validate_shape"]


Expand All @@ -35,7 +34,6 @@ def converter(array):
return converter


# pylint: disable=too-many-branches
def validate_shape(*shape_requirements: tuple):
"""Return a validator for the shape of an array or the length of an iterable.

Expand Down Expand Up @@ -87,26 +85,21 @@ def validator(obj, attribute, value):
other_name, other_axis = item
other = getattr(obj, other_name)
if other is None:
raise TypeError("Other attribute '{}' is not set.".format(other_name))
raise TypeError(f"Other attribute '{other_name}' is not set.")
if other_axis == 0:
expected_shape.append(len(other))
else:
if other_axis >= other.ndim or other_axis < 0:
raise TypeError(
"Cannot get length along axis "
"{} of attribute {} with ndim {}.".format(
other_axis, other_name, other.ndim
)
f"{other_axis} of attribute {other_name} with ndim {other.ndim}."
)
expected_shape.append(other.shape[other_axis])
else:
raise ValueError(f"Cannot interpret item in shape_requirements: {item}")
expected_shape = tuple(expected_shape)
# Get the actual shape
if isinstance(value, np.ndarray):
observed_shape = value.shape
else:
observed_shape = (len(value),)
observed_shape = value.shape if isinstance(value, np.ndarray) else (len(value),)
# Compare
match = True
if len(expected_shape) != len(observed_shape):
Expand All @@ -121,9 +114,8 @@ def validator(obj, attribute, value):
# Raise TypeError if needed.
if not match:
raise TypeError(
"Expecting shape {} for attribute {}, got {}".format(
expected_shape, attribute.name, observed_shape
)
f"Expecting shape {expected_shape} for attribute {attribute.name}, "
f"got {observed_shape}"
)

return validator
Loading