Skip to content

Commit

Permalink
3.10 linting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Aug 6, 2024
1 parent cdc3417 commit 38e6831
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 124 deletions.
11 changes: 5 additions & 6 deletions .github/pages/make_switcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,28 +3,27 @@
from argparse import ArgumentParser
from pathlib import Path
from subprocess import CalledProcessError, check_output
from typing import List, Optional


def report_output(stdout: bytes, label: str) -> List[str]:
def report_output(stdout: bytes, label: str) -> list[str]:
ret = stdout.decode().strip().split("\n")
print(f"{label}: {ret}")
return ret


def get_branch_contents(ref: str) -> List[str]:
def get_branch_contents(ref: str) -> list[str]:
"""Get the list of directories in a branch."""
stdout = check_output(["git", "ls-tree", "-d", "--name-only", ref])
return report_output(stdout, "Branch contents")


def get_sorted_tags_list() -> List[str]:
def get_sorted_tags_list() -> list[str]:
"""Get a list of sorted tags in descending order from the repository."""
stdout = check_output(["git", "tag", "-l", "--sort=-v:refname"])
return report_output(stdout, "Tags list")


def get_versions(ref: str, add: Optional[str]) -> List[str]:
def get_versions(ref: str, add: str | None) -> list[str]:
"""Generate the file containing the list of all GitHub Pages builds."""
# Get the directories (i.e. builds) from the GitHub Pages branch
try:
Expand All @@ -41,7 +40,7 @@ def get_versions(ref: str, add: Optional[str]) -> List[str]:
tags = get_sorted_tags_list()

# Make the sorted versions list from main branches and tags
versions: List[str] = []
versions: list[str] = []
for version in ["master", "main"] + tags:
if version in builds:
versions.append(version)
Expand Down
81 changes: 37 additions & 44 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,12 @@
from __future__ import annotations

from collections.abc import Callable, Iterable, Iterator, Sequence
from dataclasses import field
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
Iterator,
List,
Optional,
Sequence,
Type,
TypeVar,
Union,
)
from typing import Any, Generic, Literal, TypeVar

import numpy as np
from pydantic import BaseConfig, Extra, Field, ValidationError, create_model
from pydantic.error_wrappers import ErrorWrapper
from typing_extensions import Literal

__all__ = [
"if_instance_do",
Expand All @@ -43,11 +30,11 @@ class StrictConfig(BaseConfig):


def discriminated_union_of_subclasses(
super_cls: Optional[Type] = None,
super_cls: type | None = None,
*,
discriminator: str = "type",
config: Optional[Type[BaseConfig]] = None,
) -> Union[Type, Callable[[Type], Type]]:
config: type[BaseConfig] | None = None,
) -> type | Callable[[type], type]:
"""Add all subclasses of super_cls to a discriminated union.
For all subclasses of super_cls, add a discriminator field to identify
Expand Down Expand Up @@ -114,7 +101,7 @@ def calculate(self) -> int:
subclasses. Defaults to None.
Returns:
Union[Type, Callable[[Type], Type]]: A decorator that adds the necessary
Type | Callable[[Type], Type]: A decorator that adds the necessary
functionality to a class.
"""

Expand All @@ -130,12 +117,12 @@ def wrap(cls):


def _discriminated_union_of_subclasses(
super_cls: Type,
super_cls: type,
discriminator: str,
config: Optional[Type[BaseConfig]] = None,
) -> Union[Type, Callable[[Type], Type]]:
super_cls._ref_classes = set()
super_cls._model = None
config: type[BaseConfig] | None = None,
) -> type | Callable[[type], type]:
super_cls._ref_classes = set() # type: ignore
super_cls._model = None # type: ignore

def __init_subclass__(cls) -> None:
# Keep track of inherting classes in super class
Expand All @@ -157,7 +144,11 @@ def __validate__(cls, v: Any) -> Any:
# needs to be done once, after all subclasses have been
# declared
if cls._model is None:
root = Union[tuple(cls._ref_classes)] # type: ignore
refs: tuple[type] = tuple(cls._ref_classes) # type: ignore
root = refs[0]
if len(refs) > 1:
for ref in refs:
root |= ref
cls._model = create_model(
super_cls.__name__,
__root__=(root, Field(..., discriminator=discriminator)),
Expand Down Expand Up @@ -185,7 +176,7 @@ def __validate__(cls, v: Any) -> Any:
return super_cls


def if_instance_do(x: Any, cls: Type, func: Callable):
def if_instance_do(x: Any, cls: type, func: Callable):
"""If x is of type cls then return func(x), otherwise return NotImplemented.
Used as a helper when implementing operator overloading.
Expand All @@ -201,7 +192,7 @@ def if_instance_do(x: Any, cls: Type, func: Callable):

#: Map of axes to float ndarray of points
#: E.g. {xmotor: array([0, 1, 2]), ymotor: array([2, 2, 2])}
AxesPoints = Dict[Axis, np.ndarray]
AxesPoints = dict[Axis, np.ndarray]


class Frames(Generic[Axis]):
Expand Down Expand Up @@ -234,9 +225,9 @@ class Frames(Generic[Axis]):
def __init__(
self,
midpoints: AxesPoints[Axis],
lower: Optional[AxesPoints[Axis]] = None,
upper: Optional[AxesPoints[Axis]] = None,
gap: Optional[np.ndarray] = None,
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
):
#: The midpoints of scan frames for each axis
self.midpoints = midpoints
Expand All @@ -253,7 +244,9 @@ def __init__(
# We have a gap if upper[i] != lower[i+1] for any axes
axes_gap = [
np.roll(upper, 1) != lower
for upper, lower in zip(self.upper.values(), self.lower.values())
for upper, lower in zip(
self.upper.values(), self.lower.values(), strict=False
)
]
self.gap = np.logical_or.reduce(axes_gap)
# Check all axes and ordering are the same
Expand All @@ -270,7 +263,7 @@ def __init__(
lengths.add(len(self.gap))
assert len(lengths) <= 1, f"Mismatching lengths {list(lengths)}"

def axes(self) -> List[Axis]:
def axes(self) -> list[Axis]:
"""The axes which will move during the scan.
These will be present in `midpoints`, `lower` and `upper`.
Expand Down Expand Up @@ -300,7 +293,7 @@ def extract_dict(ds: Iterable[AxesPoints[Axis]]) -> AxesPoints[Axis]:
return {k: v[dim_indices] for k, v in d.items()}
return {}

def extract_gap(gaps: Iterable[np.ndarray]) -> Optional[np.ndarray]:
def extract_gap(gaps: Iterable[np.ndarray]) -> np.ndarray | None:
for gap in gaps:
if not calculate_gap:
return gap[dim_indices]
Expand Down Expand Up @@ -371,7 +364,7 @@ def zip_gap(gaps: Sequence[np.ndarray]) -> np.ndarray:
def _merge_frames(
*stack: Frames[Axis],
dict_merge=Callable[[Sequence[AxesPoints[Axis]]], AxesPoints[Axis]], # type: ignore
gap_merge=Callable[[Sequence[np.ndarray]], Optional[np.ndarray]],
gap_merge=Callable[[Sequence[np.ndarray]], np.ndarray | None],
) -> Frames[Axis]:
types = {type(fs) for fs in stack}
assert len(types) == 1, f"Mismatching types for {stack}"
Expand All @@ -397,9 +390,9 @@ class SnakedFrames(Frames[Axis]):
def __init__(
self,
midpoints: AxesPoints[Axis],
lower: Optional[AxesPoints[Axis]] = None,
upper: Optional[AxesPoints[Axis]] = None,
gap: Optional[np.ndarray] = None,
lower: AxesPoints[Axis] | None = None,
upper: AxesPoints[Axis] | None = None,
gap: np.ndarray | None = None,
):
super().__init__(midpoints, lower=lower, upper=upper, gap=gap)
# Override first element of gap to be True, as subsequent runs
Expand Down Expand Up @@ -431,7 +424,7 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]:
length = len(self)
backwards = (indices // length) % 2
snake_indices = np.where(backwards, (length - 1) - indices, indices) % length
cls: Type[Frames[Any]]
cls: type[Frames[Any]]
if not calculate_gap:
cls = Frames
gap = self.gap[np.where(backwards, length - indices, indices) % length]
Expand Down Expand Up @@ -464,7 +457,7 @@ def gap_between_frames(frames1: Frames[Axis], frames2: Frames[Axis]) -> bool:
return any(frames1.upper[a][-1] != frames2.lower[a][0] for a in frames1.axes())


def squash_frames(stack: List[Frames[Axis]], check_path_changes=True) -> Frames[Axis]:
def squash_frames(stack: list[Frames[Axis]], check_path_changes=True) -> Frames[Axis]:
"""Squash a stack of nested Frames into a single one.
Args:
Expand Down Expand Up @@ -530,7 +523,7 @@ class Path(Generic[Axis]):
"""

def __init__(
self, stack: List[Frames[Axis]], start: int = 0, num: Optional[int] = None
self, stack: list[Frames[Axis]], start: int = 0, num: int | None = None
):
#: The Frames stack describing the scan, from slowest to fastest moving
self.stack = stack
Expand All @@ -544,7 +537,7 @@ def __init__(
if num is not None and start + num < self.end_index:
self.end_index = start + num

def consume(self, num: Optional[int] = None) -> Frames[Axis]:
def consume(self, num: int | None = None) -> Frames[Axis]:
"""Consume at most num frames from the Path and return as a Frames object.
>>> fx = SnakedFrames({"x": np.array([1, 2])})
Expand Down Expand Up @@ -619,12 +612,12 @@ class Midpoints(Generic[Axis]):
{'y': np.int64(4), 'x': np.int64(1)}
"""

def __init__(self, stack: List[Frames[Axis]]):
def __init__(self, stack: list[Frames[Axis]]):
#: The stack of Frames describing the scan, from slowest to fastest moving
self.stack = stack

@property
def axes(self) -> List[Axis]:
def axes(self) -> list[Axis]:
"""The axes that will be present in each points dictionary."""
axes = []
for frames in self.stack:
Expand All @@ -635,7 +628,7 @@ def __len__(self) -> int:
"""The number of dictionaries that will be produced if iterated over."""
return int(np.prod([len(frames) for frames in self.stack]))

def __iter__(self) -> Iterator[Dict[Axis, float]]:
def __iter__(self) -> Iterator[dict[Axis, float]]:
"""Yield {axis: midpoint} for each frame in the scan."""
path = Path(self.stack)
while len(path):
Expand Down
19 changes: 10 additions & 9 deletions src/scanspec/plot.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections.abc import Iterator
from itertools import cycle
from typing import Any, Dict, Iterator, List, Optional
from typing import Any

import numpy as np
from matplotlib import colors, patches
Expand All @@ -14,7 +15,7 @@
__all__ = ["plot_spec"]


def _plot_arrays(axes, arrays: List[np.ndarray], **kwargs):
def _plot_arrays(axes, arrays: list[np.ndarray], **kwargs):
if len(arrays) > 2:
axes.plot3D(arrays[2], arrays[1], arrays[0], **kwargs)
elif len(arrays) == 2:
Expand All @@ -38,7 +39,7 @@ def do_3d_projection(self, renderer=None):
return np.min(zs)


def _plot_arrow(axes, arrays: List[np.ndarray]):
def _plot_arrow(axes, arrays: list[np.ndarray]):
if len(arrays) == 1:
arrays = [np.array([0, 0])] + arrays
if len(arrays) == 2:
Expand All @@ -58,16 +59,16 @@ def _plot_arrow(axes, arrays: List[np.ndarray]):
axes.add_artist(a)


def _plot_spline(axes, ranges, arrays: List[np.ndarray], index_colours: Dict[int, str]):
scaled_arrays = [a / r for a, r in zip(arrays, ranges)]
def _plot_spline(axes, ranges, arrays: list[np.ndarray], index_colours: dict[int, str]):
scaled_arrays = [a / r for a, r in zip(arrays, ranges, strict=False)]
# Define curves parametrically
t = np.zeros(len(arrays[0]))
t[1:] = np.sqrt(sum((arr[1:] - arr[:-1]) ** 2 for arr in scaled_arrays))
t = np.cumsum(t)
if t[-1] > 0:
# Can't make a spline that starts and ends in the same place, so add a small
# delta
for s, r in zip(scaled_arrays, ranges):
for s, r in zip(scaled_arrays, ranges, strict=False):
if s[0] == s[-1]:
s += np.linspace(0, r * 1e-7, len(s))
# There are no duplicated points, plot a spline
Expand All @@ -76,16 +77,16 @@ def _plot_spline(axes, ranges, arrays: List[np.ndarray], index_colours: Dict[int
tck, _ = interpolate.splprep(scaled_arrays, k=2, s=0)
starts = sorted(index_colours)
stops = starts[1:] + [len(arrays[0]) - 1]
for start, stop in zip(starts, stops):
for start, stop in zip(starts, stops, strict=False):
tnew = np.linspace(t[start], t[stop], num=1001)
spline = interpolate.splev(tnew, tck)
# Scale the splines back to the original scaling
unscaled_splines = [a * r for a, r in zip(spline, ranges)]
unscaled_splines = [a * r for a, r in zip(spline, ranges, strict=False)]
_plot_arrays(axes, unscaled_splines, color=index_colours[start])
yield unscaled_splines


def plot_spec(spec: Spec[Any], title: Optional[str] = None):
def plot_spec(spec: Spec[Any], title: str | None = None):
"""Plot a spec, drawing the path taken through the scan.
Uses a different colour for each frame, grey for the turnarounds, and
Expand Down
Loading

0 comments on commit 38e6831

Please sign in to comment.