Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include generics in discriminated union schemas #157

Merged
merged 5 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion schema.json

Large diffs are not rendered by default.

87 changes: 82 additions & 5 deletions src/scanspec/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
from __future__ import annotations

import itertools
import sys
import warnings
from collections.abc import Callable, Iterable, Iterator, Sequence
from functools import lru_cache
from inspect import isclass
from types import GenericAlias
from typing import (
Any,
Generic,
Literal,
TypeVar,
get_args,
get_origin,
)

import numpy as np
Expand All @@ -20,6 +25,19 @@
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema

if sys.version_info >= (3, 12):
from types import get_original_bases
else:
# function added to stdlib in 3.12
def get_original_bases(cls: type, /) -> tuple[Any, ...]:
try:
return cls.__dict__.get("__orig_bases__", cls.__bases__)
except AttributeError:
raise TypeError(
f"Expected an instance of type, not {type(cls).__name__!r}"
) from None


__all__ = [
"Axis",
"OtherAxis",
Expand All @@ -36,14 +54,20 @@
]

#: Used to ensure pydantic dataclasses error if given extra arguments
StrictConfig: ConfigDict = {"extra": "forbid"}
StrictConfig: ConfigDict = {"extra": "forbid", "arbitrary_types_allowed": True}

C = TypeVar("C")
T = TypeVar("T")

GapArray = npt.NDArray[np.bool_]


class UnsupportedSubclass(RuntimeWarning):
"""Warning for subclasses that are not simple extensions of generic types."""

pass


def discriminated_union_of_subclasses(
super_cls: type[C],
discriminator: str = "type",
Expand Down Expand Up @@ -132,15 +156,16 @@ def add_subclass_to_union(subclass: type[C]):
setattr(subclass, discriminator, Field(subclass.__name__, repr=False)) # type: ignore

def get_schema_of_union(
cls: type[C], source_type: Any, handler: GetCoreSchemaHandler
cls: type[C], actual_type: type, handler: GetCoreSchemaHandler
):
super(super_cls, cls).__init_subclass__()
if cls is not super_cls:
tagged_union.add_member(cls)
return handler(cls)
# Rebuild any dataclass (including this one) that references this union
# Note that this has to be done after the creation of the dataclass so that
# previously created classes can refer to this newly created class
return tagged_union.schema(handler)
return tagged_union.schema(actual_type, handler)

super_cls.__init_subclass__ = classmethod(add_subclass_to_union) # type: ignore
super_cls.__get_pydantic_core_schema__ = classmethod(get_schema_of_union) # type: ignore
Expand All @@ -157,10 +182,20 @@ def __init__(self, base_class: type[Any], discriminator: str):
self._discriminator = discriminator
# The members of the tagged union, i.e. subclasses of the baseclass
self._subclasses: list[type] = []
# The type parameters expected for the base class of the union
self._generics = _parameters(base_class)

def add_member(self, cls: type):
if cls in self._subclasses:
return
elif not self._support_subclass(cls):
warnings.warn(
f"Subclass {cls} has unsupported generics and will not be part "
"of the tagged union",
UnsupportedSubclass,
stacklevel=2,
)
return
self._subclasses.append(cls)
for member in self._subclasses:
if member is not cls:
Expand All @@ -174,13 +209,55 @@ def _rebuild(cls_or_func: Callable[..., T]) -> None:
if issubclass(cls_or_func, BaseModel):
cls_or_func.model_rebuild(force=True)

def schema(self, handler: GetCoreSchemaHandler) -> CoreSchema:
def schema(self, actual_type: type, handler: GetCoreSchemaHandler) -> CoreSchema:
return tagged_union_schema(
_make_schema(tuple(self._subclasses), handler),
_make_schema(
tuple(
self._specify_generics(sub, actual_type) for sub in self._subclasses
),
handler,
),
discriminator=self._discriminator,
ref=self._base_class.__name__,
)

def _support_subclass(self, subcls: type) -> bool:
if subcls == self._base_class:
return True
sub_params = _parameters(subcls)
if len(self._generics) != len(sub_params):
return False
if not all(
_compatible_types(actual, target)
for actual, target in zip(self._generics, sub_params, strict=True)
):
return False
if any(
not self._support_subclass(get_origin(base) or base)
for base in get_original_bases(subcls)
):
return False
return True

def _specify_generics(self, subcls: type, actual_type: type) -> type | GenericAlias:
args = get_args(actual_type)
if args:
return GenericAlias(subcls, args)
return subcls


def _parameters(possibly_generic: type) -> tuple[Any, ...]:
return getattr(possibly_generic, "__parameters__", ())


def _compatible_types(left: TypeVar, right: TypeVar) -> bool:
return (
left.__bound__ == right.__bound__
and left.__constraints__ == right.__constraints__
and left.__covariant__ == right.__covariant__
and left.__contravariant__ == right.__contravariant__
)


@lru_cache(1)
def _make_schema(
Expand Down
157 changes: 157 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from typing import Annotated, Any, Generic, TypeVar

import pytest
from pydantic import TypeAdapter
from pydantic.dataclasses import dataclass

from scanspec.core import (
UnsupportedSubclass,
discriminated_union_of_subclasses,
)

T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")

B = TypeVar("B", int, float)


@discriminated_union_of_subclasses
class Parent(Generic[T]):
pass


@dataclass
class Child(Parent[U]):
a: U


@dataclass
class AnnotatedChild(Parent[Annotated[U, "comment"]]):
b: U


@dataclass
class GrandChild(Child[V]):
# TODO: subclasses with fields?
pass


@discriminated_union_of_subclasses
class NonGenericParent:
pass


@dataclass
class NonGenericChild(NonGenericParent):
a: int
b: float


def test_specific_implementation_child():
with pytest.warns(UnsupportedSubclass):

@dataclass
class Specific(Parent[int]):
b: int

with pytest.warns(UnsupportedSubclass):

@dataclass
class SubSpecific(Specific): # type: ignore
pass


def test_extra_generic_parameters():
with pytest.warns(UnsupportedSubclass):

@dataclass
class ExtraGeneric(Parent[U], Generic[U, V]): # type: ignore
c: U
d: V


def test_unrelated_generic_parameters():
with pytest.warns(UnsupportedSubclass):

@dataclass
class UnrelatedGeneric(Parent[int], Generic[U]): # type: ignore
e: int
f: U


def test_reordered_generics():
with pytest.warns(UnsupportedSubclass):

@dataclass
class DisorderedGeneric(Parent[U], Generic[T, U, V]): # type: ignore
g: T
h: U
i: V


@pytest.mark.skip("Unsure if this case should be valid or not")
def test_unionised_child():
with pytest.warns(UnsupportedSubclass):

@dataclass
class UnionSubclass(Parent[int | U]): # type: ignore
a: U


def test_untyped_child():
with pytest.warns(UnsupportedSubclass):

@dataclass
class UnmarkedChild(Parent): # type: ignore we're testing the bad type annotations
a: int


def test_additional_type_bounds():
with pytest.warns(UnsupportedSubclass):
# Adding bounds to the generic parameter is not supported
@dataclass
class ConstrainedChild(Parent[B]): # type: ignore
cc: B


def test_adding_generics_to_nongeneric():
with pytest.warns(UnsupportedSubclass):

@dataclass
class NewGenerics(NonGenericParent, Generic[T]): # type: ignore
a: T


def deserialize(target: type[Any], source: Any) -> Any:
return TypeAdapter(target).validate_python(source) # type: ignore


def test_child():
ch = deserialize(Parent[int], {"type": "Child", "a": "42"})
assert ch.a == 42

ch = deserialize(Parent[str], {"type": "Child", "a": "42"})
assert ch.a == "42"

ch = deserialize(Parent[list[int]], {"type": "Child", "a": ["1", "2", "3"]})
assert ch.a == [1, 2, 3]


def test_annotated_child():
ch = deserialize(Parent[int], {"type": "AnnotatedChild", "b": "42"})
assert ch.b == 42


@pytest.mark.xfail(reason="Pydantic #11363")
def test_grandchild():
ch = deserialize(Parent[int], {"type": "GrandChild", "a": "42"})
assert ch.a == 42


def test_non_generic_child():
ngc = deserialize(
NonGenericParent, {"type": "NonGenericChild", "a": "42", "b": "3.14"}
)
assert ngc.a == 42
assert ngc.b == pytest.approx(3.14) # type: ignore
Loading