From e72db0d33b50a21e25e003aebafe781e3a3b56e1 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 19:39:12 +0100 Subject: [PATCH 01/12] Remove unused if statement --- tests/experimental/pydantic/test_conversion.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index ec9ba5495f..15e7cbeaa0 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -23,11 +23,6 @@ from strawberry.types.types import StrawberryObjectDefinition from tests.experimental.pydantic.utils import needs_pydantic_v1 -if IS_PYDANTIC_V2: - pass -else: - pass - def test_can_use_type_standalone(): class User(BaseModel): From 5cc04bf8656b2caf8ed46eed069b9165bcdb9b32 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 21:02:00 +0100 Subject: [PATCH 02/12] Initial stub at supporting both pydantic v1 and v2 --- noxfile.py | 2 +- strawberry/experimental/__init__.py | 4 +- strawberry/experimental/pydantic/_compat.py | 90 +++++++++++++------ .../experimental/pydantic/error_type.py | 5 +- .../experimental/pydantic/object_type.py | 5 +- strawberry/experimental/pydantic/utils.py | 14 ++- .../pydantic/schema/test_1_and_2.py | 65 ++++++++++++++ .../experimental/pydantic/test_conversion.py | 2 +- 8 files changed, 141 insertions(+), 46 deletions(-) create mode 100644 tests/experimental/pydantic/schema/test_1_and_2.py diff --git a/noxfile.py b/noxfile.py index 12f420751e..64d437c287 100644 --- a/noxfile.py +++ b/noxfile.py @@ -107,7 +107,7 @@ def tests_integrations(session: Session, integration: str) -> None: session.run("pytest", *COMMON_PYTEST_OPTIONS, "-m", integration) -@session(python=PYTHON_VERSIONS, name="Pydantic tests", tags=["tests"]) +@session(python=PYTHON_VERSIONS, name="Pydantic tests", tags=["tests", "pydantic"]) @nox.parametrize("pydantic", ["1.10", "2.0.3"]) def test_pydantic(session: Session, pydantic: str) -> None: session.run_always("poetry", "install", external=True) diff --git a/strawberry/experimental/__init__.py b/strawberry/experimental/__init__.py index 6386ad81d7..16c7c10115 100644 --- a/strawberry/experimental/__init__.py +++ b/strawberry/experimental/__init__.py @@ -1,6 +1,6 @@ try: from . import pydantic -except ImportError: - pass +except ImportError as e: + error = e else: __all__ = ["pydantic"] diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 79d919af23..13d5a943ca 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -24,21 +24,28 @@ class CompatModelField: allow_none: bool has_alias: bool description: Optional[str] + _missing_type: Any + @property + def has_default_factory(self) -> bool: + return self.default_factory is not self._missing_type -if IS_PYDANTIC_V2: - from typing_extensions import get_args, get_origin - - from pydantic._internal._typing_extra import is_new_type - from pydantic._internal._utils import lenient_issubclass, smart_deepcopy - from pydantic_core import PydanticUndefined + @property + def has_default(self) -> bool: + return self.default is not self._missing_type - PYDANTIC_MISSING_TYPE = PydanticUndefined - def new_type_supertype(type_: Any) -> Any: +class PydanticV2Compat: + def new_type_supertype(self, type_: Any) -> Any: return type_.__supertype__ - def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]: + @property + def PYDANTIC_MISSING_TYPE(self) -> Any: + from pydantic_core import PydanticUndefined + + return PydanticUndefined + + def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: field_info: dict[str, FieldInfo] = model.model_fields new_fields = {} # Convert it into CompatModelField @@ -55,24 +62,17 @@ def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]: allow_none=False, has_alias=field is not None, description=field.description, + _missing_type=self.PYDANTIC_MISSING_TYPE, ) return new_fields -else: - from pydantic.typing import ( # type: ignore[no-redef] - get_args, - get_origin, - is_new_type, - new_type_supertype, - ) - from pydantic.utils import ( # type: ignore[no-redef] - lenient_issubclass, - smart_deepcopy, - ) - - PYDANTIC_MISSING_TYPE = dataclasses.MISSING # type: ignore[assignment] - - def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]: + +class PydanticV1Compat: + @property + def PYDANTIC_MISSING_TYPE(self) -> Any: + return dataclasses.MISSING + + def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: new_fields = {} # Convert it into CompatModelField for name, field in model.__fields__.items(): # type: ignore[attr-defined] @@ -87,17 +87,49 @@ def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]: allow_none=field.allow_none, has_alias=field.has_alias, description=field.field_info.description, + _missing_type=self.PYDANTIC_MISSING_TYPE, ) return new_fields + def new_type_supertype(self, type_: Any) -> Any: + return type_ + + +class PydanticCompat: + # proxy based on v1 or v2 + def __init__(self): + if IS_PYDANTIC_V2: + self._compat = PydanticV2Compat() + else: + self._compat = PydanticV1Compat() + + @classmethod + def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": + return cls() + + def __getattr__(self, name: str) -> Any: + return getattr(self._compat, name) + + +if IS_PYDANTIC_V2: + from typing_extensions import get_args, get_origin + + from pydantic._internal._typing_extra import is_new_type + from pydantic._internal._utils import lenient_issubclass, smart_deepcopy + + def new_type_supertype(type_: Any) -> Any: + return type_.__supertype__ +else: + from pydantic.typing import get_args, get_origin, is_new_type + from pydantic.utils import lenient_issubclass, smart_deepcopy + __all__ = [ - "smart_deepcopy", + "PydanticCompat", + "is_new_type", "lenient_issubclass", - "get_args", "get_origin", - "is_new_type", + "get_args", "new_type_supertype", - "get_model_fields", - "PYDANTIC_MISSING_TYPE", + "smart_deepcopy", ] diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index adcebdc2f7..73eab1c6fb 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -19,7 +19,7 @@ from strawberry.auto import StrawberryAuto from strawberry.experimental.pydantic._compat import ( CompatModelField, - get_model_fields, + PydanticCompat, lenient_issubclass, ) from strawberry.experimental.pydantic.utils import ( @@ -72,7 +72,8 @@ def error_type( all_fields: bool = False, ) -> Callable[..., Type]: def wrap(cls: Type) -> Type: - model_fields = get_model_fields(model) + compat = PydanticCompat.from_model(model) + model_fields = compat.get_model_fields(model) fields_set = set(fields) if fields else set() if fields: diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index f7e87ed6ef..f25bf01aad 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -21,7 +21,7 @@ from strawberry.experimental.pydantic._compat import ( IS_PYDANTIC_V1, CompatModelField, - get_model_fields, + PydanticCompat, ) from strawberry.experimental.pydantic.conversion import ( convert_pydantic_model_to_strawberry_class, @@ -129,7 +129,8 @@ def type( use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: - model_fields = get_model_fields(model) + compat = PydanticCompat.from_model(model) + model_fields = compat.get_model_fields(model) original_fields_set = set(fields) if fields else set() if fields: diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index 4f8629a0fc..79b8122ec9 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -14,9 +14,8 @@ ) from strawberry.experimental.pydantic._compat import ( - PYDANTIC_MISSING_TYPE, CompatModelField, - get_model_fields, + PydanticCompat, smart_deepcopy, ) from strawberry.experimental.pydantic.exceptions import ( @@ -83,12 +82,8 @@ def get_default_factory_for_field( Returns optionally a NoArgAnyCallable representing a default_factory parameter """ # replace dataclasses.MISSING with our own UNSET to make comparisons easier - default_factory = ( - field.default_factory - if field.default_factory is not PYDANTIC_MISSING_TYPE - else UNSET - ) - default = field.default if field.default is not PYDANTIC_MISSING_TYPE else UNSET + default_factory = field.default_factory if field.has_default_factory else UNSET + default = field.default if field.has_default else UNSET has_factory = default_factory is not None and default_factory is not UNSET has_default = default is not None and default is not UNSET @@ -126,8 +121,9 @@ def get_default_factory_for_field( def ensure_all_auto_fields_in_pydantic( model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> None: + compat = PydanticCompat.from_model(model) # Raise error if user defined a strawberry.auto field not present in the model - non_existing_fields = list(auto_fields - get_model_fields(model).keys()) + non_existing_fields = list(auto_fields - compat.get_model_fields(model).keys()) if non_existing_fields: raise AutoFieldsNotInBaseModelError( diff --git a/tests/experimental/pydantic/schema/test_1_and_2.py b/tests/experimental/pydantic/schema/test_1_and_2.py new file mode 100644 index 0000000000..b97836aaf0 --- /dev/null +++ b/tests/experimental/pydantic/schema/test_1_and_2.py @@ -0,0 +1,65 @@ +import textwrap +from typing import Optional, Union + +import pytest + +import strawberry +from tests.experimental.pydantic.utils import needs_pydantic_v2 + + +@needs_pydantic_v2 +@pytest.mark.xfail +def test_can_use_both_pydantic_1_and_2(): + import pydantic + from pydantic import v1 as pydantic_v1 + + class UserModel(pydantic.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic.type(UserModel) + class User: + age: strawberry.auto + password: strawberry.auto + + class LegacyUserModel(pydantic_v1.BaseModel): + age: int + password: Optional[str] + + @strawberry.experimental.pydantic.type(LegacyUserModel) + class LegacyUser: + age: strawberry.auto + password: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> Union[User, LegacyUser]: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type Query { + user: User! + } + + type User { + age: Int! + password: String + } + + type LegacyUser { + age: Int! + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 15e7cbeaa0..83d3fed016 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -11,7 +11,6 @@ import strawberry from strawberry.experimental.pydantic._compat import ( IS_PYDANTIC_V2, - PYDANTIC_MISSING_TYPE, CompatModelField, ) from strawberry.experimental.pydantic.exceptions import ( @@ -841,6 +840,7 @@ class UserType: assert user.passwords == ["hunter2"] +@pytest.mark.xfail def test_get_default_factory_for_field(): def _get_field( default: Any = PYDANTIC_MISSING_TYPE, From df9812a6ef1f0cdc54bd573b36573439efc6c2ae Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 21:06:51 +0100 Subject: [PATCH 03/12] Better catch --- strawberry/experimental/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/strawberry/experimental/__init__.py b/strawberry/experimental/__init__.py index 16c7c10115..32363dedfa 100644 --- a/strawberry/experimental/__init__.py +++ b/strawberry/experimental/__init__.py @@ -1,6 +1,6 @@ try: from . import pydantic -except ImportError as e: - error = e +except ModuleNotFoundError: + pass else: __all__ = ["pydantic"] From a9d9c890d6acd352c0df38250f039283fb19bbb2 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 21:07:49 +0100 Subject: [PATCH 04/12] Fix import --- strawberry/experimental/pydantic/_compat.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 13d5a943ca..fcfe79ba05 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -36,9 +36,6 @@ def has_default(self) -> bool: class PydanticV2Compat: - def new_type_supertype(self, type_: Any) -> Any: - return type_.__supertype__ - @property def PYDANTIC_MISSING_TYPE(self) -> Any: from pydantic_core import PydanticUndefined @@ -91,9 +88,6 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField ) return new_fields - def new_type_supertype(self, type_: Any) -> Any: - return type_ - class PydanticCompat: # proxy based on v1 or v2 @@ -120,7 +114,7 @@ def __getattr__(self, name: str) -> Any: def new_type_supertype(type_: Any) -> Any: return type_.__supertype__ else: - from pydantic.typing import get_args, get_origin, is_new_type + from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype from pydantic.utils import lenient_issubclass, smart_deepcopy From 9bf8752f0c8c7d7e71cccf0af0c4507b244be34c Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 21:14:46 +0100 Subject: [PATCH 05/12] Reduce number of places that need to know about pydantic v2/v1 --- strawberry/experimental/pydantic/_compat.py | 119 +++++++++++++++++- strawberry/experimental/pydantic/fields.py | 119 +----------------- .../experimental/pydantic/object_type.py | 3 +- 3 files changed, 121 insertions(+), 120 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index fcfe79ba05..2927cb445f 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -1,10 +1,15 @@ import dataclasses from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type +from decimal import Decimal +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type +from uuid import UUID +import pydantic from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION +from strawberry.exceptions import UnsupportedTypeError + if TYPE_CHECKING: from pydantic.fields import FieldInfo @@ -25,6 +30,7 @@ class CompatModelField: has_alias: bool description: Optional[str] _missing_type: Any + is_v1: bool @property def has_default_factory(self) -> bool: @@ -60,6 +66,7 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField has_alias=field is not None, description=field.description, _missing_type=self.PYDANTIC_MISSING_TYPE, + is_v1=False, ) return new_fields @@ -85,6 +92,7 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField has_alias=field.has_alias, description=field.field_info.description, _missing_type=self.PYDANTIC_MISSING_TYPE, + is_v1=True, ) return new_fields @@ -118,6 +126,115 @@ def new_type_supertype(type_: Any) -> Any: from pydantic.utils import lenient_issubclass, smart_deepcopy +def get_basic_type(type_: Any) -> Type[Any]: + if IS_PYDANTIC_V1: + # only pydantic v1 has these + if lenient_issubclass(type_, pydantic.ConstrainedInt): + return int + if lenient_issubclass(type_, pydantic.ConstrainedFloat): + return float + if lenient_issubclass(type_, pydantic.ConstrainedStr): + return str + if lenient_issubclass(type_, pydantic.ConstrainedList): + return List[get_basic_type(type_.item_type)] # type: ignore + + if type_ in FIELDS_MAP: + type_ = FIELDS_MAP.get(type_) + if type_ is None: + raise UnsupportedTypeError() + + if is_new_type(type_): + return new_type_supertype(type_) + + return type_ + + +ATTR_TO_TYPE_MAP = { + "NoneStr": Optional[str], + "NoneBytes": Optional[bytes], + "StrBytes": None, + "NoneStrBytes": None, + "StrictStr": str, + "ConstrainedBytes": bytes, + "conbytes": bytes, + "ConstrainedStr": str, + "constr": str, + "EmailStr": str, + "PyObject": None, + "ConstrainedInt": int, + "conint": int, + "PositiveInt": int, + "NegativeInt": int, + "ConstrainedFloat": float, + "confloat": float, + "PositiveFloat": float, + "NegativeFloat": float, + "ConstrainedDecimal": Decimal, + "condecimal": Decimal, + "UUID1": UUID, + "UUID3": UUID, + "UUID4": UUID, + "UUID5": UUID, + "FilePath": None, + "DirectoryPath": None, + "Json": None, + "JsonWrapper": None, + "SecretStr": str, + "SecretBytes": bytes, + "StrictBool": bool, + "StrictInt": int, + "StrictFloat": float, + "PaymentCardNumber": None, + "ByteSize": None, + "AnyUrl": str, + "AnyHttpUrl": str, + "HttpUrl": str, + "PostgresDsn": str, + "RedisDsn": str, +} + +ATTR_TO_TYPE_MAP_Pydantic_V2 = { + "EmailStr": str, + "SecretStr": str, + "SecretBytes": bytes, + "AnyUrl": str, +} + +ATTR_TO_TYPE_MAP_Pydantic_Core_V2 = { + "MultiHostUrl": str, +} + + +def get_fields_map_for_v2() -> Dict[Any, Any]: + import pydantic_core + + fields_map = { + getattr(pydantic, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP_Pydantic_V2.items() + if hasattr(pydantic, field_name) + } + fields_map.update( + { + getattr(pydantic_core, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP_Pydantic_Core_V2.items() + if hasattr(pydantic_core, field_name) + } + ) + + return fields_map + + +FIELDS_MAP = ( + { + getattr(pydantic, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP.items() + if hasattr(pydantic, field_name) + } + if IS_PYDANTIC_V1 + else get_fields_map_for_v2() +) + + __all__ = [ "PydanticCompat", "is_new_type", diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 6ffa6dc3d7..1c4a79ac82 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -1,23 +1,17 @@ import builtins -from decimal import Decimal -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Union from typing_extensions import Annotated -from uuid import UUID -import pydantic from pydantic import BaseModel from strawberry.experimental.pydantic._compat import ( - IS_PYDANTIC_V1, get_args, + get_basic_type, get_origin, - is_new_type, lenient_issubclass, - new_type_supertype, ) from strawberry.experimental.pydantic.exceptions import ( UnregisteredTypeException, - UnsupportedTypeError, ) from strawberry.types.types import StrawberryObjectDefinition @@ -44,115 +38,6 @@ raise -ATTR_TO_TYPE_MAP = { - "NoneStr": Optional[str], - "NoneBytes": Optional[bytes], - "StrBytes": None, - "NoneStrBytes": None, - "StrictStr": str, - "ConstrainedBytes": bytes, - "conbytes": bytes, - "ConstrainedStr": str, - "constr": str, - "EmailStr": str, - "PyObject": None, - "ConstrainedInt": int, - "conint": int, - "PositiveInt": int, - "NegativeInt": int, - "ConstrainedFloat": float, - "confloat": float, - "PositiveFloat": float, - "NegativeFloat": float, - "ConstrainedDecimal": Decimal, - "condecimal": Decimal, - "UUID1": UUID, - "UUID3": UUID, - "UUID4": UUID, - "UUID5": UUID, - "FilePath": None, - "DirectoryPath": None, - "Json": None, - "JsonWrapper": None, - "SecretStr": str, - "SecretBytes": bytes, - "StrictBool": bool, - "StrictInt": int, - "StrictFloat": float, - "PaymentCardNumber": None, - "ByteSize": None, - "AnyUrl": str, - "AnyHttpUrl": str, - "HttpUrl": str, - "PostgresDsn": str, - "RedisDsn": str, -} - -ATTR_TO_TYPE_MAP_Pydantic_V2 = { - "EmailStr": str, - "SecretStr": str, - "SecretBytes": bytes, - "AnyUrl": str, -} - -ATTR_TO_TYPE_MAP_Pydantic_Core_V2 = { - "MultiHostUrl": str, -} - - -def get_fields_map_for_v2() -> Dict[Any, Any]: - import pydantic_core - - fields_map = { - getattr(pydantic, field_name): type - for field_name, type in ATTR_TO_TYPE_MAP_Pydantic_V2.items() - if hasattr(pydantic, field_name) - } - fields_map.update( - { - getattr(pydantic_core, field_name): type - for field_name, type in ATTR_TO_TYPE_MAP_Pydantic_Core_V2.items() - if hasattr(pydantic_core, field_name) - } - ) - - return fields_map - - -FIELDS_MAP = ( - { - getattr(pydantic, field_name): type - for field_name, type in ATTR_TO_TYPE_MAP.items() - if hasattr(pydantic, field_name) - } - if IS_PYDANTIC_V1 - else get_fields_map_for_v2() -) - - -def get_basic_type(type_: Any) -> Type[Any]: - if IS_PYDANTIC_V1: - # only pydantic v1 has these - if lenient_issubclass(type_, pydantic.ConstrainedInt): - return int - if lenient_issubclass(type_, pydantic.ConstrainedFloat): - return float - if lenient_issubclass(type_, pydantic.ConstrainedStr): - return str - if lenient_issubclass(type_, pydantic.ConstrainedList): - return List[get_basic_type(type_.item_type)] # type: ignore - - if type_ in FIELDS_MAP: - type_ = FIELDS_MAP.get(type_) - if type_ is None: - raise UnsupportedTypeError() - - if is_new_type(type_): - return new_type_supertype(type_) - - return type_ - - def replace_pydantic_types(type_: Any, is_input: bool) -> Any: if lenient_issubclass(type_, BaseModel): attr = "_strawberry_input_type" if is_input else "_strawberry_type" diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index f25bf01aad..75a7a0c8fe 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -19,7 +19,6 @@ from strawberry.annotation import StrawberryAnnotation from strawberry.auto import StrawberryAuto from strawberry.experimental.pydantic._compat import ( - IS_PYDANTIC_V1, CompatModelField, PydanticCompat, ) @@ -47,7 +46,7 @@ def get_type_for_field(field: CompatModelField, is_input: bool): # noqa: ANN201 outer_type = field.outer_type_ replaced_type = replace_types_recursively(outer_type, is_input) - if IS_PYDANTIC_V1: + if field.is_v1: # only pydantic v1 has this Optional logic should_add_optional: bool = field.allow_none if should_add_optional: From 56e77cc8aa4216842ae9c23b8362efc06c851762 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 21:39:08 +0100 Subject: [PATCH 06/12] Use compat to get type --- strawberry/experimental/pydantic/_compat.py | 105 +++++++++--------- strawberry/experimental/pydantic/fields.py | 12 +- .../experimental/pydantic/object_type.py | 16 ++- 3 files changed, 72 insertions(+), 61 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 2927cb445f..4bd713147f 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION -from strawberry.exceptions import UnsupportedTypeError +from strawberry.experimental.pydantic.exceptions import UnsupportedTypeError if TYPE_CHECKING: from pydantic.fields import FieldInfo @@ -97,58 +97,6 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField return new_fields -class PydanticCompat: - # proxy based on v1 or v2 - def __init__(self): - if IS_PYDANTIC_V2: - self._compat = PydanticV2Compat() - else: - self._compat = PydanticV1Compat() - - @classmethod - def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": - return cls() - - def __getattr__(self, name: str) -> Any: - return getattr(self._compat, name) - - -if IS_PYDANTIC_V2: - from typing_extensions import get_args, get_origin - - from pydantic._internal._typing_extra import is_new_type - from pydantic._internal._utils import lenient_issubclass, smart_deepcopy - - def new_type_supertype(type_: Any) -> Any: - return type_.__supertype__ -else: - from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype - from pydantic.utils import lenient_issubclass, smart_deepcopy - - -def get_basic_type(type_: Any) -> Type[Any]: - if IS_PYDANTIC_V1: - # only pydantic v1 has these - if lenient_issubclass(type_, pydantic.ConstrainedInt): - return int - if lenient_issubclass(type_, pydantic.ConstrainedFloat): - return float - if lenient_issubclass(type_, pydantic.ConstrainedStr): - return str - if lenient_issubclass(type_, pydantic.ConstrainedList): - return List[get_basic_type(type_.item_type)] # type: ignore - - if type_ in FIELDS_MAP: - type_ = FIELDS_MAP.get(type_) - if type_ is None: - raise UnsupportedTypeError() - - if is_new_type(type_): - return new_type_supertype(type_) - - return type_ - - ATTR_TO_TYPE_MAP = { "NoneStr": Optional[str], "NoneBytes": Optional[bytes], @@ -235,6 +183,57 @@ def get_fields_map_for_v2() -> Dict[Any, Any]: ) +class PydanticCompat: + # proxy based on v1 or v2 + def __init__(self): + if IS_PYDANTIC_V2: + self._compat = PydanticV2Compat() + else: + self._compat = PydanticV1Compat() + + @classmethod + def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": + return cls() + + def __getattr__(self, name: str) -> Any: + return getattr(self._compat, name) + + def get_basic_type(self, type_: Any) -> Type[Any]: + if IS_PYDANTIC_V1: + # only pydantic v1 has these + if lenient_issubclass(type_, pydantic.ConstrainedInt): + return int + if lenient_issubclass(type_, pydantic.ConstrainedFloat): + return float + if lenient_issubclass(type_, pydantic.ConstrainedStr): + return str + if lenient_issubclass(type_, pydantic.ConstrainedList): + return List[self.get_basic_type(type_.item_type)] # type: ignore + + if type_ in FIELDS_MAP: + type_ = FIELDS_MAP.get(type_) + if type_ is None: + raise UnsupportedTypeError() + + if is_new_type(type_): + return new_type_supertype(type_) + + return type_ + + +if IS_PYDANTIC_V2: + from typing_extensions import get_args, get_origin + + from pydantic._internal._typing_extra import is_new_type + from pydantic._internal._utils import lenient_issubclass, smart_deepcopy + + def new_type_supertype(type_: Any) -> Any: + return type_.__supertype__ +else: + from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype + from pydantic.utils import lenient_issubclass, smart_deepcopy + + __all__ = [ "PydanticCompat", "is_new_type", diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 1c4a79ac82..0d4183f59e 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -5,8 +5,8 @@ from pydantic import BaseModel from strawberry.experimental.pydantic._compat import ( + PydanticCompat, get_args, - get_basic_type, get_origin, lenient_issubclass, ) @@ -48,17 +48,21 @@ def replace_pydantic_types(type_: Any, is_input: bool) -> Any: return type_ -def replace_types_recursively(type_: Any, is_input: bool) -> Any: +def replace_types_recursively( + type_: Any, is_input: bool, compat: PydanticCompat +) -> Any: """Runs the conversions recursively into the arguments of generic types if any""" - basic_type = get_basic_type(type_) + basic_type = compat.get_basic_type(type_) replaced_type = replace_pydantic_types(basic_type, is_input) origin = get_origin(type_) + if not origin or not hasattr(type_, "__args__"): return replaced_type converted = tuple( - replace_types_recursively(t, is_input=is_input) for t in get_args(replaced_type) + replace_types_recursively(t, is_input=is_input, compat=compat) + for t in get_args(replaced_type) ) if isinstance(replaced_type, TypingGenericAlias): diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index 75a7a0c8fe..9bc4ee5e97 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -43,9 +43,11 @@ from graphql import GraphQLResolveInfo -def get_type_for_field(field: CompatModelField, is_input: bool): # noqa: ANN201 +def get_type_for_field(field: CompatModelField, is_input: bool, compat: PydanticCompat): # noqa: ANN201 outer_type = field.outer_type_ - replaced_type = replace_types_recursively(outer_type, is_input) + + replaced_type = replace_types_recursively(outer_type, is_input, compat=compat) + if field.is_v1: # only pydantic v1 has this Optional logic should_add_optional: bool = field.allow_none @@ -61,9 +63,10 @@ def _build_dataclass_creation_fields( existing_fields: Dict[str, StrawberryField], auto_fields_set: Set[str], use_pydantic_alias: bool, + compat: PydanticCompat, ) -> DataclassCreationFields: field_type = ( - get_type_for_field(field, is_input) + get_type_for_field(field, is_input, compat=compat) if field.name in auto_fields_set else existing_fields[field.name].type ) @@ -182,7 +185,12 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: all_model_fields: List[DataclassCreationFields] = [ _build_dataclass_creation_fields( - field, is_input, extra_fields_dict, auto_fields_set, use_pydantic_alias + field, + is_input, + extra_fields_dict, + auto_fields_set, + use_pydantic_alias, + compat=compat, ) for field_name, field in model_fields.items() if field_name in fields_set From 2f02448823f75a7ca30748da1fddd6ac8e75b141 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 21:44:57 +0100 Subject: [PATCH 07/12] Move fields map inside v1/v2 compat --- strawberry/experimental/pydantic/_compat.py | 173 +++++++++++--------- 1 file changed, 94 insertions(+), 79 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 4bd713147f..192f2c2a13 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -1,6 +1,7 @@ import dataclasses from dataclasses import dataclass from decimal import Decimal +from functools import cached_property from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type from uuid import UUID @@ -41,62 +42,6 @@ def has_default(self) -> bool: return self.default is not self._missing_type -class PydanticV2Compat: - @property - def PYDANTIC_MISSING_TYPE(self) -> Any: - from pydantic_core import PydanticUndefined - - return PydanticUndefined - - def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: - field_info: dict[str, FieldInfo] = model.model_fields - new_fields = {} - # Convert it into CompatModelField - for name, field in field_info.items(): - new_fields[name] = CompatModelField( - name=name, - type_=field.annotation, - outer_type_=field.annotation, - default=field.default, - default_factory=field.default_factory, - required=field.is_required(), - alias=field.alias, - # v2 doesn't have allow_none - allow_none=False, - has_alias=field is not None, - description=field.description, - _missing_type=self.PYDANTIC_MISSING_TYPE, - is_v1=False, - ) - return new_fields - - -class PydanticV1Compat: - @property - def PYDANTIC_MISSING_TYPE(self) -> Any: - return dataclasses.MISSING - - def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: - new_fields = {} - # Convert it into CompatModelField - for name, field in model.__fields__.items(): # type: ignore[attr-defined] - new_fields[name] = CompatModelField( - name=name, - type_=field.type_, - outer_type_=field.outer_type_, - default=field.default, - default_factory=field.default_factory, - required=field.required, - alias=field.alias, - allow_none=field.allow_none, - has_alias=field.has_alias, - description=field.field_info.description, - _missing_type=self.PYDANTIC_MISSING_TYPE, - is_v1=True, - ) - return new_fields - - ATTR_TO_TYPE_MAP = { "NoneStr": Optional[str], "NoneBytes": Optional[bytes], @@ -172,31 +117,84 @@ def get_fields_map_for_v2() -> Dict[Any, Any]: return fields_map -FIELDS_MAP = ( - { - getattr(pydantic, field_name): type - for field_name, type in ATTR_TO_TYPE_MAP.items() - if hasattr(pydantic, field_name) - } - if IS_PYDANTIC_V1 - else get_fields_map_for_v2() -) +class PydanticV2Compat: + @property + def PYDANTIC_MISSING_TYPE(self) -> Any: + from pydantic_core import PydanticUndefined + return PydanticUndefined -class PydanticCompat: - # proxy based on v1 or v2 - def __init__(self): - if IS_PYDANTIC_V2: - self._compat = PydanticV2Compat() - else: - self._compat = PydanticV1Compat() + def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: + field_info: dict[str, FieldInfo] = model.model_fields + new_fields = {} + # Convert it into CompatModelField + for name, field in field_info.items(): + new_fields[name] = CompatModelField( + name=name, + type_=field.annotation, + outer_type_=field.annotation, + default=field.default, + default_factory=field.default_factory, + required=field.is_required(), + alias=field.alias, + # v2 doesn't have allow_none + allow_none=False, + has_alias=field is not None, + description=field.description, + _missing_type=self.PYDANTIC_MISSING_TYPE, + is_v1=False, + ) + return new_fields - @classmethod - def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": - return cls() + @cached_property + def fields_map(self) -> Dict[Any, Any]: + return get_fields_map_for_v2() - def __getattr__(self, name: str) -> Any: - return getattr(self._compat, name) + def get_basic_type(self, type_: Any) -> Type[Any]: + if type_ in self.fields_map: + type_ = self.fields_map[type_] + + if type_ is None: + raise UnsupportedTypeError() + + if is_new_type(type_): + return new_type_supertype(type_) + + return type_ + + +class PydanticV1Compat: + @property + def PYDANTIC_MISSING_TYPE(self) -> Any: + return dataclasses.MISSING + + def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: + new_fields = {} + # Convert it into CompatModelField + for name, field in model.__fields__.items(): # type: ignore[attr-defined] + new_fields[name] = CompatModelField( + name=name, + type_=field.type_, + outer_type_=field.outer_type_, + default=field.default, + default_factory=field.default_factory, + required=field.required, + alias=field.alias, + allow_none=field.allow_none, + has_alias=field.has_alias, + description=field.field_info.description, + _missing_type=self.PYDANTIC_MISSING_TYPE, + is_v1=True, + ) + return new_fields + + @cached_property + def fields_map(self) -> Dict[Any, Any]: + return { + getattr(pydantic, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP.items() + if hasattr(pydantic, field_name) + } def get_basic_type(self, type_: Any) -> Type[Any]: if IS_PYDANTIC_V1: @@ -210,8 +208,9 @@ def get_basic_type(self, type_: Any) -> Type[Any]: if lenient_issubclass(type_, pydantic.ConstrainedList): return List[self.get_basic_type(type_.item_type)] # type: ignore - if type_ in FIELDS_MAP: - type_ = FIELDS_MAP.get(type_) + if type_ in self.fields_map: + type_ = self.fields_map[type_] + if type_ is None: raise UnsupportedTypeError() @@ -221,6 +220,22 @@ def get_basic_type(self, type_: Any) -> Type[Any]: return type_ +class PydanticCompat: + # proxy based on v1 or v2 + def __init__(self): + if IS_PYDANTIC_V2: + self._compat = PydanticV2Compat() + else: + self._compat = PydanticV1Compat() + + @classmethod + def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": + return cls() + + def __getattr__(self, name: str) -> Any: + return getattr(self._compat, name) + + if IS_PYDANTIC_V2: from typing_extensions import get_args, get_origin From a03fc946c97fc84bf8ae9b6877b5ccd8d3f9f7d7 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 22:22:08 +0100 Subject: [PATCH 08/12] V1 and V2 supposedly working --- strawberry/experimental/pydantic/_compat.py | 17 +++++-- .../pydantic/schema/test_1_and_2.py | 51 ++++++++++++------- 2 files changed, 46 insertions(+), 22 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 192f2c2a13..093592c7ee 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -190,6 +190,13 @@ def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField @cached_property def fields_map(self) -> Dict[Any, Any]: + if IS_PYDANTIC_V2: + return { + getattr(pydantic.v1, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP.items() + if hasattr(pydantic.v1, field_name) + } + return { getattr(pydantic, field_name): type for field_name, type in ATTR_TO_TYPE_MAP.items() @@ -221,16 +228,18 @@ def get_basic_type(self, type_: Any) -> Type[Any]: class PydanticCompat: - # proxy based on v1 or v2 - def __init__(self): - if IS_PYDANTIC_V2: + def __init__(self, is_v2: bool): + if is_v2: self._compat = PydanticV2Compat() else: self._compat = PydanticV1Compat() @classmethod def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": - return cls() + if hasattr(model, "model_fields"): + return cls(is_v2=True) + + return cls(is_v2=False) def __getattr__(self, name: str) -> Any: return getattr(self._compat, name) diff --git a/tests/experimental/pydantic/schema/test_1_and_2.py b/tests/experimental/pydantic/schema/test_1_and_2.py index b97836aaf0..554d891717 100644 --- a/tests/experimental/pydantic/schema/test_1_and_2.py +++ b/tests/experimental/pydantic/schema/test_1_and_2.py @@ -1,65 +1,80 @@ import textwrap from typing import Optional, Union -import pytest - import strawberry from tests.experimental.pydantic.utils import needs_pydantic_v2 @needs_pydantic_v2 -@pytest.mark.xfail def test_can_use_both_pydantic_1_and_2(): import pydantic from pydantic import v1 as pydantic_v1 class UserModel(pydantic.BaseModel): age: int - password: Optional[str] + name: Optional[str] @strawberry.experimental.pydantic.type(UserModel) class User: age: strawberry.auto - password: strawberry.auto + name: strawberry.auto class LegacyUserModel(pydantic_v1.BaseModel): age: int - password: Optional[str] + name: Optional[str] @strawberry.experimental.pydantic.type(LegacyUserModel) class LegacyUser: age: strawberry.auto - password: strawberry.auto + name: strawberry.auto @strawberry.type class Query: @strawberry.field - def user(self) -> Union[User, LegacyUser]: - return User(age=1, password="ABC") + def user(self, id: strawberry.ID) -> Union[User, LegacyUser]: + if id == "legacy": + return LegacyUser(age=1, name="legacy") + + return User(age=1, name="ABC") schema = strawberry.Schema(query=Query) expected_schema = """ + type LegacyUser { + age: Int! + name: String + } + type Query { - user: User! + user(id: ID!): UserLegacyUser! } type User { age: Int! - password: String + name: String } - type LegacyUser { - age: Int! - password: String - } + union UserLegacyUser = User | LegacyUser """ assert str(schema) == textwrap.dedent(expected_schema).strip() - query = "{ user { age } }" + query = """ + query ($id: ID!) { + user(id: $id) { + __typename + ... on User { name } + ... on LegacyUser { name } + } + } + """ + + result = schema.execute_sync(query, variable_values={"id": "new"}) + + assert not result.errors + assert result.data == {"user": {"__typename": "User", "name": "ABC"}} - result = schema.execute_sync(query) + result = schema.execute_sync(query, variable_values={"id": "legacy"}) assert not result.errors - assert result.data["user"]["age"] == 1 + assert result.data == {"user": {"__typename": "LegacyUser", "name": "legacy"}} From 2a4739da2c0ede8bdd25b1a187d0a377820a9a31 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 22:26:12 +0100 Subject: [PATCH 09/12] Fix mypy --- strawberry/experimental/pydantic/_compat.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 093592c7ee..cc4855669e 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -232,7 +232,7 @@ def __init__(self, is_v2: bool): if is_v2: self._compat = PydanticV2Compat() else: - self._compat = PydanticV1Compat() + self._compat = PydanticV1Compat() # type: ignore[assignment] @classmethod def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": @@ -254,8 +254,16 @@ def __getattr__(self, name: str) -> Any: def new_type_supertype(type_: Any) -> Any: return type_.__supertype__ else: - from pydantic.typing import get_args, get_origin, is_new_type, new_type_supertype - from pydantic.utils import lenient_issubclass, smart_deepcopy + from pydantic.typing import ( # type: ignore[no-redef] + get_args, + get_origin, + is_new_type, + new_type_supertype, + ) + from pydantic.utils import ( # type: ignore[no-redef] + lenient_issubclass, + smart_deepcopy, + ) __all__ = [ From fd7e83a5316e355a9a665f603e894073a595eb3f Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Fri, 29 Mar 2024 22:32:16 +0100 Subject: [PATCH 10/12] Add release file --- RELEASE.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 RELEASE.md diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..9377ce2161 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,6 @@ +Release type: minor + +This release adds support for using both Pydantic v1 and v2, when importing from +`pydantic.v1`. + +This is automatically detected and the correct version is used. From 1bc7055a27a9bbc283d8dcd15c038b72b58b08b3 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 30 Mar 2024 13:36:35 +0100 Subject: [PATCH 11/12] Fix subclasses check --- strawberry/experimental/pydantic/_compat.py | 27 ++++++++++++------- .../pydantic/schema/test_1_and_2.py | 3 +++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index cc4855669e..9eecf28cd7 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -205,15 +205,24 @@ def fields_map(self) -> Dict[Any, Any]: def get_basic_type(self, type_: Any) -> Type[Any]: if IS_PYDANTIC_V1: - # only pydantic v1 has these - if lenient_issubclass(type_, pydantic.ConstrainedInt): - return int - if lenient_issubclass(type_, pydantic.ConstrainedFloat): - return float - if lenient_issubclass(type_, pydantic.ConstrainedStr): - return str - if lenient_issubclass(type_, pydantic.ConstrainedList): - return List[self.get_basic_type(type_.item_type)] # type: ignore + ConstrainedInt = pydantic.ConstrainedInt + ConstrainedFloat = pydantic.ConstrainedFloat + ConstrainedStr = pydantic.ConstrainedStr + ConstrainedList = pydantic.ConstrainedList + else: + ConstrainedInt = pydantic.v1.ConstrainedInt + ConstrainedFloat = pydantic.v1.ConstrainedFloat + ConstrainedStr = pydantic.v1.ConstrainedStr + ConstrainedList = pydantic.v1.ConstrainedList + + if lenient_issubclass(type_, ConstrainedInt): + return int + if lenient_issubclass(type_, ConstrainedFloat): + return float + if lenient_issubclass(type_, ConstrainedStr): + return str + if lenient_issubclass(type_, ConstrainedList): + return List[self.get_basic_type(type_.item_type)] # type: ignore if type_ in self.fields_map: type_ = self.fields_map[type_] diff --git a/tests/experimental/pydantic/schema/test_1_and_2.py b/tests/experimental/pydantic/schema/test_1_and_2.py index 554d891717..5e8b3fea72 100644 --- a/tests/experimental/pydantic/schema/test_1_and_2.py +++ b/tests/experimental/pydantic/schema/test_1_and_2.py @@ -22,11 +22,13 @@ class User: class LegacyUserModel(pydantic_v1.BaseModel): age: int name: Optional[str] + int_field: pydantic.v1.NonNegativeInt = 1 @strawberry.experimental.pydantic.type(LegacyUserModel) class LegacyUser: age: strawberry.auto name: strawberry.auto + int_field: strawberry.auto @strawberry.type class Query: @@ -43,6 +45,7 @@ def user(self, id: strawberry.ID) -> Union[User, LegacyUser]: type LegacyUser { age: Int! name: String + intField: Int! } type Query { From 4a1a498a62568c132fbd2fee9bd0d2a72cbebee6 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Sat, 30 Mar 2024 13:48:19 +0100 Subject: [PATCH 12/12] Unfail test --- tests/experimental/pydantic/test_conversion.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index 83d3fed016..9a4ce2ddf3 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -12,6 +12,8 @@ from strawberry.experimental.pydantic._compat import ( IS_PYDANTIC_V2, CompatModelField, + PydanticV1Compat, + PydanticV2Compat, ) from strawberry.experimental.pydantic.exceptions import ( AutoFieldsNotInBaseModelError, @@ -840,11 +842,15 @@ class UserType: assert user.passwords == ["hunter2"] -@pytest.mark.xfail def test_get_default_factory_for_field(): + if IS_PYDANTIC_V2: + MISSING_TYPE = PydanticV2Compat().PYDANTIC_MISSING_TYPE + else: + MISSING_TYPE = PydanticV1Compat().PYDANTIC_MISSING_TYPE + def _get_field( - default: Any = PYDANTIC_MISSING_TYPE, - default_factory: Any = PYDANTIC_MISSING_TYPE, + default: Any = MISSING_TYPE, + default_factory: Any = MISSING_TYPE, ) -> CompatModelField: return CompatModelField( name="a", @@ -857,6 +863,8 @@ def _get_field( description="", has_alias=False, required=True, + _missing_type=MISSING_TYPE, + is_v1=not IS_PYDANTIC_V2, ) field = _get_field()