From b140b53b01a1a72fe07e020b6090aaf0619531bf Mon Sep 17 00:00:00 2001 From: Zoheb Shaikh Date: Thu, 8 Aug 2024 13:33:38 +0100 Subject: [PATCH] initial commit --- src/scanspec/core.py | 123 +++++++++------------------------------- src/scanspec/regions.py | 52 ++++++++--------- src/scanspec/service.py | 7 +-- src/scanspec/specs.py | 10 ++-- tests/test_basemodel.py | 12 ++++ 5 files changed, 70 insertions(+), 134 deletions(-) create mode 100644 tests/test_basemodel.py diff --git a/src/scanspec/core.py b/src/scanspec/core.py index 39f2a57a..40646f7b 100644 --- a/src/scanspec/core.py +++ b/src/scanspec/core.py @@ -1,17 +1,12 @@ from __future__ import annotations -import dataclasses from collections.abc import Callable, Iterable, Iterator, Sequence from functools import partial -from inspect import isclass from typing import ( Any, Generic, Literal, TypeVar, - Union, - get_origin, - get_type_hints, ) import numpy as np @@ -21,8 +16,7 @@ GetCoreSchemaHandler, TypeAdapter, ) -from pydantic.dataclasses import rebuild_dataclass -from pydantic.fields import FieldInfo +from pydantic_core.core_schema import tagged_union_schema __all__ = [ "if_instance_do", @@ -107,18 +101,17 @@ def calculate(self) -> int: super_cls: The superclass of the union, Expression in the above example discriminator: The discriminator that will be inserted into the serialized documents for type determination. Defaults to "type". - config: A pydantic config class to be inserted into all - subclasses. Defaults to None. Returns: - Type | Callable[[Type], Type]: A decorator that adds the necessary + Type: A decorator that adds the necessary functionality to a class. """ tagged_union = _TaggedUnion(cls, discriminator) - _tagged_unions[cls] = tagged_union - cls.__init_subclass__ = classmethod(partial(__init_subclass__, discriminator)) + cls.__init_subclass__ = classmethod( + partial(_add_subclass_to_tagged_union, tagged_union, discriminator) + ) cls.__get_pydantic_core_schema__ = classmethod( - partial(__get_pydantic_core_schema__, tagged_union=tagged_union) + partial(_schema_of_tagged_union, tagged_union=tagged_union) ) return cls @@ -126,112 +119,52 @@ def calculate(self) -> int: T = TypeVar("T", type, Callable) -def deserialize_as(cls, obj): - return _tagged_unions[cls].type_adapter.validate_python(obj) - - -def uses_tagged_union(cls_or_func: T) -> T: - """ - Decorator that processes the type hints of a class or function to detect and - register any tagged unions. If a tagged union is detected in the type hints, - it registers the class or function as a referrer to that tagged union. - Args: - cls_or_func (T): The class or function to be processed for tagged unions. - Returns: - T: The original class or function, unmodified. - """ - for k, v in get_type_hints(cls_or_func).items(): - tagged_union = _tagged_unions.get(get_origin(v) or v, None) - if tagged_union: - tagged_union.add_referrer(cls_or_func, k) - return cls_or_func - - class _TaggedUnion: def __init__(self, base_class: type, discriminator: str): self._base_class = base_class # The members of the tagged union, i.e. subclasses of the baseclasses self._members: list[type] = [] # Classes and their field names that refer to this tagged union - self._referrers: dict[type | Callable, set[str]] = {} - self.type_adapter: TypeAdapter = TypeAdapter(None) + self.type_adapter: TypeAdapter | None = None self._discriminator = discriminator - def _make_union(self): - if len(self._members) > 0: - return Union[tuple(self._members)] # type: ignore # noqa - - def _set_discriminator(self, cls: type | Callable, field_name: str, field: Any): - # Set the field to use the `type` discriminator on deserialize - # https://docs.pydantic.dev/2.8/concepts/unions/#discriminated-unions-with-str-discriminators - if isclass(cls): - assert isinstance( - field, FieldInfo - ), f"Expected {cls.__name__}.{field_name} to be a Pydantic field, not {field!r}" # noqa: E501 - field.discriminator = self._discriminator - def add_member(self, cls: type): if cls in self._members: - # A side effect of hooking to __get_pydantic_core_schema__ is that it is - # called muliple times for the same member, do no process if it wouldn't - # change the member list return - self._members.append(cls) - union = self._make_union() - if union: - # There are more than 1 subclasses in the union, so set all the referrers - # to use this union - for referrer, fields in self._referrers.items(): - if isclass(referrer): - for field in dataclasses.fields(referrer): - if field.name in fields: - field.type = union - self._set_discriminator(referrer, field.name, field.default) - rebuild_dataclass(referrer, force=True) - # Make a type adapter for use in deserialization - self.type_adapter = TypeAdapter(union) - - def add_referrer(self, cls: type | Callable, attr_name: str): - self._referrers.setdefault(cls, set()).add(attr_name) - union = self._make_union() - if union: - # There are more than 1 subclasses in the union, so set the referrer - # (which is currently being constructed) to use it - # note that we use annotations as the class has not been turned into - # a dataclass yet - cls.__annotations__[attr_name] = union - self._set_discriminator(cls, attr_name, getattr(cls, attr_name, None)) - - -_tagged_unions: dict[type, _TaggedUnion] = {} - - -def __init_subclass__(discriminator: str, cls: type): + + def schema(self, handler): + return tagged_union_schema( + {member.__name__: handler(member) for member in self._members}, + self._discriminator, + ) + + +def _add_subclass_to_tagged_union( + tagged_union: _TaggedUnion, discriminator: str, cls: type +): # Add a discriminator field to the class so it can - # be identified when deserailizing, and make sure it is last in the list + # be identified when deserializing, and make sure it is last in the list cls.__annotations__ = { **cls.__annotations__, discriminator: Literal[cls.__name__], # type: ignore } - cls.type = Field(cls.__name__, repr=False) # type: ignore - # Replace any bare annotation with a discriminated union of subclasses - # and register this class as one that refers to that union so it can be updated - for k, v in get_type_hints(cls).items(): - # This works for Expression[T] or Expression - tagged_union = _tagged_unions.get(get_origin(v) or v, None) - if tagged_union: - tagged_union.add_referrer(cls, k) + setattr(cls, discriminator, Field(cls.__name__, repr=False)) # type: ignore + + def _return_handler_of_cls(cls, source_type: Any, handler: GetCoreSchemaHandler): + return handler(cls) + cls.__get_pydantic_core_schema__ = classmethod(_return_handler_of_cls) + tagged_union.add_member(cls) -def __get_pydantic_core_schema__( + +def _schema_of_tagged_union( cls, source_type: Any, handler: GetCoreSchemaHandler, tagged_union: _TaggedUnion ): # Rebuild any dataclass (including this one) that references this union # Note that this has to be done after the creation of the dataclass so that # previously created classes can refer to this newly created class - tagged_union.add_member(cls) - return handler(source_type) + return tagged_union.schema(handler) def if_instance_do(x: Any, cls: type, func: Callable): diff --git a/src/scanspec/regions.py b/src/scanspec/regions.py index 0c9b0b1f..bb965756 100644 --- a/src/scanspec/regions.py +++ b/src/scanspec/regions.py @@ -5,14 +5,13 @@ from typing import Any, Generic import numpy as np -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, TypeAdapter from pydantic.dataclasses import dataclass from .core import ( AxesPoints, Axis, StrictConfig, - deserialize_as, discriminated_union_of_subclasses, if_instance_do, ) @@ -71,9 +70,8 @@ def serialize(self) -> Mapping[str, Any]: return asdict(self) # type: ignore @staticmethod - def deserialize(obj): - """Deserialize the Region from a dictionary.""" - return deserialize_as(Region, obj) + def deserialize(obj: Mapping[str, Any]) -> Region: + return TypeAdapter(Region).validate_python(obj) def get_mask(region: Region[Axis], points: AxesPoints[Axis]) -> np.ndarray: @@ -119,6 +117,28 @@ def axis_sets(self) -> list[set[Axis]]: return axis_sets +@dataclass(config=StrictConfig) +class Range(Region[Axis]): + """Mask contains points of axis >= min and <= max. + + >>> r = Range("x", 1, 2) + >>> r.mask({"x": np.array([0, 1, 2, 3, 4])}) + array([False, True, True, False, False]) + """ + + axis: Axis = Field(description="The name matching the axis to mask in spec") + min: float = Field(description="The minimum inclusive value in the region") + max: float = Field(description="The minimum inclusive value in the region") + + def axis_sets(self) -> list[set[Axis]]: + return [{self.axis}] + + def mask(self, points: AxesPoints[Axis]) -> np.ndarray: + v = points[self.axis] + mask = np.bitwise_and(v >= self.min, v <= self.max) + return mask + + # Naming so we don't clash with typing.Union @dataclass(config=StrictConfig) class UnionOf(CombinationOf[Axis]): @@ -186,28 +206,6 @@ def mask(self, points: AxesPoints[Axis]) -> np.ndarray: return mask -@dataclass(config=StrictConfig) -class Range(Region[Axis]): - """Mask contains points of axis >= min and <= max. - - >>> r = Range("x", 1, 2) - >>> r.mask({"x": np.array([0, 1, 2, 3, 4])}) - array([False, True, True, False, False]) - """ - - axis: Axis = Field(description="The name matching the axis to mask in spec") - min: float = Field(description="The minimum inclusive value in the region") - max: float = Field(description="The minimum inclusive value in the region") - - def axis_sets(self) -> list[set[Axis]]: - return [{self.axis}] - - def mask(self, points: AxesPoints[Axis]) -> np.ndarray: - v = points[self.axis] - mask = np.bitwise_and(v >= self.min, v <= self.max) - return mask - - @dataclass(config=StrictConfig) class Rectangle(Region[Axis]): """Mask contains points of axis within a rotated xy rectangle. diff --git a/src/scanspec/service.py b/src/scanspec/service.py index e52ef5d9..ac8a147f 100644 --- a/src/scanspec/service.py +++ b/src/scanspec/service.py @@ -11,7 +11,7 @@ from pydantic import Field from pydantic.dataclasses import dataclass -from scanspec.core import AxesPoints, Frames, Path, uses_tagged_union +from scanspec.core import AxesPoints, Frames, Path from .specs import Line, Spec @@ -27,7 +27,6 @@ @dataclass -@uses_tagged_union class ValidResponse: """Response model for spec validation.""" @@ -44,7 +43,6 @@ class PointsFormat(str, Enum): @dataclass -@uses_tagged_union class PointsRequest: """A request for generated scan points.""" @@ -125,7 +123,6 @@ class SmallestStepResponse: @app.post("/valid", response_model=ValidResponse) -@uses_tagged_union def valid( spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]), ) -> ValidResponse | JSONResponse: @@ -198,7 +195,6 @@ def bounds( @app.post("/gap", response_model=GapResponse) -@uses_tagged_union def gap( spec: Spec = Body( ..., @@ -224,7 +220,6 @@ def gap( @app.post("/smalleststep", response_model=SmallestStepResponse) -@uses_tagged_union def smallest_step( spec: Spec = Body(..., examples=[_EXAMPLE_SPEC]), ) -> SmallestStepResponse: diff --git a/src/scanspec/specs.py b/src/scanspec/specs.py index d51e5a6f..74c4d262 100644 --- a/src/scanspec/specs.py +++ b/src/scanspec/specs.py @@ -8,7 +8,7 @@ ) import numpy as np -from pydantic import Field, validate_call +from pydantic import Field, TypeAdapter, validate_call from pydantic.dataclasses import dataclass from .core import ( @@ -18,7 +18,6 @@ Path, SnakedFrames, StrictConfig, - deserialize_as, discriminated_union_of_subclasses, gap_between_frames, if_instance_do, @@ -107,13 +106,12 @@ def concat(self, other: Spec) -> Concat[Axis]: return Concat(self, other) def serialize(self) -> Mapping[str, Any]: - """Serialize the spec to a dictionary.""" + """Serialize the Spec to a dictionary.""" return asdict(self) # type: ignore @staticmethod - def deserialize(obj): - """Deserialize the spec from a dictionary.""" - return deserialize_as(Spec, obj) + def deserialize(obj: Mapping[str, Any]) -> Spec: + return TypeAdapter(Spec).validate_python(obj) @dataclass(config=StrictConfig) diff --git a/tests/test_basemodel.py b/tests/test_basemodel.py new file mode 100644 index 00000000..8b247215 --- /dev/null +++ b/tests/test_basemodel.py @@ -0,0 +1,12 @@ +from pydantic import BaseModel, Field + +from scanspec.specs import Line, Spec + + +def test_base_model(): + class Foo(BaseModel): + # class Foo(BaseModel): + spec: Spec = Field(description="This is for test") + # spec: float = 1.0 + + Foo(spec=Line("x", 1, 2, 5))