From 855813c1553452e1d93ec006c5222fa7ac7def2e Mon Sep 17 00:00:00 2001 From: Tom Cobb Date: Fri, 13 Sep 2024 08:11:00 +0000 Subject: [PATCH] Add pydantic autodoc and pydocstyle --- .devcontainer/devcontainer.json | 3 +- docs/conf.py | 33 +++++---------- docs/explanations/why-stack-frames.rst | 13 +++--- pyproject.toml | 22 +++++++--- src/scanspec/__init__.py | 4 +- src/scanspec/__main__.py | 2 + src/scanspec/cli.py | 2 + src/scanspec/core.py | 16 ++++++-- src/scanspec/plot.py | 2 + src/scanspec/regions.py | 45 +++++++++++--------- src/scanspec/service.py | 14 ++++++- src/scanspec/specs.py | 57 +++++++++++++++----------- src/scanspec/sphinxext.py | 3 ++ tests/test_errors.py | 2 +- 14 files changed, 130 insertions(+), 88 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index d3d639a5..b8781ba4 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -13,7 +13,8 @@ "vscode": { // Set *default* container specific settings.json values on container create. "settings": { - "python.defaultInterpreterPath": "/venv/bin/python" + "python.defaultInterpreterPath": "/venv/bin/python", + "remote.autoForwardPorts": false }, // Add the IDs of extensions you want installed when the container is created. "extensions": [ diff --git a/docs/conf.py b/docs/conf.py index de976947..acfda3be 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -35,6 +35,8 @@ "sphinx.ext.autodoc", # and making summary tables at the top of API docs "sphinx.ext.autosummary", + # With an extension for pydantic models + "sphinxcontrib.autodoc_pydantic", # This can parse google style docstrings "sphinx.ext.napoleon", # For linking to external sphinx documentation @@ -69,22 +71,8 @@ # domain name if present. Example entries would be ('py:func', 'int') or # ('envvar', 'LD_LIBRARY_PATH'). nitpick_ignore = [ - ("py:func", "int"), - ("py:class", "Axis"), - ("py:class", "~Axis"), - ("py:class", "scanspec.core.Axis"), - ("py:class", "AxesPoints"), - ("py:class", "np.ndarray"), - ("py:class", "NoneType"), - ("py:class", "'str'"), - ("py:class", "'float'"), - ("py:class", "'int'"), - ("py:class", "'bool'"), - ("py:class", "'object'"), - ("py:class", "'id'"), - ("py:class", "typing_extensions.Literal"), - ("py:class", "pydantic.config.BaseConfig"), - ("py:class", "starlette.responses.JSONResponse"), + ("py:class", "scanspec.core.C"), + ("py:class", "pydantic.config.ConfigDict"), ] # Both the class’ and the __init__ method’s docstring are concatenated and @@ -94,15 +82,14 @@ # Order the members by the order they appear in the source code autodoc_member_order = "bysource" -# Don't inherit docstrings from baseclasses -autodoc_inherit_docstrings = False +# For autodoc we want to document some additional optional modules +scanspec.__all__ += ["plot"] -# Insert inheritance links -autodoc_default_options = {"show-inheritance": True} +# Don't show config summary as it's not relevant +autodoc_pydantic_model_show_config_summary = False -# A dictionary for users defined type aliases that maps a type name to the -# full-qualified object name. -autodoc_type_aliases = {"AxesPoints": "scanspec.core.AxesPoints"} +# Show the fields in source order +autodoc_pydantic_model_summary_list_order = "bysource" # Include source in plot directive by default plot_include_source = True diff --git a/docs/explanations/why-stack-frames.rst b/docs/explanations/why-stack-frames.rst index 48ef7930..815c93b9 100644 --- a/docs/explanations/why-stack-frames.rst +++ b/docs/explanations/why-stack-frames.rst @@ -1,12 +1,12 @@ Why create a stack of Frames? ============================= -If a `Spec` tells you the parameters of a scan, `Frames` gives you the `Points` -that will let you actually execute the scan. A stack of Frames is interpreted as -nested from slowest moving to fastest moving, so each faster Frames object will -iterate once per position of the slower Frames object. When fly-scanning the -axis will traverse lower-midpoint-upper on the fastest Frames object for each -point in the scan. +If a `Spec` tells you the parameters of a scan, `Frames` gives you the `Points +` that will let you actually execute the scan. A stack of Frames is +interpreted as nested from slowest moving to fastest moving, so each faster +Frames object will iterate once per position of the slower Frames object. When +fly-scanning the axis will traverse lower-midpoint-upper on the fastest Frames +object for each point in the scan. An Example ---------- @@ -63,4 +63,3 @@ which point it destroys the performance of the VDS. For this reason, it is advisable to `Squash` any snaking Specs with the first non-snaking axis above it so that the HDF Dimension will not be snaking. See `./why-squash-can-change-path` for some details on this. - diff --git a/pyproject.toml b/pyproject.toml index f62469ae..b0bf50e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,9 @@ dev = [ # https://github.com/pypa/pip/issues/10393 "scanspec[plotting]", "scanspec[service]", + "autodoc_pydantic @ git+https://github.com/coretl/autodoc_pydantic.git@0b95311d8d10fce67a9ecd5830330364e31fa49c", "copier", + "httpx", "myst-parser", "pipdeptree", "pre-commit", @@ -44,8 +46,6 @@ dev = [ "sphinxcontrib-openapi", "tox-direct", "types-mock", - "httpx", - "myst-parser", ] [project.scripts] @@ -115,6 +115,7 @@ line-length = 88 extend-select = [ "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 + "D", # pydocstyle - https://docs.astral.sh/ruff/rules/#pydocstyle-d "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w @@ -124,10 +125,19 @@ extend-select = [ ] ignore = [ "B008", # We use function calls in service arguments + "D105", # Don't document magic methods as they don't appear in sphinx autodoc pages + "D107", # We document the class, not the __init__ method ] +[tool.ruff.lint.pydocstyle] +convention = "google" + [tool.ruff.lint.per-file-ignores] -# By default, private member access is allowed in tests -# See https://github.com/DiamondLightSource/python-copier-template/issues/154 -# Remove this line to forbid private member access in tests -"tests/**/*" = ["SLF001"] + +"tests/**/*" = [ + # By default, private member access is allowed in tests + # See https://github.com/DiamondLightSource/python-copier-template/issues/154 + # Remove this line to forbid private member access in tests + "SLF001", + "D", # Don't check docstrings in tests +] diff --git a/src/scanspec/__init__.py b/src/scanspec/__init__.py index 4ab90134..5546758d 100644 --- a/src/scanspec/__init__.py +++ b/src/scanspec/__init__.py @@ -6,7 +6,7 @@ Version number as calculated by https://github.com/pypa/setuptools_scm """ -from . import regions, specs +from . import core, regions, specs from ._version import __version__ -__all__ = ["__version__", "specs", "regions"] +__all__ = ["__version__", "core", "specs", "regions"] diff --git a/src/scanspec/__main__.py b/src/scanspec/__main__.py index 633c3d88..a8da08d1 100644 --- a/src/scanspec/__main__.py +++ b/src/scanspec/__main__.py @@ -1,3 +1,5 @@ +"""Interface for ``python -m scanspec``.""" + from scanspec import cli if __name__ == "__main__": diff --git a/src/scanspec/cli.py b/src/scanspec/cli.py index 33f56520..82a57cce 100644 --- a/src/scanspec/cli.py +++ b/src/scanspec/cli.py @@ -1,3 +1,5 @@ +"""Interface for ``python -m scanspec``.""" + import logging import string diff --git a/src/scanspec/core.py b/src/scanspec/core.py index a0ea6169..a15b644c 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -1,3 +1,5 @@ +"""Core classes like `Frames` and `Path`.""" + from __future__ import annotations from collections.abc import Callable, Iterable, Iterator, Sequence @@ -25,11 +27,10 @@ "StrictConfig", ] - +#: Used to ensure pydantic dataclasses error if given extra arguments StrictConfig: ConfigDict = {"extra": "forbid"} C = TypeVar("C") -T = TypeVar("T", type, Callable) def discriminated_union_of_subclasses( @@ -44,8 +45,7 @@ def discriminated_union_of_subclasses( Subclasses that extend this class must be Pydantic dataclasses, and types that need their schema to be updated when a new type that extends super_cls is - created must be either Pydantic dataclasses or BaseModels, and must be decorated - with @uses_tagged_union. + created must be either Pydantic dataclasses or BaseModels. Example:: @@ -106,6 +106,7 @@ def calculate(self) -> int: Returns: Type: decorated superclass with handling for subclasses to be added to its discriminated union for deserialization + """ tagged_union = _TaggedUnion(super_cls, discriminator) _tagged_unions[super_cls] = tagged_union @@ -217,6 +218,7 @@ class Frames(Generic[Axis]): See Also: `technical-terms` + """ def __init__( @@ -282,6 +284,7 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]: >>> frames = Frames({"x": np.array([1, 2, 3])}) >>> frames.extract(np.array([1, 0, 1])).midpoints {'x': array([2, 1, 2])} + """ dim_indices = indices % len(self) @@ -312,6 +315,7 @@ def concat(self, other: Frames[Axis], gap: bool = False) -> Frames[Axis]: >>> frames2 = Frames({"y": np.array([3, 2, 1]), "x": np.array([4, 5, 6])}) >>> frames.concat(frames2).midpoints {'x': array([1, 2, 3, 4, 5, 6]), 'y': array([6, 5, 4, 3, 2, 1])} + """ assert set(self.axes()) == set( other.axes() @@ -411,6 +415,7 @@ def extract(self, indices: np.ndarray, calculate_gap=True) -> Frames[Axis]: >>> frames = SnakedFrames({"x": np.array([1, 2, 3])}) >>> frames.extract(np.array([0, 1, 2, 3, 4, 5])).midpoints {'x': array([1, 2, 3, 3, 2, 1])} + """ # Calculate the indices # E.g for len = 4 @@ -470,6 +475,7 @@ def squash_frames(stack: list[Frames[Axis]], check_path_changes=True) -> Frames[ >>> fy = Frames({"y": np.array([3, 4])}) >>> squash_frames([fy, fx]).midpoints {'y': array([3, 3, 4, 4]), 'x': array([1, 2, 2, 1])} + """ path = Path(stack) # Consuming a Path through these Frames performs the squash @@ -517,6 +523,7 @@ class Path(Generic[Axis]): See Also: `iterate-a-spec` + """ def __init__( @@ -607,6 +614,7 @@ class Midpoints(Generic[Axis]): {'y': np.int64(3), 'x': np.int64(2)} {'y': np.int64(4), 'x': np.int64(2)} {'y': np.int64(4), 'x': np.int64(1)} + """ def __init__(self, stack: list[Frames[Axis]]): diff --git a/src/scanspec/plot.py b/src/scanspec/plot.py index e4fc1e8c..f81c3447 100644 --- a/src/scanspec/plot.py +++ b/src/scanspec/plot.py @@ -1,3 +1,5 @@ +"""`plot_spec` to visualize a scan.""" + from collections.abc import Iterator from itertools import cycle from typing import Any diff --git a/src/scanspec/regions.py b/src/scanspec/regions.py index e60a3494..0f8e6872 100644 --- a/src/scanspec/regions.py +++ b/src/scanspec/regions.py @@ -1,3 +1,10 @@ +"""`Region` and its subclasses. + +.. inheritance-diagram:: scanspec.regions + :top-classes: scanspec.regions.Region + :parts: 1 +""" + from __future__ import annotations from collections.abc import Iterator, Mapping @@ -45,11 +52,11 @@ class Region(Generic[Axis]): - ``^``: `SymmetricDifferenceOf` two Regions, midpoints present in one not both """ - def axis_sets(self) -> list[set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: # noqa: D102 """Produce the non-overlapping sets of axes this region spans.""" raise NotImplementedError(self) - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 """Produce a mask of which points are in the region.""" raise NotImplementedError(self) @@ -111,7 +118,7 @@ class CombinationOf(Region[Axis]): left: Region[Axis] = Field(description="The left-hand Region to combine") right: Region[Axis] = Field(description="The right-hand Region to combine") - def axis_sets(self) -> list[set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: # noqa: D102 axis_sets = list( _merge_axis_sets(self.left.axis_sets() + self.right.axis_sets()) ) @@ -130,7 +137,7 @@ class UnionOf(CombinationOf[Axis]): array([False, True, True, True, False]) """ - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 mask = get_mask(self.left, points) | get_mask(self.right, points) return mask @@ -146,7 +153,7 @@ class IntersectionOf(CombinationOf[Axis]): array([False, False, True, False, False]) """ - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 mask = get_mask(self.left, points) & get_mask(self.right, points) return mask @@ -162,7 +169,7 @@ class DifferenceOf(CombinationOf[Axis]): array([False, True, False, False, False]) """ - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 left_mask = get_mask(self.left, points) # Return the xor restricted to the left region mask = left_mask ^ get_mask(self.right, points) & left_mask @@ -180,7 +187,7 @@ class SymmetricDifferenceOf(CombinationOf[Axis]): array([False, True, False, True, False]) """ - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 mask = get_mask(self.left, points) ^ get_mask(self.right, points) return mask @@ -198,10 +205,10 @@ class Range(Region[Axis]): min: float = Field(description="The minimum inclusive value in the region") max: float = Field(description="The minimum inclusive value in the region") - def axis_sets(self) -> list[set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.axis}] - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 v = points[self.axis] mask = np.bitwise_and(v >= self.min, v <= self.max) return mask @@ -230,10 +237,10 @@ class Rectangle(Region[Axis]): description="Clockwise rotation angle of the rectangle", default=0.0 ) - def axis_sets(self) -> list[set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.x_axis, self.y_axis}] - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 x = points[self.x_axis] - self.x_min y = points[self.y_axis] - self.y_min if self.angle != 0: @@ -270,10 +277,10 @@ class Polygon(Region[Axis]): description="The Nx1 y coordinates of the polygons vertices", min_length=3 ) - def axis_sets(self) -> list[set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.x_axis, self.y_axis}] - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 x = points[self.x_axis] y = points[self.y_axis] v1x, v1y = self.x_verts[-1], self.y_verts[-1] @@ -310,10 +317,10 @@ class Circle(Region[Axis]): y_middle: float = Field(description="The central y point of the circle") radius: float = Field(description="Radius of the circle", gt=0) - def axis_sets(self) -> list[set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.x_axis, self.y_axis}] - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 x = points[self.x_axis] - self.x_middle y = points[self.y_axis] - self.y_middle mask = x * x + y * y <= (self.radius * self.radius) @@ -345,10 +352,10 @@ class Ellipse(Region[Axis]): ) angle: float = Field(description="The angle of the ellipse (degrees)", default=0.0) - def axis_sets(self) -> list[set[Axis]]: + def axis_sets(self) -> list[set[Axis]]: # noqa: D102 return [{self.x_axis, self.y_axis}] - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: # noqa: D102 x = points[self.x_axis] - self.x_middle y = points[self.y_axis] - self.y_middle if self.angle != 0: @@ -362,7 +369,7 @@ def mask(self, points: AxesPoints[Axis]) -> np.ndarray: return mask -def find_regions(obj) -> Iterator[Region[Axis]]: +def find_regions(obj) -> Iterator[Region]: """Recursively yield Regions from obj and its children.""" if ( hasattr(obj, "__pydantic_model__") @@ -372,5 +379,5 @@ def find_regions(obj) -> Iterator[Region[Axis]]: if isinstance(obj, Region): yield obj for name in obj.__dict__.keys(): - regions: Iterator[Region[Axis]] = find_regions(getattr(obj, name)) + regions: Iterator[Region] = find_regions(getattr(obj, name)) yield from regions diff --git a/src/scanspec/service.py b/src/scanspec/service.py index 52121833..ce628de5 100644 --- a/src/scanspec/service.py +++ b/src/scanspec/service.py @@ -1,3 +1,5 @@ +"""FastAPI service to query information about Specs.""" + import base64 import json from collections.abc import Mapping @@ -133,6 +135,7 @@ def valid( Returns: ValidResponse: A canonical version of the spec if it is valid. An error otherwise. + """ valid_spec = Spec.deserialize(spec.serialize()) return ValidResponse(spec, valid_spec) @@ -156,6 +159,7 @@ def midpoints( Returns: MidpointsResponse: Midpoints of the scan + """ chunk, total_frames = _to_chunk(request) return MidpointsResponse( @@ -182,6 +186,7 @@ def bounds( Returns: BoundsResponse: Bounds of the scan + """ chunk, total_frames = _to_chunk(request) return BoundsResponse( @@ -207,10 +212,11 @@ def gap( after each frame. Args: - request: Scanspec and formatting info. + spec: Scanspec and formatting info. Returns: GapResponse: Bounds of the scan + """ dims = spec.calculate() # Grab dimensions from spec path = Path(dims) # Convert to a path @@ -231,6 +237,7 @@ def smallest_step( Returns: SmallestStepResponse: A description of the smallest steps in the spec + """ dims = spec.calculate() # Grab dimensions from spec path = Path(dims) # Convert to a path @@ -281,6 +288,7 @@ def _format_axes_points( Returns: Mapping[str, Points]: A mapping of axis to formatted points. + """ if format is PointsFormat.FLOAT_LIST: return {axis: list(points) for axis, points in axes_points.items()} @@ -301,6 +309,7 @@ def _reduce_frames(stack: list[Frames[str]], max_frames: int) -> Path: Args: stack: A stack of Frames created by a spec max_frames: The maximum number of frames the user wishes to be returned + """ # Calculate the total number of frames num_frames = 1 @@ -320,6 +329,7 @@ def _sub_sample(frames: Frames[str], ratio: float) -> Frames: Args: frames: the Frames object to be reduced ratio: the reduction ratio of the dimension + """ num_indexes = int(len(frames) / ratio) indexes = np.linspace(0, len(frames) - 1, num_indexes, dtype=np.int32) @@ -344,6 +354,7 @@ def _abs_diffs(array: np.ndarray) -> np.ndarray: Returns: A newly constucted array of absolute differences + """ # [array[1] - array[0], array[2] - array[1], ...] adjacent_diffs = array[1:] - array[:-1] @@ -371,6 +382,7 @@ def scanspec_schema_text() -> str: Returns: str: The OpenAPI schema + """ return json.dumps( get_openapi( diff --git a/src/scanspec/specs.py b/src/scanspec/specs.py index a9e4d648..adfec1a7 100644 --- a/src/scanspec/specs.py +++ b/src/scanspec/specs.py @@ -1,3 +1,10 @@ +"""`Spec` and its subclasses. + +.. inheritance-diagram:: scanspec.specs + :top-classes: scanspec.specs.Spec + :parts: 1 +""" + from __future__ import annotations from collections.abc import Callable, Mapping @@ -58,14 +65,14 @@ class Spec(Generic[Axis]): - ``~``: `Snake` the Spec, reversing every other iteration of it """ - def axes(self) -> list[Axis]: + def axes(self) -> list[Axis]: # noqa: D102 """Return the list of axes that are present in the scan. Ordered from slowest moving to fastest moving. """ raise NotImplementedError(self) - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 """Produce a stack of nested `Frames` that form the scan. Ordered from slowest moving to fastest moving. @@ -130,10 +137,10 @@ class Product(Spec[Axis]): outer: Spec[Axis] = Field(description="Will be executed once") inner: Spec[Axis] = Field(description="Will be executed len(outer) times") - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 return self.outer.axes() + self.inner.axes() - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 frames_outer = self.outer.calculate(bounds=False, nested=nested) frames_inner = self.inner.calculate(bounds, nested=True) return frames_outer + frames_inner @@ -169,10 +176,10 @@ class Repeat(Spec[Axis]): default=True, ) - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 return [] - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 return [Frames({}, gap=np.full(self.num, self.gap))] @@ -206,10 +213,10 @@ class Zip(Spec[Axis]): description="The right-hand Spec to Zip, will appear later in axes" ) - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 return self.left.axes() + self.right.axes() - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 frames_left = self.left.calculate(bounds, nested) frames_right = self.right.calculate(bounds, nested) assert len(frames_left) >= len( @@ -274,10 +281,10 @@ class Mask(Spec[Axis]): default=True, ) - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 frames = self.spec.calculate(bounds, nested) for axis_set in self.region.axis_sets(): # Find the start and end index of any dimensions containing these axes @@ -332,10 +339,10 @@ class Snake(Spec[Axis]): description="The Spec to run in reverse every other iteration" ) - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 return [ SnakedFrames.from_frames(segment) for segment in self.spec.calculate(bounds, nested) @@ -371,14 +378,14 @@ class Concat(Spec[Axis]): default=True, ) - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 left_axes, right_axes = self.left.axes(), self.right.axes() # Assuming the axes are the same, the order does not matter, we inherit the # order from the left-hand side. See also scanspec.core.concat. assert set(left_axes) == set(right_axes), f"axes {left_axes} != {right_axes}" return left_axes - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 dim_left = squash_frames( self.left.calculate(bounds, nested), nested and self.check_path_changes ) @@ -401,6 +408,7 @@ class Squash(Spec[Axis]): from scanspec.specs import Line, Squash spec = Squash(Line("y", 1, 2, 3) * Line("x", 0, 1, 4)) + """ spec: Spec[Axis] = Field(description="The Spec to squash the dimensions of") @@ -409,10 +417,10 @@ class Squash(Spec[Axis]): default=True, ) - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 return self.spec.axes() - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 dims = self.spec.calculate(bounds, nested) dim = squash_frames(dims, nested and self.check_path_changes) return [dim] @@ -461,7 +469,7 @@ class Line(Spec[Axis]): stop: float = Field(description="Midpoint of the last point of the line") num: int = Field(ge=1, description="Number of frames to produce") - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 return [self.axis] def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: @@ -476,7 +484,7 @@ def _line_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: first = self.start - step / 2 return {self.axis: indexes * step + first} - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 return _dimensions_from_indexes( self._line_from_indexes, self.axes(), self.num, bounds ) @@ -547,13 +555,13 @@ def duration( """ return cls(DURATION, duration, num) - def axes(self) -> list: + def axes(self) -> list[Axis]: # noqa: D102 return [self.axis] def _repeats_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: return {self.axis: np.full(len(indexes), self.value)} - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 return _dimensions_from_indexes( self._repeats_from_indexes, self.axes(), self.num, bounds ) @@ -589,7 +597,7 @@ class Spiral(Spec[Axis]): description="How much to rotate the angle of the spiral", default=0.0 ) - def axes(self) -> list[Axis]: + def axes(self) -> list[Axis]: # noqa: D102 # TODO: reversed from __init__ args, a good idea? return [self.y_axis, self.x_axis] @@ -610,7 +618,7 @@ def _spiral_from_indexes(self, indexes: np.ndarray) -> dict[Axis, np.ndarray]: self.x_axis: self.x_start + x_scale * phi * np.sin(phi + self.rotate), } - def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: + def calculate(self, bounds=True, nested=False) -> list[Frames[Axis]]: # noqa: D102 return _dimensions_from_indexes( self._spiral_from_indexes, self.axes(), self.num, bounds ) @@ -662,6 +670,7 @@ def fly(spec: Spec[Axis], duration: float) -> Spec[Axis]: from scanspec.specs import Line, fly spec = fly(Line("x", 1, 2, 3), 0.1) + """ return spec.zip(Static.duration(duration)) @@ -680,13 +689,13 @@ def step(spec: Spec[Axis], duration: float, num: int = 1) -> Spec[Axis]: from scanspec.specs import Line, step spec = step(Line("x", 1, 2, 3), 0.1) + """ return spec * Static.duration(duration, num) def get_constant_duration(frames: list[Frames]) -> float | None: - """ - Returns the duration of a number of ScanSpec frames, if known and consistent. + """Returns the duration of a number of ScanSpec frames, if known and consistent. Args: frames (List[Frames]): A number of Frame objects diff --git a/src/scanspec/sphinxext.py b/src/scanspec/sphinxext.py index 6a1e2630..69500422 100644 --- a/src/scanspec/sphinxext.py +++ b/src/scanspec/sphinxext.py @@ -1,3 +1,5 @@ +"""An example_spec directive.""" + from contextlib import contextmanager from docutils.statemachine import StringList @@ -26,6 +28,7 @@ class ExampleSpecDirective(plot_directive.PlotDirective): """Runs `plot_spec` on the ``spec`` definied in the content.""" def run(self): + """Run the directive.""" self.content = StringList( ["# Example Spec", "", "from scanspec.plot import plot_spec"] + [str(x) for x in self.content] diff --git a/tests/test_errors.py b/tests/test_errors.py index a7ff923c..80144234 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -14,7 +14,7 @@ def test_not_implemented() -> None: with pytest.raises(NotImplementedError): Spec().calculate() with pytest.raises(TypeError): - Spec() * Region() + Spec() * Region() # type: ignore def test_non_snake_not_allowed_inside_snaking_dim() -> None: