Skip to content

Commit

Permalink
Added support unions with recursive generics
Browse files Browse the repository at this point in the history
  • Loading branch information
ermakov-oleg committed Mar 3, 2025
1 parent d9fd921 commit b50c278
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 11 deletions.
3 changes: 1 addition & 2 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@


nox.options.sessions = ['test', 'lint', 'type_check', 'rust_lint']
nox.options.python = False


def build(session, use_pip: bool = False):
Expand Down Expand Up @@ -85,7 +84,7 @@ def test_rc_leaks(session):
)


@nox.session
@nox.session(python=False)
def bench_codespeed(session):
build(session)
install(session, '-r', 'requirements/bench.txt')
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ module-name = "serpyco_rs._serpyco_rs"

[project]
name = "serpyco-rs"
dynamic = ["version"]
repository = "https://github.com/ermakov-oleg/serpyco-rs"
requires-python = ">=3.9"
license = { file = "LICENSE" }
Expand Down
31 changes: 22 additions & 9 deletions python/serpyco_rs/_describe.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,18 @@ def describe_type(
args = t.__args__
t = Union
elif hasattr(t, '__parameters__'):
# Если передан generic-класс без type-параметров, значит по PEP-484 заменяем все параметры на Any
# If a generic class without type parameters is passed,
# then according to PEP-484 we replace all parameters with Any
args = (Any,) * len(t.__parameters__)

if not meta:
meta = Meta(globals=_get_globals(t), state={})

t = _evaluate_forwardref(t, meta)
if isinstance(t, ForwardRef):
t = _evaluate_forwardref(t, meta)
# ForwardRef evaluation can return a generic class that we need to resolve
if hasattr(t, '__origin__'):
return describe_type(_wrap_annotated(metadata)(t), meta, custom_type_resolver)

filed_format = _find_metadata(metadata, FieldFormat, NoFormat)
none_format = _find_metadata(metadata, NoneFormat, KeepNone)
Expand Down Expand Up @@ -283,15 +288,18 @@ def describe_type(
custom_encoder=custom_encoder,
)

if not all(dataclasses.is_dataclass(arg) or _is_attrs(arg) for arg in args):
if not all(
_applies_to_type_or_origin(arg, dataclasses.is_dataclass) or _applies_to_type_or_origin(arg, _is_attrs)
for arg in args
):
raise RuntimeError(
f'Unions supported only for dataclasses or attrs. Provided: {t}[{",".join(map(str, args))}]'
)

meta = dataclasses.replace(meta, discriminator_field=discriminator.name)
return DiscriminatedUnionType(
item_types={
_get_discriminator_value(arg, discriminator.name): describe_type(
_get_discriminator_value(get_origin(arg) or arg, discriminator.name): describe_type(
annotation_wrapper(arg), meta, custom_type_resolver
)
for arg in args
Expand Down Expand Up @@ -508,11 +516,8 @@ def _get_globals(t: Any) -> dict[str, Any]:
return {}


def _evaluate_forwardref(t: type[_T], meta: Meta) -> type[_T]:
if not isinstance(t, ForwardRef):
return t

return t._evaluate(meta.globals, {}, recursive_guard=set())
def _evaluate_forwardref(t: ForwardRef, meta: Meta) -> Any:
return t._evaluate(meta.globals, {}, recursive_guard=frozenset[str]())


def _get_discriminator_value(t: Any, name: str) -> str:
Expand Down Expand Up @@ -589,3 +594,11 @@ def _is_new_type(t: Any) -> TypeGuard[NewType]:

def _is_supported_literal_args(args: Sequence[Any]) -> TypeGuard[list[Union[str, int, Enum]]]:
return all(isinstance(arg, (str, int, Enum)) for arg in args)


def _applies_to_type_or_origin(t: Any, predicate: Callable[[Any], bool]) -> bool:
if predicate(t):
return True
if hasattr(t, '__origin__'):
return predicate(t.__origin__)
return False
50 changes: 50 additions & 0 deletions tests/test_recursive_generic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from dataclasses import dataclass
from typing import Annotated, Generic, Literal, TypeVar

from serpyco_rs import Serializer
from serpyco_rs.metadata import Discriminator


TWidget_co = TypeVar('TWidget_co', bound='BaseWidget', covariant=True)


@dataclass
class BaseWidget(Generic[TWidget_co]):
type: str
childrens: list[TWidget_co] | None = None


@dataclass
class Widget1(BaseWidget[TWidget_co]):
type: Literal['Widget1'] = 'Widget1'


@dataclass
class Widget2(BaseWidget[TWidget_co]):
type: Literal['Widget2'] = 'Widget2'
some_field: str | None = None


Widget = Annotated[Widget1['Widget'] | Widget2['Widget'], Discriminator('type')]


def test_recursive_generics():
serializer = Serializer(Widget)

obj = Widget1(type='Widget1', childrens=[Widget2(type='Widget2', some_field='some_value')])

data = {'type': 'Widget1', 'childrens': [{'type': 'Widget2', 'some_field': 'some_value', 'childrens': None}]}

assert serializer.dump(obj) == data
assert serializer.load(data) == obj


def test_recursive_generics_propagates_annotations():
serializer = Serializer(Widget, camelcase_fields=True)

obj = Widget1(type='Widget1', childrens=[Widget2(type='Widget2', some_field='some_value', childrens=[])])

data = {'type': 'Widget1', 'childrens': [{'type': 'Widget2', 'someField': 'some_value', 'childrens': []}]}

assert serializer.dump(obj) == data
assert serializer.load(data) == obj

0 comments on commit b50c278

Please sign in to comment.