diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..9377ce2161 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,6 @@ +Release type: minor + +This release adds support for using both Pydantic v1 and v2, when importing from +`pydantic.v1`. + +This is automatically detected and the correct version is used. diff --git a/noxfile.py b/noxfile.py index 12f420751e..64d437c287 100644 --- a/noxfile.py +++ b/noxfile.py @@ -107,7 +107,7 @@ def tests_integrations(session: Session, integration: str) -> None: session.run("pytest", *COMMON_PYTEST_OPTIONS, "-m", integration) -@session(python=PYTHON_VERSIONS, name="Pydantic tests", tags=["tests"]) +@session(python=PYTHON_VERSIONS, name="Pydantic tests", tags=["tests", "pydantic"]) @nox.parametrize("pydantic", ["1.10", "2.0.3"]) def test_pydantic(session: Session, pydantic: str) -> None: session.run_always("poetry", "install", external=True) diff --git a/strawberry/experimental/__init__.py b/strawberry/experimental/__init__.py index 6386ad81d7..32363dedfa 100644 --- a/strawberry/experimental/__init__.py +++ b/strawberry/experimental/__init__.py @@ -1,6 +1,6 @@ try: from . import pydantic -except ImportError: +except ModuleNotFoundError: pass else: __all__ = ["pydantic"] diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 79d919af23..9eecf28cd7 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -1,10 +1,16 @@ import dataclasses from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type +from decimal import Decimal +from functools import cached_property +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Type +from uuid import UUID +import pydantic from pydantic import BaseModel from pydantic.version import VERSION as PYDANTIC_VERSION +from strawberry.experimental.pydantic.exceptions import UnsupportedTypeError + if TYPE_CHECKING: from pydantic.fields import FieldInfo @@ -24,21 +30,101 @@ class CompatModelField: allow_none: bool has_alias: bool description: Optional[str] + _missing_type: Any + is_v1: bool + @property + def has_default_factory(self) -> bool: + return self.default_factory is not self._missing_type -if IS_PYDANTIC_V2: - from typing_extensions import get_args, get_origin + @property + def has_default(self) -> bool: + return self.default is not self._missing_type - from pydantic._internal._typing_extra import is_new_type - from pydantic._internal._utils import lenient_issubclass, smart_deepcopy - from pydantic_core import PydanticUndefined - PYDANTIC_MISSING_TYPE = PydanticUndefined +ATTR_TO_TYPE_MAP = { + "NoneStr": Optional[str], + "NoneBytes": Optional[bytes], + "StrBytes": None, + "NoneStrBytes": None, + "StrictStr": str, + "ConstrainedBytes": bytes, + "conbytes": bytes, + "ConstrainedStr": str, + "constr": str, + "EmailStr": str, + "PyObject": None, + "ConstrainedInt": int, + "conint": int, + "PositiveInt": int, + "NegativeInt": int, + "ConstrainedFloat": float, + "confloat": float, + "PositiveFloat": float, + "NegativeFloat": float, + "ConstrainedDecimal": Decimal, + "condecimal": Decimal, + "UUID1": UUID, + "UUID3": UUID, + "UUID4": UUID, + "UUID5": UUID, + "FilePath": None, + "DirectoryPath": None, + "Json": None, + "JsonWrapper": None, + "SecretStr": str, + "SecretBytes": bytes, + "StrictBool": bool, + "StrictInt": int, + "StrictFloat": float, + "PaymentCardNumber": None, + "ByteSize": None, + "AnyUrl": str, + "AnyHttpUrl": str, + "HttpUrl": str, + "PostgresDsn": str, + "RedisDsn": str, +} + +ATTR_TO_TYPE_MAP_Pydantic_V2 = { + "EmailStr": str, + "SecretStr": str, + "SecretBytes": bytes, + "AnyUrl": str, +} + +ATTR_TO_TYPE_MAP_Pydantic_Core_V2 = { + "MultiHostUrl": str, +} - def new_type_supertype(type_: Any) -> Any: - return type_.__supertype__ - def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]: +def get_fields_map_for_v2() -> Dict[Any, Any]: + import pydantic_core + + fields_map = { + getattr(pydantic, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP_Pydantic_V2.items() + if hasattr(pydantic, field_name) + } + fields_map.update( + { + getattr(pydantic_core, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP_Pydantic_Core_V2.items() + if hasattr(pydantic_core, field_name) + } + ) + + return fields_map + + +class PydanticV2Compat: + @property + def PYDANTIC_MISSING_TYPE(self) -> Any: + from pydantic_core import PydanticUndefined + + return PydanticUndefined + + def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: field_info: dict[str, FieldInfo] = model.model_fields new_fields = {} # Convert it into CompatModelField @@ -55,24 +141,34 @@ def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]: allow_none=False, has_alias=field is not None, description=field.description, + _missing_type=self.PYDANTIC_MISSING_TYPE, + is_v1=False, ) return new_fields -else: - from pydantic.typing import ( # type: ignore[no-redef] - get_args, - get_origin, - is_new_type, - new_type_supertype, - ) - from pydantic.utils import ( # type: ignore[no-redef] - lenient_issubclass, - smart_deepcopy, - ) + @cached_property + def fields_map(self) -> Dict[Any, Any]: + return get_fields_map_for_v2() - PYDANTIC_MISSING_TYPE = dataclasses.MISSING # type: ignore[assignment] + def get_basic_type(self, type_: Any) -> Type[Any]: + if type_ in self.fields_map: + type_ = self.fields_map[type_] - def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]: + if type_ is None: + raise UnsupportedTypeError() + + if is_new_type(type_): + return new_type_supertype(type_) + + return type_ + + +class PydanticV1Compat: + @property + def PYDANTIC_MISSING_TYPE(self) -> Any: + return dataclasses.MISSING + + def get_model_fields(self, model: Type[BaseModel]) -> Dict[str, CompatModelField]: new_fields = {} # Convert it into CompatModelField for name, field in model.__fields__.items(): # type: ignore[attr-defined] @@ -87,17 +183,104 @@ def get_model_fields(model: Type[BaseModel]) -> Dict[str, CompatModelField]: allow_none=field.allow_none, has_alias=field.has_alias, description=field.field_info.description, + _missing_type=self.PYDANTIC_MISSING_TYPE, + is_v1=True, ) return new_fields + @cached_property + def fields_map(self) -> Dict[Any, Any]: + if IS_PYDANTIC_V2: + return { + getattr(pydantic.v1, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP.items() + if hasattr(pydantic.v1, field_name) + } + + return { + getattr(pydantic, field_name): type + for field_name, type in ATTR_TO_TYPE_MAP.items() + if hasattr(pydantic, field_name) + } + + def get_basic_type(self, type_: Any) -> Type[Any]: + if IS_PYDANTIC_V1: + ConstrainedInt = pydantic.ConstrainedInt + ConstrainedFloat = pydantic.ConstrainedFloat + ConstrainedStr = pydantic.ConstrainedStr + ConstrainedList = pydantic.ConstrainedList + else: + ConstrainedInt = pydantic.v1.ConstrainedInt + ConstrainedFloat = pydantic.v1.ConstrainedFloat + ConstrainedStr = pydantic.v1.ConstrainedStr + ConstrainedList = pydantic.v1.ConstrainedList + + if lenient_issubclass(type_, ConstrainedInt): + return int + if lenient_issubclass(type_, ConstrainedFloat): + return float + if lenient_issubclass(type_, ConstrainedStr): + return str + if lenient_issubclass(type_, ConstrainedList): + return List[self.get_basic_type(type_.item_type)] # type: ignore + + if type_ in self.fields_map: + type_ = self.fields_map[type_] + + if type_ is None: + raise UnsupportedTypeError() + + if is_new_type(type_): + return new_type_supertype(type_) + + return type_ + + +class PydanticCompat: + def __init__(self, is_v2: bool): + if is_v2: + self._compat = PydanticV2Compat() + else: + self._compat = PydanticV1Compat() # type: ignore[assignment] + + @classmethod + def from_model(cls, model: Type[BaseModel]) -> "PydanticCompat": + if hasattr(model, "model_fields"): + return cls(is_v2=True) + + return cls(is_v2=False) + + def __getattr__(self, name: str) -> Any: + return getattr(self._compat, name) + + +if IS_PYDANTIC_V2: + from typing_extensions import get_args, get_origin + + from pydantic._internal._typing_extra import is_new_type + from pydantic._internal._utils import lenient_issubclass, smart_deepcopy + + def new_type_supertype(type_: Any) -> Any: + return type_.__supertype__ +else: + from pydantic.typing import ( # type: ignore[no-redef] + get_args, + get_origin, + is_new_type, + new_type_supertype, + ) + from pydantic.utils import ( # type: ignore[no-redef] + lenient_issubclass, + smart_deepcopy, + ) + __all__ = [ - "smart_deepcopy", + "PydanticCompat", + "is_new_type", "lenient_issubclass", - "get_args", "get_origin", - "is_new_type", + "get_args", "new_type_supertype", - "get_model_fields", - "PYDANTIC_MISSING_TYPE", + "smart_deepcopy", ] diff --git a/strawberry/experimental/pydantic/error_type.py b/strawberry/experimental/pydantic/error_type.py index adcebdc2f7..73eab1c6fb 100644 --- a/strawberry/experimental/pydantic/error_type.py +++ b/strawberry/experimental/pydantic/error_type.py @@ -19,7 +19,7 @@ from strawberry.auto import StrawberryAuto from strawberry.experimental.pydantic._compat import ( CompatModelField, - get_model_fields, + PydanticCompat, lenient_issubclass, ) from strawberry.experimental.pydantic.utils import ( @@ -72,7 +72,8 @@ def error_type( all_fields: bool = False, ) -> Callable[..., Type]: def wrap(cls: Type) -> Type: - model_fields = get_model_fields(model) + compat = PydanticCompat.from_model(model) + model_fields = compat.get_model_fields(model) fields_set = set(fields) if fields else set() if fields: diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 6ffa6dc3d7..0d4183f59e 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -1,23 +1,17 @@ import builtins -from decimal import Decimal -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Union from typing_extensions import Annotated -from uuid import UUID -import pydantic from pydantic import BaseModel from strawberry.experimental.pydantic._compat import ( - IS_PYDANTIC_V1, + PydanticCompat, get_args, get_origin, - is_new_type, lenient_issubclass, - new_type_supertype, ) from strawberry.experimental.pydantic.exceptions import ( UnregisteredTypeException, - UnsupportedTypeError, ) from strawberry.types.types import StrawberryObjectDefinition @@ -44,115 +38,6 @@ raise -ATTR_TO_TYPE_MAP = { - "NoneStr": Optional[str], - "NoneBytes": Optional[bytes], - "StrBytes": None, - "NoneStrBytes": None, - "StrictStr": str, - "ConstrainedBytes": bytes, - "conbytes": bytes, - "ConstrainedStr": str, - "constr": str, - "EmailStr": str, - "PyObject": None, - "ConstrainedInt": int, - "conint": int, - "PositiveInt": int, - "NegativeInt": int, - "ConstrainedFloat": float, - "confloat": float, - "PositiveFloat": float, - "NegativeFloat": float, - "ConstrainedDecimal": Decimal, - "condecimal": Decimal, - "UUID1": UUID, - "UUID3": UUID, - "UUID4": UUID, - "UUID5": UUID, - "FilePath": None, - "DirectoryPath": None, - "Json": None, - "JsonWrapper": None, - "SecretStr": str, - "SecretBytes": bytes, - "StrictBool": bool, - "StrictInt": int, - "StrictFloat": float, - "PaymentCardNumber": None, - "ByteSize": None, - "AnyUrl": str, - "AnyHttpUrl": str, - "HttpUrl": str, - "PostgresDsn": str, - "RedisDsn": str, -} - -ATTR_TO_TYPE_MAP_Pydantic_V2 = { - "EmailStr": str, - "SecretStr": str, - "SecretBytes": bytes, - "AnyUrl": str, -} - -ATTR_TO_TYPE_MAP_Pydantic_Core_V2 = { - "MultiHostUrl": str, -} - - -def get_fields_map_for_v2() -> Dict[Any, Any]: - import pydantic_core - - fields_map = { - getattr(pydantic, field_name): type - for field_name, type in ATTR_TO_TYPE_MAP_Pydantic_V2.items() - if hasattr(pydantic, field_name) - } - fields_map.update( - { - getattr(pydantic_core, field_name): type - for field_name, type in ATTR_TO_TYPE_MAP_Pydantic_Core_V2.items() - if hasattr(pydantic_core, field_name) - } - ) - - return fields_map - - -FIELDS_MAP = ( - { - getattr(pydantic, field_name): type - for field_name, type in ATTR_TO_TYPE_MAP.items() - if hasattr(pydantic, field_name) - } - if IS_PYDANTIC_V1 - else get_fields_map_for_v2() -) - - -def get_basic_type(type_: Any) -> Type[Any]: - if IS_PYDANTIC_V1: - # only pydantic v1 has these - if lenient_issubclass(type_, pydantic.ConstrainedInt): - return int - if lenient_issubclass(type_, pydantic.ConstrainedFloat): - return float - if lenient_issubclass(type_, pydantic.ConstrainedStr): - return str - if lenient_issubclass(type_, pydantic.ConstrainedList): - return List[get_basic_type(type_.item_type)] # type: ignore - - if type_ in FIELDS_MAP: - type_ = FIELDS_MAP.get(type_) - if type_ is None: - raise UnsupportedTypeError() - - if is_new_type(type_): - return new_type_supertype(type_) - - return type_ - - def replace_pydantic_types(type_: Any, is_input: bool) -> Any: if lenient_issubclass(type_, BaseModel): attr = "_strawberry_input_type" if is_input else "_strawberry_type" @@ -163,17 +48,21 @@ def replace_pydantic_types(type_: Any, is_input: bool) -> Any: return type_ -def replace_types_recursively(type_: Any, is_input: bool) -> Any: +def replace_types_recursively( + type_: Any, is_input: bool, compat: PydanticCompat +) -> Any: """Runs the conversions recursively into the arguments of generic types if any""" - basic_type = get_basic_type(type_) + basic_type = compat.get_basic_type(type_) replaced_type = replace_pydantic_types(basic_type, is_input) origin = get_origin(type_) + if not origin or not hasattr(type_, "__args__"): return replaced_type converted = tuple( - replace_types_recursively(t, is_input=is_input) for t in get_args(replaced_type) + replace_types_recursively(t, is_input=is_input, compat=compat) + for t in get_args(replaced_type) ) if isinstance(replaced_type, TypingGenericAlias): diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index f7e87ed6ef..9bc4ee5e97 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -19,9 +19,8 @@ from strawberry.annotation import StrawberryAnnotation from strawberry.auto import StrawberryAuto from strawberry.experimental.pydantic._compat import ( - IS_PYDANTIC_V1, CompatModelField, - get_model_fields, + PydanticCompat, ) from strawberry.experimental.pydantic.conversion import ( convert_pydantic_model_to_strawberry_class, @@ -44,10 +43,12 @@ from graphql import GraphQLResolveInfo -def get_type_for_field(field: CompatModelField, is_input: bool): # noqa: ANN201 +def get_type_for_field(field: CompatModelField, is_input: bool, compat: PydanticCompat): # noqa: ANN201 outer_type = field.outer_type_ - replaced_type = replace_types_recursively(outer_type, is_input) - if IS_PYDANTIC_V1: + + replaced_type = replace_types_recursively(outer_type, is_input, compat=compat) + + if field.is_v1: # only pydantic v1 has this Optional logic should_add_optional: bool = field.allow_none if should_add_optional: @@ -62,9 +63,10 @@ def _build_dataclass_creation_fields( existing_fields: Dict[str, StrawberryField], auto_fields_set: Set[str], use_pydantic_alias: bool, + compat: PydanticCompat, ) -> DataclassCreationFields: field_type = ( - get_type_for_field(field, is_input) + get_type_for_field(field, is_input, compat=compat) if field.name in auto_fields_set else existing_fields[field.name].type ) @@ -129,7 +131,8 @@ def type( use_pydantic_alias: bool = True, ) -> Callable[..., Type[StrawberryTypeFromPydantic[PydanticModel]]]: def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: - model_fields = get_model_fields(model) + compat = PydanticCompat.from_model(model) + model_fields = compat.get_model_fields(model) original_fields_set = set(fields) if fields else set() if fields: @@ -182,7 +185,12 @@ def wrap(cls: Any) -> Type[StrawberryTypeFromPydantic[PydanticModel]]: all_model_fields: List[DataclassCreationFields] = [ _build_dataclass_creation_fields( - field, is_input, extra_fields_dict, auto_fields_set, use_pydantic_alias + field, + is_input, + extra_fields_dict, + auto_fields_set, + use_pydantic_alias, + compat=compat, ) for field_name, field in model_fields.items() if field_name in fields_set diff --git a/strawberry/experimental/pydantic/utils.py b/strawberry/experimental/pydantic/utils.py index 4f8629a0fc..79b8122ec9 100644 --- a/strawberry/experimental/pydantic/utils.py +++ b/strawberry/experimental/pydantic/utils.py @@ -14,9 +14,8 @@ ) from strawberry.experimental.pydantic._compat import ( - PYDANTIC_MISSING_TYPE, CompatModelField, - get_model_fields, + PydanticCompat, smart_deepcopy, ) from strawberry.experimental.pydantic.exceptions import ( @@ -83,12 +82,8 @@ def get_default_factory_for_field( Returns optionally a NoArgAnyCallable representing a default_factory parameter """ # replace dataclasses.MISSING with our own UNSET to make comparisons easier - default_factory = ( - field.default_factory - if field.default_factory is not PYDANTIC_MISSING_TYPE - else UNSET - ) - default = field.default if field.default is not PYDANTIC_MISSING_TYPE else UNSET + default_factory = field.default_factory if field.has_default_factory else UNSET + default = field.default if field.has_default else UNSET has_factory = default_factory is not None and default_factory is not UNSET has_default = default is not None and default is not UNSET @@ -126,8 +121,9 @@ def get_default_factory_for_field( def ensure_all_auto_fields_in_pydantic( model: Type[BaseModel], auto_fields: Set[str], cls_name: str ) -> None: + compat = PydanticCompat.from_model(model) # Raise error if user defined a strawberry.auto field not present in the model - non_existing_fields = list(auto_fields - get_model_fields(model).keys()) + non_existing_fields = list(auto_fields - compat.get_model_fields(model).keys()) if non_existing_fields: raise AutoFieldsNotInBaseModelError( diff --git a/tests/experimental/pydantic/schema/test_1_and_2.py b/tests/experimental/pydantic/schema/test_1_and_2.py new file mode 100644 index 0000000000..5e8b3fea72 --- /dev/null +++ b/tests/experimental/pydantic/schema/test_1_and_2.py @@ -0,0 +1,83 @@ +import textwrap +from typing import Optional, Union + +import strawberry +from tests.experimental.pydantic.utils import needs_pydantic_v2 + + +@needs_pydantic_v2 +def test_can_use_both_pydantic_1_and_2(): + import pydantic + from pydantic import v1 as pydantic_v1 + + class UserModel(pydantic.BaseModel): + age: int + name: Optional[str] + + @strawberry.experimental.pydantic.type(UserModel) + class User: + age: strawberry.auto + name: strawberry.auto + + class LegacyUserModel(pydantic_v1.BaseModel): + age: int + name: Optional[str] + int_field: pydantic.v1.NonNegativeInt = 1 + + @strawberry.experimental.pydantic.type(LegacyUserModel) + class LegacyUser: + age: strawberry.auto + name: strawberry.auto + int_field: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self, id: strawberry.ID) -> Union[User, LegacyUser]: + if id == "legacy": + return LegacyUser(age=1, name="legacy") + + return User(age=1, name="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + type LegacyUser { + age: Int! + name: String + intField: Int! + } + + type Query { + user(id: ID!): UserLegacyUser! + } + + type User { + age: Int! + name: String + } + + union UserLegacyUser = User | LegacyUser + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = """ + query ($id: ID!) { + user(id: $id) { + __typename + ... on User { name } + ... on LegacyUser { name } + } + } + """ + + result = schema.execute_sync(query, variable_values={"id": "new"}) + + assert not result.errors + assert result.data == {"user": {"__typename": "User", "name": "ABC"}} + + result = schema.execute_sync(query, variable_values={"id": "legacy"}) + + assert not result.errors + assert result.data == {"user": {"__typename": "LegacyUser", "name": "legacy"}} diff --git a/tests/experimental/pydantic/test_conversion.py b/tests/experimental/pydantic/test_conversion.py index ec9ba5495f..9a4ce2ddf3 100644 --- a/tests/experimental/pydantic/test_conversion.py +++ b/tests/experimental/pydantic/test_conversion.py @@ -11,8 +11,9 @@ import strawberry from strawberry.experimental.pydantic._compat import ( IS_PYDANTIC_V2, - PYDANTIC_MISSING_TYPE, CompatModelField, + PydanticV1Compat, + PydanticV2Compat, ) from strawberry.experimental.pydantic.exceptions import ( AutoFieldsNotInBaseModelError, @@ -23,11 +24,6 @@ from strawberry.types.types import StrawberryObjectDefinition from tests.experimental.pydantic.utils import needs_pydantic_v1 -if IS_PYDANTIC_V2: - pass -else: - pass - def test_can_use_type_standalone(): class User(BaseModel): @@ -847,9 +843,14 @@ class UserType: def test_get_default_factory_for_field(): + if IS_PYDANTIC_V2: + MISSING_TYPE = PydanticV2Compat().PYDANTIC_MISSING_TYPE + else: + MISSING_TYPE = PydanticV1Compat().PYDANTIC_MISSING_TYPE + def _get_field( - default: Any = PYDANTIC_MISSING_TYPE, - default_factory: Any = PYDANTIC_MISSING_TYPE, + default: Any = MISSING_TYPE, + default_factory: Any = MISSING_TYPE, ) -> CompatModelField: return CompatModelField( name="a", @@ -862,6 +863,8 @@ def _get_field( description="", has_alias=False, required=True, + _missing_type=MISSING_TYPE, + is_v1=not IS_PYDANTIC_V2, ) field = _get_field()