Skip to content

Commit

Permalink
Include generics in discriminated union schemas (#157)
Browse files Browse the repository at this point in the history
* Restrict subclasses of typed unions
* Allow arbitrary types to be included in 'StrictConfig' types
* Reword _TaggedUnion comments
  • Loading branch information
tpoliaw authored Feb 7, 2025
1 parent fdb5bde commit 96f5982
Show file tree
Hide file tree
Showing 3 changed files with 240 additions and 6 deletions.
2 changes: 1 addition & 1 deletion schema.json

Large diffs are not rendered by default.

87 changes: 82 additions & 5 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
from __future__ import annotations

import itertools
import sys
import warnings
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import lru_cache
from inspect import isclass
from types import GenericAlias
from typing import (
Any,
Generic,
Literal,
TypeVar,
get_args,
get_origin,
)

import numpy as np
Expand All @@ -20,6 +25,19 @@
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema

if sys.version_info >= (3, 12):
from types import get_original_bases
else:
# function added to stdlib in 3.12
def get_original_bases(cls: type, /) -> tuple[Any, ...]:
try:
return cls.__dict__.get("__orig_bases__", cls.__bases__)
except AttributeError:
raise TypeError(
f"Expected an instance of type, not {type(cls).__name__!r}"
) from None


__all__ = [
"Axis",
"OtherAxis",
Expand All @@ -36,14 +54,20 @@
]

#: Used to ensure pydantic dataclasses error if given extra arguments
StrictConfig: ConfigDict = {"extra": "forbid"}
StrictConfig: ConfigDict = {"extra": "forbid", "arbitrary_types_allowed": True}

C = TypeVar("C")
T = TypeVar("T")

GapArray = npt.NDArray[np.bool_]


class UnsupportedSubclass(RuntimeWarning):
"""Warning for subclasses that are not simple extensions of generic types."""

pass


def discriminated_union_of_subclasses(
super_cls: type[C],
discriminator: str = "type",
Expand Down Expand Up @@ -132,15 +156,16 @@ def add_subclass_to_union(subclass: type[C]):
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore

def get_schema_of_union(
cls: type[C], source_type: Any, handler: GetCoreSchemaHandler
cls: type[C], actual_type: type, handler: GetCoreSchemaHandler
):
super(super_cls, cls).__init_subclass__()
if cls is not super_cls:
tagged_union.add_member(cls)
return handler(cls)
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
return tagged_union.schema(handler)
return tagged_union.schema(actual_type, handler)

super_cls.__init_subclass__ = classmethod(add_subclass_to_union) # type: ignore
super_cls.__get_pydantic_core_schema__ = classmethod(get_schema_of_union) # type: ignore
Expand All @@ -157,10 +182,20 @@ def __init__(self, base_class: type[Any], discriminator: str):
self._discriminator = discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._subclasses: list[type] = []
# The type parameters expected for the base class of the union
self._generics = _parameters(base_class)

def add_member(self, cls: type):
if cls in self._subclasses:
return
elif not self._support_subclass(cls):
warnings.warn(
f"Subclass {cls} has unsupported generics and will not be part "
"of the tagged union",
UnsupportedSubclass,
stacklevel=2,
)
return
self._subclasses.append(cls)
for member in self._subclasses:
if member is not cls:
Expand All @@ -174,13 +209,55 @@ def _rebuild(cls_or_func: Callable[..., T]) -> None:
if issubclass(cls_or_func, BaseModel):
cls_or_func.model_rebuild(force=True)

def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:
def schema(self, actual_type: type, handler: GetCoreSchemaHandler) -> CoreSchema:
return tagged_union_schema(
_make_schema(tuple(self._subclasses), handler),
_make_schema(
tuple(
self._specify_generics(sub, actual_type) for sub in self._subclasses
),
handler,
),
discriminator=self._discriminator,
ref=self._base_class.__name__,
)

def _support_subclass(self, subcls: type) -> bool:
if subcls == self._base_class:
return True
sub_params = _parameters(subcls)
if len(self._generics) != len(sub_params):
return False
if not all(
_compatible_types(actual, target)
for actual, target in zip(self._generics, sub_params, strict=True)
):
return False
if any(
not self._support_subclass(get_origin(base) or base)
for base in get_original_bases(subcls)
):
return False
return True

def _specify_generics(self, subcls: type, actual_type: type) -> type | GenericAlias:
args = get_args(actual_type)
if args:
return GenericAlias(subcls, args)
return subcls


def _parameters(possibly_generic: type) -> tuple[Any, ...]:
return getattr(possibly_generic, "__parameters__", ())


def _compatible_types(left: TypeVar, right: TypeVar) -> bool:
return (
left.__bound__ == right.__bound__
and left.__constraints__ == right.__constraints__
and left.__covariant__ == right.__covariant__
and left.__contravariant__ == right.__contravariant__
)


@lru_cache(1)
def _make_schema(
Expand Down
157 changes: 157 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Annotated, Any, Generic, TypeVar

import pytest
from pydantic import TypeAdapter
from pydantic.dataclasses import dataclass

from scanspec.core import (
UnsupportedSubclass,
discriminated_union_of_subclasses,
)

T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")

B = TypeVar("B", int, float)


@discriminated_union_of_subclasses
class Parent(Generic[T]):
pass


@dataclass
class Child(Parent[U]):
a: U


@dataclass
class AnnotatedChild(Parent[Annotated[U, "comment"]]):
b: U


@dataclass
class GrandChild(Child[V]):
# TODO: subclasses with fields?
pass


@discriminated_union_of_subclasses
class NonGenericParent:
pass


@dataclass
class NonGenericChild(NonGenericParent):
a: int
b: float


def test_specific_implementation_child():
with pytest.warns(UnsupportedSubclass):

@dataclass
class Specific(Parent[int]):
b: int

with pytest.warns(UnsupportedSubclass):

@dataclass
class SubSpecific(Specific): # type: ignore
pass


def test_extra_generic_parameters():
with pytest.warns(UnsupportedSubclass):

@dataclass
class ExtraGeneric(Parent[U], Generic[U, V]): # type: ignore
c: U
d: V


def test_unrelated_generic_parameters():
with pytest.warns(UnsupportedSubclass):

@dataclass
class UnrelatedGeneric(Parent[int], Generic[U]): # type: ignore
e: int
f: U


def test_reordered_generics():
with pytest.warns(UnsupportedSubclass):

@dataclass
class DisorderedGeneric(Parent[U], Generic[T, U, V]): # type: ignore
g: T
h: U
i: V


@pytest.mark.skip("Unsure if this case should be valid or not")
def test_unionised_child():
with pytest.warns(UnsupportedSubclass):

@dataclass
class UnionSubclass(Parent[int | U]): # type: ignore
a: U


def test_untyped_child():
with pytest.warns(UnsupportedSubclass):

@dataclass
class UnmarkedChild(Parent): # type: ignore we're testing the bad type annotations
a: int


def test_additional_type_bounds():
with pytest.warns(UnsupportedSubclass):
# Adding bounds to the generic parameter is not supported
@dataclass
class ConstrainedChild(Parent[B]): # type: ignore
cc: B


def test_adding_generics_to_nongeneric():
with pytest.warns(UnsupportedSubclass):

@dataclass
class NewGenerics(NonGenericParent, Generic[T]): # type: ignore
a: T


def deserialize(target: type[Any], source: Any) -> Any:
return TypeAdapter(target).validate_python(source) # type: ignore


def test_child():
ch = deserialize(Parent[int], {"type": "Child", "a": "42"})
assert ch.a == 42

ch = deserialize(Parent[str], {"type": "Child", "a": "42"})
assert ch.a == "42"

ch = deserialize(Parent[list[int]], {"type": "Child", "a": ["1", "2", "3"]})
assert ch.a == [1, 2, 3]


def test_annotated_child():
ch = deserialize(Parent[int], {"type": "AnnotatedChild", "b": "42"})
assert ch.b == 42


@pytest.mark.xfail(reason="Pydantic #11363")
def test_grandchild():
ch = deserialize(Parent[int], {"type": "GrandChild", "a": "42"})
assert ch.a == 42


def test_non_generic_child():
ngc = deserialize(
NonGenericParent, {"type": "NonGenericChild", "a": "42", "b": "3.14"}
)
assert ngc.a == 42
assert ngc.b == pytest.approx(3.14) # type: ignore

0 comments on commit 96f5982

Please sign in to comment.