From b50c278e39a8698bf511619e08adde08bb3580ce Mon Sep 17 00:00:00 2001 From: "o.ermakov" Date: Mon, 3 Mar 2025 19:16:48 +0100 Subject: [PATCH] Added support unions with recursive generics --- noxfile.py | 3 +- pyproject.toml | 1 + python/serpyco_rs/_describe.py | 31 ++++++++++++++------ tests/test_recursive_generic.py | 50 +++++++++++++++++++++++++++++++++ 4 files changed, 74 insertions(+), 11 deletions(-) create mode 100644 tests/test_recursive_generic.py diff --git a/noxfile.py b/noxfile.py index 1edb0ba..f697424 100644 --- a/noxfile.py +++ b/noxfile.py @@ -5,7 +5,6 @@ nox.options.sessions = ['test', 'lint', 'type_check', 'rust_lint'] -nox.options.python = False def build(session, use_pip: bool = False): @@ -85,7 +84,7 @@ def test_rc_leaks(session): ) -@nox.session +@nox.session(python=False) def bench_codespeed(session): build(session) install(session, '-r', 'requirements/bench.txt') diff --git a/pyproject.toml b/pyproject.toml index d75295a..4ed8b68 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ module-name = "serpyco_rs._serpyco_rs" [project] name = "serpyco-rs" +dynamic = ["version"] repository = "https://github.com/ermakov-oleg/serpyco-rs" requires-python = ">=3.9" license = { file = "LICENSE" } diff --git a/python/serpyco_rs/_describe.py b/python/serpyco_rs/_describe.py index e9582b3..75b0fc2 100644 --- a/python/serpyco_rs/_describe.py +++ b/python/serpyco_rs/_describe.py @@ -117,13 +117,18 @@ def describe_type( args = t.__args__ t = Union elif hasattr(t, '__parameters__'): - # Если передан generic-класс без type-параметров, значит по PEP-484 заменяем все параметры на Any + # If a generic class without type parameters is passed, + # then according to PEP-484 we replace all parameters with Any args = (Any,) * len(t.__parameters__) if not meta: meta = Meta(globals=_get_globals(t), state={}) - t = _evaluate_forwardref(t, meta) + if isinstance(t, ForwardRef): + t = _evaluate_forwardref(t, meta) + # ForwardRef evaluation can return a generic class that we need to resolve + if hasattr(t, '__origin__'): + return describe_type(_wrap_annotated(metadata)(t), meta, custom_type_resolver) filed_format = _find_metadata(metadata, FieldFormat, NoFormat) none_format = _find_metadata(metadata, NoneFormat, KeepNone) @@ -283,7 +288,10 @@ def describe_type( custom_encoder=custom_encoder, ) - if not all(dataclasses.is_dataclass(arg) or _is_attrs(arg) for arg in args): + if not all( + _applies_to_type_or_origin(arg, dataclasses.is_dataclass) or _applies_to_type_or_origin(arg, _is_attrs) + for arg in args + ): raise RuntimeError( f'Unions supported only for dataclasses or attrs. Provided: {t}[{",".join(map(str, args))}]' ) @@ -291,7 +299,7 @@ def describe_type( meta = dataclasses.replace(meta, discriminator_field=discriminator.name) return DiscriminatedUnionType( item_types={ - _get_discriminator_value(arg, discriminator.name): describe_type( + _get_discriminator_value(get_origin(arg) or arg, discriminator.name): describe_type( annotation_wrapper(arg), meta, custom_type_resolver ) for arg in args @@ -508,11 +516,8 @@ def _get_globals(t: Any) -> dict[str, Any]: return {} -def _evaluate_forwardref(t: type[_T], meta: Meta) -> type[_T]: - if not isinstance(t, ForwardRef): - return t - - return t._evaluate(meta.globals, {}, recursive_guard=set()) +def _evaluate_forwardref(t: ForwardRef, meta: Meta) -> Any: + return t._evaluate(meta.globals, {}, recursive_guard=frozenset[str]()) def _get_discriminator_value(t: Any, name: str) -> str: @@ -589,3 +594,11 @@ def _is_new_type(t: Any) -> TypeGuard[NewType]: def _is_supported_literal_args(args: Sequence[Any]) -> TypeGuard[list[Union[str, int, Enum]]]: return all(isinstance(arg, (str, int, Enum)) for arg in args) + + +def _applies_to_type_or_origin(t: Any, predicate: Callable[[Any], bool]) -> bool: + if predicate(t): + return True + if hasattr(t, '__origin__'): + return predicate(t.__origin__) + return False diff --git a/tests/test_recursive_generic.py b/tests/test_recursive_generic.py new file mode 100644 index 0000000..5b1eab5 --- /dev/null +++ b/tests/test_recursive_generic.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass +from typing import Annotated, Generic, Literal, TypeVar + +from serpyco_rs import Serializer +from serpyco_rs.metadata import Discriminator + + +TWidget_co = TypeVar('TWidget_co', bound='BaseWidget', covariant=True) + + +@dataclass +class BaseWidget(Generic[TWidget_co]): + type: str + childrens: list[TWidget_co] | None = None + + +@dataclass +class Widget1(BaseWidget[TWidget_co]): + type: Literal['Widget1'] = 'Widget1' + + +@dataclass +class Widget2(BaseWidget[TWidget_co]): + type: Literal['Widget2'] = 'Widget2' + some_field: str | None = None + + +Widget = Annotated[Widget1['Widget'] | Widget2['Widget'], Discriminator('type')] + + +def test_recursive_generics(): + serializer = Serializer(Widget) + + obj = Widget1(type='Widget1', childrens=[Widget2(type='Widget2', some_field='some_value')]) + + data = {'type': 'Widget1', 'childrens': [{'type': 'Widget2', 'some_field': 'some_value', 'childrens': None}]} + + assert serializer.dump(obj) == data + assert serializer.load(data) == obj + + +def test_recursive_generics_propagates_annotations(): + serializer = Serializer(Widget, camelcase_fields=True) + + obj = Widget1(type='Widget1', childrens=[Widget2(type='Widget2', some_field='some_value', childrens=[])]) + + data = {'type': 'Widget1', 'childrens': [{'type': 'Widget2', 'someField': 'some_value', 'childrens': []}]} + + assert serializer.dump(obj) == data + assert serializer.load(data) == obj