diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..429639c --- /dev/null +++ b/.editorconfig @@ -0,0 +1,21 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +indent_style = space +indent_size = 4 + +[*.py] +max_line_length = 120 + +[Makefile] +indent_style = tab +indent_size = 4 + + +[{*.json,*.yml,*.yaml}] +indent_size = 2 +insert_final_newline = false diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index bf6b02e..ba8d668 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -23,7 +23,7 @@ jobs: - name: Build package run: python -m build - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + uses: pypa/gh-action-pypi-publish@v1.8.10 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index c2b3eaf..cc12569 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -42,6 +42,7 @@ jobs: strategy: matrix: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + pydantic-version: ["1.10.13", "2.4.2"] services: postgres: @@ -75,5 +76,7 @@ jobs: sudo apt update && sudo apt install -qy python3-dev default-libmysqlclient-dev build-essential python -m pip install --upgrade pip python -m pip install -e .[dev,test,ci] + - name: Install Pydantic ${{ matrix.pydantic-version }} + run: python -m pip install "pydantic==${{ matrix.pydantic-version }}" - name: Test package run: pytest diff --git a/.gitignore b/.gitignore index 37737ef..f936d7d 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ dist/ *.egg-info/ build htmlcov + +.python-version diff --git a/Makefile b/Makefile index 2d52756..88d7574 100644 --- a/Makefile +++ b/Makefile @@ -1,36 +1,28 @@ +.PHONY: install build test lint upload upload-test clean -.PHONY: install install: python3 -m pip install build twine python3 -m pip install -e .[dev,test] - -.PHONY: build build: python3 -m build +migrations: + DJANGO_SETTINGS_MODULE="tests.settings.django_test_settings" python3 -m django makemigrations --noinput -.PHONY: test test: A= test: pytest $(A) -.PHONY: lint lint: A=. lint: mypy $(A) - -.PHONY: upload upload: python3 -m twine upload dist/* - -.PHONY: upload-test upload-test: python3 -m twine upload --repository testpypi dist/* - -.PHONY: clean clean: rm -rf dist/* diff --git a/README.md b/README.md index 903cfe0..0ac9352 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,7 @@ Django JSONField with Pydantic models as a Schema -**[Pydantic 2 support](https://github.com/surenkov/django-pydantic-field/discussions/36) is in progress, -you can track the status [in this PR](https://github.com/surenkov/django-pydantic-field/pull/34)** +**[Pydantic 2 support](https://github.com/surenkov/django-pydantic-field/discussions/36) is in progress, you can track the status [in this PR](https://github.com/surenkov/django-pydantic-field/pull/34)** ## Usage diff --git a/django_pydantic_field/__init__.py b/django_pydantic_field/__init__.py index 7746f2c..29e6ff9 100644 --- a/django_pydantic_field/__init__.py +++ b/django_pydantic_field/__init__.py @@ -1 +1,8 @@ -from .fields import * +from .fields import SchemaField as SchemaField + +def __getattr__(name): + if name == "_migration_serializers": + module = __import__("django_pydantic_field._migration_serializers", fromlist=["*"]) + return module + + raise AttributeError(f"Module {__name__!r} has no attribute {name!r}") diff --git a/django_pydantic_field/_migration_serializers.py b/django_pydantic_field/_migration_serializers.py index d608855..20a834f 100644 --- a/django_pydantic_field/_migration_serializers.py +++ b/django_pydantic_field/_migration_serializers.py @@ -1,144 +1,9 @@ -""" -Django Migration serializer helpers - -[Built-in generic annotations](https://peps.python.org/pep-0585/) - introduced in Python 3.9 are having a different semantics from `typing` collections. - Due to how Django treats field serialization/reconstruction while writing migrations, - it is not possible to distnguish between `types.GenericAlias` and any other regular types, - thus annotations are being erased by `MigrationWriter` serializers. - - To mitigate this, I had to introduce custom container for schema deconstruction. - -[Union types syntax](https://peps.python.org/pep-0604/) - `typing.Union` and its special forms, like `typing.Optional`, have its own inheritance chain. - Moreover, `types.UnionType`, introduced in 3.10, do not allow explicit type construction, - only with `X | Y` syntax. Both cases require a dedicated serializer for migration writes. -""" -import sys -import types -import typing as t - -try: - from typing import get_args, get_origin -except ImportError: - from typing_extensions import get_args, get_origin - -from django.db.migrations.serializer import BaseSerializer, serializer_factory -from django.db.migrations.writer import MigrationWriter - - -class GenericContainer: - __slots__ = "origin", "args" - - def __init__(self, origin, args: tuple = ()): - self.origin = origin - self.args = args - - @classmethod - def wrap(cls, typ_): - if isinstance(typ_, GenericTypes): - wrapped_args = tuple(map(cls.wrap, get_args(typ_))) - return cls(get_origin(typ_), wrapped_args) - return typ_ - - @classmethod - def unwrap(cls, type_): - if not isinstance(type_, GenericContainer): - return type_ - - if not type_.args: - return type_.origin - - unwrapped_args = tuple(map(cls.unwrap, type_.args)) - try: - # This is a fallback for Python < 3.8, please be careful with that - return type_.origin[unwrapped_args] - except TypeError: - return GenericAlias(type_.origin, unwrapped_args) - - def __repr__(self): - return repr(self.unwrap(self)) - - __str__ = __repr__ - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.origin == other.origin and self.args == other.args - if isinstance(other, GenericTypes): - return self == self.wrap(other) - return NotImplemented - - -class GenericSerializer(BaseSerializer): - value: GenericContainer - - def serialize(self): - value = self.value - - tp_repr, imports = serializer_factory(type(value)).serialize() - orig_repr, orig_imports = serializer_factory(value.origin).serialize() - imports.update(orig_imports) - - args = [] - for arg in value.args: - arg_repr, arg_imports = serializer_factory(arg).serialize() - args.append(arg_repr) - imports.update(arg_imports) - - if args: - args_repr = ", ".join(args) - generic_repr = "%s(%s, (%s,))" % (tp_repr, orig_repr, args_repr) - else: - generic_repr = "%s(%s)" % (tp_repr, orig_repr) - - return generic_repr, imports - - -class TypingSerializer(BaseSerializer): - def serialize(self): - orig_module = self.value.__module__ - orig_repr = repr(self.value) - - if not orig_repr.startswith(orig_module): - orig_repr = f"{orig_module}.{orig_repr}" - - return orig_repr, {f"import {orig_module}"} - - -if sys.version_info >= (3, 9): - GenericAlias = types.GenericAlias - GenericTypes: t.Tuple[t.Any, ...] = ( - GenericAlias, - type(t.List[int]), - type(t.List), - ) -else: - # types.GenericAlias is missing, meaning python version < 3.9, - # which has a different inheritance models for typed generics - GenericAlias = type(t.List[int]) - GenericTypes = GenericAlias, type(t.List) - - -MigrationWriter.register_serializer(GenericContainer, GenericSerializer) -MigrationWriter.register_serializer(t.ForwardRef, TypingSerializer) -MigrationWriter.register_serializer(type(t.Union), TypingSerializer) # type: ignore - - -if sys.version_info >= (3, 10): - UnionType = types.UnionType - - class UnionTypeSerializer(BaseSerializer): - value: UnionType - - def serialize(self): - imports = set() - if isinstance(self.value, type(t.Union)): # type: ignore - imports.add("import typing") - - for arg in get_args(self.value): - _, arg_imports = serializer_factory(arg).serialize() - imports.update(arg_imports) - - return repr(self.value), imports - - MigrationWriter.register_serializer(UnionType, UnionTypeSerializer) +import warnings +from .compat.django import * + +DEPRECATION_MSG = ( + "Module 'django_pydantic_field._migration_serializers' is deprecated " + "and will be removed in version 1.0.0. " + "Please replace it with 'django_pydantic_field.compat.django' in migrations." +) +warnings.warn(DEPRECATION_MSG, category=DeprecationWarning) diff --git a/django_pydantic_field/compat/__init__.py b/django_pydantic_field/compat/__init__.py new file mode 100644 index 0000000..4374b6b --- /dev/null +++ b/django_pydantic_field/compat/__init__.py @@ -0,0 +1,5 @@ +from .pydantic import PYDANTIC_V1 as PYDANTIC_V1 +from .pydantic import PYDANTIC_V2 as PYDANTIC_V2 +from .pydantic import PYDANTIC_VERSION as PYDANTIC_VERSION +from .django import GenericContainer as GenericContainer +from .django import MigrationWriter as MigrationWriter diff --git a/django_pydantic_field/compat/deprecation.py b/django_pydantic_field/compat/deprecation.py new file mode 100644 index 0000000..3758164 --- /dev/null +++ b/django_pydantic_field/compat/deprecation.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing as ty +import warnings + +_MISSING = object() +_DEPRECATED_KWARGS = ( + "allow_nan", + "indent", + "separators", + "skipkeys", + "sort_keys", +) +_DEPRECATED_KWARGS_MESSAGE = ( + "The `%s=` argument is not supported by Pydantic v2 and will be removed in the future versions." +) + + +def truncate_deprecated_v1_export_kwargs(kwargs: dict[str, ty.Any]) -> None: + for kwarg in _DEPRECATED_KWARGS: + maybe_present_kwarg = kwargs.pop(kwarg, _MISSING) + if maybe_present_kwarg is not _MISSING: + warnings.warn(_DEPRECATED_KWARGS_MESSAGE % (kwarg,), DeprecationWarning, stacklevel=2) diff --git a/django_pydantic_field/compat/django.py b/django_pydantic_field/compat/django.py new file mode 100644 index 0000000..af08042 --- /dev/null +++ b/django_pydantic_field/compat/django.py @@ -0,0 +1,141 @@ +""" +Django Migration serializer helpers + +[Built-in generic annotations](https://peps.python.org/pep-0585/) + introduced in Python 3.9 are having a different semantics from `typing` collections. + Due to how Django treats field serialization/reconstruction while writing migrations, + it is not possible to distnguish between `types.GenericAlias` and any other regular types, + thus annotations are being erased by `MigrationWriter` serializers. + + To mitigate this, I had to introduce custom container for schema deconstruction. + +[Union types syntax](https://peps.python.org/pep-0604/) + `typing.Union` and its special forms, like `typing.Optional`, have its own inheritance chain. + Moreover, `types.UnionType`, introduced in 3.10, do not allow explicit type construction, + only with `X | Y` syntax. Both cases require a dedicated serializer for migration writes. +""" +import sys +import types +import typing as ty + +from django.db.migrations.serializer import BaseSerializer, serializer_factory +from django.db.migrations.writer import MigrationWriter + +from .typing import get_args, get_origin + + +class GenericContainer: + __slots__ = "origin", "args" + + def __init__(self, origin, args: tuple = ()): + self.origin = origin + self.args = args + + @classmethod + def wrap(cls, typ_): + if isinstance(typ_, GenericTypes): + wrapped_args = tuple(map(cls.wrap, get_args(typ_))) + return cls(get_origin(typ_), wrapped_args) + return typ_ + + @classmethod + def unwrap(cls, type_): + if not isinstance(type_, cls): + return type_ + + if not type_.args: + return type_.origin + + unwrapped_args = tuple(map(cls.unwrap, type_.args)) + try: + # This is a fallback for Python < 3.8, please be careful with that + return type_.origin[unwrapped_args] + except TypeError: + return GenericAlias(type_.origin, unwrapped_args) + + def __repr__(self): + return repr(self.unwrap(self)) + + __str__ = __repr__ + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.origin == other.origin and self.args == other.args + if isinstance(other, GenericTypes): + return self == self.wrap(other) + return NotImplemented + + +class GenericSerializer(BaseSerializer): + value: GenericContainer + + def serialize(self): + value = self.value + + tp_repr, imports = serializer_factory(type(value)).serialize() + orig_repr, orig_imports = serializer_factory(value.origin).serialize() + imports.update(orig_imports) + + args = [] + for arg in value.args: + arg_repr, arg_imports = serializer_factory(arg).serialize() + args.append(arg_repr) + imports.update(arg_imports) + + if args: + args_repr = ", ".join(args) + generic_repr = "%s(%s, (%s,))" % (tp_repr, orig_repr, args_repr) + else: + generic_repr = "%s(%s)" % (tp_repr, orig_repr) + + return generic_repr, imports + + +class TypingSerializer(BaseSerializer): + def serialize(self): + orig_module = self.value.__module__ + orig_repr = repr(self.value) + + if not orig_repr.startswith(orig_module): + orig_repr = f"{orig_module}.{orig_repr}" + + return orig_repr, {f"import {orig_module}"} + + +if sys.version_info >= (3, 9): + GenericAlias = types.GenericAlias + GenericTypes: ty.Tuple[ty.Any, ...] = ( + GenericAlias, + type(ty.List[int]), + type(ty.List), + ) +else: + # types.GenericAlias is missing, meaning python version < 3.9, + # which has a different inheritance models for typed generics + GenericAlias = type(ty.List[int]) # noqa + GenericTypes = GenericAlias, type(ty.List) # noqa + + +MigrationWriter.register_serializer(GenericContainer, GenericSerializer) +MigrationWriter.register_serializer(ty.ForwardRef, TypingSerializer) +MigrationWriter.register_serializer(type(ty.Union), TypingSerializer) # type: ignore + + +if sys.version_info >= (3, 10): + UnionType = types.UnionType + + class UnionTypeSerializer(BaseSerializer): + value: UnionType + + def serialize(self): + imports = set() + if isinstance(self.value, (type(ty.Union), types.UnionType)): # type: ignore + imports.add("import typing") + + for arg in get_args(self.value): + _, arg_imports = serializer_factory(arg).serialize() + imports.update(arg_imports) + + return repr(self.value), imports + + MigrationWriter.register_serializer(UnionType, UnionTypeSerializer) diff --git a/django_pydantic_field/compat/functools.py b/django_pydantic_field/compat/functools.py new file mode 100644 index 0000000..d89be7b --- /dev/null +++ b/django_pydantic_field/compat/functools.py @@ -0,0 +1,4 @@ +try: + from functools import cached_property as cached_property +except ImportError: + from django.utils.functional import cached_property as cached_property # type: ignore diff --git a/django_pydantic_field/compat/imports.py b/django_pydantic_field/compat/imports.py new file mode 100644 index 0000000..1587d91 --- /dev/null +++ b/django_pydantic_field/compat/imports.py @@ -0,0 +1,38 @@ +import functools +import importlib +import types + +from .pydantic import PYDANTIC_V1, PYDANTIC_V2, PYDANTIC_VERSION + +__all__ = ("compat_getattr", "compat_dir") + + +def compat_getattr(module_name: str): + module = _import_compat_module(module_name) + return functools.partial(getattr, module) + + +def compat_dir(module_name: str): + compat_module = _import_compat_module(module_name) + return dir(compat_module) + + +def _import_compat_module(module_name: str) -> types.ModuleType: + try: + package, _, module = module_name.partition(".") + except ValueError: + package, module = module_name, "" + + module_path_parts = [package] + if PYDANTIC_V2: + module_path_parts.append("v2") + elif PYDANTIC_V1: + module_path_parts.append("v1") + else: + raise RuntimeError(f"Pydantic {PYDANTIC_VERSION} is not supported") + + if module: + module_path_parts.append(module) + + module_path = ".".join(module_path_parts) + return importlib.import_module(module_path) diff --git a/django_pydantic_field/compat/pydantic.py b/django_pydantic_field/compat/pydantic.py new file mode 100644 index 0000000..e8d7f12 --- /dev/null +++ b/django_pydantic_field/compat/pydantic.py @@ -0,0 +1,6 @@ +from pydantic.version import VERSION as PYDANTIC_VERSION + +__all__ = ("PYDANTIC_V2", "PYDANTIC_V1", "PYDANTIC_VERSION") + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.") diff --git a/django_pydantic_field/compat/typing.py b/django_pydantic_field/compat/typing.py new file mode 100644 index 0000000..fd0865e --- /dev/null +++ b/django_pydantic_field/compat/typing.py @@ -0,0 +1,6 @@ +try: + from typing import get_args as get_args + from typing import get_origin as get_origin +except ImportError: + from typing_extensions import get_args as get_args # type: ignore + from typing_extensions import get_origin as get_origin # type: ignore diff --git a/django_pydantic_field/fields.py b/django_pydantic_field/fields.py index 9ff16b4..d76f309 100644 --- a/django_pydantic_field/fields.py +++ b/django_pydantic_field/fields.py @@ -1,194 +1,4 @@ -import json -import typing as t -from functools import partial +from .compat.imports import compat_getattr, compat_dir -import django -import pydantic -from django.core import exceptions as django_exceptions -from django.db.models.fields import NOT_PROVIDED -from django.db.models.fields.json import JSONField -from django.db.models.query_utils import DeferredAttribute - -from . import base, forms, utils -from ._migration_serializers import GenericContainer, GenericTypes - -__all__ = ("SchemaField",) - - -class SchemaAttribute(DeferredAttribute): - """ - Forces Django to call to_python on fields when setting them. - This is useful when you want to add some custom field data postprocessing. - - Should be added to field like a so: - - ``` - def contribute_to_class(self, cls, name, *args, **kwargs): - super().contribute_to_class(cls, name, *args, **kwargs) - setattr(cls, name, SchemaDeferredAttribute(self)) - ``` - """ - - field: "PydanticSchemaField" - - def __set__(self, obj, value): - obj.__dict__[self.field.attname] = self.field.to_python(value) - - -class PydanticSchemaField(JSONField, t.Generic[base.ST]): - descriptor_class = SchemaAttribute - _is_prepared_schema: bool = False - - def __init__( - self, - *args, - schema: t.Union[t.Type["base.ST"], "GenericContainer", "t.ForwardRef", str, None] = None, - config: "base.ConfigType" = None, - **kwargs, - ): - self.export_params = base.extract_export_kwargs(kwargs, dict.pop) - super().__init__(*args, **kwargs) - - self.config = config - self._resolve_schema(schema) - - def __copy__(self): - _, _, args, kwargs = self.deconstruct() - copied = type(self)(*args, **kwargs) - copied.set_attributes_from_name(self.name) - return copied - - def get_default(self): - value = super().get_default() - return self.to_python(value) - - def to_python(self, value) -> "base.SchemaT": - # Attempt to resolve forward referencing schema if it was not succesful - # during `.contribute_to_class` call - if not self._is_prepared_schema: - self._prepare_model_schema() - try: - assert self.decoder is not None - return self.decoder().decode(value) - except pydantic.ValidationError as e: - raise django_exceptions.ValidationError(e.errors()) - - if django.VERSION[:2] >= (4, 2): - - def get_prep_value(self, value): - if not self._is_prepared_schema: - self._prepare_model_schema() - prep_value = super().get_prep_value(value) - prep_value = self.encoder().encode(prep_value) # type: ignore - return json.loads(prep_value) - - def deconstruct(self): - name, path, args, kwargs = super().deconstruct() - self._deconstruct_schema(kwargs) - self._deconstruct_default(kwargs) - self._deconstruct_config(kwargs) - - kwargs.pop("decoder") - kwargs.pop("encoder") - - return name, path, args, kwargs - - def contribute_to_class(self, cls, name, private_only=False): - if self.schema is None: - self._resolve_schema_from_type_hints(cls, name) - - try: - self._prepare_model_schema(cls) - except NameError: - # Pydantic was not able to resolve forward references, which means - # that it should be postponed until initial access to the field - self._is_prepared_schema = False - - super().contribute_to_class(cls, name, private_only) - - def formfield(self, **kwargs): - if self.schema is None: - self._resolve_schema_from_type_hints(self.model, self.attname) - - owner_model = getattr(self, "model", None) - field_kwargs = dict( - form_class=forms.SchemaField, - schema=self.schema, - config=self.config, - __module__=getattr(owner_model, "__module__", None), - **self.export_params, - ) - field_kwargs.update(kwargs) - return super().formfield(**field_kwargs) - - def _resolve_schema(self, schema): - schema = t.cast(t.Type["base.ST"], GenericContainer.unwrap(schema)) - - self.schema = schema - if schema is not None: - self.serializer_schema = serializer = base.wrap_schema(schema, self.config, self.null) - self.decoder = partial(base.SchemaDecoder, serializer) # type: ignore - self.encoder = partial(base.SchemaEncoder, schema=serializer, export=self.export_params) # type: ignore - - def _resolve_schema_from_type_hints(self, cls, name): - annotated_schema = utils.get_annotated_type(cls, name) - if annotated_schema is None: - raise django_exceptions.FieldError( - f"{cls._meta.label}.{name} needs to be either annotated " - "or `schema=` field attribute should be explicitly passed" - ) - self._resolve_schema(annotated_schema) - - def _prepare_model_schema(self, cls=None): - cls = cls or getattr(self, "model", None) - if cls is not None: - base.prepare_schema(self.serializer_schema, cls) - self._is_prepared_schema = True - - def _deconstruct_default(self, kwargs): - default = kwargs.get("default", NOT_PROVIDED) - - if not (default is NOT_PROVIDED or callable(default)): - if self._is_prepared_schema: - default = self.get_prep_value(default) - kwargs.update(default=default) - - def _deconstruct_schema(self, kwargs): - kwargs.update(schema=GenericContainer.wrap(self.schema)) - - def _deconstruct_config(self, kwargs): - kwargs.update(base.deconstruct_export_kwargs(self.export_params)) - kwargs.update(config=self.config) - - -if t.TYPE_CHECKING: - OptSchemaT = t.Optional[base.SchemaT] - - -@t.overload -def SchemaField( - schema: "t.Union[t.Type[t.Optional[base.ST]], t.ForwardRef]" = ..., - config: "base.ConfigType" = ..., - default: "t.Union[OptSchemaT, t.Callable[[], OptSchemaT]]" = ..., - *args, - null: "t.Literal[True]", - **kwargs, -) -> "t.Optional[base.ST]": - ... - - -@t.overload -def SchemaField( - schema: "t.Union[t.Type[base.ST], t.ForwardRef]" = ..., - config: "base.ConfigType" = ..., - default: "t.Union[base.SchemaT, t.Callable[[], base.SchemaT]]" = ..., - *args, - null: "t.Literal[False]" = ..., - **kwargs, -) -> "base.ST": - ... - - -def SchemaField(schema=None, config=None, *args, **kwargs) -> t.Any: - kwargs.update(schema=schema, config=config) - return PydanticSchemaField(*args, **kwargs) +__getattr__ = compat_getattr(__name__) +__dir__ = compat_dir(__name__) diff --git a/django_pydantic_field/fields.pyi b/django_pydantic_field/fields.pyi new file mode 100644 index 0000000..20aa9ad --- /dev/null +++ b/django_pydantic_field/fields.pyi @@ -0,0 +1,116 @@ +from __future__ import annotations + +import json +import typing as ty +import typing_extensions as te + +import typing_extensions as te +from pydantic import BaseConfig, BaseModel, ConfigDict + +try: + from pydantic.dataclasses import DataclassClassOrWrapper as PydanticDataclass +except ImportError: + from pydantic._internal._dataclasses import PydanticDataclass as PydanticDataclass + +__all__ = ("SchemaField",) + +SchemaT: ty.TypeAlias = ty.Union[ + BaseModel, + PydanticDataclass, + ty.Sequence[ty.Any], + ty.Mapping[str, ty.Any], + ty.Set[ty.Any], + ty.FrozenSet[ty.Any], +] +OptSchemaT: ty.TypeAlias = ty.Optional[SchemaT] +ST = ty.TypeVar("ST", bound=SchemaT) +IncEx = ty.Union[ty.Set[int], ty.Set[str], ty.Dict[int, ty.Any], ty.Dict[str, ty.Any]] +ConfigType = ty.Union[ConfigDict, ty.Type[BaseConfig], type] + +class _FieldKwargs(te.TypedDict, total=False): + name: str | None + verbose_name: str | None + primary_key: bool + max_length: int | None + unique: bool + blank: bool + db_index: bool + rel: ty.Any + editable: bool + serialize: bool + unique_for_date: str | None + unique_for_month: str | None + unique_for_year: str | None + choices: ty.Sequence[ty.Tuple[str, str]] | None + help_text: str | None + db_column: str | None + db_tablespace: str | None + auto_created: bool + validators: ty.Sequence[ty.Callable] | None + error_messages: ty.Mapping[str, str] | None + db_comment: str | None + +class _JSONFieldKwargs(_FieldKwargs, total=False): + encoder: ty.Callable[[], json.JSONEncoder] + decoder: ty.Callable[[], json.JSONDecoder] + +class _ExportKwargs(te.TypedDict, total=False): + strict: bool + from_attributes: bool + mode: te.Literal["json", "python"] + include: IncEx | None + exclude: IncEx | None + by_alias: bool + exclude_unset: bool + exclude_defaults: bool + exclude_none: bool + round_trip: bool + warnings: bool + +class _SchemaFieldKwargs(_JSONFieldKwargs, _ExportKwargs, total=False): ... + +class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs, total=False): + allow_nan: ty.Any + indent: ty.Any + separators: ty.Any + skipkeys: ty.Any + sort_keys: ty.Any + +@ty.overload +def SchemaField( + schema: ty.Type[ST] | None | ty.ForwardRef = ..., + config: ConfigType = ..., + default: OptSchemaT | ty.Callable[[], OptSchemaT] = ..., + *args, + null: ty.Literal[True], + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> ST | None: ... +@ty.overload +def SchemaField( + schema: ty.Type[ST] | ty.ForwardRef = ..., + config: ConfigType = ..., + default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., + *args, + null: ty.Literal[False] = ..., + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> ST: ... +@ty.overload +@te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") +def SchemaField( + schema: ty.Type[ST] | None | ty.ForwardRef = ..., + config: ConfigType = ..., + default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., + *args, + null: ty.Literal[True], + **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs], +) -> ST | None: ... +@ty.overload +@te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") +def SchemaField( + schema: ty.Type[ST] | ty.ForwardRef = ..., + config: ConfigType = ..., + default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., + *args, + null: ty.Literal[False] = ..., + **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs], +) -> ST: ... diff --git a/django_pydantic_field/forms.py b/django_pydantic_field/forms.py index 3185aad..d76f309 100644 --- a/django_pydantic_field/forms.py +++ b/django_pydantic_field/forms.py @@ -1,66 +1,4 @@ -import typing as t -from functools import partial +from .compat.imports import compat_getattr, compat_dir -import pydantic -from django.core.exceptions import ValidationError -from django.forms.fields import InvalidJSONInput, JSONField -from django.utils.translation import gettext_lazy as _ - -from . import base - -__all__ = ("SchemaField",) - - -class SchemaField(JSONField, t.Generic[base.ST]): - default_error_messages = { - "schema_error": _("Schema didn't match. Detail: %(detail)s"), - } - - def __init__( - self, - schema: t.Union[t.Type["base.ST"], t.ForwardRef], - config: t.Optional["base.ConfigType"] = None, - __module__: t.Optional[str] = None, - **kwargs, - ): - self.schema = base.wrap_schema( - schema, - config, - allow_null=not kwargs.get("required", True), - __module__=__module__, - ) - export_params = base.extract_export_kwargs(kwargs, dict.pop) - decoder = partial(base.SchemaDecoder, self.schema) - encoder = partial( - base.SchemaEncoder, - schema=self.schema, - export=export_params, - raise_errors=True, - ) - kwargs.update(encoder=encoder, decoder=decoder) - super().__init__(**kwargs) - - def to_python(self, value): - try: - return super().to_python(value) - except pydantic.ValidationError as e: - raise ValidationError( - self.error_messages["schema_error"], - code="invalid", - params={ - "value": value, - "detail": str(e), - "errors": e.errors(), - "json": e.json(), - }, - ) - - def bound_data(self, data, initial): - try: - return super().bound_data(data, initial) - except pydantic.ValidationError: - return InvalidJSONInput(data) - - def get_bound_field(self, form, field_name): - base.prepare_schema(self.schema, form) - return super().get_bound_field(form, field_name) +__getattr__ = compat_getattr(__name__) +__dir__ = compat_dir(__name__) diff --git a/django_pydantic_field/forms.pyi b/django_pydantic_field/forms.pyi new file mode 100644 index 0000000..646f3a7 --- /dev/null +++ b/django_pydantic_field/forms.pyi @@ -0,0 +1,67 @@ +from __future__ import annotations + +import json +import typing as ty +import typing_extensions as te + +from django.forms.fields import JSONField +from django.forms.widgets import Widget +from django.utils.functional import _StrOrPromise + +from .fields import ST, ConfigType, _ExportKwargs + +__all__ = ("SchemaField",) + +class _FieldKwargs(ty.TypedDict, total=False): + required: bool + widget: Widget | type[Widget] | None + label: _StrOrPromise | None + initial: ty.Any | None + help_text: _StrOrPromise + error_messages: ty.Mapping[str, _StrOrPromise] | None + show_hidden_initial: bool + validators: ty.Sequence[ty.Callable[[ty.Any], None]] + localize: bool + disabled: bool + label_suffix: str | None + +class _CharFieldKwargs(_FieldKwargs, total=False): + max_length: int | None + min_length: int | None + strip: bool + empty_value: ty.Any + +class _JSONFieldKwargs(_CharFieldKwargs, total=False): + encoder: ty.Callable[[], json.JSONEncoder] | None + decoder: ty.Callable[[], json.JSONDecoder] | None + +class _SchemaFieldKwargs(_ExportKwargs, _JSONFieldKwargs, total=False): + allow_null: bool | None + + +class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs, total=False): + allow_nan: ty.Any + indent: ty.Any + separators: ty.Any + skipkeys: ty.Any + sort_keys: ty.Any + + +class SchemaField(JSONField, ty.Generic[ST]): + @ty.overload + def __init__( + self, + schema: ty.Type[ST] | ty.ForwardRef | str, + config: ConfigType | None = ..., + *args, + **kwargs: te.Unpack[_SchemaFieldKwargs], + ) -> None: ... + @ty.overload + @te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") + def __init__( + self, + schema: ty.Type[ST] | ty.ForwardRef | str, + config: ConfigType | None = ..., + *args, + **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs], + ) -> None: ... diff --git a/django_pydantic_field/rest_framework.py b/django_pydantic_field/rest_framework.py index 4001dcd..d76f309 100644 --- a/django_pydantic_field/rest_framework.py +++ b/django_pydantic_field/rest_framework.py @@ -1,260 +1,4 @@ -import typing as t +from .compat.imports import compat_getattr, compat_dir -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args - -from django.conf import settings -from pydantic import BaseModel, ValidationError - -from rest_framework import exceptions, parsers, renderers, serializers -from rest_framework.schemas import openapi -from rest_framework.schemas.utils import is_list_view - -from . import base - -__all__ = ( - "SchemaField", - "SchemaRenderer", - "SchemaParser", - "AutoSchema", -) - -if t.TYPE_CHECKING: - RequestResponseContext = t.Mapping[str, t.Any] - - -class AnnotatedSchemaT(t.Generic[base.ST]): - schema_ctx_attr: t.ClassVar[str] = "schema" - require_explicit_schema: t.ClassVar[bool] = False - _cached_annotation_schema: t.Type[BaseModel] - - def get_schema(self, ctx: "RequestResponseContext") -> t.Optional[t.Type[BaseModel]]: - schema = self.get_context_schema(ctx) - if schema is None: - schema = self.get_annotation_schema(ctx) - - if self.require_explicit_schema and schema is None: - raise ValueError( - "Schema should be either explicitly set with annotation " - "or passed in the context" - ) - - return schema - - def get_context_schema(self, ctx: "RequestResponseContext"): - schema = ctx.get(self.schema_ctx_attr) - if schema is not None: - schema = base.wrap_schema(schema) - base.prepare_schema(schema, ctx.get("view")) - - return schema - - def get_annotation_schema(self, ctx: "RequestResponseContext"): - try: - schema = self._cached_annotation_schema - except AttributeError: - try: - schema = get_args(self.__orig_class__)[0] # type: ignore - except (AttributeError, IndexError): - return None - - self._cached_annotation_schema = schema = base.wrap_schema(schema) - base.prepare_schema(schema, ctx.get("view")) - - return schema - - -class SchemaField(serializers.Field, t.Generic[base.ST]): - decoder: "base.SchemaDecoder[base.ST]" - _is_prepared_schema: bool = False - - def __init__( - self, - schema: t.Type["base.ST"], - config: t.Optional["base.ConfigType"] = None, - **kwargs, - ): - nullable = kwargs.get("allow_null", False) - - self.schema = field_schema = base.wrap_schema(schema, config, nullable) - self.export_params = base.extract_export_kwargs(kwargs, dict.pop) - self.decoder = base.SchemaDecoder(field_schema) - super().__init__(**kwargs) - - def bind(self, field_name, parent): - if not self._is_prepared_schema: - base.prepare_schema(self.schema, parent) - self._is_prepared_schema = True - - super().bind(field_name, parent) - - def to_internal_value(self, data: t.Any) -> t.Optional["base.ST"]: - try: - return self.decoder.decode(data) - except ValidationError as e: - raise serializers.ValidationError(e.errors(), self.field_name) - - def to_representation(self, value: t.Optional["base.ST"]) -> t.Any: - obj = self.schema.parse_obj(value) - raw_obj = obj.dict(**self.export_params) - return raw_obj["__root__"] - - -class SchemaRenderer(AnnotatedSchemaT[base.ST], renderers.JSONRenderer): - schema_ctx_attr = "render_schema" - - def render(self, data, accepted_media_type=None, renderer_context=None): - renderer_context = renderer_context or {} - response = renderer_context.get("response") - if response is not None and response.exception: - return super().render(data, accepted_media_type, renderer_context) - - try: - json_str = self.render_data(data, renderer_context) - except ValidationError as e: - json_str = e.json().encode() - except AttributeError: - json_str = super().render(data, accepted_media_type, renderer_context) - - return json_str - - def render_data(self, data, renderer_ctx) -> bytes: - schema = self.get_schema(renderer_ctx or {}) - if schema is not None: - data = schema(__root__=data) - - export_kw = base.extract_export_kwargs(renderer_ctx) - json_str = data.json(**export_kw, ensure_ascii=self.ensure_ascii) - return json_str.encode() - - -class SchemaParser(AnnotatedSchemaT[base.ST], parsers.JSONParser): - schema_ctx_attr = "parser_schema" - renderer_class = SchemaRenderer - require_explicit_schema = True - - def parse(self, stream, media_type=None, parser_context=None): - parser_context = parser_context or {} - encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) - schema = t.cast(BaseModel, self.get_schema(parser_context)) - - try: - return schema.parse_raw(stream.read(), encoding=encoding).__root__ - except ValidationError as e: - raise exceptions.ParseError(e.errors()) - - -class AutoSchema(openapi.AutoSchema): - get_request_serializer: t.Callable - _get_reference: t.Callable - - def map_field(self, field: serializers.Field): - if isinstance(field, SchemaField): - return field.schema.schema() - return super().map_field(field) - - def map_parsers(self, path: str, method: str): - request_types: t.List[t.Any] = [] - parser_ctx = self.view.get_parser_context(None) - - for parser_type in self.view.parser_classes: - parser = parser_type() - - if isinstance(parser, SchemaParser): - schema = self._extract_openapi_schema(parser, parser_ctx) - if schema is not None: - request_types.append((parser.media_type, schema)) - else: - request_types.append(parser.media_type) - else: - request_types.append(parser.media_type) - - return request_types - - def map_renderers(self, path: str, method: str): - response_types: t.List[t.Any] = [] - renderer_ctx = self.view.get_renderer_context() - - for renderer_type in self.view.renderer_classes: - renderer = renderer_type() - - if isinstance(renderer, SchemaRenderer): - schema = self._extract_openapi_schema(renderer, renderer_ctx) - if schema is not None: - response_types.append((renderer.media_type, schema)) - else: - response_types.append(renderer.media_type) - - elif not isinstance(renderer, renderers.BrowsableAPIRenderer): - response_types.append(renderer.media_type) - - return response_types - - def get_request_body(self, path: str, method: str): - if method not in ('PUT', 'PATCH', 'POST'): - return {} - - self.request_media_types = self.map_parsers(path, method) - serializer = self.get_request_serializer(path, method) - content_schemas = {} - - for request_type in self.request_media_types: - if isinstance(request_type, tuple): - media_type, request_schema = request_type - content_schemas[media_type] = {"schema": request_schema} - else: - serializer_ref = self._get_reference(serializer) - content_schemas[request_type] = {"schema": serializer_ref} - - return {'content': content_schemas} - - def get_responses(self, path: str, method: str): - if method == "DELETE": - return {"204": {"description": ""}} - - self.response_media_types = self.map_renderers(path, method) - status_code = "201" if method == "POST" else "200" - content_types = {} - - for response_type in self.response_media_types: - if isinstance(response_type, tuple): - media_type, response_schema = response_type - content_types[media_type] = {"schema": response_schema} - else: - response_schema = self._get_serializer_response_schema(path, method) - content_types[response_type] = {"schema": response_schema} - - return { - status_code: { - "content": content_types, - "description": "", - } - } - - def _extract_openapi_schema(self, schemable: AnnotatedSchemaT, ctx: "RequestResponseContext"): - schema_model = schemable.get_schema(ctx) - if schema_model is not None: - return schema_model.schema() - return None - - def _get_serializer_response_schema(self, path, method): - serializer = self.get_response_serializer(path, method) - - if not isinstance(serializer, serializers.Serializer): - item_schema = {} - else: - item_schema = self._get_reference(serializer) - - if is_list_view(path, method, self.view): - response_schema = { - "type": "array", - "items": item_schema, - } - paginator = self.get_paginator() - if paginator: - response_schema = paginator.get_paginated_response_schema(response_schema) - else: - response_schema = item_schema - return response_schema +__getattr__ = compat_getattr(__name__) +__dir__ = compat_dir(__name__) diff --git a/django_pydantic_field/rest_framework.pyi b/django_pydantic_field/rest_framework.pyi new file mode 100644 index 0000000..e4a3b87 --- /dev/null +++ b/django_pydantic_field/rest_framework.pyi @@ -0,0 +1,63 @@ +import typing as ty +import typing_extensions as te + +from rest_framework import parsers, renderers +from rest_framework.fields import _DefaultInitial, Field +from rest_framework.validators import Validator + +from django.utils.functional import _StrOrPromise + +from .fields import ST, ConfigType, _ExportKwargs + +__all__ = ("SchemaField", "SchemaParser", "SchemaRenderer") + +class _FieldKwargs(te.TypedDict, ty.Generic[ST], total=False): + read_only: bool + write_only: bool + required: bool + default: _DefaultInitial[ST] + initial: _DefaultInitial[ST] + source: str + label: _StrOrPromise + help_text: _StrOrPromise + style: dict[str, ty.Any] + error_messages: dict[str, _StrOrPromise] + validators: ty.Sequence[Validator[ST]] + allow_null: bool + +class _SchemaFieldKwargs(_FieldKwargs[ST], _ExportKwargs, total=False): + pass + +class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs[ST], total=False): + allow_nan: ty.Any + indent: ty.Any + separators: ty.Any + skipkeys: ty.Any + sort_keys: ty.Any + +class SchemaField(Field, ty.Generic[ST]): + @ty.overload + def __init__( + self, + schema: ty.Type[ST] | ty.ForwardRef | str, + config: ConfigType | None = ..., + *args, + **kwargs: te.Unpack[_SchemaFieldKwargs[ST]], + ) -> None: ... + @ty.overload + @te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") + def __init__( + self, + schema: ty.Type[ST] | ty.ForwardRef | str, + config: ConfigType | None = ..., + *args, + **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs[ST]], + ) -> None: ... + +class SchemaParser(parsers.JSONParser, ty.Generic[ST]): + schema_context_key: ty.ClassVar[str] + config_context_key: ty.ClassVar[str] + +class SchemaRenderer(renderers.JSONRenderer, ty.Generic[ST]): + schema_context_key: ty.ClassVar[str] + config_context_key: ty.ClassVar[str] diff --git a/django_pydantic_field/v1/__init__.py b/django_pydantic_field/v1/__init__.py new file mode 100644 index 0000000..00c87c8 --- /dev/null +++ b/django_pydantic_field/v1/__init__.py @@ -0,0 +1,6 @@ +from django_pydantic_field.compat.pydantic import PYDANTIC_V1 + +if not PYDANTIC_V1: + raise ImportError("django_pydantic_field.v1 package is only compatible with Pydantic v1") + +from .fields import * diff --git a/django_pydantic_field/base.py b/django_pydantic_field/v1/base.py similarity index 100% rename from django_pydantic_field/base.py rename to django_pydantic_field/v1/base.py diff --git a/django_pydantic_field/v1/fields.py b/django_pydantic_field/v1/fields.py new file mode 100644 index 0000000..188ecf4 --- /dev/null +++ b/django_pydantic_field/v1/fields.py @@ -0,0 +1,198 @@ +import json +import typing as t +from functools import partial + +import django +import pydantic +from django.core import exceptions as django_exceptions +from django.db.models.fields import NOT_PROVIDED +from django.db.models.fields.json import JSONField +from django.db.models.query_utils import DeferredAttribute + +from django_pydantic_field.compat.django import GenericContainer, GenericTypes + +from . import base, forms, utils + +__all__ = ("SchemaField",) + + +class SchemaAttribute(DeferredAttribute): + """ + Forces Django to call to_python on fields when setting them. + This is useful when you want to add some custom field data postprocessing. + + Should be added to field like a so: + + ``` + def contribute_to_class(self, cls, name, *args, **kwargs): + super().contribute_to_class(cls, name, *args, **kwargs) + setattr(cls, name, SchemaDeferredAttribute(self)) + ``` + """ + + field: "PydanticSchemaField" + + def __set__(self, obj, value): + obj.__dict__[self.field.attname] = self.field.to_python(value) + + +class PydanticSchemaField(JSONField, t.Generic[base.ST]): + descriptor_class = SchemaAttribute + _is_prepared_schema: bool = False + + def __init__( + self, + *args, + schema: t.Union[t.Type["base.ST"], "GenericContainer", "t.ForwardRef", str, None] = None, + config: t.Optional["base.ConfigType"] = None, + **kwargs, + ): + self.export_params = base.extract_export_kwargs(kwargs, dict.pop) + super().__init__(*args, **kwargs) + + self.config = config + self._resolve_schema(schema) + + def __copy__(self): + _, _, args, kwargs = self.deconstruct() + copied = type(self)(*args, **kwargs) + copied.set_attributes_from_name(self.name) + return copied + + def get_default(self): + value = super().get_default() + return self.to_python(value) + + def to_python(self, value) -> "base.SchemaT": + # Attempt to resolve forward referencing schema if it was not succesful + # during `.contribute_to_class` call + if not self._is_prepared_schema: + self._prepare_model_schema() + try: + assert self.decoder is not None + return self.decoder().decode(value) + except pydantic.ValidationError as e: + raise django_exceptions.ValidationError(e.errors()) + + if django.VERSION[:2] >= (4, 2): + + def get_prep_value(self, value): + if not self._is_prepared_schema: + self._prepare_model_schema() + prep_value = super().get_prep_value(value) + prep_value = self.encoder().encode(prep_value) # type: ignore + return json.loads(prep_value) + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if path.startswith("django_pydantic_field.v1."): + path = path.replace("django_pydantic_field.v1", "django_pydantic_field", 1) + + self._deconstruct_schema(kwargs) + self._deconstruct_default(kwargs) + self._deconstruct_config(kwargs) + + kwargs.pop("decoder") + kwargs.pop("encoder") + + return name, path, args, kwargs + + def contribute_to_class(self, cls, name, private_only=False): + if self.schema is None: + self._resolve_schema_from_type_hints(cls, name) + + try: + self._prepare_model_schema(cls) + except NameError: + # Pydantic was not able to resolve forward references, which means + # that it should be postponed until initial access to the field + self._is_prepared_schema = False + + super().contribute_to_class(cls, name, private_only) + + def formfield(self, **kwargs): + if self.schema is None: + self._resolve_schema_from_type_hints(self.model, self.attname) + + owner_model = getattr(self, "model", None) + field_kwargs = dict( + form_class=forms.SchemaField, + schema=self.schema, + config=self.config, + __module__=getattr(owner_model, "__module__", None), + **self.export_params, + ) + field_kwargs.update(kwargs) + return super().formfield(**field_kwargs) + + def _resolve_schema(self, schema): + schema = t.cast(t.Type["base.ST"], GenericContainer.unwrap(schema)) + + self.schema = schema + if schema is not None: + self.serializer_schema = serializer = base.wrap_schema(schema, self.config, self.null) + self.decoder = partial(base.SchemaDecoder, serializer) # type: ignore + self.encoder = partial(base.SchemaEncoder, schema=serializer, export=self.export_params) # type: ignore + + def _resolve_schema_from_type_hints(self, cls, name): + annotated_schema = utils.get_annotated_type(cls, name) + if annotated_schema is None: + raise django_exceptions.FieldError( + f"{cls._meta.label}.{name} needs to be either annotated " + "or `schema=` field attribute should be explicitly passed" + ) + self._resolve_schema(annotated_schema) + + def _prepare_model_schema(self, cls=None): + cls = cls or getattr(self, "model", None) + if cls is not None: + base.prepare_schema(self.serializer_schema, cls) + self._is_prepared_schema = True + + def _deconstruct_default(self, kwargs): + default = kwargs.get("default", NOT_PROVIDED) + + if not (default is NOT_PROVIDED or callable(default)): + if self._is_prepared_schema: + default = self.get_prep_value(default) + kwargs.update(default=default) + + def _deconstruct_schema(self, kwargs): + kwargs.update(schema=GenericContainer.wrap(self.schema)) + + def _deconstruct_config(self, kwargs): + kwargs.update(base.deconstruct_export_kwargs(self.export_params)) + kwargs.update(config=self.config) + + +if t.TYPE_CHECKING: + OptSchemaT = t.Optional[base.SchemaT] + + +@t.overload +def SchemaField( + schema: "t.Union[t.Type[t.Optional[base.ST]], t.ForwardRef]" = ..., + config: "base.ConfigType" = ..., + default: "t.Union[OptSchemaT, t.Callable[[], OptSchemaT]]" = ..., + *args, + null: "t.Literal[True]", + **kwargs, +) -> "t.Optional[base.ST]": + ... + + +@t.overload +def SchemaField( + schema: "t.Union[t.Type[base.ST], t.ForwardRef]" = ..., + config: "base.ConfigType" = ..., + default: "t.Union[base.SchemaT, t.Callable[[], base.SchemaT]]" = ..., + *args, + null: "t.Literal[False]" = ..., + **kwargs, +) -> "base.ST": + ... + + +def SchemaField(schema=None, config=None, default=None, *args, **kwargs) -> t.Any: + kwargs.update(schema=schema, config=config, default=default) + return PydanticSchemaField(*args, **kwargs) diff --git a/django_pydantic_field/v1/forms.py b/django_pydantic_field/v1/forms.py new file mode 100644 index 0000000..6cfaad4 --- /dev/null +++ b/django_pydantic_field/v1/forms.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import typing as t +from functools import partial + +import pydantic +from django.core.exceptions import ValidationError +from django.forms.fields import InvalidJSONInput, JSONField +from django.utils.translation import gettext_lazy as _ + +from . import base + +__all__ = ("SchemaField",) + + +class SchemaField(JSONField, t.Generic[base.ST]): + default_error_messages = { + "schema_error": _("Schema didn't match. Detail: %(detail)s"), + } + decoder: partial[base.SchemaDecoder] + encoder: partial[base.SchemaEncoder] + + def __init__( + self, + schema: t.Union[t.Type["base.ST"], t.ForwardRef], + config: t.Optional["base.ConfigType"] = None, + __module__: t.Optional[str] = None, + **kwargs, + ): + self.schema = base.wrap_schema( + schema, + config, + allow_null=not kwargs.get("required", True), + __module__=__module__, + ) + export_params = base.extract_export_kwargs(kwargs, dict.pop) + decoder: partial[base.SchemaDecoder] = partial(base.SchemaDecoder, self.schema) + encoder: partial[base.SchemaEncoder] = partial( + base.SchemaEncoder, + schema=self.schema, + export=export_params, + raise_errors=True, + ) + kwargs.update(encoder=encoder, decoder=decoder) + super().__init__(**kwargs) + + def to_python(self, value): + try: + return super().to_python(value) + except pydantic.ValidationError as e: + raise ValidationError( + self.error_messages["schema_error"], + code="invalid", + params={ + "value": value, + "detail": str(e), + "errors": e.errors(), + "json": e.json(), + }, + ) + + def bound_data(self, data, initial): + try: + return super().bound_data(data, initial) + except pydantic.ValidationError: + return InvalidJSONInput(data) + + def get_bound_field(self, form, field_name): + base.prepare_schema(self.schema, form) + return super().get_bound_field(form, field_name) diff --git a/django_pydantic_field/v1/rest_framework.py b/django_pydantic_field/v1/rest_framework.py new file mode 100644 index 0000000..465bfe4 --- /dev/null +++ b/django_pydantic_field/v1/rest_framework.py @@ -0,0 +1,256 @@ +import typing as t + +from django.conf import settings +from pydantic import BaseModel, ValidationError + +from rest_framework import exceptions, parsers, renderers, serializers +from rest_framework.schemas import openapi +from rest_framework.schemas.utils import is_list_view + +from . import base +from django_pydantic_field.compat.typing import get_args + +__all__ = ( + "SchemaField", + "SchemaRenderer", + "SchemaParser", + "AutoSchema", +) + +if t.TYPE_CHECKING: + RequestResponseContext = t.Mapping[str, t.Any] + + +class AnnotatedSchemaT(t.Generic[base.ST]): + schema_ctx_attr: t.ClassVar[str] = "schema" + require_explicit_schema: t.ClassVar[bool] = False + _cached_annotation_schema: t.Type[BaseModel] + + def get_schema(self, ctx: "RequestResponseContext") -> t.Optional[t.Type[BaseModel]]: + schema = self.get_context_schema(ctx) + if schema is None: + schema = self.get_annotation_schema(ctx) + + if self.require_explicit_schema and schema is None: + raise ValueError( + "Schema should be either explicitly set with annotation " + "or passed in the context" + ) + + return schema + + def get_context_schema(self, ctx: "RequestResponseContext"): + schema = ctx.get(self.schema_ctx_attr) + if schema is not None: + schema = base.wrap_schema(schema) + base.prepare_schema(schema, ctx.get("view")) + + return schema + + def get_annotation_schema(self, ctx: "RequestResponseContext"): + try: + schema = self._cached_annotation_schema + except AttributeError: + try: + schema = get_args(self.__orig_class__)[0] # type: ignore + except (AttributeError, IndexError): + return None + + self._cached_annotation_schema = schema = base.wrap_schema(schema) + base.prepare_schema(schema, ctx.get("view")) + + return schema + + +class SchemaField(serializers.Field, t.Generic[base.ST]): + decoder: "base.SchemaDecoder[base.ST]" + _is_prepared_schema: bool = False + + def __init__( + self, + schema: t.Type["base.ST"], + config: t.Optional["base.ConfigType"] = None, + **kwargs, + ): + nullable = kwargs.get("allow_null", False) + + self.schema = field_schema = base.wrap_schema(schema, config, nullable) + self.export_params = base.extract_export_kwargs(kwargs, dict.pop) + self.decoder = base.SchemaDecoder(field_schema) + super().__init__(**kwargs) + + def bind(self, field_name, parent): + if not self._is_prepared_schema: + base.prepare_schema(self.schema, parent) + self._is_prepared_schema = True + + super().bind(field_name, parent) + + def to_internal_value(self, data: t.Any) -> t.Optional["base.ST"]: + try: + return self.decoder.decode(data) + except ValidationError as e: + raise serializers.ValidationError(e.errors(), self.field_name) # type: ignore[arg-type] + + def to_representation(self, value: t.Optional["base.ST"]) -> t.Any: + obj = self.schema.parse_obj(value) + raw_obj = obj.dict(**self.export_params) + return raw_obj["__root__"] + + +class SchemaRenderer(AnnotatedSchemaT[base.ST], renderers.JSONRenderer): + schema_ctx_attr = "render_schema" + + def render(self, data, accepted_media_type=None, renderer_context=None): + renderer_context = renderer_context or {} + response = renderer_context.get("response") + if response is not None and response.exception: + return super().render(data, accepted_media_type, renderer_context) + + try: + json_str = self.render_data(data, renderer_context) + except ValidationError as e: + json_str = e.json().encode() + except AttributeError: + json_str = super().render(data, accepted_media_type, renderer_context) + + return json_str + + def render_data(self, data, renderer_ctx) -> bytes: + schema = self.get_schema(renderer_ctx or {}) + if schema is not None: + data = schema(__root__=data) + + export_kw = base.extract_export_kwargs(renderer_ctx) + json_str = data.json(**export_kw, ensure_ascii=self.ensure_ascii) + return json_str.encode() + + +class SchemaParser(AnnotatedSchemaT[base.ST], parsers.JSONParser): + schema_ctx_attr = "parser_schema" + renderer_class = SchemaRenderer + require_explicit_schema = True + + def parse(self, stream, media_type=None, parser_context=None): + parser_context = parser_context or {} + encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) + schema = t.cast(BaseModel, self.get_schema(parser_context)) + + try: + return schema.parse_raw(stream.read(), encoding=encoding).__root__ + except ValidationError as e: + raise exceptions.ParseError(e.errors()) + + +class AutoSchema(openapi.AutoSchema): + get_request_serializer: t.Callable + _get_reference: t.Callable + + def map_field(self, field: serializers.Field): + if isinstance(field, SchemaField): + return field.schema.schema() + return super().map_field(field) + + def map_parsers(self, path: str, method: str): + request_types: t.List[t.Any] = [] + parser_ctx = self.view.get_parser_context(None) + + for parser_type in self.view.parser_classes: + parser = parser_type() + + if isinstance(parser, SchemaParser): + schema = self._extract_openapi_schema(parser, parser_ctx) + if schema is not None: + request_types.append((parser.media_type, schema)) + else: + request_types.append(parser.media_type) + else: + request_types.append(parser.media_type) + + return request_types + + def map_renderers(self, path: str, method: str): + response_types: t.List[t.Any] = [] + renderer_ctx = self.view.get_renderer_context() + + for renderer_type in self.view.renderer_classes: + renderer = renderer_type() + + if isinstance(renderer, SchemaRenderer): + schema = self._extract_openapi_schema(renderer, renderer_ctx) + if schema is not None: + response_types.append((renderer.media_type, schema)) + else: + response_types.append(renderer.media_type) + + elif not isinstance(renderer, renderers.BrowsableAPIRenderer): + response_types.append(renderer.media_type) + + return response_types + + def get_request_body(self, path: str, method: str): + if method not in ('PUT', 'PATCH', 'POST'): + return {} + + self.request_media_types = self.map_parsers(path, method) + serializer = self.get_request_serializer(path, method) + content_schemas = {} + + for request_type in self.request_media_types: + if isinstance(request_type, tuple): + media_type, request_schema = request_type + content_schemas[media_type] = {"schema": request_schema} + else: + serializer_ref = self._get_reference(serializer) + content_schemas[request_type] = {"schema": serializer_ref} + + return {'content': content_schemas} + + def get_responses(self, path: str, method: str): + if method == "DELETE": + return {"204": {"description": ""}} + + self.response_media_types = self.map_renderers(path, method) + status_code = "201" if method == "POST" else "200" + content_types = {} + + for response_type in self.response_media_types: + if isinstance(response_type, tuple): + media_type, response_schema = response_type + content_types[media_type] = {"schema": response_schema} + else: + response_schema = self._get_serializer_response_schema(path, method) + content_types[response_type] = {"schema": response_schema} + + return { + status_code: { + "content": content_types, + "description": "", + } + } + + def _extract_openapi_schema(self, schemable: AnnotatedSchemaT, ctx: "RequestResponseContext"): + schema_model = schemable.get_schema(ctx) + if schema_model is not None: + return schema_model.schema() + return None + + def _get_serializer_response_schema(self, path, method): + serializer = self.get_response_serializer(path, method) + + if not isinstance(serializer, serializers.Serializer): + item_schema = {} + else: + item_schema = self._get_reference(serializer) + + if is_list_view(path, method, self.view): + response_schema = { + "type": "array", + "items": item_schema, + } + paginator = self.get_paginator() + if paginator: + response_schema = paginator.get_paginated_response_schema(response_schema) + else: + response_schema = item_schema + return response_schema diff --git a/django_pydantic_field/utils.py b/django_pydantic_field/v1/utils.py similarity index 100% rename from django_pydantic_field/utils.py rename to django_pydantic_field/v1/utils.py diff --git a/django_pydantic_field/v2/__init__.py b/django_pydantic_field/v2/__init__.py new file mode 100644 index 0000000..c905b8b --- /dev/null +++ b/django_pydantic_field/v2/__init__.py @@ -0,0 +1,6 @@ +from django_pydantic_field.compat.pydantic import PYDANTIC_V2 + +if not PYDANTIC_V2: + raise ImportError("django_pydantic_field.v2 package is only compatible with Pydantic v2") + +from .fields import SchemaField as SchemaField diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py new file mode 100644 index 0000000..d8e6a8e --- /dev/null +++ b/django_pydantic_field/v2/fields.py @@ -0,0 +1,232 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from django.core import checks, exceptions +from django.core.serializers.json import DjangoJSONEncoder +from django.db.models.expressions import BaseExpression, Col, Value +from django.db.models.fields import NOT_PROVIDED +from django.db.models.fields.json import JSONField +from django.db.models.lookups import Transform +from django.db.models.query_utils import DeferredAttribute + +from django_pydantic_field.compat import deprecation +from django_pydantic_field.compat.django import GenericContainer + +from . import forms, types + +if ty.TYPE_CHECKING: + import json + import typing_extensions as te + + class _SchemaFieldKwargs(types.ExportKwargs, total=False): + # django.db.models.fields.Field kwargs + name: str | None + verbose_name: str | None + primary_key: bool + max_length: int | None + unique: bool + blank: bool + db_index: bool + rel: ty.Any + editable: bool + serialize: bool + unique_for_date: str | None + unique_for_month: str | None + unique_for_year: str | None + choices: ty.Sequence[ty.Tuple[str, str]] | None + help_text: str | None + db_column: str | None + db_tablespace: str | None + auto_created: bool + validators: ty.Sequence[ty.Callable] | None + error_messages: ty.Mapping[str, str] | None + db_comment: str | None + # django.db.models.fields.json.JSONField kwargs + encoder: ty.Callable[[], json.JSONEncoder] + decoder: ty.Callable[[], json.JSONDecoder] + + +__all__ = ("SchemaField",) + + +class SchemaAttribute(DeferredAttribute): + field: PydanticSchemaField + + def __set_name__(self, owner, name): + self.field.adapter.bind(owner, name) + + def __set__(self, obj, value): + obj.__dict__[self.field.attname] = self.field.to_python(value) + + +class PydanticSchemaField(JSONField, ty.Generic[types.ST]): + descriptor_class: type[DeferredAttribute] = SchemaAttribute + adapter: types.SchemaAdapter + + def __init__( + self, + *args, + schema: type[types.ST] | GenericContainer | ty.ForwardRef | str | None = None, + config: pydantic.ConfigDict | None = None, + **kwargs, + ): + kwargs.setdefault("encoder", DjangoJSONEncoder) + self.export_kwargs = export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) + super().__init__(*args, **kwargs) + + self.schema = GenericContainer.unwrap(schema) + self.config = config + self.adapter = types.SchemaAdapter(schema, config, None, self.get_attname(), self.null, **export_kwargs) + + def __copy__(self): + _, _, args, kwargs = self.deconstruct() + copied = self.__class__(*args, **kwargs) + copied.set_attributes_from_name(self.name) + return copied + + def deconstruct(self) -> ty.Any: + field_name, import_path, args, kwargs = super().deconstruct() + if import_path.startswith("django_pydantic_field.v2."): + import_path = import_path.replace("django_pydantic_field.v2", "django_pydantic_field", 1) + + default = kwargs.get("default", NOT_PROVIDED) + if default is not NOT_PROVIDED and not callable(default): + kwargs["default"] = self.adapter.dump_python(default, include=None, exclude=None, round_trip=True) + + prep_schema = GenericContainer.wrap(self.adapter.prepared_schema) + kwargs.update(schema=prep_schema, config=self.config, **self.export_kwargs) + + return field_name, import_path, args, kwargs + + def contribute_to_class(self, cls: types.DjangoModelType, name: str, private_only: bool = False) -> None: + self.adapter.bind(cls, name) + super().contribute_to_class(cls, name, private_only) + + def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: + # Remove checks of using mutable datastructure instances as `default` values, since they'll be adapted anyway. + performed_checks = [check for check in super().check(**kwargs) if check.id != "fields.E010"] + try: + # Test that the schema could be resolved in runtime, even if it contains forward references. + self.adapter.validate_schema() + except types.ImproperlyConfiguredSchema as exc: + message = f"Cannot resolve the schema. Original error: \n{exc.args[0]}" + performed_checks.append(checks.Error(message, obj=self, id="pydantic.E001")) + + if self.has_default(): + try: + # Test that the default value conforms to the schema. + self.get_prep_value(self.get_default()) + except pydantic.ValidationError as exc: + message = f"Default value cannot be adapted to the schema. Pydantic error: \n{str(exc)}" + performed_checks.append(checks.Error(message, obj=self, id="pydantic.E002")) + + if {"include", "exclude"} & self.export_kwargs.keys(): + # Try to prepare the default value to test export ability against it. + schema_default = self.get_default() + if schema_default is None: + # If the default value is not set, try to get the default value from the schema. + prep_value = self.adapter.type_adapter.get_default_value() + if prep_value is not None: + prep_value = prep_value.value + schema_default = prep_value + + if schema_default is not None: + try: + # Perform the full round-trip transformation to test the export ability. + self.adapter.validate_python(self.get_prep_value(self.default)) + except pydantic.ValidationError as exc: + message = f"Export arguments may lead to data integrity problems. Pydantic error: \n{str(exc)}" + hint = "Please review `import` and `export` arguments." + performed_checks.append(checks.Warning(message, obj=self, hint=hint, id="pydantic.E003")) + + return performed_checks + + def validate(self, value: ty.Any, model_instance: ty.Any) -> None: + value = self.adapter.validate_python(value) + return super(JSONField, self).validate(value, model_instance) + + def to_python(self, value: ty.Any): + try: + return self.adapter.validate_python(value) + except pydantic.ValidationError as exc: + error_params = {"errors": exc.errors(), "field": self} + raise exceptions.ValidationError(exc.json(), code="invalid", params=error_params) from exc + + def get_prep_value(self, value: ty.Any): + if isinstance(value, Value) and isinstance(value.output_field, self.__class__): + # Prepare inner value for `Value`-wrapped expressions. + value = Value(self.get_prep_value(value.value), value.output_field) + elif not isinstance(value, BaseExpression): + # Prepare the value if it is not a query expression. + prep_value = self.adapter.validate_python(value) + value = self.adapter.dump_python(prep_value) + + return super().get_prep_value(value) + + def get_transform(self, lookup_name: str): + transform: ty.Any = super().get_transform(lookup_name) + if transform is not None: + transform = SchemaKeyTransformAdapter(transform) + return transform + + def get_default(self) -> types.ST: + default_value = super().get_default() + return self.adapter.validate_python(default_value) + + def formfield(self, **kwargs): + field_kwargs = dict( + form_class=forms.SchemaField, + # Trying to resolve the schema before passing it to the formfield, since in Django < 4.0, + # formfield is unbound during form validation and is not able to resolve forward refs defined in the model. + schema=self.adapter.prepared_schema, + config=self.config, + **self.export_kwargs, + ) + field_kwargs.update(kwargs) + return super().formfield(**field_kwargs) # type: ignore + + +class SchemaKeyTransformAdapter: + """An adapter for creating key transforms for schema field lookups.""" + + def __init__(self, transform: type[Transform]): + self.transform = transform + + def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None: + """All transforms should bypass the SchemaField's adaptaion with `get_prep_value`, + and routed to JSONField's `get_prep_value` for further processing.""" + if isinstance(col, BaseExpression): + col = col.copy() + col.output_field = super(PydanticSchemaField, col.output_field) # type: ignore + return self.transform(col, *args, **kwargs) + + +@ty.overload +def SchemaField( + schema: type[types.ST | None] | ty.ForwardRef = ..., + config: pydantic.ConfigDict = ..., + default: types.SchemaT | None | ty.Callable[[], types.SchemaT | None] = ..., + *args, + null: ty.Literal[True], + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> types.ST | None: + ... + + +@ty.overload +def SchemaField( + schema: type[types.ST] | ty.ForwardRef = ..., + config: pydantic.ConfigDict = ..., + default: ty.Union[types.SchemaT, ty.Callable[[], types.SchemaT]] = ..., + *args, + null: ty.Literal[False] = ..., + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> types.ST: + ... + + +def SchemaField(schema=None, config=None, default=NOT_PROVIDED, *args, **kwargs): # type: ignore + deprecation.truncate_deprecated_v1_export_kwargs(kwargs) + return PydanticSchemaField(*args, schema=schema, config=config, default=default, **kwargs) diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py new file mode 100644 index 0000000..b2c49d2 --- /dev/null +++ b/django_pydantic_field/v2/forms.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from django.core.exceptions import ValidationError +from django.forms.fields import InvalidJSONInput, JSONField, JSONString +from django.utils.translation import gettext_lazy as _ + +from django_pydantic_field.compat import deprecation +from . import types + + +class SchemaField(JSONField, ty.Generic[types.ST]): + adapter: types.SchemaAdapter + default_error_messages = { + "schema_error": _("Schema didn't match for %(title)s."), + } + + def __init__( + self, + schema: type[types.ST] | ty.ForwardRef | str, + config: pydantic.ConfigDict | None = None, + allow_null: bool | None = None, + *args, + **kwargs, + ): + deprecation.truncate_deprecated_v1_export_kwargs(kwargs) + + self.schema = schema + self.config = config + self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) + self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null, **self.export_kwargs) + super().__init__(*args, **kwargs) + + def get_bound_field(self, form: ty.Any, field_name: str): + if not self.adapter.is_bound: + self.adapter.bind(form, field_name) + return super().get_bound_field(form, field_name) + + def bound_data(self, data: ty.Any, initial: ty.Any): + if self.disabled: + return self.adapter.validate_python(initial) + if data is None: + return None + try: + return self.adapter.validate_json(data) + except pydantic.ValidationError: + return InvalidJSONInput(data) + + def to_python(self, value: ty.Any) -> ty.Any: + if self.disabled: + return value + if value in self.empty_values: + return None + elif isinstance(value, JSONString): + return value + try: + converted = self.adapter.validate_json(value) + except pydantic.ValidationError as exc: + error_params = {"value": value, "title": exc.title, "detail": exc.json(), "errors": exc.errors()} + raise ValidationError(self.error_messages["schema_error"], code="invalid", params=error_params) from exc + + if isinstance(converted, str): + return JSONString(converted) + + return converted + + def prepare_value(self, value): + if isinstance(value, InvalidJSONInput): + return value + + value = self.adapter.validate_python(value) + return self.adapter.dump_json(value).decode() + + def has_changed(self, initial: ty.Any | None, data: ty.Any | None) -> bool: + if super(JSONField, self).has_changed(initial, data): + return True + return self.adapter.dump_json(initial) != self.adapter.dump_json(data) diff --git a/django_pydantic_field/v2/rest_framework/__init__.py b/django_pydantic_field/v2/rest_framework/__init__.py new file mode 100644 index 0000000..f151648 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/__init__.py @@ -0,0 +1,20 @@ +from .fields import SchemaField as SchemaField +from .parsers import SchemaParser as SchemaParser +from .renderers import SchemaRenderer as SchemaRenderer + +_DEPRECATED_MESSAGE = ( + "`django_pydantic_field.rest_framework.AutoSchema` is deprecated, " + "please use explicit imports for `django_pydantic_field.rest_framework.openapi.AutoSchema` " + "or `django_pydantic_field.rest_framework.coreapi.AutoSchema` instead." +) + +def __getattr__(key): + if key == "AutoSchema": + import warnings + + from .openapi import AutoSchema + + warnings.warn(_DEPRECATED_MESSAGE, DeprecationWarning) + return AutoSchema + + raise AttributeError(key) diff --git a/django_pydantic_field/v2/rest_framework/coreapi.py b/django_pydantic_field/v2/rest_framework/coreapi.py new file mode 100644 index 0000000..bc1fe79 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/coreapi.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import typing as ty + +from rest_framework.compat import coreapi, coreschema +from rest_framework.schemas.coreapi import AutoSchema as _CoreAPIAutoSchema + +from .fields import SchemaField + +if ty.TYPE_CHECKING: + from coreschema.schemas import Schema as _CoreAPISchema # type: ignore[import-untyped] + from rest_framework.serializers import Serializer + +__all__ = ("AutoSchema",) + + +class AutoSchema(_CoreAPIAutoSchema): + """Not implemented yet.""" + + def get_serializer_fields(self, path: str, method: str) -> list[coreapi.Field]: + base_field_schemas = super().get_serializer_fields(path, method) + if not base_field_schemas: + return [] + + serializer: Serializer = self.view.get_serializer() + pydantic_schema_fields: dict[str, coreapi.Field] = {} + + for field_name, field in serializer.fields.items(): + if not field.read_only and isinstance(field, SchemaField): + pydantic_schema_fields[field_name] = self._prepare_schema_field(field) + + if not pydantic_schema_fields: + return base_field_schemas + + return [pydantic_schema_fields.get(field.name, field) for field in base_field_schemas] + + def _prepare_schema_field(self, field: SchemaField) -> coreapi.Field: + build_core_schema = SimpleCoreSchemaTransformer(field.adapter.json_schema()) + return coreapi.Field( + name=field.field_name, + location="form", + required=field.required, + schema=build_core_schema(), + description=field.help_text, + ) + + +class SimpleCoreSchemaTransformer: + def __init__(self, json_schema: dict[str, ty.Any]): + self.root_schema = json_schema + + def __call__(self) -> _CoreAPISchema: + definitions = self._populate_definitions() + root_schema = self._transform(self.root_schema) + + if definitions: + if isinstance(root_schema, coreschema.Ref): + schema_name = root_schema.ref_name + else: + schema_name = root_schema.title or "Schema" + definitions[schema_name] = root_schema + + root_schema = coreschema.RefSpace(definitions, schema_name) + + return root_schema + + def _populate_definitions(self): + schemas = self.root_schema.get("$defs", {}) + return {ref_name: self._transform(schema) for ref_name, schema in schemas.items()} + + def _transform(self, schema) -> _CoreAPISchema: + schemas = [ + *self._transform_type_schema(schema), + *self._transform_composite_types(schema), + *self._transform_ref(schema), + ] + if not schemas: + schema = self._transform_any(schema) + elif len(schemas) == 1: + schema = schemas[0] + else: + schema = coreschema.Intersection(schemas) + return schema + + def _transform_type_schema(self, schema): + schema_type = schema.get("type", None) + + if schema_type is not None: + schema_types = schema_type if isinstance(schema_type, list) else [schema_type] + + for schema_type in schema_types: + transformer = getattr(self, f"transform_{schema_type}") + yield transformer(schema) + + def _transform_composite_types(self, schema): + for operation, transform_name in self.COMBINATOR_TYPES.items(): + value = schema.get(operation, None) + + if value is not None: + transformer = getattr(self, transform_name) + yield transformer(schema) + + def _transform_ref(self, schema): + reference = schema.get("$ref", None) + if reference is not None: + yield coreschema.Ref(reference) + + def _transform_any(self, schema): + attrs = self._get_common_attributes(schema) + return coreschema.Anything(**attrs) + + # Simple types transformers + + def transform_object(self, schema) -> coreschema.Object: + properties = schema.get("properties", None) + if properties is not None: + properties = {prop: self._transform(prop_schema) for prop, prop_schema in properties.items()} + + pattern_props = schema.get("patternProperties", None) + if pattern_props is not None: + pattern_props = {pattern: self._transform(prop_schema) for pattern, prop_schema in pattern_props.items()} + + extra_props = schema.get("additionalProperties", None) + if extra_props is not None: + if extra_props not in (True, False): + extra_props = self._transform(schema) + + return coreschema.Object( + properties=properties, + pattern_properties=pattern_props, + additional_properties=extra_props, # type: ignore + min_properties=schema.get("minProperties"), + max_properties=schema.get("maxProperties"), + required=schema.get("required", []), + **self._get_common_attributes(schema), + ) + + def transform_array(self, schema) -> coreschema.Array: + items = schema.get("items", None) + if items is not None: + if isinstance(items, list): + items = list(map(self._transform, items)) + elif items not in (True, False): + items = self._transform(items) + + extra_items = schema.get("additionalItems") + if extra_items is not None: + if isinstance(items, list): + items = list(map(self._transform, items)) + elif items not in (True, False): + items = self._transform(items) + + return coreschema.Array( + items=items, + additional_items=extra_items, + min_items=schema.get("minItems"), + max_items=schema.get("maxItems"), + unique_items=schema.get("uniqueItems"), + **self._get_common_attributes(schema), + ) + + def transform_boolean(self, schema) -> coreschema.Boolean: + attrs = self._get_common_attributes(schema) + return coreschema.Boolean(**attrs) + + def transform_integer(self, schema) -> coreschema.Integer: + return self._transform_numeric(schema, cls=coreschema.Integer) + + def transform_null(self, schema) -> coreschema.Null: + attrs = self._get_common_attributes(schema) + return coreschema.Null(**attrs) + + def transform_number(self, schema) -> coreschema.Number: + return self._transform_numeric(schema, cls=coreschema.Number) + + def transform_string(self, schema) -> coreschema.String: + return coreschema.String( + min_length=schema.get("minLength"), + max_length=schema.get("maxLength"), + pattern=schema.get("pattern"), + format=schema.get("format"), + **self._get_common_attributes(schema), + ) + + # Composite types transformers + + COMBINATOR_TYPES = { + "anyOf": "transform_union", + "oneOf": "transform_exclusive_union", + "allOf": "transform_intersection", + "not": "transform_not", + } + + def transform_union(self, schema): + return coreschema.Union([self._transform(option) for option in schema["anyOf"]]) + + def transform_exclusive_union(self, schema): + return coreschema.ExclusiveUnion([self._transform(option) for option in schema["oneOf"]]) + + def transform_intersection(self, schema): + return coreschema.Intersection([self._transform(option) for option in schema["allOf"]]) + + def transform_not(self, schema): + return coreschema.Not(self._transform(schema["not"])) + + # Common schema transformations + + def _get_common_attributes(self, schema): + return dict( + title=schema.get("title"), + description=schema.get("description"), + default=schema.get("default"), + ) + + def _transform_numeric(self, schema, cls): + return cls( + minimum=schema.get("minimum"), + maximum=schema.get("maximum"), + exclusive_minimum=schema.get("exclusiveMinimum"), + exclusive_maximum=schema.get("exclusiveMaximum"), + multiple_of=schema.get("multipleOf"), + **self._get_common_attributes(schema), + ) diff --git a/django_pydantic_field/v2/rest_framework/fields.py b/django_pydantic_field/v2/rest_framework/fields.py new file mode 100644 index 0000000..b15aa43 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/fields.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from rest_framework import exceptions, fields + +from django_pydantic_field.compat import deprecation +from django_pydantic_field.v2 import types + +if ty.TYPE_CHECKING: + from collections.abc import Mapping + + from rest_framework.serializers import BaseSerializer + + RequestResponseContext = Mapping[str, ty.Any] + + +class SchemaField(fields.Field, ty.Generic[types.ST]): + adapter: types.SchemaAdapter + + def __init__( + self, + schema: type[types.ST], + config: pydantic.ConfigDict | None = None, + *args, + allow_null: bool = False, + **kwargs, + ): + deprecation.truncate_deprecated_v1_export_kwargs(kwargs) + + self.schema = schema + self.config = config + self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) + self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null, **self.export_kwargs) + super().__init__(*args, **kwargs) + + def bind(self, field_name: str, parent: BaseSerializer): + if not self.adapter.is_bound: + self.adapter.bind(type(parent), field_name) + super().bind(field_name, parent) + + def to_internal_value(self, data: ty.Any): + try: + if isinstance(data, (str, bytes)): + return self.adapter.validate_json(data) + return self.adapter.validate_python(data) + except pydantic.ValidationError as exc: + raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore + + def to_representation(self, value: ty.Optional[types.ST]): + try: + prep_value = self.adapter.validate_python(value) + return self.adapter.dump_python(prep_value) + except pydantic.ValidationError as exc: + raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore diff --git a/django_pydantic_field/v2/rest_framework/mixins.py b/django_pydantic_field/v2/rest_framework/mixins.py new file mode 100644 index 0000000..a8a8c59 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/mixins.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import typing as ty + +from django_pydantic_field.compat.typing import get_args +from django_pydantic_field.v2 import types + +if ty.TYPE_CHECKING: + from collections.abc import Mapping + + RequestResponseContext = Mapping[str, ty.Any] + + +class AnnotatedAdapterMixin(ty.Generic[types.ST]): + media_type: ty.ClassVar[str] + schema_context_key: ty.ClassVar[str] = "response_schema" + config_context_key: ty.ClassVar[str] = "response_schema_config" + + def get_adapter(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + adapter = self._make_adapter_from_context(ctx) + if adapter is None: + adapter = self._make_adapter_from_annotation(ctx) + + return adapter + + def _make_adapter_from_context(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + schema = ctx.get(self.schema_context_key) + if schema is not None: + config = ctx.get(self.config_context_key) + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(ctx)) + return types.SchemaAdapter(schema, config, type(ctx.get("view")), None, **export_kwargs) + + return schema + + def _make_adapter_from_annotation(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + try: + schema = get_args(self.__orig_class__)[0] # type: ignore + except (AttributeError, IndexError): + return None + + config = ctx.get(self.config_context_key) + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(ctx)) + return types.SchemaAdapter(schema, config, type(ctx.get("view")), None, **export_kwargs) diff --git a/django_pydantic_field/v2/rest_framework/openapi.py b/django_pydantic_field/v2/rest_framework/openapi.py new file mode 100644 index 0000000..4ccb265 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/openapi.py @@ -0,0 +1,189 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from rest_framework import serializers +from rest_framework.schemas import openapi +from rest_framework.schemas import utils as drf_schema_utils +from rest_framework.test import APIRequestFactory + +from . import fields, parsers, renderers +from ..utils import get_origin_type + +if ty.TYPE_CHECKING: + from collections.abc import Iterable + + from pydantic.json_schema import JsonSchemaMode + + from . import mixins + + +class AutoSchema(openapi.AutoSchema): + REF_TEMPLATE_PREFIX = "#/components/schemas/{model}" + + def __init__(self, tags=None, operation_id_base=None, component_name=None) -> None: + super().__init__(tags, operation_id_base, component_name) + self.collected_schema_defs: dict[str, ty.Any] = {} + self.collected_adapter_schema_refs: dict[str, ty.Any] = {} + self.adapter_mode: JsonSchemaMode = "validation" + self.rf = APIRequestFactory() + + def get_components(self, path: str, method: str) -> dict[str, ty.Any]: + if method.lower() == "delete": + return {} + + request_serializer = self.get_request_serializer(path, method) # type: ignore[attr-defined] + response_serializer = self.get_response_serializer(path, method) # type: ignore[attr-defined] + + components = { + **self._collect_serializer_component(response_serializer, "serialization"), + **self._collect_serializer_component(request_serializer, "validation"), + } + if self.collected_schema_defs: + components.update(self.collected_schema_defs) + self.collected_schema_defs = {} + + return components + + def get_request_body(self, path, method): + if method not in ("PUT", "PATCH", "POST"): + return {} + + self.request_media_types = self.map_parsers(path, method) + + request_schema = {} + serializer = self.get_request_serializer(path, method) + if isinstance(serializer, serializers.Serializer): + request_schema = self.get_reference(serializer) + + schema_content = {} + + for parser, ct in zip(self.view.parser_classes, self.request_media_types): + if issubclass(get_origin_type(parser), parsers.SchemaParser): + parser_schema = self.collected_adapter_schema_refs[repr(parser)] + else: + parser_schema = request_schema + + schema_content[ct] = {"schema": parser_schema} + + return {"content": schema_content} + + def get_responses(self, path, method): + if method == "DELETE": + return {"204": {"description": ""}} + + self.response_media_types = self.map_renderers(path, method) + serializer = self.get_response_serializer(path, method) + + response_schema = {} + if isinstance(serializer, serializers.Serializer): + response_schema = self.get_reference(serializer) + + is_list_view = drf_schema_utils.is_list_view(path, method, self.view) + if is_list_view: + response_schema = self._get_paginated_schema(response_schema) + + schema_content = {} + for renderer, ct in zip(self.view.renderer_classes, self.response_media_types): + if issubclass(get_origin_type(renderer), renderers.SchemaRenderer): + renderer_schema = {"schema": self.collected_adapter_schema_refs[repr(renderer)]} + if is_list_view: + renderer_schema = self._get_paginated_schema(renderer_schema) + schema_content[ct] = renderer_schema + else: + schema_content[ct] = response_schema + + status_code = "201" if method == "POST" else "200" + return { + status_code: { + "content": schema_content, + "description": "", + } + } + + def map_parsers(self, path: str, method: str) -> list[str]: + schema_parsers = [] + media_types = [] + + for parser in self.view.parser_classes: + media_types.append(parser.media_type) + if issubclass(get_origin_type(parser), parsers.SchemaParser): + schema_parsers.append(parser) + + if schema_parsers: + self.adapter_mode = "validation" + request = self.rf.generic(method, path) + schemas = self._collect_adapter_components(schema_parsers, self.view.get_parser_context(request)) + self.collected_adapter_schema_refs.update(schemas) + + return media_types + + def map_renderers(self, path: str, method: str) -> list[str]: + schema_renderers = [] + media_types = [] + + for renderer in self.view.renderer_classes: + media_types.append(renderer.media_type) + if issubclass(get_origin_type(renderer), renderers.SchemaRenderer): + schema_renderers.append(renderer) + + if schema_renderers: + self.adapter_mode = "serialization" + schemas = self._collect_adapter_components(schema_renderers, self.view.get_renderer_context()) + self.collected_adapter_schema_refs.update(schemas) + + return media_types + + def map_serializer(self, serializer): + component_content = super().map_serializer(serializer) + field_adapters = [] + + for field in serializer.fields.values(): + if isinstance(field, fields.SchemaField): + field_adapters.append((field.field_name, self.adapter_mode, field.adapter.type_adapter)) + + if field_adapters: + field_schemas = self._collect_type_adapter_schemas(field_adapters) + for field_name, field_schema in field_schemas.items(): + component_content["properties"][field_name] = field_schema + + return component_content + + def _collect_serializer_component(self, serializer: serializers.BaseSerializer | None, mode: JsonSchemaMode): + schema_definition = {} + if isinstance(serializer, serializers.Serializer): + self.adapter_mode = mode + component_name = self.get_component_name(serializer) + schema_definition[component_name] = self.map_serializer(serializer) + return schema_definition + + def _collect_adapter_components(self, components: Iterable[type[mixins.AnnotatedAdapterMixin]], context: dict): + type_adapters = [] + + for component in components: + schema_adapter = component().get_adapter(context) + if schema_adapter is not None: + type_adapters.append((repr(component), self.adapter_mode, schema_adapter.type_adapter)) + + if type_adapters: + return self._collect_type_adapter_schemas(type_adapters) + + return {} + + def _collect_type_adapter_schemas(self, adapters: Iterable[tuple[str, JsonSchemaMode, pydantic.TypeAdapter]]): + inner_schemas = {} + + schemas, common_schemas = pydantic.TypeAdapter.json_schemas(adapters, ref_template=self.REF_TEMPLATE_PREFIX) + for (field_name, _), field_schema in schemas.items(): + inner_schemas[field_name] = field_schema + + self.collected_schema_defs.update(common_schemas.get("$defs", {})) + return inner_schemas + + def _get_paginated_schema(self, schema) -> ty.Any: + response_schema = {"type": "array", "items": schema} + paginator = self.get_paginator() + if paginator: + response_schema = paginator.get_paginated_response_schema(response_schema) # type: ignore + return response_schema diff --git a/django_pydantic_field/v2/rest_framework/parsers.py b/django_pydantic_field/v2/rest_framework/parsers.py new file mode 100644 index 0000000..726ad85 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/parsers.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from rest_framework import exceptions, parsers + +from .. import types +from . import mixins, renderers + + +class SchemaParser(mixins.AnnotatedAdapterMixin[types.ST], parsers.JSONParser): + schema_context_key = "parser_schema" + config_context_key = "parser_config" + renderer_class = renderers.SchemaRenderer + + def parse(self, stream: ty.IO[bytes], media_type=None, parser_context=None): + parser_context = parser_context or {} + adapter = self.get_adapter(parser_context) + if adapter is None: + raise RuntimeError("Schema should be either explicitly set with annotation or passed in the context") + + try: + return adapter.validate_json(stream.read()) + except pydantic.ValidationError as exc: + raise exceptions.ParseError(exc.errors()) # type: ignore diff --git a/django_pydantic_field/v2/rest_framework/renderers.py b/django_pydantic_field/v2/rest_framework/renderers.py new file mode 100644 index 0000000..820aa54 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/renderers.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from rest_framework import renderers + +from .. import types +from . import mixins + +if ty.TYPE_CHECKING: + from collections.abc import Mapping + + RequestResponseContext = Mapping[str, ty.Any] + +__all__ = ("SchemaRenderer",) + + +class SchemaRenderer(mixins.AnnotatedAdapterMixin[types.ST], renderers.JSONRenderer): + schema_context_key = "renderer_schema" + config_context_key = "renderer_config" + + def render(self, data: ty.Any, accepted_media_type=None, renderer_context=None): + renderer_context = renderer_context or {} + response = renderer_context.get("response") + if response is not None and response.exception: + return super().render(data, accepted_media_type, renderer_context) + + adapter = self.get_adapter(renderer_context) + if adapter is None and isinstance(data, pydantic.BaseModel): + return self.render_pydantic_model(data, renderer_context) + if adapter is None: + raise RuntimeError("Schema should be either explicitly set with annotation or passed in the context") + + try: + prep_data = adapter.validate_python(data) + return adapter.dump_json(prep_data) + except pydantic.ValidationError as exc: + return exc.json(indent=True, include_input=True).encode() + + def render_pydantic_model(self, instance: pydantic.BaseModel, renderer_context: Mapping[str, ty.Any]): + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(renderer_context)) + export_kwargs.pop("strict", None) + export_kwargs.pop("from_attributes", None) + export_kwargs.pop("mode", None) + + json_dump = instance.model_dump_json(**export_kwargs) # type: ignore + return json_dump.encode() diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py new file mode 100644 index 0000000..7cfa372 --- /dev/null +++ b/django_pydantic_field/v2/types.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +import typing as ty +from collections import ChainMap + +import pydantic +import typing_extensions as te + +from django_pydantic_field.compat.django import GenericContainer +from django_pydantic_field.compat.functools import cached_property +from . import utils + +if ty.TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from django.db.models import Model + from pydantic.dataclasses import DataclassClassOrWrapper + from pydantic.type_adapter import IncEx + + ModelType = ty.Type[pydantic.BaseModel] + DjangoModelType = ty.Type[Model] + SchemaT = ty.Union[ + pydantic.BaseModel, + DataclassClassOrWrapper, + Sequence[ty.Any], + Mapping[str, ty.Any], + set[ty.Any], + frozenset[ty.Any], + ] + +ST = ty.TypeVar("ST", bound="SchemaT") + + +class ExportKwargs(te.TypedDict, total=False): + strict: bool + from_attributes: bool + mode: ty.Literal["json", "python"] + include: IncEx | None + exclude: IncEx | None + by_alias: bool + exclude_unset: bool + exclude_defaults: bool + exclude_none: bool + round_trip: bool + warnings: bool + + +class ImproperlyConfiguredSchema(ValueError): + """Raised when the schema is improperly configured.""" + + +class SchemaAdapter(ty.Generic[ST]): + def __init__( + self, + schema: ty.Any, + config: pydantic.ConfigDict | None, + parent_type: type | None, + attname: str | None, + allow_null: bool | None = None, + **export_kwargs: ty.Unpack[ExportKwargs], + ): + self.schema = GenericContainer.unwrap(schema) + self.config = config + self.parent_type = parent_type + self.attname = attname + self.allow_null = allow_null + self.export_kwargs = export_kwargs + + @classmethod + def from_type( + cls, + schema: ty.Any, + config: pydantic.ConfigDict | None = None, + **kwargs: ty.Unpack[ExportKwargs], + ) -> SchemaAdapter[ST]: + """Create an adapter from a type.""" + return cls(schema, config, None, None, **kwargs) + + @classmethod + def from_annotation( + cls, + parent_type: type, + attname: str, + config: pydantic.ConfigDict | None = None, + **kwargs: ty.Unpack[ExportKwargs], + ) -> SchemaAdapter[ST]: + """Create an adapter from a type annotation.""" + return cls(None, config, parent_type, attname, **kwargs) + + @staticmethod + def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: + """Extract the export kwargs from the kwargs passed to the field. + This method mutates passed kwargs by removing those that are used by the adapter.""" + common_keys = kwargs.keys() & ExportKwargs.__annotations__.keys() + export_kwargs = {key: kwargs.pop(key) for key in common_keys} + return ty.cast(ExportKwargs, export_kwargs) + + @cached_property + def type_adapter(self) -> pydantic.TypeAdapter: + return pydantic.TypeAdapter(self.prepared_schema, config=self.config) # type: ignore + + @property + def is_bound(self) -> bool: + """Return True if the adapter is bound to a specific attribute of a `parent_type`.""" + return self.parent_type is not None and self.attname is not None + + def bind(self, parent_type: type | None, attname: str | None) -> te.Self: + """Bind the adapter to specific attribute of a `parent_type`.""" + self.parent_type = parent_type + self.attname = attname + self.__dict__.pop("prepared_schema", None) + self.__dict__.pop("type_adapter", None) + return self + + def validate_schema(self) -> None: + """Validate the schema and raise an exception if it is invalid.""" + try: + self._prepare_schema() + except Exception as exc: + if not isinstance(exc, ImproperlyConfiguredSchema): + raise ImproperlyConfiguredSchema(*exc.args) from exc + raise + + def validate_python(self, value: ty.Any, *, strict: bool | None = None, from_attributes: bool | None = None) -> ST: + """Validate the value and raise an exception if it is invalid.""" + if strict is None: + strict = self.export_kwargs.get("strict", None) + if from_attributes is None: + from_attributes = self.export_kwargs.get("from_attributes", None) + return self.type_adapter.validate_python(value, strict=strict, from_attributes=from_attributes) + + def validate_json(self, value: str | bytes, *, strict: bool | None = None) -> ST: + if strict is None: + strict = self.export_kwargs.get("strict", None) + return self.type_adapter.validate_json(value, strict=strict) + + def dump_python(self, value: ty.Any, **override_kwargs: ty.Unpack[ExportKwargs]) -> ty.Any: + """Dump the value to a Python object.""" + union_kwargs = ChainMap(override_kwargs, self._dump_python_kwargs) # type: ignore + return self.type_adapter.dump_python(value, **union_kwargs) + + def dump_json(self, value: ty.Any, **override_kwargs: ty.Unpack[ExportKwargs]) -> bytes: + union_kwargs = ChainMap(override_kwargs, self._dump_python_kwargs) # type: ignore + return self.type_adapter.dump_json(value, **union_kwargs) + + def json_schema(self) -> dict[str, ty.Any]: + """Return the JSON schema for the field.""" + by_alias = self.export_kwargs.get("by_alias", True) + return self.type_adapter.json_schema(by_alias=by_alias) + + def _prepare_schema(self) -> type[ST]: + """Prepare the schema for the adapter. + + This method is called by `prepared_schema` property and should not be called directly. + The intent is to resolve the real schema from an annotations or a forward references. + """ + schema = self.schema + + if schema is None and self.is_bound: + schema = self._guess_schema_from_annotations() + if isinstance(schema, str): + schema = ty.ForwardRef(schema) + + schema = self._resolve_schema_forward_ref(schema) + if schema is None: + if self.is_bound: + error_msg = f"Annotation is not provided for {self.parent_type.__name__}.{self.attname}" # type: ignore[union-attr] + else: + error_msg = "Cannot resolve the schema. The adapter is accessed before it was bound." + raise ImproperlyConfiguredSchema(error_msg) + + if self.allow_null: + schema = ty.Optional[schema] # type: ignore + + return ty.cast(ty.Type[ST], schema) + + prepared_schema = cached_property(_prepare_schema) + + def __copy__(self): + instance = self.__class__( + self.schema, + self.config, + self.parent_type, + self.attname, + self.allow_null, + **self.export_kwargs, + ) + instance.__dict__.update(self.__dict__) + return instance + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(bound={self.is_bound}, schema={self.schema!r}, config={self.config!r})" + + def __eq__(self, other: ty.Any) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + + self_fields: list[ty.Any] = [self.attname, self.export_kwargs] + other_fields: list[ty.Any] = [other.attname, other.export_kwargs] + try: + self_fields.append(self.prepared_schema) + other_fields.append(other.prepared_schema) + except ImproperlyConfiguredSchema: + if self.is_bound and other.is_bound: + return False + else: + self_fields.extend((self.schema, self.config, self.allow_null)) + other_fields.extend((other.schema, other.config, other.allow_null)) + + return self_fields == other_fields + + def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | None: + return utils.get_annotated_type(self.parent_type, self.attname) + + def _resolve_schema_forward_ref(self, schema: ty.Any) -> ty.Any: + if schema is None: + return None + + if isinstance(schema, ty.ForwardRef): + globalns = utils.get_namespace(self.parent_type) + return utils.evaluate_forward_ref(schema, globalns) + + wrapped_schema = GenericContainer.wrap(schema) + if not isinstance(wrapped_schema, GenericContainer): + return schema + + origin = self._resolve_schema_forward_ref(wrapped_schema.origin) + args = map(self._resolve_schema_forward_ref, wrapped_schema.args) + return GenericContainer.unwrap(GenericContainer(origin, tuple(args))) + + @cached_property + def _dump_python_kwargs(self) -> dict[str, ty.Any]: + export_kwargs = self.export_kwargs.copy() + export_kwargs.pop("strict", None) + export_kwargs.pop("from_attributes", None) + return ty.cast(dict, export_kwargs) diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py new file mode 100644 index 0000000..0a724b6 --- /dev/null +++ b/django_pydantic_field/v2/utils.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import sys +import typing as ty +from collections import ChainMap + +from django_pydantic_field.compat import typing + +if ty.TYPE_CHECKING: + from collections.abc import Mapping + + +def get_annotated_type(obj, field, default=None) -> ty.Any: + try: + if isinstance(obj, type): + annotations = obj.__dict__["__annotations__"] + else: + annotations = obj.__annotations__ + + return annotations[field] + except (AttributeError, KeyError): + return default + + +def get_namespace(cls) -> ChainMap[str, ty.Any]: + return ChainMap(get_local_namespace(cls), get_global_namespace(cls)) + + +def get_global_namespace(cls) -> dict[str, ty.Any]: + try: + module = cls.__module__ + return vars(sys.modules[module]) + except (KeyError, AttributeError): + return {} + + +def get_local_namespace(cls) -> dict[str, ty.Any]: + try: + return vars(cls) + except TypeError: + return {} + + +def get_origin_type(cls: type): + origin_tp = typing.get_origin(cls) + if origin_tp is not None: + return origin_tp + return cls + + +if sys.version_info >= (3, 9): + + def evaluate_forward_ref(ref: ty.ForwardRef, ns: Mapping[str, ty.Any]) -> ty.Any: + return ref._evaluate(dict(ns), {}, frozenset()) + +else: + + def evaluate_forward_ref(ref: ty.ForwardRef, ns: Mapping[str, ty.Any]) -> ty.Any: + return ref._evaluate(dict(ns), {}) diff --git a/pyproject.toml b/pyproject.toml index c24c106..123cf2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "django-pydantic-field" -version = "0.2.13" +version = "0.3.0-beta1" description = "Django JSONField with Pydantic models as a Schema" readme = "README.md" license = { file = "LICENSE" } @@ -27,6 +27,7 @@ classifiers = [ "Framework :: Django :: 5.0", "Framework :: Pydantic", "Framework :: Pydantic :: 1", + "Framework :: Pydantic :: 2", "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Programming Language :: Python :: 3", @@ -41,26 +42,31 @@ classifiers = [ requires-python = ">=3.7" dependencies = [ - "pydantic>=1.9,<2", + "pydantic>=1.10,<3", "django>=3.1,<6", "typing_extensions", ] [project.optional-dependencies] +openapi = ["uritemplate"] +coreapi = ["coreapi"] dev = [ + "build", "black", "isort", "mypy", - "pytest==7.0.*", + "pytest~=7.4", "djangorestframework>=3.11,<4", - "django-stubs[compatible-mypy]~=1.12.0", - "djangorestframework-stubs[compatible-mypy]~=1.7.0", + "django-stubs[compatible-mypy]~=4.2", + "djangorestframework-stubs[compatible-mypy]~=3.14", "pytest-django>=4.5,<5", ] test = [ + "django_pydantic_field[openapi,coreapi]", "dj-database-url~=2.0", "djangorestframework>=3,<4", "pyyaml", + "syrupy>=3,<5", ] ci = [ 'psycopg[binary]>=3.1,<4; python_version>="3.9"', @@ -74,12 +80,38 @@ Documentation = "https://github.com/surenkov/django-pydantic-field" Source = "https://github.com/surenkov/django-pydantic-field" Changelog = "https://github.com/surenkov/django-pydantic-field/releases" +[tool.isort] +py_version = 311 +profile = "black" +line_length = 120 +multi_line_output = 3 +include_trailing_comma = true +force_alphabetical_sort_within_sections = true +force_grid_wrap = 0 +use_parentheses = true +skip_glob = [ + "*/migrations/*", +] + +[tool.black] +target-version = ["py38", "py39", "py310", "py311", "py312"] +line-length = 120 +exclude = ''' +/( + \.pytest_cache + | \.venv + | venv + | migrations +)/ +''' + [tool.mypy] plugins = [ "mypy_django_plugin.main", "mypy_drf_plugin.main" ] exclude = [".env", "tests"] +enable_incomplete_feature = ["Unpack"] [tool.django-stubs] django_settings_module = "tests.settings.django_test_settings" diff --git a/tests/conftest.py b/tests/conftest.py index 03c84b5..be81a87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import typing as t from datetime import date +from syrupy.extensions.json import JSONSnapshotExtension import pydantic import pytest @@ -70,3 +71,8 @@ def mysql_backend(settings): ) def available_database_backends(request, settings): yield request.param(settings) + + +@pytest.fixture +def snapshot_json(snapshot): + return snapshot.use_extension(JSONSnapshotExtension) diff --git a/tests/sample_app/migrations/0001_initial.py b/tests/sample_app/migrations/0001_initial.py index 67e4a7e..78db57b 100644 --- a/tests/sample_app/migrations/0001_initial.py +++ b/tests/sample_app/migrations/0001_initial.py @@ -1,11 +1,12 @@ -# Generated by Django 4.2.2 on 2023-06-19 12:36 -import typing +# Generated by Django 3.2.23 on 2023-11-21 14:37 -import django_pydantic_field._migration_serializers +import django.core.serializers.json +from django.db import migrations, models +import django_pydantic_field.compat.django import django_pydantic_field.fields import tests.sample_app.models +import typing import typing_extensions -from django.db import migrations, models class Migration(migrations.Migration): @@ -17,26 +18,19 @@ class Migration(migrations.Migration): migrations.CreateModel( name="Building", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "opt_meta", django_pydantic_field.fields.PydanticSchemaField( config=None, - default={"type": "frame"}, + default={"buildingType": "frame"}, + encoder=django.core.serializers.json.DjangoJSONEncoder, exclude={"type"}, null=True, - schema=django_pydantic_field._migration_serializers.GenericContainer( + schema=django_pydantic_field.compat.django.GenericContainer( typing.Union, ( - typing.ForwardRef("BuildingMeta"), + tests.sample_app.models.BuildingMeta, type(None), ), ), @@ -45,10 +39,12 @@ class Migration(migrations.Migration): ( "meta", django_pydantic_field.fields.PydanticSchemaField( + by_alias=True, config=None, - default={"type": "frame"}, + default={"buildingType": "frame"}, + encoder=django.core.serializers.json.DjangoJSONEncoder, include={"type"}, - schema="BuildingMeta", + schema=tests.sample_app.models.BuildingMeta, ), ), ( @@ -56,7 +52,10 @@ class Migration(migrations.Migration): django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, - schema=typing.ForwardRef("t.List[BuildingMeta]"), + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + list, (tests.sample_app.models.BuildingMeta,) + ), ), ), ( @@ -64,21 +63,22 @@ class Migration(migrations.Migration): django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, - schema=django_pydantic_field._migration_serializers.GenericContainer( - list, (typing.ForwardRef("BuildingMeta"),) + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + list, (tests.sample_app.models.BuildingMeta,) ), ), ), ( "meta_untyped_list", django_pydantic_field.fields.PydanticSchemaField( - config=None, default=list, schema=list + config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "meta_untyped_builtin_list", django_pydantic_field.fields.PydanticSchemaField( - config=None, default=list, schema=list + config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ], @@ -86,22 +86,15 @@ class Migration(migrations.Migration): migrations.CreateModel( name="PostponedBuilding", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "meta", django_pydantic_field.fields.PydanticSchemaField( by_alias=True, config=None, - default={"buildingType": None}, - schema="BuildingMeta", + default={"buildingType": tests.sample_app.models.BuildingTypes["FRAME"]}, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.sample_app.models.BuildingMeta, ), ), ( @@ -109,7 +102,8 @@ class Migration(migrations.Migration): django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, - schema=django_pydantic_field._migration_serializers.GenericContainer( + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( list, (tests.sample_app.models.BuildingMeta,) ), ), @@ -119,39 +113,41 @@ class Migration(migrations.Migration): django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, - schema=django_pydantic_field._migration_serializers.GenericContainer( - typing.List, (typing.ForwardRef("BuildingMeta"),) + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + list, (tests.sample_app.models.BuildingMeta,) ), ), ), ( "meta_untyped_list", django_pydantic_field.fields.PydanticSchemaField( - config=None, default=list, schema=list + config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "meta_untyped_builtin_list", django_pydantic_field.fields.PydanticSchemaField( - config=None, default=list, schema=list + config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "nested_generics", django_pydantic_field.fields.PydanticSchemaField( config=None, - schema=django_pydantic_field._migration_serializers.GenericContainer( + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( typing.Union, ( - django_pydantic_field._migration_serializers.GenericContainer( - typing.List, + django_pydantic_field.compat.django.GenericContainer( + list, ( - django_pydantic_field._migration_serializers.GenericContainer( + django_pydantic_field.compat.django.GenericContainer( typing_extensions.Literal, ("foo",) ), ), ), - django_pydantic_field._migration_serializers.GenericContainer( + django_pydantic_field.compat.django.GenericContainer( typing_extensions.Literal, ("bar",) ), ), diff --git a/tests/sample_app/models.py b/tests/sample_app/models.py index f8bb8a0..be3c465 100644 --- a/tests/sample_app/models.py +++ b/tests/sample_app/models.py @@ -14,8 +14,8 @@ class BuildingTypes(str, enum.Enum): class Building(models.Model): - opt_meta: t.Optional["BuildingMeta"] = SchemaField(default={"type": "frame"}, exclude={"type"}, null=True) - meta: "BuildingMeta" = SchemaField(default={"type": "frame"}, include={"type"}) + opt_meta: t.Optional["BuildingMeta"] = SchemaField(default={"buildingType": "frame"}, exclude={"type"}, null=True) + meta: "BuildingMeta" = SchemaField(default={"buildingType": "frame"}, include={"type"}, by_alias=True) meta_schema_list = SchemaField(schema=t.ForwardRef("t.List[BuildingMeta]"), default=list) meta_typing_list: t.List["BuildingMeta"] = SchemaField(default=list) @@ -28,7 +28,7 @@ class BuildingMeta(pydantic.BaseModel): class PostponedBuilding(models.Model): - meta: "BuildingMeta" = SchemaField(default=BuildingMeta(type=BuildingTypes.FRAME), by_alias=True) + meta: "BuildingMeta" = SchemaField(default=BuildingMeta(buildingType=BuildingTypes.FRAME), by_alias=True) meta_builtin_list: t.List[BuildingMeta] = SchemaField(schema=t.List[BuildingMeta], default=list) meta_typing_list: t.List["BuildingMeta"] = SchemaField(default=list) meta_untyped_list: list = SchemaField(schema=t.List, default=list) diff --git a/tests/test_app/migrations/0001_initial.py b/tests/test_app/migrations/0001_initial.py index 00a6f1c..80a3436 100644 --- a/tests/test_app/migrations/0001_initial.py +++ b/tests/test_app/migrations/0001_initial.py @@ -1,10 +1,11 @@ -# Generated by Django 4.1.7 on 2023-02-27 14:41 +# Generated by Django 3.2.23 on 2023-11-21 14:37 +import django.core.serializers.json from django.db import migrations, models -import django_pydantic_field._migration_serializers +import django_pydantic_field.compat.django import django_pydantic_field.fields import tests.conftest -import typing +import tests.test_app.models class Migration(migrations.Migration): @@ -16,19 +17,14 @@ class Migration(migrations.Migration): migrations.CreateModel( name="SampleForwardRefModel", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "annotated_field", django_pydantic_field.fields.PydanticSchemaField( - config=None, default=dict, schema="SampleSchema" + config=None, + default=dict, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.test_app.models.SampleSchema, ), ), ( @@ -36,7 +32,8 @@ class Migration(migrations.Migration): django_pydantic_field.fields.PydanticSchemaField( config=None, default=dict, - schema=typing.ForwardRef("SampleSchema"), + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.test_app.models.SampleSchema, ), ), ], @@ -44,19 +41,12 @@ class Migration(migrations.Migration): migrations.CreateModel( name="SampleModel", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "sample_field", django_pydantic_field.fields.PydanticSchemaField( - config={"allow_mutation": False, "frozen": True}, + config=None, + encoder=django.core.serializers.json.DjangoJSONEncoder, schema=tests.conftest.InnerSchema, ), ), @@ -64,7 +54,8 @@ class Migration(migrations.Migration): "sample_list", django_pydantic_field.fields.PydanticSchemaField( config=None, - schema=django_pydantic_field._migration_serializers.GenericContainer( + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( list, (tests.conftest.InnerSchema,) ), ), @@ -74,7 +65,8 @@ class Migration(migrations.Migration): django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, - schema=django_pydantic_field._migration_serializers.GenericContainer( + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( list, (tests.conftest.InnerSchema,) ), ), diff --git a/tests/test_app/models.py b/tests/test_app/models.py index e5d75da..62bd340 100644 --- a/tests/test_app/models.py +++ b/tests/test_app/models.py @@ -7,8 +7,12 @@ from ..conftest import InnerSchema +class FrozenInnerSchema(InnerSchema): + model_config = pydantic.ConfigDict({"frozen": True}) + + class SampleModel(models.Model): - sample_field: InnerSchema = SchemaField(config={"frozen": True, "allow_mutation": False}) + sample_field: InnerSchema = SchemaField() sample_list: t.List[InnerSchema] = SchemaField() sample_seq: t.Sequence[InnerSchema] = SchemaField(schema=t.List[InnerSchema], default=list) diff --git a/tests/test_e2e_models.py b/tests/test_e2e_models.py index ce5c5e7..c48cf73 100644 --- a/tests/test_e2e_models.py +++ b/tests/test_e2e_models.py @@ -3,8 +3,8 @@ import pytest from django.db.models import F, Q, JSONField, Value -from .conftest import InnerSchema -from .test_app.models import SampleModel +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel pytestmark = [ pytest.mark.usefixtures("available_database_backends"), @@ -17,15 +17,11 @@ [ ( { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, ), @@ -35,9 +31,7 @@ "sample_list": [{"stub_str": "abc", "stub_list": []}], }, { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, ), diff --git a/tests/test_django_model_field.py b/tests/test_fields.py similarity index 55% rename from tests/test_django_model_field.py rename to tests/test_fields.py index e754c2b..621e9aa 100644 --- a/tests/test_django_model_field.py +++ b/tests/test_fields.py @@ -1,18 +1,20 @@ import json import sys -import typing as t +import typing as ty from collections import abc from copy import copy from datetime import date -import django +import pydantic import pytest -from django.core.exceptions import FieldError, ValidationError -from django.db import models +from django.core.exceptions import ValidationError +from django.db import connection, models from django.db.migrations.writer import MigrationWriter + from django_pydantic_field import fields +from django_pydantic_field.compat.pydantic import PYDANTIC_V1, PYDANTIC_V2 -from .conftest import InnerSchema, SampleDataclass +from .conftest import InnerSchema, SampleDataclass # noqa from .sample_app.models import Building from .test_app.models import SampleForwardRefModel, SampleModel, SampleSchema @@ -22,10 +24,9 @@ def test_sample_field(): existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) expected_encoded = {"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]} - if django.VERSION[:2] < (4, 2): - expected_encoded = json.dumps(expected_encoded) + expected_prepared = json.dumps(expected_encoded) - assert sample_field.get_prep_value(existing_instance) == expected_encoded + assert sample_field.get_db_prep_value(existing_instance, connection) == expected_prepared assert sample_field.to_python(expected_encoded) == existing_instance @@ -34,56 +35,21 @@ def test_sample_field_with_raw_data(): existing_raw = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} expected_encoded = {"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]} - if django.VERSION[:2] < (4, 2): - expected_encoded = json.dumps(expected_encoded) + expected_prepared = json.dumps(expected_encoded) - assert sample_field.get_prep_value(existing_raw) == expected_encoded + assert sample_field.get_db_prep_value(existing_raw, connection) == expected_prepared assert sample_field.to_python(expected_encoded) == InnerSchema(**existing_raw) -def test_simple_model_field(): - sample_field = SampleModel._meta.get_field("sample_field") - assert sample_field.schema == InnerSchema - - sample_list_field = SampleModel._meta.get_field("sample_list") - assert sample_list_field.schema == t.List[InnerSchema] - - sample_seq_field = SampleModel._meta.get_field("sample_seq") - assert sample_seq_field.schema == t.List[InnerSchema] - - existing_raw_field = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} - existing_raw_list = [{"stub_str": "abc", "stub_list": []}] - - instance = SampleModel( - sample_field=existing_raw_field, sample_list=existing_raw_list - ) - - expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - expected_list = [InnerSchema(stub_str="abc", stub_list=[])] - - assert instance.sample_field == expected_instance - assert instance.sample_list == expected_list - - def test_null_field(): field = fields.SchemaField(InnerSchema, null=True, default=None) assert field.to_python(None) is None assert field.get_prep_value(None) is None - field = fields.SchemaField(t.Optional[InnerSchema], null=True, default=None) + field = fields.SchemaField(ty.Optional[InnerSchema], null=True, default=None) assert field.get_prep_value(None) is None -def test_untyped_model_field_raises(): - with pytest.raises(FieldError): - - class UntypedModel(models.Model): - sample_field = fields.SchemaField() - - class Meta: - app_label = "test_app" - - def test_forwardrefs_deferred_resolution(): obj = SampleForwardRefModel(field={}, annotated_field={}) assert isinstance(obj.field, SampleSchema) @@ -91,7 +57,12 @@ def test_forwardrefs_deferred_resolution(): @pytest.mark.parametrize( - "forward_ref", ["InnerSchema", t.ForwardRef("SampleDataclass"), t.List["int"]] + "forward_ref", + [ + "InnerSchema", + ty.ForwardRef("SampleDataclass"), + ty.List["int"], + ], ) def test_resolved_forwardrefs(forward_ref): class ModelWithForwardRefs(models.Model): @@ -108,10 +79,6 @@ class Meta: schema=InnerSchema, default=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), ), - fields.PydanticSchemaField( - schema=InnerSchema, - default=(("stub_str", "abc"), ("stub_list", [date(2022, 7, 1)])), - ), fields.PydanticSchemaField( schema=InnerSchema, default={"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}, @@ -121,8 +88,17 @@ class Meta: schema=SampleDataclass, default={"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}, ), - fields.PydanticSchemaField( - schema=t.Optional[InnerSchema], null=True, default=None + fields.PydanticSchemaField(schema=ty.Optional[InnerSchema], null=True, default=None), + pytest.param( + fields.PydanticSchemaField( + schema=InnerSchema, + default=(("stub_str", "abc"), ("stub_list", [date(2022, 7, 1)])), + ), + marks=pytest.mark.xfail( + PYDANTIC_V2, + reason="Tuple-based default reconstruction is not supported with Pydantic 2", + raises=pydantic.ValidationError, + ), ), ], ) @@ -130,46 +106,34 @@ def test_field_serialization(field): _test_field_serialization(field) -@pytest.mark.skipif( - sys.version_info < (3, 9), reason="Built-in type subscription supports only in 3.9+" -) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Built-in type subscription supports only in 3.9+") @pytest.mark.parametrize( "field_factory", [ lambda: fields.PydanticSchemaField(schema=list[InnerSchema], default=list), - lambda: fields.PydanticSchemaField(schema=dict[str, InnerSchema], default=list), - lambda: fields.PydanticSchemaField( - schema=abc.Sequence[InnerSchema], default=list - ), - lambda: fields.PydanticSchemaField( - schema=abc.Mapping[str, InnerSchema], default=dict - ), + lambda: fields.PydanticSchemaField(schema=dict[str, InnerSchema], default=dict), + lambda: fields.PydanticSchemaField(schema=abc.Sequence[InnerSchema], default=list), + lambda: fields.PydanticSchemaField(schema=abc.Mapping[str, InnerSchema], default=dict), ], ) def test_field_builtin_annotations_serialization(field_factory): _test_field_serialization(field_factory()) -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="Union type syntax supported only in 3.10+" -) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Union type syntax supported only in 3.10+") def test_field_union_type_serialization(): - field = fields.PydanticSchemaField( - schema=(InnerSchema | None), null=True, default=None - ) + field = fields.PydanticSchemaField(schema=(InnerSchema | None), null=True, default=None) _test_field_serialization(field) -@pytest.mark.skipif( - sys.version_info >= (3, 9), reason="Should test against builtin generic types" -) +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Should test against builtin generic types") @pytest.mark.parametrize( "field", [ - fields.PydanticSchemaField(schema=t.List[InnerSchema], default=list), - fields.PydanticSchemaField(schema=t.Dict[str, InnerSchema], default=list), - fields.PydanticSchemaField(schema=t.Sequence[InnerSchema], default=list), - fields.PydanticSchemaField(schema=t.Mapping[str, InnerSchema], default=dict), + fields.PydanticSchemaField(schema=ty.List[InnerSchema], default=list), + fields.PydanticSchemaField(schema=ty.Dict[str, InnerSchema], default=dict), + fields.PydanticSchemaField(schema=ty.Sequence[InnerSchema], default=list), + fields.PydanticSchemaField(schema=ty.Mapping[str, InnerSchema], default=dict), ], ) def test_field_typing_annotations_serialization(field): @@ -184,42 +148,24 @@ def test_field_typing_annotations_serialization(field): "old_field, new_field", [ ( - lambda: fields.PydanticSchemaField( - schema=t.List[InnerSchema], default=list - ), + lambda: fields.PydanticSchemaField(schema=ty.List[InnerSchema], default=list), lambda: fields.PydanticSchemaField(schema=list[InnerSchema], default=list), ), ( - lambda: fields.PydanticSchemaField( - schema=t.Dict[str, InnerSchema], default=list - ), - lambda: fields.PydanticSchemaField( - schema=dict[str, InnerSchema], default=list - ), + lambda: fields.PydanticSchemaField(schema=ty.Dict[str, InnerSchema], default=dict), + lambda: fields.PydanticSchemaField(schema=dict[str, InnerSchema], default=dict), ), ( - lambda: fields.PydanticSchemaField( - schema=t.Sequence[InnerSchema], default=list - ), - lambda: fields.PydanticSchemaField( - schema=abc.Sequence[InnerSchema], default=list - ), + lambda: fields.PydanticSchemaField(schema=ty.Sequence[InnerSchema], default=list), + lambda: fields.PydanticSchemaField(schema=abc.Sequence[InnerSchema], default=list), ), ( - lambda: fields.PydanticSchemaField( - schema=t.Mapping[str, InnerSchema], default=dict - ), - lambda: fields.PydanticSchemaField( - schema=abc.Mapping[str, InnerSchema], default=dict - ), + lambda: fields.PydanticSchemaField(schema=ty.Mapping[str, InnerSchema], default=dict), + lambda: fields.PydanticSchemaField(schema=abc.Mapping[str, InnerSchema], default=dict), ), ( - lambda: fields.PydanticSchemaField( - schema=t.Mapping[str, InnerSchema], default=dict - ), - lambda: fields.PydanticSchemaField( - schema=abc.Mapping[str, InnerSchema], default=dict - ), + lambda: fields.PydanticSchemaField(schema=ty.Mapping[str, InnerSchema], default=dict), + lambda: fields.PydanticSchemaField(schema=abc.Mapping[str, InnerSchema], default=dict), ), ], ) @@ -229,19 +175,11 @@ def test_field_typing_to_builtin_serialization(old_field, new_field): _, _, args, kwargs = old_field.deconstruct() reconstructed_field = fields.PydanticSchemaField(*args, **kwargs) - assert ( - old_field.get_default() - == new_field.get_default() - == reconstructed_field.get_default() - ) + assert old_field.get_default() == new_field.get_default() == reconstructed_field.get_default() assert new_field.schema == reconstructed_field.schema deserialized_field = reconstruct_field(serialize_field(old_field)) - assert ( - old_field.get_default() - == deserialized_field.get_default() - == new_field.get_default() - ) + assert old_field.get_default() == deserialized_field.get_default() == new_field.get_default() assert new_field.schema == deserialized_field.schema @@ -249,8 +187,8 @@ def test_field_typing_to_builtin_serialization(old_field, new_field): "field, flawed_data", [ (fields.PydanticSchemaField(schema=InnerSchema), {}), - (fields.PydanticSchemaField(schema=t.List[InnerSchema]), [{}]), - (fields.PydanticSchemaField(schema=t.Dict[int, float]), {"1": "abc"}), + (fields.PydanticSchemaField(schema=ty.List[InnerSchema]), [{}]), + (fields.PydanticSchemaField(schema=ty.Dict[int, float]), {"1": "abc"}), ], ) def test_field_validation_exceptions(field, flawed_data): @@ -297,15 +235,27 @@ def test_export_kwargs_support(export_kwargs): def _test_field_serialization(field): - _, _, args, kwargs = field.deconstruct() + _, _, args, kwargs = field_data = field.deconstruct() reconstructed_field = fields.PydanticSchemaField(*args, **kwargs) assert field.get_default() == reconstructed_field.get_default() - assert field.schema == reconstructed_field.schema + + if PYDANTIC_V2: + assert reconstructed_field.deconstruct() == field_data + elif PYDANTIC_V1: + assert reconstructed_field.schema == field.schema + else: + pytest.fail("Unsupported Pydantic version") deserialized_field = reconstruct_field(serialize_field(field)) assert deserialized_field.get_default() == field.get_default() - assert field.schema == deserialized_field.schema + + if PYDANTIC_V2: + assert deserialized_field.deconstruct() == field_data + elif PYDANTIC_V1: + assert deserialized_field.schema == field.schema + else: + pytest.fail("Unsupported Pydantic version") def serialize_field(field: fields.PydanticSchemaField) -> str: diff --git a/tests/test_migration_serializers.py b/tests/test_migration_serializers.py index 9a6a802..c67b122 100644 --- a/tests/test_migration_serializers.py +++ b/tests/test_migration_serializers.py @@ -6,7 +6,10 @@ import pytest import django_pydantic_field -from django_pydantic_field._migration_serializers import GenericContainer +try: + from django_pydantic_field.compat.django import GenericContainer +except ImportError: + from django_pydantic_field._migration_serializers import GenericContainer if sys.version_info < (3, 9): test_types = [ diff --git a/tests/v1/__init__.py b/tests/v1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_base_marshalling.py b/tests/v1/test_base.py similarity index 67% rename from tests/test_base_marshalling.py rename to tests/v1/test_base.py index 011b113..8637e30 100644 --- a/tests/test_base_marshalling.py +++ b/tests/v1/test_base.py @@ -5,9 +5,10 @@ import pydantic import pytest -from django_pydantic_field import base -from .conftest import InnerSchema, SampleDataclass +from tests.conftest import InnerSchema, SampleDataclass + +base = pytest.importorskip("django_pydantic_field.v1.base") class SampleSchema(pydantic.BaseModel): @@ -66,7 +67,7 @@ def test_schema_wrapper_transformers(): assert parsed_wrapper.__root__ == [expected_decoded] -class test_schema_wrapper_config_inheritance(): +def test_schema_wrapper_config_inheritance(): parsed_wrapper = base.wrap_schema(InnerSchema, config={"allow_mutation": False}) assert not parsed_wrapper.Config.allow_mutation assert not parsed_wrapper.Config.frozen @@ -76,13 +77,24 @@ class test_schema_wrapper_config_inheritance(): assert parsed_wrapper.Config.frozen -@pytest.mark.parametrize("type_, encoded, decoded", [ - (InnerSchema, '{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}', InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])), - (SampleDataclass, '{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}', SampleDataclass(stub_str="abc", stub_list=[date(2022, 7, 1)])), - (t.List[int], '[1, 2, 3]', [1, 2, 3]), - (t.Mapping[int, date], '{"1": "1970-01-01"}', {1: date(1970, 1, 1)}), - (t.Set[UUID], '["ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c"]', {UUID("ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c")}), -]) +@pytest.mark.parametrize( + "type_, encoded, decoded", + [ + ( + InnerSchema, + '{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}', + InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + ), + ( + SampleDataclass, + '{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}', + SampleDataclass(stub_str="abc", stub_list=[date(2022, 7, 1)]), + ), + (t.List[int], "[1, 2, 3]", [1, 2, 3]), + (t.Mapping[int, date], '{"1": "1970-01-01"}', {1: date(1970, 1, 1)}), + (t.Set[UUID], '["ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c"]', {UUID("ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c")}), + ], +) def test_concrete_types(type_, encoded, decoded): schema = base.wrap_schema(type_) encoder = base.SchemaEncoder(schema=schema) @@ -96,13 +108,20 @@ def test_concrete_types(type_, encoded, decoded): @pytest.mark.skipif(sys.version_info < (3, 9), reason="Should test against builtin generic types") -@pytest.mark.parametrize("type_factory, encoded, decoded", [ - (lambda: list[InnerSchema], '[{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}]', [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]), - (lambda: list[SampleDataclass], '[{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}]', [SampleDataclass(stub_str="abc", stub_list=[date(2022, 7, 1)])]), # type: ignore - (lambda: list[int], '[1, 2, 3]', [1, 2, 3]), - (lambda: dict[int, date], '{"1": "1970-01-01"}', {1: date(1970, 1, 1)}), - (lambda: set[UUID], '["ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c"]', {UUID("ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c")}), -]) +@pytest.mark.parametrize( + "type_factory, encoded, decoded", + [ + ( + lambda: list[InnerSchema], + '[{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}]', + [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])], + ), + (lambda: list[SampleDataclass], '[{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}]', [SampleDataclass(stub_str="abc", stub_list=[date(2022, 7, 1)])]), # type: ignore + (lambda: list[int], "[1, 2, 3]", [1, 2, 3]), + (lambda: dict[int, date], '{"1": "1970-01-01"}', {1: date(1970, 1, 1)}), + (lambda: set[UUID], '["ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c"]', {UUID("ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c")}), + ], +) def test_concrete_raw_types(type_factory, encoded, decoded): type_ = type_factory() @@ -117,11 +136,14 @@ def test_concrete_raw_types(type_factory, encoded, decoded): assert decoder.decode(existing_encoded) == decoded -@pytest.mark.parametrize("forward_ref, sample_data", [ - (t.ForwardRef("t.List[int]"), '[1, 2]'), - (t.ForwardRef("InnerSchema"), '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}'), - (t.ForwardRef("PostponedSchema"), '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}'), -]) +@pytest.mark.parametrize( + "forward_ref, sample_data", + [ + (t.ForwardRef("t.List[int]"), "[1, 2]"), + (t.ForwardRef("InnerSchema"), '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}'), + (t.ForwardRef("PostponedSchema"), '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}'), + ], +) def test_forward_refs_preparation(forward_ref, sample_data): schema = base.wrap_schema(forward_ref) base.prepare_schema(schema, test_forward_refs_preparation) diff --git a/tests/v1/test_fields.py b/tests/v1/test_fields.py new file mode 100644 index 0000000..a2fdd9d --- /dev/null +++ b/tests/v1/test_fields.py @@ -0,0 +1,43 @@ +import pytest +import typing as t +from datetime import date + +from django.core.exceptions import FieldError +from django.db import models + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +fields = pytest.importorskip("django_pydantic_field.v1.fields") + + +def test_simple_model_field(): + sample_field = SampleModel._meta.get_field("sample_field") + assert sample_field.schema == InnerSchema + + sample_list_field = SampleModel._meta.get_field("sample_list") + assert sample_list_field.schema == t.List[InnerSchema] + + sample_seq_field = SampleModel._meta.get_field("sample_seq") + assert sample_seq_field.schema == t.List[InnerSchema] + + existing_raw_field = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + existing_raw_list = [{"stub_str": "abc", "stub_list": []}] + + instance = SampleModel(sample_field=existing_raw_field, sample_list=existing_raw_list) + + expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_list = [InnerSchema(stub_str="abc", stub_list=[])] + + assert instance.sample_field == expected_instance + assert instance.sample_list == expected_list + + +def test_untyped_model_field_raises(): + with pytest.raises(FieldError): + + class UntypedModel(models.Model): + sample_field = fields.SchemaField() + + class Meta: + app_label = "test_app" diff --git a/tests/test_form_field.py b/tests/v1/test_forms.py similarity index 93% rename from tests/test_form_field.py rename to tests/v1/test_forms.py index cadcc6a..a816d25 100644 --- a/tests/test_form_field.py +++ b/tests/v1/test_forms.py @@ -4,10 +4,12 @@ import pytest from django.core.exceptions import ValidationError from django.forms import Form, modelform_factory -from django_pydantic_field import fields, forms -from .conftest import InnerSchema -from .test_app.models import SampleForwardRefModel, SampleSchema +from tests.conftest import InnerSchema +from tests.test_app.models import SampleForwardRefModel, SampleSchema + +fields = pytest.importorskip("django_pydantic_field.v1.fields") +forms = pytest.importorskip("django_pydantic_field.v1.forms") class SampleForm(Form): diff --git a/tests/test_drf_adapters.py b/tests/v1/test_rest_framework.py similarity index 98% rename from tests/test_drf_adapters.py rename to tests/v1/test_rest_framework.py index c1f04c6..f23bbdf 100644 --- a/tests/test_drf_adapters.py +++ b/tests/v1/test_rest_framework.py @@ -1,18 +1,18 @@ import io -import json import typing as t from datetime import date import pytest import yaml from django.urls import path -from django_pydantic_field import rest_framework from rest_framework import exceptions, generics, schemas, serializers, views from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema from rest_framework.response import Response -from .conftest import InnerSchema -from .test_app.models import SampleModel +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v1.rest_framework") class SampleSerializer(serializers.Serializer): diff --git a/tests/v2/__init__.py b/tests/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/v2/rest_framework/__init__.py b/tests/v2/rest_framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-class].json b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-class].json new file mode 100644 index 0000000..a0d5192 --- /dev/null +++ b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-class].json @@ -0,0 +1,217 @@ +{ + "components": { + "schemas": { + "InnerSchema": { + "properties": { + "stub_int": { + "default": 1, + "title": "Stub Int", + "type": "integer" + }, + "stub_list": { + "items": { + "format": "date", + "type": "string" + }, + "title": "Stub List", + "type": "array" + }, + "stub_str": { + "title": "Stub Str", + "type": "string" + } + }, + "required": [ + "stub_str", + "stub_list" + ], + "title": "InnerSchema", + "type": "object" + }, + "Sample": { + "properties": { + "field": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "required": [ + "field" + ], + "type": "object" + } + } + }, + "info": { + "title": "", + "version": "" + }, + "openapi": "3.0.2", + "paths": { + "/class": { + "get": { + "description": "", + "operationId": "retrieveSample", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "patch": { + "description": "", + "operationId": "partialUpdateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "put": { + "description": "", + "operationId": "updateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + } + }, + "/func": { + "get": { + "description": "", + "operationId": "listsample_views", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "items": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "type": "array" + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + }, + "post": { + "description": "", + "operationId": "createsample_view", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InnerSchema" + } + } + } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + } + } + } +} diff --git a/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-func].json b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-func].json new file mode 100644 index 0000000..a0d5192 --- /dev/null +++ b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-func].json @@ -0,0 +1,217 @@ +{ + "components": { + "schemas": { + "InnerSchema": { + "properties": { + "stub_int": { + "default": 1, + "title": "Stub Int", + "type": "integer" + }, + "stub_list": { + "items": { + "format": "date", + "type": "string" + }, + "title": "Stub List", + "type": "array" + }, + "stub_str": { + "title": "Stub Str", + "type": "string" + } + }, + "required": [ + "stub_str", + "stub_list" + ], + "title": "InnerSchema", + "type": "object" + }, + "Sample": { + "properties": { + "field": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "required": [ + "field" + ], + "type": "object" + } + } + }, + "info": { + "title": "", + "version": "" + }, + "openapi": "3.0.2", + "paths": { + "/class": { + "get": { + "description": "", + "operationId": "retrieveSample", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "patch": { + "description": "", + "operationId": "partialUpdateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "put": { + "description": "", + "operationId": "updateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + } + }, + "/func": { + "get": { + "description": "", + "operationId": "listsample_views", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "items": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "type": "array" + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + }, + "post": { + "description": "", + "operationId": "createsample_view", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InnerSchema" + } + } + } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + } + } + } +} diff --git a/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[POST-func].json b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[POST-func].json new file mode 100644 index 0000000..a0d5192 --- /dev/null +++ b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[POST-func].json @@ -0,0 +1,217 @@ +{ + "components": { + "schemas": { + "InnerSchema": { + "properties": { + "stub_int": { + "default": 1, + "title": "Stub Int", + "type": "integer" + }, + "stub_list": { + "items": { + "format": "date", + "type": "string" + }, + "title": "Stub List", + "type": "array" + }, + "stub_str": { + "title": "Stub Str", + "type": "string" + } + }, + "required": [ + "stub_str", + "stub_list" + ], + "title": "InnerSchema", + "type": "object" + }, + "Sample": { + "properties": { + "field": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "required": [ + "field" + ], + "type": "object" + } + } + }, + "info": { + "title": "", + "version": "" + }, + "openapi": "3.0.2", + "paths": { + "/class": { + "get": { + "description": "", + "operationId": "retrieveSample", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "patch": { + "description": "", + "operationId": "partialUpdateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "put": { + "description": "", + "operationId": "updateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + } + }, + "/func": { + "get": { + "description": "", + "operationId": "listsample_views", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "items": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "type": "array" + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + }, + "post": { + "description": "", + "operationId": "createsample_view", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InnerSchema" + } + } + } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + } + } + } +} diff --git a/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[PUT-class].json b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[PUT-class].json new file mode 100644 index 0000000..a0d5192 --- /dev/null +++ b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[PUT-class].json @@ -0,0 +1,217 @@ +{ + "components": { + "schemas": { + "InnerSchema": { + "properties": { + "stub_int": { + "default": 1, + "title": "Stub Int", + "type": "integer" + }, + "stub_list": { + "items": { + "format": "date", + "type": "string" + }, + "title": "Stub List", + "type": "array" + }, + "stub_str": { + "title": "Stub Str", + "type": "string" + } + }, + "required": [ + "stub_str", + "stub_list" + ], + "title": "InnerSchema", + "type": "object" + }, + "Sample": { + "properties": { + "field": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "required": [ + "field" + ], + "type": "object" + } + } + }, + "info": { + "title": "", + "version": "" + }, + "openapi": "3.0.2", + "paths": { + "/class": { + "get": { + "description": "", + "operationId": "retrieveSample", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "patch": { + "description": "", + "operationId": "partialUpdateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "put": { + "description": "", + "operationId": "updateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + } + }, + "/func": { + "get": { + "description": "", + "operationId": "listsample_views", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "items": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "type": "array" + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + }, + "post": { + "description": "", + "operationId": "createsample_view", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InnerSchema" + } + } + } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + } + } + } +} diff --git a/tests/v2/rest_framework/test_coreapi.py b/tests/v2/rest_framework/test_coreapi.py new file mode 100644 index 0000000..c3e2669 --- /dev/null +++ b/tests/v2/rest_framework/test_coreapi.py @@ -0,0 +1,26 @@ +import sys + +import pytest +from rest_framework import schemas +from rest_framework.request import Request + +from .view_fixtures import create_views_urlconf + +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="CoreAPI is not compatible with 3.12") +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/func"), + ("POST", "/func"), + ("GET", "/class"), + ("PUT", "/class"), + ], +) +def test_coreapi_schema_generators(request_factory, method, path): + urlconf = create_views_urlconf(coreapi.AutoSchema) + generator = schemas.SchemaGenerator(urlconf=urlconf) + request = Request(request_factory.generic(method, path)) + coreapi_schema = generator.get_schema(request) + assert coreapi_schema diff --git a/tests/v2/rest_framework/test_e2e_views.py b/tests/v2/rest_framework/test_e2e_views.py new file mode 100644 index 0000000..4790263 --- /dev/null +++ b/tests/v2/rest_framework/test_e2e_views.py @@ -0,0 +1,56 @@ +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +from .view_fixtures import ( + ClassBasedView, + ClassBasedViewWithModel, + ClassBasedViewWithSchemaContext, + sample_view, +) + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + + +@pytest.mark.parametrize( + "view", + [ + sample_view, + ClassBasedView.as_view(), + ClassBasedViewWithSchemaContext.as_view(), + ], +) +def test_end_to_end_api_view(view, request_factory): + expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + existing_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + request = request_factory.post("/", existing_encoded, content_type="application/json") + response = view(request) + + assert response.data == [expected_instance] + assert response.data[0] is not expected_instance + + assert response.rendered_content == b"[%s]" % existing_encoded + + +@pytest.mark.django_db +def test_end_to_end_list_create_api_view(request_factory): + field_data = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]).json() + expected_result = { + "sample_field": {"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}, + "sample_list": [{"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}], + "sample_seq": [], + } + + payload = '{"sample_field": %s, "sample_list": [%s], "sample_seq": []}' % ((field_data,) * 2) + request = request_factory.post("/", payload.encode(), content_type="application/json") + response = ClassBasedViewWithModel.as_view()(request) + + assert response.data == expected_result + + request = request_factory.get("/", content_type="application/json") + response = ClassBasedViewWithModel.as_view()(request) + assert response.data == [expected_result] diff --git a/tests/v2/rest_framework/test_fields.py b/tests/v2/rest_framework/test_fields.py new file mode 100644 index 0000000..6098da6 --- /dev/null +++ b/tests/v2/rest_framework/test_fields.py @@ -0,0 +1,108 @@ +import typing as ty +from datetime import date + +import pytest +from rest_framework import exceptions, serializers + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +class SampleSerializer(serializers.Serializer): + field = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + + +class SampleModelSerializer(serializers.ModelSerializer): + sample_field = rest_framework.SchemaField(schema=InnerSchema) + sample_list = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + sample_seq = rest_framework.SchemaField(schema=ty.List[InnerSchema], default=list) + + class Meta: + model = SampleModel + fields = "sample_field", "sample_list", "sample_seq" + + +def test_schema_field(): + field = rest_framework.SchemaField(InnerSchema) + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = { + "stub_str": "abc", + "stub_int": 1, + "stub_list": [date(2022, 7, 1)], + } + + assert field.to_representation(existing_instance) == expected_encoded + assert field.to_internal_value(expected_encoded) == existing_instance + + with pytest.raises(serializers.ValidationError): + field.to_internal_value(None) + + with pytest.raises(serializers.ValidationError): + field.to_internal_value("null") + + +def test_field_schema_with_custom_config(): + field = rest_framework.SchemaField(InnerSchema, allow_null=True, exclude={"stub_int"}) + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + + assert field.to_representation(existing_instance) == expected_encoded + assert field.to_internal_value(expected_encoded) == existing_instance + assert field.to_internal_value(None) is None + assert field.to_internal_value("null") is None + + +def test_serializer_marshalling_with_schema_field(): + existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]} + expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}]} + + serializer = SampleSerializer(instance=existing_instance) + assert serializer.data == expected_data + + serializer = SampleSerializer(data=expected_data) + serializer.is_valid(raise_exception=True) + assert serializer.validated_data == existing_instance + + +def test_model_serializer_marshalling_with_schema_field(): + instance = SampleModel( + sample_field=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + sample_list=[InnerSchema(stub_str="abc", stub_int=2, stub_list=[date(2022, 7, 1)])] * 2, + sample_seq=[InnerSchema(stub_str="abc", stub_int=3, stub_list=[date(2022, 7, 1)])] * 3, + ) + serializer = SampleModelSerializer(instance) + + expected_data = { + "sample_field": {"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}, + "sample_list": [{"stub_str": "abc", "stub_int": 2, "stub_list": [date(2022, 7, 1)]}] * 2, + "sample_seq": [{"stub_str": "abc", "stub_int": 3, "stub_list": [date(2022, 7, 1)]}] * 3, + } + assert serializer.data == expected_data + + +@pytest.mark.parametrize( + "export_kwargs", + [ + {"include": {"stub_str", "stub_int"}}, + {"exclude": {"stub_list"}}, + {"exclude_unset": True}, + {"exclude_defaults": True}, + {"exclude_none": True}, + {"by_alias": True}, + ], +) +def test_field_export_kwargs(export_kwargs): + field = rest_framework.SchemaField(InnerSchema, **export_kwargs) + assert field.to_representation(InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])) + + +def test_invalid_data_serialization(): + invalid_data = {"field": [{"stub_int": "abc", "stub_list": ["abc"]}]} + serializer = SampleSerializer(data=invalid_data) + + with pytest.raises(exceptions.ValidationError) as e: + serializer.is_valid(raise_exception=True) + + assert e.match(r".*stub_str.*stub_int.*stub_list.*") diff --git a/tests/v2/rest_framework/test_openapi.py b/tests/v2/rest_framework/test_openapi.py new file mode 100644 index 0000000..0ce6ca6 --- /dev/null +++ b/tests/v2/rest_framework/test_openapi.py @@ -0,0 +1,22 @@ +import pytest +from rest_framework.schemas.openapi import SchemaGenerator +from rest_framework.request import Request + +from .view_fixtures import create_views_urlconf + +openapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.openapi") + +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/func"), + ("POST", "/func"), + ("GET", "/class"), + ("PUT", "/class"), + ], +) +def test_openapi_schema_generators(request_factory, method, path, snapshot_json): + urlconf = create_views_urlconf(openapi.AutoSchema) + generator = SchemaGenerator(urlconf=urlconf) + request = Request(request_factory.generic(method, path)) + assert snapshot_json() == generator.get_schema(request) diff --git a/tests/v2/rest_framework/test_parsers.py b/tests/v2/rest_framework/test_parsers.py new file mode 100644 index 0000000..db64a56 --- /dev/null +++ b/tests/v2/rest_framework/test_parsers.py @@ -0,0 +1,23 @@ +import io +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +@pytest.mark.parametrize( + "schema_type, existing_encoded, expected_decoded", + [ + ( + InnerSchema, + '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}', + InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + ) + ], +) +def test_schema_parser(schema_type, existing_encoded, expected_decoded): + parser = rest_framework.SchemaParser[schema_type]() + assert parser.parse(io.StringIO(existing_encoded)) == expected_decoded diff --git a/tests/v2/rest_framework/test_renderers.py b/tests/v2/rest_framework/test_renderers.py new file mode 100644 index 0000000..59d4aed --- /dev/null +++ b/tests/v2/rest_framework/test_renderers.py @@ -0,0 +1,23 @@ +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +def test_schema_renderer(): + renderer = rest_framework.SchemaRenderer() + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + assert renderer.render(existing_instance) == expected_encoded + + +def test_typed_schema_renderer(): + renderer = rest_framework.SchemaRenderer[InnerSchema]() + existing_data = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + assert renderer.render(existing_data) == expected_encoded diff --git a/tests/v2/rest_framework/view_fixtures.py b/tests/v2/rest_framework/view_fixtures.py new file mode 100644 index 0000000..a55ddd6 --- /dev/null +++ b/tests/v2/rest_framework/view_fixtures.py @@ -0,0 +1,88 @@ +import typing as ty +from types import SimpleNamespace + +import pytest +from django.urls import path +from rest_framework import generics, serializers, views +from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema +from rest_framework.response import Response + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + + +class SampleSerializer(serializers.Serializer): + field = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + + +class SampleModelSerializer(serializers.ModelSerializer): + sample_field = rest_framework.SchemaField(schema=InnerSchema) + sample_list = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + sample_seq = rest_framework.SchemaField(schema=ty.List[InnerSchema], default=list) + + class Meta: + model = SampleModel + fields = "sample_field", "sample_list", "sample_seq" + + +class ClassBasedView(views.APIView): + parser_classes = [rest_framework.SchemaParser[InnerSchema]] + renderer_classes = [rest_framework.SchemaRenderer[ty.List[InnerSchema]]] + + def post(self, request, *args, **kwargs): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + +class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): + serializer_class = SampleSerializer + + +class ClassBasedViewWithModel(generics.ListCreateAPIView): + queryset = SampleModel.objects.all() + serializer_class = SampleModelSerializer + + +class ClassBasedViewWithSchemaContext(ClassBasedView): + parser_classes = [rest_framework.SchemaParser] + renderer_classes = [rest_framework.SchemaRenderer] + + def get_renderer_context(self): + ctx = super().get_renderer_context() + return dict(ctx, renderer_schema=ty.List[InnerSchema]) + + def get_parser_context(self, http_request): + ctx = super().get_parser_context(http_request) + return dict(ctx, parser_schema=InnerSchema) + + +@api_view(["GET", "POST"]) +@parser_classes([rest_framework.SchemaParser[InnerSchema]]) +@renderer_classes([rest_framework.SchemaRenderer[ty.List[InnerSchema]]]) +def sample_view(request): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + +def create_views_urlconf(schema_view_inspector): + @api_view(["GET", "POST"]) + @schema(schema_view_inspector()) + @parser_classes([rest_framework.SchemaParser[InnerSchema]]) + @renderer_classes([rest_framework.SchemaRenderer[ty.List[InnerSchema]]]) + def sample_view(request): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): + serializer_class = SampleSerializer + schema = schema_view_inspector() + + return SimpleNamespace( + urlpatterns=[ + path("/func", sample_view), + path("/class", ClassBasedViewWithSerializer.as_view()), + ], + ) diff --git a/tests/v2/test_forms.py b/tests/v2/test_forms.py new file mode 100644 index 0000000..4a1081e --- /dev/null +++ b/tests/v2/test_forms.py @@ -0,0 +1,101 @@ +import typing as ty + +import django +import pytest +from django.core.exceptions import ValidationError +from django.forms import Form, modelform_factory + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleForwardRefModel, SampleSchema + +fields = pytest.importorskip("django_pydantic_field.v2.fields") +forms = pytest.importorskip("django_pydantic_field.v2.forms") + + +class SampleForm(Form): + field = forms.SchemaField(ty.ForwardRef("SampleSchema")) + + +def test_form_schema_field(): + field = forms.SchemaField(InnerSchema) + + cleaned_data = field.clean('{"stub_str": "abc", "stub_list": ["1970-01-01"]}') + assert cleaned_data == InnerSchema.model_validate({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + + +def test_empty_form_values(): + field = forms.SchemaField(InnerSchema, required=False) + assert field.clean("") is None + assert field.clean(None) is None + + +def test_prepare_value(): + field = forms.SchemaField(InnerSchema, required=False) + expected = '{"stub_str":"abc","stub_int":1,"stub_list":["1970-01-01"]}' + assert expected == field.prepare_value({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + + +def test_empty_required_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean("") + + assert e.match("This field is required") + + +def test_invalid_schema_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean('{"stub_list": "abc"}') + + assert e.match("Schema didn't match for") + assert "stub_list" in e.value.params["detail"] # type: ignore + assert "stub_str" in e.value.params["detail"] # type: ignore + + +def test_invalid_json_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean('{"stub_list": "abc}') + + assert e.match("Schema didn't match for") + assert '"type":"json_invalid"' in e.value.params["detail"] # type: ignore + + +@pytest.mark.xfail( + django.VERSION[:2] < (4, 0), + reason="Django < 4 has it's own feeling on bound fields resolution", +) +def test_forwardref_field(): + form = SampleForm(data={"field": '{"field": "2"}'}) + assert form.is_valid() + + +def test_model_formfield(): + field = fields.PydanticSchemaField(schema=InnerSchema) + assert isinstance(field.formfield(), forms.SchemaField) + + +def test_forwardref_model_formfield(): + form_cls = modelform_factory(SampleForwardRefModel, exclude=("field",)) + form = form_cls(data={"annotated_field": '{"field": "2"}'}) + + assert form.is_valid(), form.errors + cleaned_data = form.cleaned_data + + assert cleaned_data is not None + assert cleaned_data["annotated_field"] == SampleSchema(field=2) + + +@pytest.mark.parametrize("export_kwargs", [ + {"include": {"stub_str", "stub_int"}}, + {"exclude": {"stub_list"}}, + {"exclude_unset": True}, + {"exclude_defaults": True}, + {"exclude_none": True}, + {"by_alias": True}, +]) +def test_form_field_export_kwargs(export_kwargs): + field = forms.SchemaField(InnerSchema, required=False, **export_kwargs) + value = InnerSchema.model_validate({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + assert field.prepare_value(value) diff --git a/tests/v2/test_types.py b/tests/v2/test_types.py new file mode 100644 index 0000000..7050b51 --- /dev/null +++ b/tests/v2/test_types.py @@ -0,0 +1,148 @@ +import sys +import pydantic +import pytest +import typing as ty + +from ..conftest import InnerSchema, SampleDataclass + +types = pytest.importorskip("django_pydantic_field.v2.types") +skip_unsupported_builtin_subscription = pytest.mark.skipif( + sys.version_info < (3, 9), + reason="Built-in type subscription supports only in 3.9+", +) + + +# fmt: off +@pytest.mark.parametrize( + "ctor, args, kwargs", + [ + pytest.param(types.SchemaAdapter, ["list[int]", None, None, None], {}, marks=skip_unsupported_builtin_subscription), + pytest.param(types.SchemaAdapter, ["list[int]", {"strict": True}, None, None], {}, marks=skip_unsupported_builtin_subscription), + (types.SchemaAdapter, [ty.List[int], None, None, None], {}), + (types.SchemaAdapter, [ty.List[int], {"strict": True}, None, None], {}), + (types.SchemaAdapter, [None, None, InnerSchema, "stub_int"], {}), + (types.SchemaAdapter, [None, None, SampleDataclass, "stub_int"], {}), + pytest.param(types.SchemaAdapter.from_type, ["list[int]"], {}, marks=skip_unsupported_builtin_subscription), + pytest.param(types.SchemaAdapter.from_type, ["list[int]", {"strict": True}], {}, marks=skip_unsupported_builtin_subscription), + (types.SchemaAdapter.from_type, [ty.List[int]], {}), + (types.SchemaAdapter.from_type, [ty.List[int], {"strict": True}], {}), + (types.SchemaAdapter.from_annotation, [InnerSchema, "stub_int"], {}), + (types.SchemaAdapter.from_annotation, [InnerSchema, "stub_int", {"strict": True}], {}), + (types.SchemaAdapter.from_annotation, [SampleDataclass, "stub_int"], {}), + (types.SchemaAdapter.from_annotation, [SampleDataclass, "stub_int", {"strict": True}], {}), + ], +) +# fmt: on +def test_schema_adapter_constructors(ctor, args, kwargs): + adapter = ctor(*args, **kwargs) + adapter.validate_schema() + assert isinstance(adapter.type_adapter, pydantic.TypeAdapter) + + +def test_schema_adapter_is_bound(): + adapter = types.SchemaAdapter(None, None, None, None) + with pytest.raises(types.ImproperlyConfiguredSchema): + adapter.validate_schema() # Schema cannot be resolved for fully unbound adapter + + adapter = types.SchemaAdapter(ty.List[int], None, None, None) + assert not adapter.is_bound, "SchemaAdapter should not be bound" + adapter.validate_schema() # Schema should be resolved from direct argument + + adapter.bind(InnerSchema, "stub_int") + assert adapter.is_bound, "SchemaAdapter should be bound" + adapter.validate_schema() # Schema should be resolved from direct argument + + adapter = types.SchemaAdapter(None, None, InnerSchema, "stub_int") + assert adapter.is_bound, "SchemaAdapter should be bound" + adapter.validate_schema() # Schema should be resolved from bound attribute + + +# fmt: off +@pytest.mark.parametrize( + "kwargs, expected_export_kwargs", + [ + ({}, {}), + ({"strict": True}, {"strict": True}), + ({"strict": True, "by_alias": False}, {"strict": True, "by_alias": False}), + ({"strict": True, "from_attributes": False, "on_delete": "CASCADE"}, {"strict": True, "from_attributes": False}), + ], +) +# fmt: on +def test_schema_adapter_extract_export_kwargs(kwargs, expected_export_kwargs): + orig_kwargs = dict(kwargs) + assert types.SchemaAdapter.extract_export_kwargs(kwargs) == expected_export_kwargs + assert kwargs == {key: orig_kwargs[key] for key in orig_kwargs.keys() - expected_export_kwargs.keys()} + + +def test_schema_adapter_validate_python(): + adapter = types.SchemaAdapter.from_type(ty.List[int]) + assert adapter.validate_python([1, 2, 3]) == [1, 2, 3] + assert adapter.validate_python([1, 2, 3], strict=True) == [1, 2, 3] + assert adapter.validate_python([1, 2, 3], strict=False) == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {"strict": True}) + assert adapter.validate_python([1, 2, 3]) == [1, 2, 3] + assert adapter.validate_python(["1", "2", "3"], strict=False) == [1, 2, 3] + assert sorted(adapter.validate_python({1, 2, 3}, strict=False)) == [1, 2, 3] + with pytest.raises(pydantic.ValidationError): + assert adapter.validate_python(["1", "2", "3"]) == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {"strict": False}) + assert adapter.validate_python([1, 2, 3]) == [1, 2, 3] + assert adapter.validate_python([1, 2, 3], strict=False) == [1, 2, 3] + assert sorted(adapter.validate_python({1, 2, 3})) == [1, 2, 3] + with pytest.raises(pydantic.ValidationError): + assert adapter.validate_python({1, 2, 3}, strict=True) == [1, 2, 3] + + +def test_schema_adapter_validate_json(): + adapter = types.SchemaAdapter.from_type(ty.List[int]) + assert adapter.validate_json("[1, 2, 3]") == [1, 2, 3] + assert adapter.validate_json("[1, 2, 3]", strict=True) == [1, 2, 3] + assert adapter.validate_json("[1, 2, 3]", strict=False) == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {"strict": True}) + assert adapter.validate_json("[1, 2, 3]") == [1, 2, 3] + assert adapter.validate_json('["1", "2", "3"]', strict=False) == [1, 2, 3] + with pytest.raises(pydantic.ValidationError): + assert adapter.validate_json('["1", "2", "3"]') == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {"strict": False}) + assert adapter.validate_json("[1, 2, 3]") == [1, 2, 3] + assert adapter.validate_json("[1, 2, 3]", strict=False) == [1, 2, 3] + with pytest.raises(pydantic.ValidationError): + assert adapter.validate_json('["1", "2", "3"]', strict=True) == [1, 2, 3] + + +def test_schema_adapter_dump_python(): + adapter = types.SchemaAdapter.from_type(ty.List[int]) + assert adapter.dump_python([1, 2, 3]) == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {}) + assert adapter.dump_python([1, 2, 3]) == [1, 2, 3] + assert sorted(adapter.dump_python({1, 2, 3})) == [1, 2, 3] + with pytest.warns(UserWarning): + assert adapter.dump_python(["1", "2", "3"]) == ["1", "2", "3"] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {}) + assert adapter.dump_python([1, 2, 3]) == [1, 2, 3] + assert sorted(adapter.dump_python({1, 2, 3})) == [1, 2, 3] + with pytest.warns(UserWarning): + assert adapter.dump_python(["1", "2", "3"]) == ["1", "2", "3"] + + +def test_schema_adapter_dump_json(): + adapter = types.SchemaAdapter.from_type(ty.List[int]) + assert adapter.dump_json([1, 2, 3]) == b"[1,2,3]" + + adapter = types.SchemaAdapter.from_type(ty.List[int], {}) + assert adapter.dump_json([1, 2, 3]) == b"[1,2,3]" + assert adapter.dump_json({1, 2, 3}) == b"[1,2,3]" + with pytest.warns(UserWarning): + assert adapter.dump_json(["1", "2", "3"]) == b'["1","2","3"]' + + adapter = types.SchemaAdapter.from_type(ty.List[int], {}) + assert adapter.dump_json([1, 2, 3]) == b"[1,2,3]" + assert adapter.dump_json({1, 2, 3}) == b"[1,2,3]" + with pytest.warns(UserWarning): + assert adapter.dump_json(["1", "2", "3"]) == b'["1","2","3"]'