From 4fd78c94cedfb565b447348695ef13ccb05906a4 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Sat, 4 Nov 2023 22:34:38 +0400 Subject: [PATCH 01/34] Add compatibility layer to separate v1 and v2. --- .gitignore | 2 + django_pydantic_field/__init__.py | 17 +- django_pydantic_field/compat/__init__.py | 5 + .../django.py} | 4 +- django_pydantic_field/compat/imports.py | 38 +++ django_pydantic_field/compat/pydantic.py | 6 + django_pydantic_field/fields.py | 196 +------------ django_pydantic_field/fields.pyi | 59 ++++ django_pydantic_field/forms.py | 68 +---- django_pydantic_field/rest_framework.py | 262 +----------------- django_pydantic_field/v1/__init__.py | 1 + .../v1/_migration_serializers.py | 0 django_pydantic_field/{ => v1}/base.py | 0 django_pydantic_field/v1/fields.py | 194 +++++++++++++ django_pydantic_field/v1/forms.py | 66 +++++ django_pydantic_field/v1/rest_framework.py | 260 +++++++++++++++++ django_pydantic_field/{ => v1}/utils.py | 0 django_pydantic_field/v2/__init__.py | 0 pyproject.toml | 4 +- tests/sample_app/migrations/0001_initial.py | 2 +- tests/test_migration_serializers.py | 2 +- 21 files changed, 662 insertions(+), 524 deletions(-) create mode 100644 django_pydantic_field/compat/__init__.py rename django_pydantic_field/{_migration_serializers.py => compat/django.py} (97%) create mode 100644 django_pydantic_field/compat/imports.py create mode 100644 django_pydantic_field/compat/pydantic.py create mode 100644 django_pydantic_field/fields.pyi create mode 100644 django_pydantic_field/v1/__init__.py create mode 100644 django_pydantic_field/v1/_migration_serializers.py rename django_pydantic_field/{ => v1}/base.py (100%) create mode 100644 django_pydantic_field/v1/fields.py create mode 100644 django_pydantic_field/v1/forms.py create mode 100644 django_pydantic_field/v1/rest_framework.py rename django_pydantic_field/{ => v1}/utils.py (100%) create mode 100644 django_pydantic_field/v2/__init__.py diff --git a/.gitignore b/.gitignore index 37737ef..f936d7d 100644 --- a/.gitignore +++ b/.gitignore @@ -12,3 +12,5 @@ dist/ *.egg-info/ build htmlcov + +.python-version diff --git a/django_pydantic_field/__init__.py b/django_pydantic_field/__init__.py index 7746f2c..9d5c124 100644 --- a/django_pydantic_field/__init__.py +++ b/django_pydantic_field/__init__.py @@ -1 +1,16 @@ -from .fields import * +from .fields import SchemaField as SchemaField + +def __getattr__(name): + if name == "_migration_serializers": + import warnings + from .compat import django + + DEPRECATION_MSG = ( + "Module 'django_pydantic_field._migration_serializers' is deprecated " + "and will be removed in version 1.0.0. " + "Please replace it with 'django_pydantic_field.compat.django' in migrations." + ) + warnings.warn(DEPRECATION_MSG, category=DeprecationWarning) + return django + + raise AttributeError(f"Module {__name__!r} has no attribute {name!r}") diff --git a/django_pydantic_field/compat/__init__.py b/django_pydantic_field/compat/__init__.py new file mode 100644 index 0000000..4374b6b --- /dev/null +++ b/django_pydantic_field/compat/__init__.py @@ -0,0 +1,5 @@ +from .pydantic import PYDANTIC_V1 as PYDANTIC_V1 +from .pydantic import PYDANTIC_V2 as PYDANTIC_V2 +from .pydantic import PYDANTIC_VERSION as PYDANTIC_VERSION +from .django import GenericContainer as GenericContainer +from .django import MigrationWriter as MigrationWriter diff --git a/django_pydantic_field/_migration_serializers.py b/django_pydantic_field/compat/django.py similarity index 97% rename from django_pydantic_field/_migration_serializers.py rename to django_pydantic_field/compat/django.py index d608855..682f126 100644 --- a/django_pydantic_field/_migration_serializers.py +++ b/django_pydantic_field/compat/django.py @@ -115,8 +115,8 @@ def serialize(self): else: # types.GenericAlias is missing, meaning python version < 3.9, # which has a different inheritance models for typed generics - GenericAlias = type(t.List[int]) - GenericTypes = GenericAlias, type(t.List) + GenericAlias = type(t.List[int]) # noqa + GenericTypes = GenericAlias, type(t.List) # noqa MigrationWriter.register_serializer(GenericContainer, GenericSerializer) diff --git a/django_pydantic_field/compat/imports.py b/django_pydantic_field/compat/imports.py new file mode 100644 index 0000000..1587d91 --- /dev/null +++ b/django_pydantic_field/compat/imports.py @@ -0,0 +1,38 @@ +import functools +import importlib +import types + +from .pydantic import PYDANTIC_V1, PYDANTIC_V2, PYDANTIC_VERSION + +__all__ = ("compat_getattr", "compat_dir") + + +def compat_getattr(module_name: str): + module = _import_compat_module(module_name) + return functools.partial(getattr, module) + + +def compat_dir(module_name: str): + compat_module = _import_compat_module(module_name) + return dir(compat_module) + + +def _import_compat_module(module_name: str) -> types.ModuleType: + try: + package, _, module = module_name.partition(".") + except ValueError: + package, module = module_name, "" + + module_path_parts = [package] + if PYDANTIC_V2: + module_path_parts.append("v2") + elif PYDANTIC_V1: + module_path_parts.append("v1") + else: + raise RuntimeError(f"Pydantic {PYDANTIC_VERSION} is not supported") + + if module: + module_path_parts.append(module) + + module_path = ".".join(module_path_parts) + return importlib.import_module(module_path) diff --git a/django_pydantic_field/compat/pydantic.py b/django_pydantic_field/compat/pydantic.py new file mode 100644 index 0000000..e8d7f12 --- /dev/null +++ b/django_pydantic_field/compat/pydantic.py @@ -0,0 +1,6 @@ +from pydantic.version import VERSION as PYDANTIC_VERSION + +__all__ = ("PYDANTIC_V2", "PYDANTIC_V1", "PYDANTIC_VERSION") + +PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.") +PYDANTIC_V1 = PYDANTIC_VERSION.startswith("1.") diff --git a/django_pydantic_field/fields.py b/django_pydantic_field/fields.py index 9ff16b4..d76f309 100644 --- a/django_pydantic_field/fields.py +++ b/django_pydantic_field/fields.py @@ -1,194 +1,4 @@ -import json -import typing as t -from functools import partial +from .compat.imports import compat_getattr, compat_dir -import django -import pydantic -from django.core import exceptions as django_exceptions -from django.db.models.fields import NOT_PROVIDED -from django.db.models.fields.json import JSONField -from django.db.models.query_utils import DeferredAttribute - -from . import base, forms, utils -from ._migration_serializers import GenericContainer, GenericTypes - -__all__ = ("SchemaField",) - - -class SchemaAttribute(DeferredAttribute): - """ - Forces Django to call to_python on fields when setting them. - This is useful when you want to add some custom field data postprocessing. - - Should be added to field like a so: - - ``` - def contribute_to_class(self, cls, name, *args, **kwargs): - super().contribute_to_class(cls, name, *args, **kwargs) - setattr(cls, name, SchemaDeferredAttribute(self)) - ``` - """ - - field: "PydanticSchemaField" - - def __set__(self, obj, value): - obj.__dict__[self.field.attname] = self.field.to_python(value) - - -class PydanticSchemaField(JSONField, t.Generic[base.ST]): - descriptor_class = SchemaAttribute - _is_prepared_schema: bool = False - - def __init__( - self, - *args, - schema: t.Union[t.Type["base.ST"], "GenericContainer", "t.ForwardRef", str, None] = None, - config: "base.ConfigType" = None, - **kwargs, - ): - self.export_params = base.extract_export_kwargs(kwargs, dict.pop) - super().__init__(*args, **kwargs) - - self.config = config - self._resolve_schema(schema) - - def __copy__(self): - _, _, args, kwargs = self.deconstruct() - copied = type(self)(*args, **kwargs) - copied.set_attributes_from_name(self.name) - return copied - - def get_default(self): - value = super().get_default() - return self.to_python(value) - - def to_python(self, value) -> "base.SchemaT": - # Attempt to resolve forward referencing schema if it was not succesful - # during `.contribute_to_class` call - if not self._is_prepared_schema: - self._prepare_model_schema() - try: - assert self.decoder is not None - return self.decoder().decode(value) - except pydantic.ValidationError as e: - raise django_exceptions.ValidationError(e.errors()) - - if django.VERSION[:2] >= (4, 2): - - def get_prep_value(self, value): - if not self._is_prepared_schema: - self._prepare_model_schema() - prep_value = super().get_prep_value(value) - prep_value = self.encoder().encode(prep_value) # type: ignore - return json.loads(prep_value) - - def deconstruct(self): - name, path, args, kwargs = super().deconstruct() - self._deconstruct_schema(kwargs) - self._deconstruct_default(kwargs) - self._deconstruct_config(kwargs) - - kwargs.pop("decoder") - kwargs.pop("encoder") - - return name, path, args, kwargs - - def contribute_to_class(self, cls, name, private_only=False): - if self.schema is None: - self._resolve_schema_from_type_hints(cls, name) - - try: - self._prepare_model_schema(cls) - except NameError: - # Pydantic was not able to resolve forward references, which means - # that it should be postponed until initial access to the field - self._is_prepared_schema = False - - super().contribute_to_class(cls, name, private_only) - - def formfield(self, **kwargs): - if self.schema is None: - self._resolve_schema_from_type_hints(self.model, self.attname) - - owner_model = getattr(self, "model", None) - field_kwargs = dict( - form_class=forms.SchemaField, - schema=self.schema, - config=self.config, - __module__=getattr(owner_model, "__module__", None), - **self.export_params, - ) - field_kwargs.update(kwargs) - return super().formfield(**field_kwargs) - - def _resolve_schema(self, schema): - schema = t.cast(t.Type["base.ST"], GenericContainer.unwrap(schema)) - - self.schema = schema - if schema is not None: - self.serializer_schema = serializer = base.wrap_schema(schema, self.config, self.null) - self.decoder = partial(base.SchemaDecoder, serializer) # type: ignore - self.encoder = partial(base.SchemaEncoder, schema=serializer, export=self.export_params) # type: ignore - - def _resolve_schema_from_type_hints(self, cls, name): - annotated_schema = utils.get_annotated_type(cls, name) - if annotated_schema is None: - raise django_exceptions.FieldError( - f"{cls._meta.label}.{name} needs to be either annotated " - "or `schema=` field attribute should be explicitly passed" - ) - self._resolve_schema(annotated_schema) - - def _prepare_model_schema(self, cls=None): - cls = cls or getattr(self, "model", None) - if cls is not None: - base.prepare_schema(self.serializer_schema, cls) - self._is_prepared_schema = True - - def _deconstruct_default(self, kwargs): - default = kwargs.get("default", NOT_PROVIDED) - - if not (default is NOT_PROVIDED or callable(default)): - if self._is_prepared_schema: - default = self.get_prep_value(default) - kwargs.update(default=default) - - def _deconstruct_schema(self, kwargs): - kwargs.update(schema=GenericContainer.wrap(self.schema)) - - def _deconstruct_config(self, kwargs): - kwargs.update(base.deconstruct_export_kwargs(self.export_params)) - kwargs.update(config=self.config) - - -if t.TYPE_CHECKING: - OptSchemaT = t.Optional[base.SchemaT] - - -@t.overload -def SchemaField( - schema: "t.Union[t.Type[t.Optional[base.ST]], t.ForwardRef]" = ..., - config: "base.ConfigType" = ..., - default: "t.Union[OptSchemaT, t.Callable[[], OptSchemaT]]" = ..., - *args, - null: "t.Literal[True]", - **kwargs, -) -> "t.Optional[base.ST]": - ... - - -@t.overload -def SchemaField( - schema: "t.Union[t.Type[base.ST], t.ForwardRef]" = ..., - config: "base.ConfigType" = ..., - default: "t.Union[base.SchemaT, t.Callable[[], base.SchemaT]]" = ..., - *args, - null: "t.Literal[False]" = ..., - **kwargs, -) -> "base.ST": - ... - - -def SchemaField(schema=None, config=None, *args, **kwargs) -> t.Any: - kwargs.update(schema=schema, config=config) - return PydanticSchemaField(*args, **kwargs) +__getattr__ = compat_getattr(__name__) +__dir__ = compat_dir(__name__) diff --git a/django_pydantic_field/fields.pyi b/django_pydantic_field/fields.pyi new file mode 100644 index 0000000..728e2be --- /dev/null +++ b/django_pydantic_field/fields.pyi @@ -0,0 +1,59 @@ +from __future__ import annotations + +import typing as ty + +from pydantic import BaseModel +from pydantic.dataclasses import DataclassClassOrWrapper + +from .compat.pydantic import PYDANTIC_V1, PYDANTIC_V2 + +__all__ = ("SchemaField",) + +SchemaT: ty.TypeAlias = ty.Union[ + BaseModel, + DataclassClassOrWrapper, + ty.Sequence[ty.Any], + ty.Mapping[str, ty.Any], + ty.Set[ty.Any], + ty.FrozenSet[ty.Any], +] +OptSchemaT: ty.TypeAlias = ty.Optional[SchemaT] +ST = ty.TypeVar("ST", bound=SchemaT) + +ConfigType: ty.TypeAlias = ty.Any + +if PYDANTIC_V1: + from pydantic import ConfigDict, BaseConfig + + ConfigType = ty.Union[ConfigDict, type[BaseConfig], type] +elif PYDANTIC_V2: + from pydantic import ConfigDict + + ConfigType = ConfigDict + + +@ty.overload +def SchemaField( + schema: type[ST | None] | ty.ForwardRef = ..., + config: ConfigType = ..., + default: OptSchemaT | ty.Callable[[], OptSchemaT] = ..., + *args, + null: ty.Literal[True], + **kwargs, +) -> ST | None: + ... + + +@ty.overload +def SchemaField( + schema: type[ST] | ty.ForwardRef = ..., + config: ConfigType = ..., + default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., + *args, + null: ty.Literal[False] = ..., + **kwargs, +) -> ST: + ... + +def SchemaField(*args, **kwargs) -> ty.Any: + ... diff --git a/django_pydantic_field/forms.py b/django_pydantic_field/forms.py index 3185aad..d76f309 100644 --- a/django_pydantic_field/forms.py +++ b/django_pydantic_field/forms.py @@ -1,66 +1,4 @@ -import typing as t -from functools import partial +from .compat.imports import compat_getattr, compat_dir -import pydantic -from django.core.exceptions import ValidationError -from django.forms.fields import InvalidJSONInput, JSONField -from django.utils.translation import gettext_lazy as _ - -from . import base - -__all__ = ("SchemaField",) - - -class SchemaField(JSONField, t.Generic[base.ST]): - default_error_messages = { - "schema_error": _("Schema didn't match. Detail: %(detail)s"), - } - - def __init__( - self, - schema: t.Union[t.Type["base.ST"], t.ForwardRef], - config: t.Optional["base.ConfigType"] = None, - __module__: t.Optional[str] = None, - **kwargs, - ): - self.schema = base.wrap_schema( - schema, - config, - allow_null=not kwargs.get("required", True), - __module__=__module__, - ) - export_params = base.extract_export_kwargs(kwargs, dict.pop) - decoder = partial(base.SchemaDecoder, self.schema) - encoder = partial( - base.SchemaEncoder, - schema=self.schema, - export=export_params, - raise_errors=True, - ) - kwargs.update(encoder=encoder, decoder=decoder) - super().__init__(**kwargs) - - def to_python(self, value): - try: - return super().to_python(value) - except pydantic.ValidationError as e: - raise ValidationError( - self.error_messages["schema_error"], - code="invalid", - params={ - "value": value, - "detail": str(e), - "errors": e.errors(), - "json": e.json(), - }, - ) - - def bound_data(self, data, initial): - try: - return super().bound_data(data, initial) - except pydantic.ValidationError: - return InvalidJSONInput(data) - - def get_bound_field(self, form, field_name): - base.prepare_schema(self.schema, form) - return super().get_bound_field(form, field_name) +__getattr__ = compat_getattr(__name__) +__dir__ = compat_dir(__name__) diff --git a/django_pydantic_field/rest_framework.py b/django_pydantic_field/rest_framework.py index 4001dcd..d76f309 100644 --- a/django_pydantic_field/rest_framework.py +++ b/django_pydantic_field/rest_framework.py @@ -1,260 +1,4 @@ -import typing as t +from .compat.imports import compat_getattr, compat_dir -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args - -from django.conf import settings -from pydantic import BaseModel, ValidationError - -from rest_framework import exceptions, parsers, renderers, serializers -from rest_framework.schemas import openapi -from rest_framework.schemas.utils import is_list_view - -from . import base - -__all__ = ( - "SchemaField", - "SchemaRenderer", - "SchemaParser", - "AutoSchema", -) - -if t.TYPE_CHECKING: - RequestResponseContext = t.Mapping[str, t.Any] - - -class AnnotatedSchemaT(t.Generic[base.ST]): - schema_ctx_attr: t.ClassVar[str] = "schema" - require_explicit_schema: t.ClassVar[bool] = False - _cached_annotation_schema: t.Type[BaseModel] - - def get_schema(self, ctx: "RequestResponseContext") -> t.Optional[t.Type[BaseModel]]: - schema = self.get_context_schema(ctx) - if schema is None: - schema = self.get_annotation_schema(ctx) - - if self.require_explicit_schema and schema is None: - raise ValueError( - "Schema should be either explicitly set with annotation " - "or passed in the context" - ) - - return schema - - def get_context_schema(self, ctx: "RequestResponseContext"): - schema = ctx.get(self.schema_ctx_attr) - if schema is not None: - schema = base.wrap_schema(schema) - base.prepare_schema(schema, ctx.get("view")) - - return schema - - def get_annotation_schema(self, ctx: "RequestResponseContext"): - try: - schema = self._cached_annotation_schema - except AttributeError: - try: - schema = get_args(self.__orig_class__)[0] # type: ignore - except (AttributeError, IndexError): - return None - - self._cached_annotation_schema = schema = base.wrap_schema(schema) - base.prepare_schema(schema, ctx.get("view")) - - return schema - - -class SchemaField(serializers.Field, t.Generic[base.ST]): - decoder: "base.SchemaDecoder[base.ST]" - _is_prepared_schema: bool = False - - def __init__( - self, - schema: t.Type["base.ST"], - config: t.Optional["base.ConfigType"] = None, - **kwargs, - ): - nullable = kwargs.get("allow_null", False) - - self.schema = field_schema = base.wrap_schema(schema, config, nullable) - self.export_params = base.extract_export_kwargs(kwargs, dict.pop) - self.decoder = base.SchemaDecoder(field_schema) - super().__init__(**kwargs) - - def bind(self, field_name, parent): - if not self._is_prepared_schema: - base.prepare_schema(self.schema, parent) - self._is_prepared_schema = True - - super().bind(field_name, parent) - - def to_internal_value(self, data: t.Any) -> t.Optional["base.ST"]: - try: - return self.decoder.decode(data) - except ValidationError as e: - raise serializers.ValidationError(e.errors(), self.field_name) - - def to_representation(self, value: t.Optional["base.ST"]) -> t.Any: - obj = self.schema.parse_obj(value) - raw_obj = obj.dict(**self.export_params) - return raw_obj["__root__"] - - -class SchemaRenderer(AnnotatedSchemaT[base.ST], renderers.JSONRenderer): - schema_ctx_attr = "render_schema" - - def render(self, data, accepted_media_type=None, renderer_context=None): - renderer_context = renderer_context or {} - response = renderer_context.get("response") - if response is not None and response.exception: - return super().render(data, accepted_media_type, renderer_context) - - try: - json_str = self.render_data(data, renderer_context) - except ValidationError as e: - json_str = e.json().encode() - except AttributeError: - json_str = super().render(data, accepted_media_type, renderer_context) - - return json_str - - def render_data(self, data, renderer_ctx) -> bytes: - schema = self.get_schema(renderer_ctx or {}) - if schema is not None: - data = schema(__root__=data) - - export_kw = base.extract_export_kwargs(renderer_ctx) - json_str = data.json(**export_kw, ensure_ascii=self.ensure_ascii) - return json_str.encode() - - -class SchemaParser(AnnotatedSchemaT[base.ST], parsers.JSONParser): - schema_ctx_attr = "parser_schema" - renderer_class = SchemaRenderer - require_explicit_schema = True - - def parse(self, stream, media_type=None, parser_context=None): - parser_context = parser_context or {} - encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) - schema = t.cast(BaseModel, self.get_schema(parser_context)) - - try: - return schema.parse_raw(stream.read(), encoding=encoding).__root__ - except ValidationError as e: - raise exceptions.ParseError(e.errors()) - - -class AutoSchema(openapi.AutoSchema): - get_request_serializer: t.Callable - _get_reference: t.Callable - - def map_field(self, field: serializers.Field): - if isinstance(field, SchemaField): - return field.schema.schema() - return super().map_field(field) - - def map_parsers(self, path: str, method: str): - request_types: t.List[t.Any] = [] - parser_ctx = self.view.get_parser_context(None) - - for parser_type in self.view.parser_classes: - parser = parser_type() - - if isinstance(parser, SchemaParser): - schema = self._extract_openapi_schema(parser, parser_ctx) - if schema is not None: - request_types.append((parser.media_type, schema)) - else: - request_types.append(parser.media_type) - else: - request_types.append(parser.media_type) - - return request_types - - def map_renderers(self, path: str, method: str): - response_types: t.List[t.Any] = [] - renderer_ctx = self.view.get_renderer_context() - - for renderer_type in self.view.renderer_classes: - renderer = renderer_type() - - if isinstance(renderer, SchemaRenderer): - schema = self._extract_openapi_schema(renderer, renderer_ctx) - if schema is not None: - response_types.append((renderer.media_type, schema)) - else: - response_types.append(renderer.media_type) - - elif not isinstance(renderer, renderers.BrowsableAPIRenderer): - response_types.append(renderer.media_type) - - return response_types - - def get_request_body(self, path: str, method: str): - if method not in ('PUT', 'PATCH', 'POST'): - return {} - - self.request_media_types = self.map_parsers(path, method) - serializer = self.get_request_serializer(path, method) - content_schemas = {} - - for request_type in self.request_media_types: - if isinstance(request_type, tuple): - media_type, request_schema = request_type - content_schemas[media_type] = {"schema": request_schema} - else: - serializer_ref = self._get_reference(serializer) - content_schemas[request_type] = {"schema": serializer_ref} - - return {'content': content_schemas} - - def get_responses(self, path: str, method: str): - if method == "DELETE": - return {"204": {"description": ""}} - - self.response_media_types = self.map_renderers(path, method) - status_code = "201" if method == "POST" else "200" - content_types = {} - - for response_type in self.response_media_types: - if isinstance(response_type, tuple): - media_type, response_schema = response_type - content_types[media_type] = {"schema": response_schema} - else: - response_schema = self._get_serializer_response_schema(path, method) - content_types[response_type] = {"schema": response_schema} - - return { - status_code: { - "content": content_types, - "description": "", - } - } - - def _extract_openapi_schema(self, schemable: AnnotatedSchemaT, ctx: "RequestResponseContext"): - schema_model = schemable.get_schema(ctx) - if schema_model is not None: - return schema_model.schema() - return None - - def _get_serializer_response_schema(self, path, method): - serializer = self.get_response_serializer(path, method) - - if not isinstance(serializer, serializers.Serializer): - item_schema = {} - else: - item_schema = self._get_reference(serializer) - - if is_list_view(path, method, self.view): - response_schema = { - "type": "array", - "items": item_schema, - } - paginator = self.get_paginator() - if paginator: - response_schema = paginator.get_paginated_response_schema(response_schema) - else: - response_schema = item_schema - return response_schema +__getattr__ = compat_getattr(__name__) +__dir__ = compat_dir(__name__) diff --git a/django_pydantic_field/v1/__init__.py b/django_pydantic_field/v1/__init__.py new file mode 100644 index 0000000..7746f2c --- /dev/null +++ b/django_pydantic_field/v1/__init__.py @@ -0,0 +1 @@ +from .fields import * diff --git a/django_pydantic_field/v1/_migration_serializers.py b/django_pydantic_field/v1/_migration_serializers.py new file mode 100644 index 0000000..e69de29 diff --git a/django_pydantic_field/base.py b/django_pydantic_field/v1/base.py similarity index 100% rename from django_pydantic_field/base.py rename to django_pydantic_field/v1/base.py diff --git a/django_pydantic_field/v1/fields.py b/django_pydantic_field/v1/fields.py new file mode 100644 index 0000000..92d43e6 --- /dev/null +++ b/django_pydantic_field/v1/fields.py @@ -0,0 +1,194 @@ +import json +import typing as t +from functools import partial + +import django +import pydantic +from django.core import exceptions as django_exceptions +from django.db.models.fields import NOT_PROVIDED +from django.db.models.fields.json import JSONField +from django.db.models.query_utils import DeferredAttribute + +from . import base, forms, utils +from ..compat.django import GenericContainer, GenericTypes + +__all__ = ("SchemaField",) + + +class SchemaAttribute(DeferredAttribute): + """ + Forces Django to call to_python on fields when setting them. + This is useful when you want to add some custom field data postprocessing. + + Should be added to field like a so: + + ``` + def contribute_to_class(self, cls, name, *args, **kwargs): + super().contribute_to_class(cls, name, *args, **kwargs) + setattr(cls, name, SchemaDeferredAttribute(self)) + ``` + """ + + field: "PydanticSchemaField" + + def __set__(self, obj, value): + obj.__dict__[self.field.attname] = self.field.to_python(value) + + +class PydanticSchemaField(JSONField, t.Generic[base.ST]): + descriptor_class = SchemaAttribute + _is_prepared_schema: bool = False + + def __init__( + self, + *args, + schema: t.Union[t.Type["base.ST"], "GenericContainer", "t.ForwardRef", str, None] = None, + config: "base.ConfigType" = None, + **kwargs, + ): + self.export_params = base.extract_export_kwargs(kwargs, dict.pop) + super().__init__(*args, **kwargs) + + self.config = config + self._resolve_schema(schema) + + def __copy__(self): + _, _, args, kwargs = self.deconstruct() + copied = type(self)(*args, **kwargs) + copied.set_attributes_from_name(self.name) + return copied + + def get_default(self): + value = super().get_default() + return self.to_python(value) + + def to_python(self, value) -> "base.SchemaT": + # Attempt to resolve forward referencing schema if it was not succesful + # during `.contribute_to_class` call + if not self._is_prepared_schema: + self._prepare_model_schema() + try: + assert self.decoder is not None + return self.decoder().decode(value) + except pydantic.ValidationError as e: + raise django_exceptions.ValidationError(e.errors()) + + if django.VERSION[:2] >= (4, 2): + + def get_prep_value(self, value): + if not self._is_prepared_schema: + self._prepare_model_schema() + prep_value = super().get_prep_value(value) + prep_value = self.encoder().encode(prep_value) # type: ignore + return json.loads(prep_value) + + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + self._deconstruct_schema(kwargs) + self._deconstruct_default(kwargs) + self._deconstruct_config(kwargs) + + kwargs.pop("decoder") + kwargs.pop("encoder") + + return name, path, args, kwargs + + def contribute_to_class(self, cls, name, private_only=False): + if self.schema is None: + self._resolve_schema_from_type_hints(cls, name) + + try: + self._prepare_model_schema(cls) + except NameError: + # Pydantic was not able to resolve forward references, which means + # that it should be postponed until initial access to the field + self._is_prepared_schema = False + + super().contribute_to_class(cls, name, private_only) + + def formfield(self, **kwargs): + if self.schema is None: + self._resolve_schema_from_type_hints(self.model, self.attname) + + owner_model = getattr(self, "model", None) + field_kwargs = dict( + form_class=forms.SchemaField, + schema=self.schema, + config=self.config, + __module__=getattr(owner_model, "__module__", None), + **self.export_params, + ) + field_kwargs.update(kwargs) + return super().formfield(**field_kwargs) + + def _resolve_schema(self, schema): + schema = t.cast(t.Type["base.ST"], GenericContainer.unwrap(schema)) + + self.schema = schema + if schema is not None: + self.serializer_schema = serializer = base.wrap_schema(schema, self.config, self.null) + self.decoder = partial(base.SchemaDecoder, serializer) # type: ignore + self.encoder = partial(base.SchemaEncoder, schema=serializer, export=self.export_params) # type: ignore + + def _resolve_schema_from_type_hints(self, cls, name): + annotated_schema = utils.get_annotated_type(cls, name) + if annotated_schema is None: + raise django_exceptions.FieldError( + f"{cls._meta.label}.{name} needs to be either annotated " + "or `schema=` field attribute should be explicitly passed" + ) + self._resolve_schema(annotated_schema) + + def _prepare_model_schema(self, cls=None): + cls = cls or getattr(self, "model", None) + if cls is not None: + base.prepare_schema(self.serializer_schema, cls) + self._is_prepared_schema = True + + def _deconstruct_default(self, kwargs): + default = kwargs.get("default", NOT_PROVIDED) + + if not (default is NOT_PROVIDED or callable(default)): + if self._is_prepared_schema: + default = self.get_prep_value(default) + kwargs.update(default=default) + + def _deconstruct_schema(self, kwargs): + kwargs.update(schema=GenericContainer.wrap(self.schema)) + + def _deconstruct_config(self, kwargs): + kwargs.update(base.deconstruct_export_kwargs(self.export_params)) + kwargs.update(config=self.config) + + +if t.TYPE_CHECKING: + OptSchemaT = t.Optional[base.SchemaT] + + +@t.overload +def SchemaField( + schema: "t.Union[t.Type[t.Optional[base.ST]], t.ForwardRef]" = ..., + config: "base.ConfigType" = ..., + default: "t.Union[OptSchemaT, t.Callable[[], OptSchemaT]]" = ..., + *args, + null: "t.Literal[True]", + **kwargs, +) -> "t.Optional[base.ST]": + ... + + +@t.overload +def SchemaField( + schema: "t.Union[t.Type[base.ST], t.ForwardRef]" = ..., + config: "base.ConfigType" = ..., + default: "t.Union[base.SchemaT, t.Callable[[], base.SchemaT]]" = ..., + *args, + null: "t.Literal[False]" = ..., + **kwargs, +) -> "base.ST": + ... + + +def SchemaField(schema=None, config=None, *args, **kwargs) -> t.Any: + kwargs.update(schema=schema, config=config) + return PydanticSchemaField(*args, **kwargs) diff --git a/django_pydantic_field/v1/forms.py b/django_pydantic_field/v1/forms.py new file mode 100644 index 0000000..3185aad --- /dev/null +++ b/django_pydantic_field/v1/forms.py @@ -0,0 +1,66 @@ +import typing as t +from functools import partial + +import pydantic +from django.core.exceptions import ValidationError +from django.forms.fields import InvalidJSONInput, JSONField +from django.utils.translation import gettext_lazy as _ + +from . import base + +__all__ = ("SchemaField",) + + +class SchemaField(JSONField, t.Generic[base.ST]): + default_error_messages = { + "schema_error": _("Schema didn't match. Detail: %(detail)s"), + } + + def __init__( + self, + schema: t.Union[t.Type["base.ST"], t.ForwardRef], + config: t.Optional["base.ConfigType"] = None, + __module__: t.Optional[str] = None, + **kwargs, + ): + self.schema = base.wrap_schema( + schema, + config, + allow_null=not kwargs.get("required", True), + __module__=__module__, + ) + export_params = base.extract_export_kwargs(kwargs, dict.pop) + decoder = partial(base.SchemaDecoder, self.schema) + encoder = partial( + base.SchemaEncoder, + schema=self.schema, + export=export_params, + raise_errors=True, + ) + kwargs.update(encoder=encoder, decoder=decoder) + super().__init__(**kwargs) + + def to_python(self, value): + try: + return super().to_python(value) + except pydantic.ValidationError as e: + raise ValidationError( + self.error_messages["schema_error"], + code="invalid", + params={ + "value": value, + "detail": str(e), + "errors": e.errors(), + "json": e.json(), + }, + ) + + def bound_data(self, data, initial): + try: + return super().bound_data(data, initial) + except pydantic.ValidationError: + return InvalidJSONInput(data) + + def get_bound_field(self, form, field_name): + base.prepare_schema(self.schema, form) + return super().get_bound_field(form, field_name) diff --git a/django_pydantic_field/v1/rest_framework.py b/django_pydantic_field/v1/rest_framework.py new file mode 100644 index 0000000..4001dcd --- /dev/null +++ b/django_pydantic_field/v1/rest_framework.py @@ -0,0 +1,260 @@ +import typing as t + +try: + from typing import get_args +except ImportError: + from typing_extensions import get_args + +from django.conf import settings +from pydantic import BaseModel, ValidationError + +from rest_framework import exceptions, parsers, renderers, serializers +from rest_framework.schemas import openapi +from rest_framework.schemas.utils import is_list_view + +from . import base + +__all__ = ( + "SchemaField", + "SchemaRenderer", + "SchemaParser", + "AutoSchema", +) + +if t.TYPE_CHECKING: + RequestResponseContext = t.Mapping[str, t.Any] + + +class AnnotatedSchemaT(t.Generic[base.ST]): + schema_ctx_attr: t.ClassVar[str] = "schema" + require_explicit_schema: t.ClassVar[bool] = False + _cached_annotation_schema: t.Type[BaseModel] + + def get_schema(self, ctx: "RequestResponseContext") -> t.Optional[t.Type[BaseModel]]: + schema = self.get_context_schema(ctx) + if schema is None: + schema = self.get_annotation_schema(ctx) + + if self.require_explicit_schema and schema is None: + raise ValueError( + "Schema should be either explicitly set with annotation " + "or passed in the context" + ) + + return schema + + def get_context_schema(self, ctx: "RequestResponseContext"): + schema = ctx.get(self.schema_ctx_attr) + if schema is not None: + schema = base.wrap_schema(schema) + base.prepare_schema(schema, ctx.get("view")) + + return schema + + def get_annotation_schema(self, ctx: "RequestResponseContext"): + try: + schema = self._cached_annotation_schema + except AttributeError: + try: + schema = get_args(self.__orig_class__)[0] # type: ignore + except (AttributeError, IndexError): + return None + + self._cached_annotation_schema = schema = base.wrap_schema(schema) + base.prepare_schema(schema, ctx.get("view")) + + return schema + + +class SchemaField(serializers.Field, t.Generic[base.ST]): + decoder: "base.SchemaDecoder[base.ST]" + _is_prepared_schema: bool = False + + def __init__( + self, + schema: t.Type["base.ST"], + config: t.Optional["base.ConfigType"] = None, + **kwargs, + ): + nullable = kwargs.get("allow_null", False) + + self.schema = field_schema = base.wrap_schema(schema, config, nullable) + self.export_params = base.extract_export_kwargs(kwargs, dict.pop) + self.decoder = base.SchemaDecoder(field_schema) + super().__init__(**kwargs) + + def bind(self, field_name, parent): + if not self._is_prepared_schema: + base.prepare_schema(self.schema, parent) + self._is_prepared_schema = True + + super().bind(field_name, parent) + + def to_internal_value(self, data: t.Any) -> t.Optional["base.ST"]: + try: + return self.decoder.decode(data) + except ValidationError as e: + raise serializers.ValidationError(e.errors(), self.field_name) + + def to_representation(self, value: t.Optional["base.ST"]) -> t.Any: + obj = self.schema.parse_obj(value) + raw_obj = obj.dict(**self.export_params) + return raw_obj["__root__"] + + +class SchemaRenderer(AnnotatedSchemaT[base.ST], renderers.JSONRenderer): + schema_ctx_attr = "render_schema" + + def render(self, data, accepted_media_type=None, renderer_context=None): + renderer_context = renderer_context or {} + response = renderer_context.get("response") + if response is not None and response.exception: + return super().render(data, accepted_media_type, renderer_context) + + try: + json_str = self.render_data(data, renderer_context) + except ValidationError as e: + json_str = e.json().encode() + except AttributeError: + json_str = super().render(data, accepted_media_type, renderer_context) + + return json_str + + def render_data(self, data, renderer_ctx) -> bytes: + schema = self.get_schema(renderer_ctx or {}) + if schema is not None: + data = schema(__root__=data) + + export_kw = base.extract_export_kwargs(renderer_ctx) + json_str = data.json(**export_kw, ensure_ascii=self.ensure_ascii) + return json_str.encode() + + +class SchemaParser(AnnotatedSchemaT[base.ST], parsers.JSONParser): + schema_ctx_attr = "parser_schema" + renderer_class = SchemaRenderer + require_explicit_schema = True + + def parse(self, stream, media_type=None, parser_context=None): + parser_context = parser_context or {} + encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET) + schema = t.cast(BaseModel, self.get_schema(parser_context)) + + try: + return schema.parse_raw(stream.read(), encoding=encoding).__root__ + except ValidationError as e: + raise exceptions.ParseError(e.errors()) + + +class AutoSchema(openapi.AutoSchema): + get_request_serializer: t.Callable + _get_reference: t.Callable + + def map_field(self, field: serializers.Field): + if isinstance(field, SchemaField): + return field.schema.schema() + return super().map_field(field) + + def map_parsers(self, path: str, method: str): + request_types: t.List[t.Any] = [] + parser_ctx = self.view.get_parser_context(None) + + for parser_type in self.view.parser_classes: + parser = parser_type() + + if isinstance(parser, SchemaParser): + schema = self._extract_openapi_schema(parser, parser_ctx) + if schema is not None: + request_types.append((parser.media_type, schema)) + else: + request_types.append(parser.media_type) + else: + request_types.append(parser.media_type) + + return request_types + + def map_renderers(self, path: str, method: str): + response_types: t.List[t.Any] = [] + renderer_ctx = self.view.get_renderer_context() + + for renderer_type in self.view.renderer_classes: + renderer = renderer_type() + + if isinstance(renderer, SchemaRenderer): + schema = self._extract_openapi_schema(renderer, renderer_ctx) + if schema is not None: + response_types.append((renderer.media_type, schema)) + else: + response_types.append(renderer.media_type) + + elif not isinstance(renderer, renderers.BrowsableAPIRenderer): + response_types.append(renderer.media_type) + + return response_types + + def get_request_body(self, path: str, method: str): + if method not in ('PUT', 'PATCH', 'POST'): + return {} + + self.request_media_types = self.map_parsers(path, method) + serializer = self.get_request_serializer(path, method) + content_schemas = {} + + for request_type in self.request_media_types: + if isinstance(request_type, tuple): + media_type, request_schema = request_type + content_schemas[media_type] = {"schema": request_schema} + else: + serializer_ref = self._get_reference(serializer) + content_schemas[request_type] = {"schema": serializer_ref} + + return {'content': content_schemas} + + def get_responses(self, path: str, method: str): + if method == "DELETE": + return {"204": {"description": ""}} + + self.response_media_types = self.map_renderers(path, method) + status_code = "201" if method == "POST" else "200" + content_types = {} + + for response_type in self.response_media_types: + if isinstance(response_type, tuple): + media_type, response_schema = response_type + content_types[media_type] = {"schema": response_schema} + else: + response_schema = self._get_serializer_response_schema(path, method) + content_types[response_type] = {"schema": response_schema} + + return { + status_code: { + "content": content_types, + "description": "", + } + } + + def _extract_openapi_schema(self, schemable: AnnotatedSchemaT, ctx: "RequestResponseContext"): + schema_model = schemable.get_schema(ctx) + if schema_model is not None: + return schema_model.schema() + return None + + def _get_serializer_response_schema(self, path, method): + serializer = self.get_response_serializer(path, method) + + if not isinstance(serializer, serializers.Serializer): + item_schema = {} + else: + item_schema = self._get_reference(serializer) + + if is_list_view(path, method, self.view): + response_schema = { + "type": "array", + "items": item_schema, + } + paginator = self.get_paginator() + if paginator: + response_schema = paginator.get_paginated_response_schema(response_schema) + else: + response_schema = item_schema + return response_schema diff --git a/django_pydantic_field/utils.py b/django_pydantic_field/v1/utils.py similarity index 100% rename from django_pydantic_field/utils.py rename to django_pydantic_field/v1/utils.py diff --git a/django_pydantic_field/v2/__init__.py b/django_pydantic_field/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index e9fe486..d8e5862 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "django-pydantic-field" -version = "0.2.11" +version = "0.3.0-alpha1" description = "Django JSONField with Pydantic models as a Schema" readme = "README.md" license = { file = "LICENSE" } @@ -39,7 +39,7 @@ classifiers = [ requires-python = ">=3.7" dependencies = [ - "pydantic>=1.9,<2", + "pydantic>=1.10,<3", "django>=3.1,<5", "typing_extensions", ] diff --git a/tests/sample_app/migrations/0001_initial.py b/tests/sample_app/migrations/0001_initial.py index 67e4a7e..a52069f 100644 --- a/tests/sample_app/migrations/0001_initial.py +++ b/tests/sample_app/migrations/0001_initial.py @@ -1,7 +1,7 @@ # Generated by Django 4.2.2 on 2023-06-19 12:36 import typing -import django_pydantic_field._migration_serializers +from django_pydantic_field.compat.django import GenericContainer import django_pydantic_field.fields import tests.sample_app.models import typing_extensions diff --git a/tests/test_migration_serializers.py b/tests/test_migration_serializers.py index 9a6a802..246665c 100644 --- a/tests/test_migration_serializers.py +++ b/tests/test_migration_serializers.py @@ -6,7 +6,7 @@ import pytest import django_pydantic_field -from django_pydantic_field._migration_serializers import GenericContainer +from django_pydantic_field.compat.django import GenericContainer if sys.version_info < (3, 9): test_types = [ From 3b7b14d0eb5e8abf7554a7878248184e8948acaa Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Sat, 4 Nov 2023 23:05:31 +0400 Subject: [PATCH 02/34] Add pydantic to test matrix. --- .github/workflows/python-test.yml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/python-test.yml b/.github/workflows/python-test.yml index c2b3eaf..cc12569 100644 --- a/.github/workflows/python-test.yml +++ b/.github/workflows/python-test.yml @@ -42,6 +42,7 @@ jobs: strategy: matrix: python-version: ["3.7", "3.8", "3.9", "3.10", "3.11", "3.12"] + pydantic-version: ["1.10.13", "2.4.2"] services: postgres: @@ -75,5 +76,7 @@ jobs: sudo apt update && sudo apt install -qy python3-dev default-libmysqlclient-dev build-essential python -m pip install --upgrade pip python -m pip install -e .[dev,test,ci] + - name: Install Pydantic ${{ matrix.pydantic-version }} + run: python -m pip install "pydantic==${{ matrix.pydantic-version }}" - name: Test package run: pytest From 2cfdf49c8dc22535d1cece052592a473b4f4ed29 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Sun, 5 Nov 2023 02:11:27 +0400 Subject: [PATCH 03/34] Implement basic abstractions around pydantic 2. --- .editorconfig | 16 ++++ django_pydantic_field/v2/__init__.py | 1 + django_pydantic_field/v2/fields.py | 94 +++++++++++++++++++ django_pydantic_field/v2/types.py | 130 +++++++++++++++++++++++++++ django_pydantic_field/v2/utils.py | 24 +++++ pyproject.toml | 25 ++++++ 6 files changed, 290 insertions(+) create mode 100644 .editorconfig create mode 100644 django_pydantic_field/v2/fields.py create mode 100644 django_pydantic_field/v2/types.py create mode 100644 django_pydantic_field/v2/utils.py diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..ccd4e7f --- /dev/null +++ b/.editorconfig @@ -0,0 +1,16 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +indent_style = space +indent_size = 4 + +[*.py] +max_line_length = 120 + +[{*.json,*.yml,*.yaml}] +indent_size = 2 +insert_final_newline = false diff --git a/django_pydantic_field/v2/__init__.py b/django_pydantic_field/v2/__init__.py index e69de29..9e72cfd 100644 --- a/django_pydantic_field/v2/__init__.py +++ b/django_pydantic_field/v2/__init__.py @@ -0,0 +1 @@ +from .fields import SchemaField as SchemaField diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py new file mode 100644 index 0000000..3edc08b --- /dev/null +++ b/django_pydantic_field/v2/fields.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import typing as ty + +import pydantic + +from django.core import checks, exceptions + +from django.db.models.expressions import BaseExpression +from django.db.models.fields.json import JSONField +from django.db.models.query_utils import DeferredAttribute + +from . import types + +if ty.TYPE_CHECKING: + from ..compat.django import GenericContainer + + +class SchemaAttribute(DeferredAttribute): + field: PydanticSchemaField + + def __set_name__(self, owner, name): + self.field.adapter.bind(owner, name) + + def __set__(self, obj, value): + obj.__dict__[self.field.attname] = self.field.to_python(value) + + +class PydanticSchemaField(JSONField, ty.Generic[types.ST]): + def __init__( + self, + *args, + schema: type[types.ST] | GenericContainer | ty.ForwardRef | str | None = None, + config: pydantic.ConfigDict | None = None, + **kwargs, + ): + self.export_kwargs = export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) + super().__init__(*args, **kwargs) + self.schema = schema + self.config = config + self.adapter = types.SchemaAdapter(schema, config, None, self.get_attname(), self.null, **export_kwargs) + + def __copy__(self): + _, _, args, kwargs = self.deconstruct() + copied = self.__class__(*args, **kwargs) + copied.set_attributes_from_name(self.name) + return copied + + def deconstruct(self) -> ty.Any: + field_name, import_path, args, kwargs = super().deconstruct() + kwargs.update(schema=GenericContainer.wrap(self.schema), config=self.config, **self.export_kwargs) + return field_name, import_path, args, kwargs + + def contribute_to_class(self, cls: types.DjangoModelType, name: str, private_only: bool = False) -> None: + self.adapter.bind(cls, name) + super().contribute_to_class(cls, name, private_only) + + def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: + performed_checks = super().check(**kwargs) + try: + self.adapter.validate_schema() + except ValueError as exc: + performed_checks.append(checks.Error(exc.args[0], obj=self)) + return performed_checks + + def to_python(self, value: ty.Any): + try: + return self.adapter.validate_python(value) + except pydantic.ValidationError as exc: + raise exceptions.ValidationError(exc.title, code="invalid", params=exc.errors()) from exc + + def get_prep_value(self, value: ty.Any): + if isinstance(value, BaseExpression): + # We don't want to perform coercion on database query expressions. + return super().get_prep_value(value) + return self.adapter.dump_python(value) + + def validate(self, value: ty.Any, model_instance: ty.Any) -> None: + value = self.adapter.validate_python(value) + return super().validate(value, model_instance) + + +@ty.overload +def SchemaField(schema: None = None) -> ty.Any: + ... + + +@ty.overload +def SchemaField(schema: type[types.ST]) -> ty.Any: + ... + + +def SchemaField(schema=None, config=None, *args, **kwargs): # type: ignore + return PydanticSchemaField(*args, schema=schema, config=config, **kwargs) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py new file mode 100644 index 0000000..c49485e --- /dev/null +++ b/django_pydantic_field/v2/types.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import functools +import typing as ty + +from pydantic.type_adapter import TypeAdapter + +from . import utils +from ..compat.django import GenericContainer + +ST = ty.TypeVar("ST", bound="SchemaT") + +if ty.TYPE_CHECKING: + from pydantic import BaseModel + from pydantic.type_adapter import IncEx + from pydantic.dataclasses import DataclassClassOrWrapper + from django.db.models import Model + + ModelType = ty.Type[BaseModel] + DjangoModelType = ty.Type[Model] + SchemaT = ty.Union[ + BaseModel, + DataclassClassOrWrapper, + ty.Sequence[ty.Any], + ty.Mapping[str, ty.Any], + ty.Set[ty.Any], + ty.FrozenSet[ty.Any], + ] + +class ExportKwargs(ty.TypedDict, total=False): + strict: bool + mode: ty.Literal["json", "python"] + include: IncEx | None + exclude: IncEx | None + by_alias: bool + exclude_unset: bool + exclude_defaults: bool + exclude_none: bool + round_trip: bool + warnings: bool + + +class SchemaAdapter(ty.Generic[ST]): + def __init__( + self, + schema, + config, + parent_type, + attname, + allow_null, + *, + parent_depth=4, + **export_kwargs: ty.Unpack[ExportKwargs], + ): + self.schema = schema + self.config = config + self.parent_type = parent_type + self.attname = attname + self.allow_null = allow_null + self.parent_depth = parent_depth + self.export_kwargs = export_kwargs + + @staticmethod + def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: + common_keys = kwargs.keys() & ExportKwargs.__annotations__.keys() + export_kwargs = {key: kwargs.pop(key) for key in common_keys} + return ty.cast(ExportKwargs, export_kwargs) + + @functools.cached_property + def type_adapter(self) -> TypeAdapter: + schema = self._get_prepared_schema() + return TypeAdapter(schema, config=self.config, _parent_depth=4) # type: ignore + + def bind(self, parent_type, attname): + self.parent_type = parent_type + self.attname = attname + self.__dict__.pop("type_adapter", None) + + def validate_schema(self) -> None: + """Validate the schema and raise an exception if it is invalid.""" + self._get_prepared_schema() + + def validate_python(self, value: ty.Any) -> ST: + """Validate the value and raise an exception if it is invalid.""" + return self.type_adapter.validate_python( + value, + strict=self.export_kwargs.get("strict", None), + ) + + def dump_python(self, value: ty.Any) -> ty.Any: + """Dump the value to a Python object.""" + return self.type_adapter.dump_python(value, **self._dump_python_kwargs) + + def json_schema(self) -> ty.Any: + """Return the JSON schema for the field.""" + by_alias = self.export_kwargs.get("by_alias", True) + return self.type_adapter.json_schema(by_alias=by_alias) + + def _get_prepared_schema(self) -> type[ST]: + schema = self.schema + + if schema is None: + schema = self._guess_schema_from_annotations() + if isinstance(schema, GenericContainer): + schema = ty.cast(type[ST], GenericContainer.unwrap(schema)) + if isinstance(schema, (str, ty.ForwardRef)): + schema = self._resolve_schema_forward_ref(schema) + + if schema is None: + error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" + raise ValueError(error_msg) + + if self.allow_null: + schema = ty.Optional[schema] + return ty.cast(type[ST], schema) + + def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | None: + return utils.get_annotated_type(self.parent_type, self.attname) + + def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any: + if isinstance(schema, str): + schema = ty.ForwardRef(schema) + namespace = utils.get_local_namespace(self.parent_type) + return schema._evaluate(namespace, vars(self.parent_type), frozenset()) # type: ignore + + @functools.cached_property + def _dump_python_kwargs(self) -> dict[str, ty.Any]: + export_kwargs = self.export_kwargs.copy() + export_kwargs.pop("strict", None) + return ty.cast(dict, export_kwargs) diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py new file mode 100644 index 0000000..39b625a --- /dev/null +++ b/django_pydantic_field/v2/utils.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +import sys +import typing as ty + + +def get_annotated_type(obj, field, default=None) -> ty.Any: + try: + if isinstance(obj, type): + annotations = obj.__dict__["__annotations__"] + else: + annotations = obj.__annotations__ + + return annotations[field] + except (AttributeError, KeyError): + return default + + +def get_local_namespace(cls) -> dict[str, ty.Any]: + try: + module = cls.__module__ + return vars(sys.modules[module]) + except (KeyError, AttributeError): + return {} diff --git a/pyproject.toml b/pyproject.toml index d8e5862..864762a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,31 @@ Changelog = "https://github.com/surenkov/django-pydantic-field/releases" [tool.setuptools] packages = ["django_pydantic_field"] +[tool.isort] +py_version = 312 +profile = "black" +line_length = 120 +multi_line_output = 3 +include_trailing_comma = true +force_alphabetical_sort_within_sections = true +force_grid_wrap = 0 +use_parentheses = true +skip_glob = [ + "*/migrations/*", +] + +[tool.black] +target-version = ['py312'] +line-length = 120 +exclude = ''' +/( + \.pytest_cache + | \.venv + | venv + | migrations +)/ +''' + [tool.mypy] plugins = [ "mypy_django_plugin.main", From 9aeec6398366999ee95a3f0c83e7cba963c6f444 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Mon, 6 Nov 2023 00:53:05 +0400 Subject: [PATCH 04/34] Make all ORM field tests pass. --- django_pydantic_field/__init__.py | 12 +--- .../_migration_serializers.py | 9 +++ django_pydantic_field/v2/fields.py | 37 +++++++++-- django_pydantic_field/v2/types.py | 9 ++- pyproject.toml | 2 +- tests/sample_app/migrations/0001_initial.py | 1 - tests/sample_app/models.py | 2 +- tests/test_app/models.py | 6 +- ...t_django_model_field.py => test_fields.py} | 49 ++------------ tests/test_migration_serializers.py | 5 +- .../v1/__init__.py | 0 .../test_base.py} | 66 ++++++++++++------- tests/v1/test_fields.py | 42 ++++++++++++ 13 files changed, 148 insertions(+), 92 deletions(-) create mode 100644 django_pydantic_field/_migration_serializers.py rename tests/{test_django_model_field.py => test_fields.py} (85%) rename django_pydantic_field/v1/_migration_serializers.py => tests/v1/__init__.py (100%) rename tests/{test_base_marshalling.py => v1/test_base.py} (67%) create mode 100644 tests/v1/test_fields.py diff --git a/django_pydantic_field/__init__.py b/django_pydantic_field/__init__.py index 9d5c124..29e6ff9 100644 --- a/django_pydantic_field/__init__.py +++ b/django_pydantic_field/__init__.py @@ -2,15 +2,7 @@ def __getattr__(name): if name == "_migration_serializers": - import warnings - from .compat import django - - DEPRECATION_MSG = ( - "Module 'django_pydantic_field._migration_serializers' is deprecated " - "and will be removed in version 1.0.0. " - "Please replace it with 'django_pydantic_field.compat.django' in migrations." - ) - warnings.warn(DEPRECATION_MSG, category=DeprecationWarning) - return django + module = __import__("django_pydantic_field._migration_serializers", fromlist=["*"]) + return module raise AttributeError(f"Module {__name__!r} has no attribute {name!r}") diff --git a/django_pydantic_field/_migration_serializers.py b/django_pydantic_field/_migration_serializers.py new file mode 100644 index 0000000..20a834f --- /dev/null +++ b/django_pydantic_field/_migration_serializers.py @@ -0,0 +1,9 @@ +import warnings +from .compat.django import * + +DEPRECATION_MSG = ( + "Module 'django_pydantic_field._migration_serializers' is deprecated " + "and will be removed in version 1.0.0. " + "Please replace it with 'django_pydantic_field.compat.django' in migrations." +) +warnings.warn(DEPRECATION_MSG, category=DeprecationWarning) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 3edc08b..443fd7a 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -5,15 +5,14 @@ import pydantic from django.core import checks, exceptions +from django.core.serializers.json import DjangoJSONEncoder from django.db.models.expressions import BaseExpression from django.db.models.fields.json import JSONField from django.db.models.query_utils import DeferredAttribute from . import types - -if ty.TYPE_CHECKING: - from ..compat.django import GenericContainer +from ..compat.django import GenericContainer class SchemaAttribute(DeferredAttribute): @@ -27,6 +26,8 @@ def __set__(self, obj, value): class PydanticSchemaField(JSONField, ty.Generic[types.ST]): + descriptor_class = SchemaAttribute + def __init__( self, *args, @@ -34,8 +35,11 @@ def __init__( config: pydantic.ConfigDict | None = None, **kwargs, ): + kwargs.setdefault("encoder", DjangoJSONEncoder) + self.export_kwargs = export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) super().__init__(*args, **kwargs) + self.schema = schema self.config = config self.adapter = types.SchemaAdapter(schema, config, None, self.get_attname(), self.null, **export_kwargs) @@ -63,6 +67,10 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: performed_checks.append(checks.Error(exc.args[0], obj=self)) return performed_checks + def validate(self, value: ty.Any, model_instance: ty.Any) -> None: + value = self.adapter.validate_python(value) + return super().validate(value, model_instance) + def to_python(self, value: ty.Any): try: return self.adapter.validate_python(value) @@ -73,11 +81,26 @@ def get_prep_value(self, value: ty.Any): if isinstance(value, BaseExpression): # We don't want to perform coercion on database query expressions. return super().get_prep_value(value) - return self.adapter.dump_python(value) - def validate(self, value: ty.Any, model_instance: ty.Any) -> None: - value = self.adapter.validate_python(value) - return super().validate(value, model_instance) + try: + prep_value = self.adapter.validate_python(value, strict=True) + except TypeError: + prep_value = self.adapter.dump_python(value) + prep_value = self.adapter.validate_python(prep_value) + + plain_value = self.adapter.dump_python(prep_value) + return super().get_prep_value(plain_value) + + def get_default(self) -> types.ST: + default_value = super().get_default() + try: + raw_value = dict(default_value) + prep_value = self.adapter.validate_python(raw_value, strict=True) + except (TypeError, ValueError): + prep_value = self.adapter.validate_python(default_value) + + return prep_value + @ty.overload diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index c49485e..1dfc9fe 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -80,12 +80,11 @@ def validate_schema(self) -> None: """Validate the schema and raise an exception if it is invalid.""" self._get_prepared_schema() - def validate_python(self, value: ty.Any) -> ST: + def validate_python(self, value: ty.Any, *, strict: bool | None = None) -> ST: """Validate the value and raise an exception if it is invalid.""" - return self.type_adapter.validate_python( - value, - strict=self.export_kwargs.get("strict", None), - ) + if strict is None: + strict = self.export_kwargs.get("strict", None) + return self.type_adapter.validate_python(value, strict=strict) def dump_python(self, value: ty.Any) -> ty.Any: """Dump the value to a Python object.""" diff --git a/pyproject.toml b/pyproject.toml index 864762a..1b095f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,7 @@ Changelog = "https://github.com/surenkov/django-pydantic-field/releases" packages = ["django_pydantic_field"] [tool.isort] -py_version = 312 +py_version = 311 profile = "black" line_length = 120 multi_line_output = 3 diff --git a/tests/sample_app/migrations/0001_initial.py b/tests/sample_app/migrations/0001_initial.py index a52069f..6d72082 100644 --- a/tests/sample_app/migrations/0001_initial.py +++ b/tests/sample_app/migrations/0001_initial.py @@ -1,7 +1,6 @@ # Generated by Django 4.2.2 on 2023-06-19 12:36 import typing -from django_pydantic_field.compat.django import GenericContainer import django_pydantic_field.fields import tests.sample_app.models import typing_extensions diff --git a/tests/sample_app/models.py b/tests/sample_app/models.py index f8bb8a0..d9e68c8 100644 --- a/tests/sample_app/models.py +++ b/tests/sample_app/models.py @@ -28,7 +28,7 @@ class BuildingMeta(pydantic.BaseModel): class PostponedBuilding(models.Model): - meta: "BuildingMeta" = SchemaField(default=BuildingMeta(type=BuildingTypes.FRAME), by_alias=True) + meta: "BuildingMeta" = SchemaField(default=BuildingMeta(buildingType=BuildingTypes.FRAME), by_alias=True) meta_builtin_list: t.List[BuildingMeta] = SchemaField(schema=t.List[BuildingMeta], default=list) meta_typing_list: t.List["BuildingMeta"] = SchemaField(default=list) meta_untyped_list: list = SchemaField(schema=t.List, default=list) diff --git a/tests/test_app/models.py b/tests/test_app/models.py index e5d75da..62bd340 100644 --- a/tests/test_app/models.py +++ b/tests/test_app/models.py @@ -7,8 +7,12 @@ from ..conftest import InnerSchema +class FrozenInnerSchema(InnerSchema): + model_config = pydantic.ConfigDict({"frozen": True}) + + class SampleModel(models.Model): - sample_field: InnerSchema = SchemaField(config={"frozen": True, "allow_mutation": False}) + sample_field: InnerSchema = SchemaField() sample_list: t.List[InnerSchema] = SchemaField() sample_seq: t.Sequence[InnerSchema] = SchemaField(schema=t.List[InnerSchema], default=list) diff --git a/tests/test_django_model_field.py b/tests/test_fields.py similarity index 85% rename from tests/test_django_model_field.py rename to tests/test_fields.py index e754c2b..79fe2ae 100644 --- a/tests/test_django_model_field.py +++ b/tests/test_fields.py @@ -5,10 +5,9 @@ from copy import copy from datetime import date -import django import pytest -from django.core.exceptions import FieldError, ValidationError -from django.db import models +from django.core.exceptions import ValidationError +from django.db import models, connection from django.db.migrations.writer import MigrationWriter from django_pydantic_field import fields @@ -22,10 +21,9 @@ def test_sample_field(): existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) expected_encoded = {"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]} - if django.VERSION[:2] < (4, 2): - expected_encoded = json.dumps(expected_encoded) + expected_prepared = json.dumps(expected_encoded) - assert sample_field.get_prep_value(existing_instance) == expected_encoded + assert sample_field.get_db_prep_value(existing_instance, connection) == expected_prepared assert sample_field.to_python(expected_encoded) == existing_instance @@ -34,37 +32,12 @@ def test_sample_field_with_raw_data(): existing_raw = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} expected_encoded = {"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]} - if django.VERSION[:2] < (4, 2): - expected_encoded = json.dumps(expected_encoded) + expected_prepared = json.dumps(expected_encoded) - assert sample_field.get_prep_value(existing_raw) == expected_encoded + assert sample_field.get_db_prep_value(existing_raw, connection) == expected_prepared assert sample_field.to_python(expected_encoded) == InnerSchema(**existing_raw) -def test_simple_model_field(): - sample_field = SampleModel._meta.get_field("sample_field") - assert sample_field.schema == InnerSchema - - sample_list_field = SampleModel._meta.get_field("sample_list") - assert sample_list_field.schema == t.List[InnerSchema] - - sample_seq_field = SampleModel._meta.get_field("sample_seq") - assert sample_seq_field.schema == t.List[InnerSchema] - - existing_raw_field = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} - existing_raw_list = [{"stub_str": "abc", "stub_list": []}] - - instance = SampleModel( - sample_field=existing_raw_field, sample_list=existing_raw_list - ) - - expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - expected_list = [InnerSchema(stub_str="abc", stub_list=[])] - - assert instance.sample_field == expected_instance - assert instance.sample_list == expected_list - - def test_null_field(): field = fields.SchemaField(InnerSchema, null=True, default=None) assert field.to_python(None) is None @@ -74,16 +47,6 @@ def test_null_field(): assert field.get_prep_value(None) is None -def test_untyped_model_field_raises(): - with pytest.raises(FieldError): - - class UntypedModel(models.Model): - sample_field = fields.SchemaField() - - class Meta: - app_label = "test_app" - - def test_forwardrefs_deferred_resolution(): obj = SampleForwardRefModel(field={}, annotated_field={}) assert isinstance(obj.field, SampleSchema) diff --git a/tests/test_migration_serializers.py b/tests/test_migration_serializers.py index 246665c..c67b122 100644 --- a/tests/test_migration_serializers.py +++ b/tests/test_migration_serializers.py @@ -6,7 +6,10 @@ import pytest import django_pydantic_field -from django_pydantic_field.compat.django import GenericContainer +try: + from django_pydantic_field.compat.django import GenericContainer +except ImportError: + from django_pydantic_field._migration_serializers import GenericContainer if sys.version_info < (3, 9): test_types = [ diff --git a/django_pydantic_field/v1/_migration_serializers.py b/tests/v1/__init__.py similarity index 100% rename from django_pydantic_field/v1/_migration_serializers.py rename to tests/v1/__init__.py diff --git a/tests/test_base_marshalling.py b/tests/v1/test_base.py similarity index 67% rename from tests/test_base_marshalling.py rename to tests/v1/test_base.py index 011b113..8637e30 100644 --- a/tests/test_base_marshalling.py +++ b/tests/v1/test_base.py @@ -5,9 +5,10 @@ import pydantic import pytest -from django_pydantic_field import base -from .conftest import InnerSchema, SampleDataclass +from tests.conftest import InnerSchema, SampleDataclass + +base = pytest.importorskip("django_pydantic_field.v1.base") class SampleSchema(pydantic.BaseModel): @@ -66,7 +67,7 @@ def test_schema_wrapper_transformers(): assert parsed_wrapper.__root__ == [expected_decoded] -class test_schema_wrapper_config_inheritance(): +def test_schema_wrapper_config_inheritance(): parsed_wrapper = base.wrap_schema(InnerSchema, config={"allow_mutation": False}) assert not parsed_wrapper.Config.allow_mutation assert not parsed_wrapper.Config.frozen @@ -76,13 +77,24 @@ class test_schema_wrapper_config_inheritance(): assert parsed_wrapper.Config.frozen -@pytest.mark.parametrize("type_, encoded, decoded", [ - (InnerSchema, '{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}', InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])), - (SampleDataclass, '{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}', SampleDataclass(stub_str="abc", stub_list=[date(2022, 7, 1)])), - (t.List[int], '[1, 2, 3]', [1, 2, 3]), - (t.Mapping[int, date], '{"1": "1970-01-01"}', {1: date(1970, 1, 1)}), - (t.Set[UUID], '["ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c"]', {UUID("ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c")}), -]) +@pytest.mark.parametrize( + "type_, encoded, decoded", + [ + ( + InnerSchema, + '{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}', + InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + ), + ( + SampleDataclass, + '{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}', + SampleDataclass(stub_str="abc", stub_list=[date(2022, 7, 1)]), + ), + (t.List[int], "[1, 2, 3]", [1, 2, 3]), + (t.Mapping[int, date], '{"1": "1970-01-01"}', {1: date(1970, 1, 1)}), + (t.Set[UUID], '["ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c"]', {UUID("ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c")}), + ], +) def test_concrete_types(type_, encoded, decoded): schema = base.wrap_schema(type_) encoder = base.SchemaEncoder(schema=schema) @@ -96,13 +108,20 @@ def test_concrete_types(type_, encoded, decoded): @pytest.mark.skipif(sys.version_info < (3, 9), reason="Should test against builtin generic types") -@pytest.mark.parametrize("type_factory, encoded, decoded", [ - (lambda: list[InnerSchema], '[{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}]', [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]), - (lambda: list[SampleDataclass], '[{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}]', [SampleDataclass(stub_str="abc", stub_list=[date(2022, 7, 1)])]), # type: ignore - (lambda: list[int], '[1, 2, 3]', [1, 2, 3]), - (lambda: dict[int, date], '{"1": "1970-01-01"}', {1: date(1970, 1, 1)}), - (lambda: set[UUID], '["ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c"]', {UUID("ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c")}), -]) +@pytest.mark.parametrize( + "type_factory, encoded, decoded", + [ + ( + lambda: list[InnerSchema], + '[{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}]', + [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])], + ), + (lambda: list[SampleDataclass], '[{"stub_str": "abc", "stub_list": ["2022-07-01"], "stub_int": 1}]', [SampleDataclass(stub_str="abc", stub_list=[date(2022, 7, 1)])]), # type: ignore + (lambda: list[int], "[1, 2, 3]", [1, 2, 3]), + (lambda: dict[int, date], '{"1": "1970-01-01"}', {1: date(1970, 1, 1)}), + (lambda: set[UUID], '["ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c"]', {UUID("ba6eb330-4f7f-11eb-a2fb-67c34e9ac07c")}), + ], +) def test_concrete_raw_types(type_factory, encoded, decoded): type_ = type_factory() @@ -117,11 +136,14 @@ def test_concrete_raw_types(type_factory, encoded, decoded): assert decoder.decode(existing_encoded) == decoded -@pytest.mark.parametrize("forward_ref, sample_data", [ - (t.ForwardRef("t.List[int]"), '[1, 2]'), - (t.ForwardRef("InnerSchema"), '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}'), - (t.ForwardRef("PostponedSchema"), '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}'), -]) +@pytest.mark.parametrize( + "forward_ref, sample_data", + [ + (t.ForwardRef("t.List[int]"), "[1, 2]"), + (t.ForwardRef("InnerSchema"), '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}'), + (t.ForwardRef("PostponedSchema"), '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}'), + ], +) def test_forward_refs_preparation(forward_ref, sample_data): schema = base.wrap_schema(forward_ref) base.prepare_schema(schema, test_forward_refs_preparation) diff --git a/tests/v1/test_fields.py b/tests/v1/test_fields.py new file mode 100644 index 0000000..a983dfb --- /dev/null +++ b/tests/v1/test_fields.py @@ -0,0 +1,42 @@ +import pytest +import typing as t +from datetime import date + +from django.core.exceptions import FieldError + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +fields = pytest.importorskip("django_pydantic_field.v1.fields") + + +def test_simple_model_field(): + sample_field = SampleModel._meta.get_field("sample_field") + assert sample_field.schema == InnerSchema + + sample_list_field = SampleModel._meta.get_field("sample_list") + assert sample_list_field.schema == t.List[InnerSchema] + + sample_seq_field = SampleModel._meta.get_field("sample_seq") + assert sample_seq_field.schema == t.List[InnerSchema] + + existing_raw_field = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + existing_raw_list = [{"stub_str": "abc", "stub_list": []}] + + instance = SampleModel(sample_field=existing_raw_field, sample_list=existing_raw_list) + + expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_list = [InnerSchema(stub_str="abc", stub_list=[])] + + assert instance.sample_field == expected_instance + assert instance.sample_list == expected_list + + +def test_untyped_model_field_raises(): + with pytest.raises(FieldError): + + class UntypedModel(models.Model): + sample_field = fields.SchemaField() + + class Meta: + app_label = "test_app" From 74a5d93194eed8db279a001a3603965e56ee1090 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Mon, 6 Nov 2023 14:16:54 +0400 Subject: [PATCH 05/34] Add stubs for form/drf fields. --- django_pydantic_field/v2/forms.py | 4 ++++ django_pydantic_field/v2/rest_framework.py | 23 ++++++++++++++++++++++ 2 files changed, 27 insertions(+) create mode 100644 django_pydantic_field/v2/forms.py create mode 100644 django_pydantic_field/v2/rest_framework.py diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py new file mode 100644 index 0000000..5e4ea0c --- /dev/null +++ b/django_pydantic_field/v2/forms.py @@ -0,0 +1,4 @@ + +class SchemaField: + def __init__(*args, **kwargs): + ... diff --git a/django_pydantic_field/v2/rest_framework.py b/django_pydantic_field/v2/rest_framework.py new file mode 100644 index 0000000..c8bb04c --- /dev/null +++ b/django_pydantic_field/v2/rest_framework.py @@ -0,0 +1,23 @@ +import typing as ty + +from . import types + + +class SchemaField: + def __init__(*args, **kwargs): + ... + + +class AutoSchema: + def __init__(*args, **kwargs): + ... + + +class SchemaParser(ty.Generic[types.ST]): + def __init__(*args, **kwargs): + ... + + +class SchemaRenderer(ty.Generic[types.ST]): + def __init__(*args, **kwargs): + ... From 9db8faac850aa16a89d82187bcf52485c2451266 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Mon, 6 Nov 2023 22:05:55 +0400 Subject: [PATCH 06/34] Adapt SchemaField's transformations for json lookups. - Fixes all e2e tests --- django_pydantic_field/v2/fields.py | 25 +++++++++++++++++++++++-- tests/test_e2e_models.py | 16 +++++----------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 443fd7a..652d58f 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -7,8 +7,9 @@ from django.core import checks, exceptions from django.core.serializers.json import DjangoJSONEncoder -from django.db.models.expressions import BaseExpression +from django.db.models.expressions import BaseExpression, Col from django.db.models.fields.json import JSONField +from django.db.models.lookups import Transform from django.db.models.query_utils import DeferredAttribute from . import types @@ -84,13 +85,19 @@ def get_prep_value(self, value: ty.Any): try: prep_value = self.adapter.validate_python(value, strict=True) - except TypeError: + except pydantic.ValidationError: prep_value = self.adapter.dump_python(value) prep_value = self.adapter.validate_python(prep_value) plain_value = self.adapter.dump_python(prep_value) return super().get_prep_value(plain_value) + def get_transform(self, lookup_name: str): + transform = super().get_transform(lookup_name) + if transform is not None: + transform = SchemaKeyTransformAdapter(transform) + return transform + def get_default(self) -> types.ST: default_value = super().get_default() try: @@ -102,6 +109,20 @@ def get_default(self) -> types.ST: return prep_value +class SchemaKeyTransformAdapter: + """An adapter for creating key transforms for schema field lookups.""" + + def __init__(self, transform: type[Transform]): + self.transform = transform + + def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None: + """All transforms should bypass the SchemaField's adaptaion with `get_prep_value`, + and routed to JSONField's `get_prep_value` for further processing.""" + if isinstance(col, BaseExpression): + col = col.copy() + col.output_field = super(PydanticSchemaField, col.output_field) # type: ignore + return self.transform(col, *args, **kwargs) + @ty.overload def SchemaField(schema: None = None) -> ty.Any: diff --git a/tests/test_e2e_models.py b/tests/test_e2e_models.py index ce5c5e7..c48cf73 100644 --- a/tests/test_e2e_models.py +++ b/tests/test_e2e_models.py @@ -3,8 +3,8 @@ import pytest from django.db.models import F, Q, JSONField, Value -from .conftest import InnerSchema -from .test_app.models import SampleModel +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel pytestmark = [ pytest.mark.usefixtures("available_database_backends"), @@ -17,15 +17,11 @@ [ ( { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, ), @@ -35,9 +31,7 @@ "sample_list": [{"stub_str": "abc", "stub_list": []}], }, { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, ), From e1af3dff475e8d5a5a787aed1aadf1926f97af10 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Mon, 6 Nov 2023 23:46:29 +0400 Subject: [PATCH 07/34] Adapt v2 impl to latest mypy. --- django_pydantic_field/compat/django.py | 2 +- django_pydantic_field/fields.pyi | 22 +++------------------- django_pydantic_field/v2/fields.py | 15 +++++++-------- django_pydantic_field/v2/types.py | 7 +++++-- pyproject.toml | 5 +++-- tests/v1/test_fields.py | 1 + 6 files changed, 20 insertions(+), 32 deletions(-) diff --git a/django_pydantic_field/compat/django.py b/django_pydantic_field/compat/django.py index 682f126..57b99f3 100644 --- a/django_pydantic_field/compat/django.py +++ b/django_pydantic_field/compat/django.py @@ -21,7 +21,7 @@ try: from typing import get_args, get_origin except ImportError: - from typing_extensions import get_args, get_origin + from typing_extensions import get_args, get_origin # type: ignore[no-redef] from django.db.migrations.serializer import BaseSerializer, serializer_factory from django.db.migrations.writer import MigrationWriter diff --git a/django_pydantic_field/fields.pyi b/django_pydantic_field/fields.pyi index 728e2be..f0230b0 100644 --- a/django_pydantic_field/fields.pyi +++ b/django_pydantic_field/fields.pyi @@ -2,11 +2,9 @@ from __future__ import annotations import typing as ty -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from pydantic.dataclasses import DataclassClassOrWrapper -from .compat.pydantic import PYDANTIC_V1, PYDANTIC_V2 - __all__ = ("SchemaField",) SchemaT: ty.TypeAlias = ty.Union[ @@ -20,22 +18,11 @@ SchemaT: ty.TypeAlias = ty.Union[ OptSchemaT: ty.TypeAlias = ty.Optional[SchemaT] ST = ty.TypeVar("ST", bound=SchemaT) -ConfigType: ty.TypeAlias = ty.Any - -if PYDANTIC_V1: - from pydantic import ConfigDict, BaseConfig - - ConfigType = ty.Union[ConfigDict, type[BaseConfig], type] -elif PYDANTIC_V2: - from pydantic import ConfigDict - - ConfigType = ConfigDict - @ty.overload def SchemaField( schema: type[ST | None] | ty.ForwardRef = ..., - config: ConfigType = ..., + config: ConfigDict = ..., default: OptSchemaT | ty.Callable[[], OptSchemaT] = ..., *args, null: ty.Literal[True], @@ -47,13 +34,10 @@ def SchemaField( @ty.overload def SchemaField( schema: type[ST] | ty.ForwardRef = ..., - config: ConfigType = ..., + config: ConfigDict = ..., default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., *args, null: ty.Literal[False] = ..., **kwargs, ) -> ST: ... - -def SchemaField(*args, **kwargs) -> ty.Any: - ... diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 652d58f..3b59439 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -27,7 +27,8 @@ def __set__(self, obj, value): class PydanticSchemaField(JSONField, ty.Generic[types.ST]): - descriptor_class = SchemaAttribute + descriptor_class: type[DeferredAttribute] = SchemaAttribute + adapter: types.SchemaAdapter def __init__( self, @@ -76,23 +77,21 @@ def to_python(self, value: ty.Any): try: return self.adapter.validate_python(value) except pydantic.ValidationError as exc: - raise exceptions.ValidationError(exc.title, code="invalid", params=exc.errors()) from exc + error_params = {"errors": exc.errors(), "field": self} + raise exceptions.ValidationError(exc.title, code="invalid", params=error_params) from exc def get_prep_value(self, value: ty.Any): if isinstance(value, BaseExpression): # We don't want to perform coercion on database query expressions. return super().get_prep_value(value) - try: - prep_value = self.adapter.validate_python(value, strict=True) - except pydantic.ValidationError: - prep_value = self.adapter.dump_python(value) - prep_value = self.adapter.validate_python(prep_value) - + prep_value = self.adapter.validate_python(value) plain_value = self.adapter.dump_python(prep_value) + return super().get_prep_value(plain_value) def get_transform(self, lookup_name: str): + transform: type[Transform] | SchemaKeyTransformAdapter | None transform = super().get_transform(lookup_name) if transform is not None: transform = SchemaKeyTransformAdapter(transform) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 1dfc9fe..9747ec7 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -29,6 +29,7 @@ class ExportKwargs(ty.TypedDict, total=False): strict: bool + from_attributes: bool mode: ty.Literal["json", "python"] include: IncEx | None exclude: IncEx | None @@ -80,11 +81,13 @@ def validate_schema(self) -> None: """Validate the schema and raise an exception if it is invalid.""" self._get_prepared_schema() - def validate_python(self, value: ty.Any, *, strict: bool | None = None) -> ST: + def validate_python(self, value: ty.Any, *, strict: bool | None = None, from_attributes: bool | None = None) -> ST: """Validate the value and raise an exception if it is invalid.""" if strict is None: strict = self.export_kwargs.get("strict", None) - return self.type_adapter.validate_python(value, strict=strict) + if from_attributes is None: + from_attributes = self.export_kwargs.get("from_attributes", None) + return self.type_adapter.validate_python(value, strict=strict, from_attributes=from_attributes) def dump_python(self, value: ty.Any) -> ty.Any: """Dump the value to a Python object.""" diff --git a/pyproject.toml b/pyproject.toml index 1b095f6..11a4a1f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,8 +51,8 @@ dev = [ "mypy", "pytest==7.0.*", "djangorestframework>=3.11,<4", - "django-stubs[compatible-mypy]~=1.12.0", - "djangorestframework-stubs[compatible-mypy]~=1.7.0", + "django-stubs[compatible-mypy]~=4.2", + "djangorestframework-stubs[compatible-mypy]~=3.14", "pytest-django>=4.5,<5", ] test = [ @@ -106,6 +106,7 @@ plugins = [ "mypy_drf_plugin.main" ] exclude = [".env", "tests"] +enable_incomplete_feature = ["Unpack"] [tool.django-stubs] django_settings_module = "tests.settings.django_test_settings" diff --git a/tests/v1/test_fields.py b/tests/v1/test_fields.py index a983dfb..a2fdd9d 100644 --- a/tests/v1/test_fields.py +++ b/tests/v1/test_fields.py @@ -3,6 +3,7 @@ from datetime import date from django.core.exceptions import FieldError +from django.db import models from tests.conftest import InnerSchema from tests.test_app.models import SampleModel From 3d08bf91aede443fd5112e6b2a8d9af590854e9e Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Mon, 6 Nov 2023 23:52:19 +0400 Subject: [PATCH 08/34] Adapt v1 to latest mypy. --- django_pydantic_field/v1/fields.py | 2 +- django_pydantic_field/v1/rest_framework.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/django_pydantic_field/v1/fields.py b/django_pydantic_field/v1/fields.py index 92d43e6..b85041e 100644 --- a/django_pydantic_field/v1/fields.py +++ b/django_pydantic_field/v1/fields.py @@ -43,7 +43,7 @@ def __init__( self, *args, schema: t.Union[t.Type["base.ST"], "GenericContainer", "t.ForwardRef", str, None] = None, - config: "base.ConfigType" = None, + config: t.Optional["base.ConfigType"] = None, **kwargs, ): self.export_params = base.extract_export_kwargs(kwargs, dict.pop) diff --git a/django_pydantic_field/v1/rest_framework.py b/django_pydantic_field/v1/rest_framework.py index 4001dcd..53708fd 100644 --- a/django_pydantic_field/v1/rest_framework.py +++ b/django_pydantic_field/v1/rest_framework.py @@ -94,7 +94,7 @@ def to_internal_value(self, data: t.Any) -> t.Optional["base.ST"]: try: return self.decoder.decode(data) except ValidationError as e: - raise serializers.ValidationError(e.errors(), self.field_name) + raise serializers.ValidationError(e.errors(), self.field_name) # type: ignore[arg-type] def to_representation(self, value: t.Optional["base.ST"]) -> t.Any: obj = self.schema.parse_obj(value) From 417b6697919e87ecafa0cc975f6369e6370224ae Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Mon, 6 Nov 2023 23:58:44 +0400 Subject: [PATCH 09/34] Add missing `uritemplate` after migrating to recent deps. --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 11a4a1f..30258ce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ dependencies = [ ] [project.optional-dependencies] +openapi = ["uritemplate"] dev = [ "black", "isort", @@ -56,6 +57,7 @@ dev = [ "pytest-django>=4.5,<5", ] test = [ + "django_pydantic_field[openapi]", "dj-database-url~=2.0", "djangorestframework>=3,<4", "pyyaml", From bc5e59163c20fabb3e2393591fa70eb01e5161f8 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 7 Nov 2023 23:04:28 +0400 Subject: [PATCH 10/34] Django form field impl for V2. --- django_pydantic_field/v1/__init__.py | 5 + django_pydantic_field/v2/__init__.py | 5 + django_pydantic_field/v2/fields.py | 34 +++--- django_pydantic_field/v2/forms.py | 82 +++++++++++++- django_pydantic_field/v2/types.py | 67 ++++++++---- django_pydantic_field/v2/utils.py | 11 +- .../{test_form_field.py => v1/test_forms.py} | 8 +- tests/v2/__init__.py | 0 tests/v2/test_forms.py | 101 ++++++++++++++++++ 9 files changed, 268 insertions(+), 45 deletions(-) rename tests/{test_form_field.py => v1/test_forms.py} (93%) create mode 100644 tests/v2/__init__.py create mode 100644 tests/v2/test_forms.py diff --git a/django_pydantic_field/v1/__init__.py b/django_pydantic_field/v1/__init__.py index 7746f2c..91323b9 100644 --- a/django_pydantic_field/v1/__init__.py +++ b/django_pydantic_field/v1/__init__.py @@ -1 +1,6 @@ +from ..compat.pydantic import PYDANTIC_V1 + +if not PYDANTIC_V1: + raise ImportError("django_pydantic_field.v1 package is only compatible with Pydantic v1") + from .fields import * diff --git a/django_pydantic_field/v2/__init__.py b/django_pydantic_field/v2/__init__.py index 9e72cfd..91b48f0 100644 --- a/django_pydantic_field/v2/__init__.py +++ b/django_pydantic_field/v2/__init__.py @@ -1 +1,6 @@ +from ..compat.pydantic import PYDANTIC_V2 + +if not PYDANTIC_V2: + raise ImportError("django_pydantic_field.v2 package is only compatible with Pydantic v2") + from .fields import SchemaField as SchemaField diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 3b59439..2b2998c 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -12,16 +12,13 @@ from django.db.models.lookups import Transform from django.db.models.query_utils import DeferredAttribute -from . import types +from . import types, forms, utils from ..compat.django import GenericContainer class SchemaAttribute(DeferredAttribute): field: PydanticSchemaField - def __set_name__(self, owner, name): - self.field.adapter.bind(owner, name) - def __set__(self, obj, value): obj.__dict__[self.field.attname] = self.field.to_python(value) @@ -71,14 +68,14 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: def validate(self, value: ty.Any, model_instance: ty.Any) -> None: value = self.adapter.validate_python(value) - return super().validate(value, model_instance) + return super(JSONField, self).validate(value, model_instance) def to_python(self, value: ty.Any): try: return self.adapter.validate_python(value) except pydantic.ValidationError as exc: error_params = {"errors": exc.errors(), "field": self} - raise exceptions.ValidationError(exc.title, code="invalid", params=error_params) from exc + raise exceptions.ValidationError(exc.json(), code="invalid", params=error_params) from exc def get_prep_value(self, value: ty.Any): if isinstance(value, BaseExpression): @@ -87,7 +84,6 @@ def get_prep_value(self, value: ty.Any): prep_value = self.adapter.validate_python(value) plain_value = self.adapter.dump_python(prep_value) - return super().get_prep_value(plain_value) def get_transform(self, lookup_name: str): @@ -107,6 +103,20 @@ def get_default(self) -> types.ST: return prep_value + def formfield(self, **kwargs): + schema = self.schema + if schema is None: + schema = utils.get_annotated_type(self.model, self.attname) + + field_kwargs = dict( + form_class=forms.SchemaField, + schema=schema, + config=self.config, + **self.export_kwargs, + ) + field_kwargs.update(kwargs) + return super().formfield(**field_kwargs) # type: ignore + class SchemaKeyTransformAdapter: """An adapter for creating key transforms for schema field lookups.""" @@ -123,15 +133,5 @@ def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None: return self.transform(col, *args, **kwargs) -@ty.overload -def SchemaField(schema: None = None) -> ty.Any: - ... - - -@ty.overload -def SchemaField(schema: type[types.ST]) -> ty.Any: - ... - - def SchemaField(schema=None, config=None, *args, **kwargs): # type: ignore return PydanticSchemaField(*args, schema=schema, config=config, **kwargs) diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py index 5e4ea0c..afc59f4 100644 --- a/django_pydantic_field/v2/forms.py +++ b/django_pydantic_field/v2/forms.py @@ -1,4 +1,80 @@ +from __future__ import annotations -class SchemaField: - def __init__(*args, **kwargs): - ... +import typing as ty +from collections import ChainMap +from django.forms import BaseForm, ModelForm + +import pydantic +from django.core.exceptions import ValidationError +from django.forms.fields import JSONField, JSONString, InvalidJSONInput +from django.utils.translation import gettext_lazy as _ + +from . import types, utils + + +class SchemaField(JSONField, ty.Generic[types.ST]): + adapter: types.SchemaAdapter + default_error_messages = { + "schema_error": _("Schema didn't match for %(title)s."), + } + + def __init__(self, schema: types.ST | ty.ForwardRef, config: pydantic.ConfigDict | None = None, *args, **kwargs): + self.schema = schema + self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) + allow_null = None in self.empty_values + self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null=allow_null, **self.export_kwargs) + super().__init__(*args, **kwargs) + + def get_bound_field(self, form: ty.Any, field_name: str): + if not self.adapter.is_bound: + self._bind_schema_adapter(form, field_name) + return super().get_bound_field(form, field_name) + + def bound_data(self, data: ty.Any, initial: ty.Any): + if self.disabled: + return initial + if data is None: + return None + try: + return self.adapter.validate_json(data) + except pydantic.ValidationError: + return InvalidJSONInput(data) + + def to_python(self, value: ty.Any) -> ty.Any: + if self.disabled: + return value + if value in self.empty_values: + return None + elif isinstance(value, JSONString): + return value + try: + converted = self.adapter.validate_json(value) + except pydantic.ValidationError as exc: + error_params = {"value": value, "title": exc.title, "detail": exc.json(), "errors": exc.errors()} + raise ValidationError(self.error_messages["schema_error"], code="invalid", params=error_params) from exc + + if isinstance(converted, str): + return JSONString(converted) + + return converted + + def prepare_value(self, value): + if isinstance(value, InvalidJSONInput): + return value + + value = self.adapter.validate_python(value) + return self.adapter.dump_json(value).decode() + + def has_changed(self, initial: ty.Any | None, data: ty.Any | None) -> bool: + if super(JSONField, self).has_changed(initial, data): + return True + return self.adapter.dump_json(initial) != self.adapter.dump_json(data) + + def _bind_schema_adapter(self, form: BaseForm, field_name: str): + modelns = None + if isinstance(form, ModelForm): + modelns = ChainMap( + utils.get_local_namespace(form._meta.model), + utils.get_global_namespace(form._meta.model), + ) + self.adapter.bind(form, field_name, modelns) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 9747ec7..1a0e25c 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -1,9 +1,10 @@ from __future__ import annotations -import functools import typing as ty +import typing_extensions as te +from collections import ChainMap -from pydantic.type_adapter import TypeAdapter +import pydantic from . import utils from ..compat.django import GenericContainer @@ -11,15 +12,16 @@ ST = ty.TypeVar("ST", bound="SchemaT") if ty.TYPE_CHECKING: - from pydantic import BaseModel + from collections.abc import MutableMapping + from pydantic.type_adapter import IncEx from pydantic.dataclasses import DataclassClassOrWrapper from django.db.models import Model - ModelType = ty.Type[BaseModel] + ModelType = ty.Type[pydantic.BaseModel] DjangoModelType = ty.Type[Model] SchemaT = ty.Union[ - BaseModel, + pydantic.BaseModel, DataclassClassOrWrapper, ty.Sequence[ty.Any], ty.Mapping[str, ty.Any], @@ -27,7 +29,8 @@ ty.FrozenSet[ty.Any], ] -class ExportKwargs(ty.TypedDict, total=False): + +class ExportKwargs(te.TypedDict, total=False): strict: bool from_attributes: bool mode: ty.Literal["json", "python"] @@ -44,11 +47,11 @@ class ExportKwargs(ty.TypedDict, total=False): class SchemaAdapter(ty.Generic[ST]): def __init__( self, - schema, - config, - parent_type, - attname, - allow_null, + schema: ty.Any, + config: pydantic.ConfigDict | None, + parent_type: type | None, + attname: str | None, + allow_null: bool | None = None, *, parent_depth=4, **export_kwargs: ty.Unpack[ExportKwargs], @@ -60,6 +63,7 @@ def __init__( self.allow_null = allow_null self.parent_depth = parent_depth self.export_kwargs = export_kwargs + self.__namespace: MutableMapping[str, ty.Any] = {} @staticmethod def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: @@ -67,14 +71,19 @@ def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: export_kwargs = {key: kwargs.pop(key) for key in common_keys} return ty.cast(ExportKwargs, export_kwargs) - @functools.cached_property - def type_adapter(self) -> TypeAdapter: + @utils.cached_property + def type_adapter(self) -> pydantic.TypeAdapter: schema = self._get_prepared_schema() - return TypeAdapter(schema, config=self.config, _parent_depth=4) # type: ignore + return pydantic.TypeAdapter(schema, config=self.config, _parent_depth=4) # type: ignore - def bind(self, parent_type, attname): + @property + def is_bound(self) -> bool: + return self.parent_type is not None and self.attname is not None + + def bind(self, parent_type, attname, __namespace: MutableMapping[str, ty.Any] | None = None): self.parent_type = parent_type self.attname = attname + self.__namespace = __namespace if __namespace is not None else {} self.__dict__.pop("type_adapter", None) def validate_schema(self) -> None: @@ -89,10 +98,18 @@ def validate_python(self, value: ty.Any, *, strict: bool | None = None, from_att from_attributes = self.export_kwargs.get("from_attributes", None) return self.type_adapter.validate_python(value, strict=strict, from_attributes=from_attributes) + def validate_json(self, value: str | bytes, *, strict: bool | None = None) -> ST: + if strict is None: + strict = self.export_kwargs.get("strict", None) + return self.type_adapter.validate_json(value, strict=strict) + def dump_python(self, value: ty.Any) -> ty.Any: """Dump the value to a Python object.""" return self.type_adapter.dump_python(value, **self._dump_python_kwargs) + def dump_json(self, value: ty.Any) -> bytes: + return self.type_adapter.dump_json(value, **self._dump_python_kwargs) + def json_schema(self) -> ty.Any: """Return the JSON schema for the field.""" by_alias = self.export_kwargs.get("by_alias", True) @@ -104,17 +121,20 @@ def _get_prepared_schema(self) -> type[ST]: if schema is None: schema = self._guess_schema_from_annotations() if isinstance(schema, GenericContainer): - schema = ty.cast(type[ST], GenericContainer.unwrap(schema)) + schema = ty.cast(ty.Type[ST], GenericContainer.unwrap(schema)) if isinstance(schema, (str, ty.ForwardRef)): schema = self._resolve_schema_forward_ref(schema) if schema is None: - error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" + if self.parent_type is not None: + error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" + else: + error_msg = "The adapter is accessed before it was bound to a field" raise ValueError(error_msg) if self.allow_null: schema = ty.Optional[schema] - return ty.cast(type[ST], schema) + return ty.cast(ty.Type[ST], schema) def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | None: return utils.get_annotated_type(self.parent_type, self.attname) @@ -122,10 +142,15 @@ def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | Non def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any: if isinstance(schema, str): schema = ty.ForwardRef(schema) - namespace = utils.get_local_namespace(self.parent_type) - return schema._evaluate(namespace, vars(self.parent_type), frozenset()) # type: ignore - @functools.cached_property + globalns = ChainMap( + self.__namespace, + utils.get_local_namespace(self.parent_type), + utils.get_global_namespace(self.parent_type), + ) + return schema._evaluate(dict(globalns), {}, frozenset()) # type: ignore + + @utils.cached_property def _dump_python_kwargs(self) -> dict[str, ty.Any]: export_kwargs = self.export_kwargs.copy() export_kwargs.pop("strict", None) diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py index 39b625a..ca005cc 100644 --- a/django_pydantic_field/v2/utils.py +++ b/django_pydantic_field/v2/utils.py @@ -3,6 +3,11 @@ import sys import typing as ty +try: + from functools import cached_property +except ImportError: + from django.utils.functional import cached_property + def get_annotated_type(obj, field, default=None) -> ty.Any: try: @@ -16,9 +21,13 @@ def get_annotated_type(obj, field, default=None) -> ty.Any: return default -def get_local_namespace(cls) -> dict[str, ty.Any]: +def get_global_namespace(cls) -> dict[str, ty.Any]: try: module = cls.__module__ return vars(sys.modules[module]) except (KeyError, AttributeError): return {} + + +def get_local_namespace(cls) -> dict[str, ty.Any]: + return dict(vars(cls)) diff --git a/tests/test_form_field.py b/tests/v1/test_forms.py similarity index 93% rename from tests/test_form_field.py rename to tests/v1/test_forms.py index cadcc6a..a816d25 100644 --- a/tests/test_form_field.py +++ b/tests/v1/test_forms.py @@ -4,10 +4,12 @@ import pytest from django.core.exceptions import ValidationError from django.forms import Form, modelform_factory -from django_pydantic_field import fields, forms -from .conftest import InnerSchema -from .test_app.models import SampleForwardRefModel, SampleSchema +from tests.conftest import InnerSchema +from tests.test_app.models import SampleForwardRefModel, SampleSchema + +fields = pytest.importorskip("django_pydantic_field.v1.fields") +forms = pytest.importorskip("django_pydantic_field.v1.forms") class SampleForm(Form): diff --git a/tests/v2/__init__.py b/tests/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/v2/test_forms.py b/tests/v2/test_forms.py new file mode 100644 index 0000000..4a1081e --- /dev/null +++ b/tests/v2/test_forms.py @@ -0,0 +1,101 @@ +import typing as ty + +import django +import pytest +from django.core.exceptions import ValidationError +from django.forms import Form, modelform_factory + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleForwardRefModel, SampleSchema + +fields = pytest.importorskip("django_pydantic_field.v2.fields") +forms = pytest.importorskip("django_pydantic_field.v2.forms") + + +class SampleForm(Form): + field = forms.SchemaField(ty.ForwardRef("SampleSchema")) + + +def test_form_schema_field(): + field = forms.SchemaField(InnerSchema) + + cleaned_data = field.clean('{"stub_str": "abc", "stub_list": ["1970-01-01"]}') + assert cleaned_data == InnerSchema.model_validate({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + + +def test_empty_form_values(): + field = forms.SchemaField(InnerSchema, required=False) + assert field.clean("") is None + assert field.clean(None) is None + + +def test_prepare_value(): + field = forms.SchemaField(InnerSchema, required=False) + expected = '{"stub_str":"abc","stub_int":1,"stub_list":["1970-01-01"]}' + assert expected == field.prepare_value({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + + +def test_empty_required_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean("") + + assert e.match("This field is required") + + +def test_invalid_schema_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean('{"stub_list": "abc"}') + + assert e.match("Schema didn't match for") + assert "stub_list" in e.value.params["detail"] # type: ignore + assert "stub_str" in e.value.params["detail"] # type: ignore + + +def test_invalid_json_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean('{"stub_list": "abc}') + + assert e.match("Schema didn't match for") + assert '"type":"json_invalid"' in e.value.params["detail"] # type: ignore + + +@pytest.mark.xfail( + django.VERSION[:2] < (4, 0), + reason="Django < 4 has it's own feeling on bound fields resolution", +) +def test_forwardref_field(): + form = SampleForm(data={"field": '{"field": "2"}'}) + assert form.is_valid() + + +def test_model_formfield(): + field = fields.PydanticSchemaField(schema=InnerSchema) + assert isinstance(field.formfield(), forms.SchemaField) + + +def test_forwardref_model_formfield(): + form_cls = modelform_factory(SampleForwardRefModel, exclude=("field",)) + form = form_cls(data={"annotated_field": '{"field": "2"}'}) + + assert form.is_valid(), form.errors + cleaned_data = form.cleaned_data + + assert cleaned_data is not None + assert cleaned_data["annotated_field"] == SampleSchema(field=2) + + +@pytest.mark.parametrize("export_kwargs", [ + {"include": {"stub_str", "stub_int"}}, + {"exclude": {"stub_list"}}, + {"exclude_unset": True}, + {"exclude_defaults": True}, + {"exclude_none": True}, + {"by_alias": True}, +]) +def test_form_field_export_kwargs(export_kwargs): + field = forms.SchemaField(InnerSchema, required=False, **export_kwargs) + value = InnerSchema.model_validate({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + assert field.prepare_value(value) From 672c2c56351e38e993ba0f9f0a20f356f9bbc01e Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 8 Nov 2023 02:25:25 +0400 Subject: [PATCH 11/34] Some compatibility for py3.7 and py3.8 [skip ci] --- django_pydantic_field/v2/types.py | 2 +- django_pydantic_field/v2/utils.py | 23 ++++++++++++++++++++--- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 1a0e25c..359c628 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -148,7 +148,7 @@ def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any: utils.get_local_namespace(self.parent_type), utils.get_global_namespace(self.parent_type), ) - return schema._evaluate(dict(globalns), {}, frozenset()) # type: ignore + return utils.evaluate_forward_ref(schema, globalns) @utils.cached_property def _dump_python_kwargs(self) -> dict[str, ty.Any]: diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py index ca005cc..3d076be 100644 --- a/django_pydantic_field/v2/utils.py +++ b/django_pydantic_field/v2/utils.py @@ -4,9 +4,12 @@ import typing as ty try: - from functools import cached_property + from functools import cached_property as cached_property except ImportError: - from django.utils.functional import cached_property + from django.utils.functional import cached_property as cached_property + +if ty.TYPE_CHECKING: + from collections.abc import Mapping def get_annotated_type(obj, field, default=None) -> ty.Any: @@ -30,4 +33,18 @@ def get_global_namespace(cls) -> dict[str, ty.Any]: def get_local_namespace(cls) -> dict[str, ty.Any]: - return dict(vars(cls)) + try: + return dict(vars(cls)) + except TypeError: + return {} + + +if sys.version_info >= (3, 9): + + def evaluate_forward_ref(ref: ty.ForwardRef, ns: Mapping[str, ty.Any]) -> ty.Any: + return ref._evaluate(dict(ns), {}, frozenset()) + +else: + + def evaluate_forward_ref(ref: ty.ForwardRef, ns: Mapping[str, ty.Any]) -> ty.Any: + return ref._evaluate(dict(ns), {}) From 39fde666daa1755a76a269b9da3e53821d9656e4 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 8 Nov 2023 23:53:21 +0400 Subject: [PATCH 12/34] Make all form fields tests homogenic with v1. [skip ci] --- django_pydantic_field/v2/fields.py | 10 ++++------ django_pydantic_field/v2/forms.py | 32 +++++++++++++++--------------- django_pydantic_field/v2/types.py | 25 +++++++++-------------- django_pydantic_field/v2/utils.py | 9 +++++++-- 4 files changed, 36 insertions(+), 40 deletions(-) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 2b2998c..c904b6a 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -12,7 +12,7 @@ from django.db.models.lookups import Transform from django.db.models.query_utils import DeferredAttribute -from . import types, forms, utils +from . import types, forms from ..compat.django import GenericContainer @@ -104,13 +104,11 @@ def get_default(self) -> types.ST: return prep_value def formfield(self, **kwargs): - schema = self.schema - if schema is None: - schema = utils.get_annotated_type(self.model, self.attname) - field_kwargs = dict( form_class=forms.SchemaField, - schema=schema, + # Trying to resolve the schema before passing it to the formfield, since in Django < 4.0, + # formfield is unbound during form validation and is not able to resolve forward refs defined in the model. + schema=self.adapter.prepared_schema, config=self.config, **self.export_kwargs, ) diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py index afc59f4..573d35b 100644 --- a/django_pydantic_field/v2/forms.py +++ b/django_pydantic_field/v2/forms.py @@ -2,15 +2,17 @@ import typing as ty from collections import ChainMap -from django.forms import BaseForm, ModelForm import pydantic from django.core.exceptions import ValidationError -from django.forms.fields import JSONField, JSONString, InvalidJSONInput +from django.forms.fields import InvalidJSONInput, JSONField, JSONString from django.utils.translation import gettext_lazy as _ from . import types, utils +if ty.TYPE_CHECKING: + from django.forms import BaseForm + class SchemaField(JSONField, ty.Generic[types.ST]): adapter: types.SchemaAdapter @@ -18,21 +20,28 @@ class SchemaField(JSONField, ty.Generic[types.ST]): "schema_error": _("Schema didn't match for %(title)s."), } - def __init__(self, schema: types.ST | ty.ForwardRef, config: pydantic.ConfigDict | None = None, *args, **kwargs): + def __init__( + self, + schema: types.ST, + config: pydantic.ConfigDict | None = None, + allow_null: bool | None = None, + *args, + **kwargs, + ): self.schema = schema + self.config = config self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) - allow_null = None in self.empty_values - self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null=allow_null, **self.export_kwargs) + self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null, **self.export_kwargs) super().__init__(*args, **kwargs) def get_bound_field(self, form: ty.Any, field_name: str): if not self.adapter.is_bound: - self._bind_schema_adapter(form, field_name) + self.adapter.bind(form, field_name) return super().get_bound_field(form, field_name) def bound_data(self, data: ty.Any, initial: ty.Any): if self.disabled: - return initial + return self.adapter.validate_python(initial) if data is None: return None try: @@ -69,12 +78,3 @@ def has_changed(self, initial: ty.Any | None, data: ty.Any | None) -> bool: if super(JSONField, self).has_changed(initial, data): return True return self.adapter.dump_json(initial) != self.adapter.dump_json(data) - - def _bind_schema_adapter(self, form: BaseForm, field_name: str): - modelns = None - if isinstance(form, ModelForm): - modelns = ChainMap( - utils.get_local_namespace(form._meta.model), - utils.get_global_namespace(form._meta.model), - ) - self.adapter.bind(form, field_name, modelns) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 359c628..43ce115 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -12,7 +12,7 @@ ST = ty.TypeVar("ST", bound="SchemaT") if ty.TYPE_CHECKING: - from collections.abc import MutableMapping + from collections.abc import Mapping from pydantic.type_adapter import IncEx from pydantic.dataclasses import DataclassClassOrWrapper @@ -52,8 +52,6 @@ def __init__( parent_type: type | None, attname: str | None, allow_null: bool | None = None, - *, - parent_depth=4, **export_kwargs: ty.Unpack[ExportKwargs], ): self.schema = schema @@ -61,9 +59,7 @@ def __init__( self.parent_type = parent_type self.attname = attname self.allow_null = allow_null - self.parent_depth = parent_depth self.export_kwargs = export_kwargs - self.__namespace: MutableMapping[str, ty.Any] = {} @staticmethod def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: @@ -73,22 +69,21 @@ def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: @utils.cached_property def type_adapter(self) -> pydantic.TypeAdapter: - schema = self._get_prepared_schema() - return pydantic.TypeAdapter(schema, config=self.config, _parent_depth=4) # type: ignore + return pydantic.TypeAdapter(self.prepared_schema, config=self.config) # type: ignore @property def is_bound(self) -> bool: return self.parent_type is not None and self.attname is not None - def bind(self, parent_type, attname, __namespace: MutableMapping[str, ty.Any] | None = None): + def bind(self, parent_type: type, attname: str) -> None: self.parent_type = parent_type self.attname = attname - self.__namespace = __namespace if __namespace is not None else {} + self.__dict__.pop("prepared_schema", None) self.__dict__.pop("type_adapter", None) def validate_schema(self) -> None: """Validate the schema and raise an exception if it is invalid.""" - self._get_prepared_schema() + self.prepared_schema() def validate_python(self, value: ty.Any, *, strict: bool | None = None, from_attributes: bool | None = None) -> ST: """Validate the value and raise an exception if it is invalid.""" @@ -115,7 +110,8 @@ def json_schema(self) -> ty.Any: by_alias = self.export_kwargs.get("by_alias", True) return self.type_adapter.json_schema(by_alias=by_alias) - def _get_prepared_schema(self) -> type[ST]: + @utils.cached_property + def prepared_schema(self) -> type[ST]: schema = self.schema if schema is None: @@ -134,6 +130,7 @@ def _get_prepared_schema(self) -> type[ST]: if self.allow_null: schema = ty.Optional[schema] + return ty.cast(ty.Type[ST], schema) def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | None: @@ -143,11 +140,7 @@ def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any: if isinstance(schema, str): schema = ty.ForwardRef(schema) - globalns = ChainMap( - self.__namespace, - utils.get_local_namespace(self.parent_type), - utils.get_global_namespace(self.parent_type), - ) + globalns = utils.get_namespace(self.parent_type) return utils.evaluate_forward_ref(schema, globalns) @utils.cached_property diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py index 3d076be..2f2f890 100644 --- a/django_pydantic_field/v2/utils.py +++ b/django_pydantic_field/v2/utils.py @@ -2,11 +2,12 @@ import sys import typing as ty +from collections import ChainMap try: from functools import cached_property as cached_property except ImportError: - from django.utils.functional import cached_property as cached_property + from django.utils.functional import cached_property as cached_property # type: ignore if ty.TYPE_CHECKING: from collections.abc import Mapping @@ -24,6 +25,10 @@ def get_annotated_type(obj, field, default=None) -> ty.Any: return default +def get_namespace(cls) -> ChainMap[str, ty.Any]: + return ChainMap(get_local_namespace(cls), get_global_namespace(cls)) + + def get_global_namespace(cls) -> dict[str, ty.Any]: try: module = cls.__module__ @@ -34,7 +39,7 @@ def get_global_namespace(cls) -> dict[str, ty.Any]: def get_local_namespace(cls) -> dict[str, ty.Any]: try: - return dict(vars(cls)) + return vars(cls) except TypeError: return {} From 776c9a1ddcc23e68ffbe816548b591d1983806d1 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Thu, 9 Nov 2023 01:28:20 +0400 Subject: [PATCH 13/34] Implement REST framework schema field. --- django_pydantic_field/v2/rest_framework.py | 53 +++++++++++++++++++--- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/django_pydantic_field/v2/rest_framework.py b/django_pydantic_field/v2/rest_framework.py index c8bb04c..7d88dd9 100644 --- a/django_pydantic_field/v2/rest_framework.py +++ b/django_pydantic_field/v2/rest_framework.py @@ -1,23 +1,62 @@ +from __future__ import annotations + import typing as ty -from . import types +import pydantic +from rest_framework import exceptions, fields, parsers, renderers +from . import types -class SchemaField: - def __init__(*args, **kwargs): - ... +if ty.TYPE_CHECKING: + from rest_framework.serializers import BaseSerializer + + +class SchemaField(fields.Field, ty.Generic[types.ST]): + def __init__( + self, + schema: type[types.ST], + config: pydantic.ConfigDict | None = None, + *args, + allow_null: bool = False, + **kwargs, + ): + self.schema = schema + self.config = config + self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) + self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null, **self.export_kwargs) + super().__init__(*args, **kwargs) + + def bind(self, field_name: str, parent: BaseSerializer): + if not self.adapter.is_bound: + self.adapter.bind(type(parent), field_name) + super().bind(field_name, parent) + + def to_internal_value(self, data: ty.Any): + try: + if isinstance(data, (str, bytes)): + return self.adapter.validate_json(data) + return self.adapter.validate_python(data) + except pydantic.ValidationError as exc: + raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore + + def to_representation(self, value: ty.Optional[types.ST]): + try: + prep_value = self.adapter.validate_python(value) + return self.adapter.dump_python(prep_value) + except pydantic.ValidationError as exc: + raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore -class AutoSchema: +class SchemaParser(ty.Generic[types.ST]): def __init__(*args, **kwargs): ... -class SchemaParser(ty.Generic[types.ST]): +class SchemaRenderer(ty.Generic[types.ST]): def __init__(*args, **kwargs): ... -class SchemaRenderer(ty.Generic[types.ST]): +class AutoSchema: def __init__(*args, **kwargs): ... From 49266295e86eaecd149c3c7df6c9b5b41d1f326b Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Thu, 9 Nov 2023 02:24:47 +0400 Subject: [PATCH 14/34] Implement REST framework schema parser in renderer for v2. --- django_pydantic_field/v2/rest_framework.py | 89 ++++++- django_pydantic_field/v2/types.py | 6 +- django_pydantic_field/v2/utils.py | 5 + .../test_rest_framework.py} | 8 +- tests/v2/test_rest_framework.py | 217 ++++++++++++++++++ 5 files changed, 311 insertions(+), 14 deletions(-) rename tests/{test_drf_adapters.py => v1/test_rest_framework.py} (98%) create mode 100644 tests/v2/test_rest_framework.py diff --git a/django_pydantic_field/v2/rest_framework.py b/django_pydantic_field/v2/rest_framework.py index 7d88dd9..4979992 100644 --- a/django_pydantic_field/v2/rest_framework.py +++ b/django_pydantic_field/v2/rest_framework.py @@ -5,13 +5,18 @@ import pydantic from rest_framework import exceptions, fields, parsers, renderers -from . import types +from . import types, utils if ty.TYPE_CHECKING: + from collections.abc import Mapping from rest_framework.serializers import BaseSerializer + RequestResponseContext = Mapping[str, ty.Any] + class SchemaField(fields.Field, ty.Generic[types.ST]): + adapter: types.SchemaAdapter + def __init__( self, schema: type[types.ST], @@ -47,14 +52,84 @@ def to_representation(self, value: ty.Optional[types.ST]): raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore -class SchemaParser(ty.Generic[types.ST]): - def __init__(*args, **kwargs): - ... +class _AnnotatedAdapterMixin(ty.Generic[types.ST]): + schema_context_key: ty.ClassVar[str] = "response_schema" + config_context_key: ty.ClassVar[str] = "response_schema_config" + def get_adapter(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + adapter = self._make_adapter_from_context(ctx) + if adapter is None: + adapter = self._make_adapter_from_annotation(ctx) -class SchemaRenderer(ty.Generic[types.ST]): - def __init__(*args, **kwargs): - ... + return adapter + + def _make_adapter_from_context(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + schema = ctx.get(self.schema_context_key) + if schema is not None: + config = ctx.get(self.config_context_key) + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(ctx)) + return types.SchemaAdapter(schema, config, type(ctx.get("view")), None, **export_kwargs) + + return schema + + def _make_adapter_from_annotation(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + try: + schema = utils.get_args(self.__orig_class__)[0] # type: ignore + except (AttributeError, IndexError): + return None + + config = ctx.get(self.config_context_key) + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(ctx)) + return types.SchemaAdapter(schema, config, type(ctx.get("view")), None, **export_kwargs) + + +class SchemaRenderer(_AnnotatedAdapterMixin[types.ST], renderers.JSONRenderer): + schema_context_key = "renderer_schema" + config_context_key = "renderer_schema_config" + + def render(self, data: ty.Any, accepted_media_type=None, renderer_context=None): + renderer_context = renderer_context or {} + response = renderer_context.get("response") + if response is not None and response.exception: + return super().render(data, accepted_media_type, renderer_context) + + adapter = self.get_adapter(renderer_context) + if adapter is None and isinstance(data, pydantic.BaseModel): + return self.render_pydantic_model(data, renderer_context) + if adapter is None: + raise RuntimeError("Schema should be either explicitly set with annotation or passed in the context") + + try: + prep_data = adapter.validate_python(data) + return adapter.dump_json(prep_data) + except pydantic.ValidationError as exc: + return exc.json(indent=True, include_input=True).encode() + + def render_pydantic_model(self, instance: pydantic.BaseModel, renderer_context: Mapping[str, ty.Any]): + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(renderer_context)) + export_kwargs.pop("strict", None) + export_kwargs.pop("from_attributes", None) + export_kwargs.pop("mode", None) + + json_dump = instance.model_dump_json(**export_kwargs) # type: ignore + return json_dump.encode() + + +class SchemaParser(_AnnotatedAdapterMixin[types.ST], parsers.JSONParser): + schema_context_key = "parser_schema" + config_context_key = "parser_schema_config" + renderer_class = SchemaRenderer + + def parse(self, stream: ty.IO[bytes], media_type=None, parser_context=None): + parser_context = parser_context or {} + adapter = self.get_adapter(parser_context) + if adapter is None: + raise RuntimeError("Schema should be either explicitly set with annotation or passed in the context") + + try: + return adapter.validate_json(stream.read()) + except pydantic.ValidationError as exc: + raise exceptions.ParseError(exc.errors()) # type: ignore class AutoSchema: diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 43ce115..55e74aa 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -114,7 +114,7 @@ def json_schema(self) -> ty.Any: def prepared_schema(self) -> type[ST]: schema = self.schema - if schema is None: + if schema is None and self.attname is not None: schema = self._guess_schema_from_annotations() if isinstance(schema, GenericContainer): schema = ty.cast(ty.Type[ST], GenericContainer.unwrap(schema)) @@ -122,10 +122,10 @@ def prepared_schema(self) -> type[ST]: schema = self._resolve_schema_forward_ref(schema) if schema is None: - if self.parent_type is not None: + if self.parent_type is not None and self.attname is not None: error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" else: - error_msg = "The adapter is accessed before it was bound to a field" + error_msg = "The adapter is accessed before it was bound" raise ValueError(error_msg) if self.allow_null: diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py index 2f2f890..1f3fd6a 100644 --- a/django_pydantic_field/v2/utils.py +++ b/django_pydantic_field/v2/utils.py @@ -9,6 +9,11 @@ except ImportError: from django.utils.functional import cached_property as cached_property # type: ignore +try: + from typing import get_args as get_args +except ImportError: + from typing_extensions import get_args as get_args # type: ignore + if ty.TYPE_CHECKING: from collections.abc import Mapping diff --git a/tests/test_drf_adapters.py b/tests/v1/test_rest_framework.py similarity index 98% rename from tests/test_drf_adapters.py rename to tests/v1/test_rest_framework.py index c1f04c6..f23bbdf 100644 --- a/tests/test_drf_adapters.py +++ b/tests/v1/test_rest_framework.py @@ -1,18 +1,18 @@ import io -import json import typing as t from datetime import date import pytest import yaml from django.urls import path -from django_pydantic_field import rest_framework from rest_framework import exceptions, generics, schemas, serializers, views from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema from rest_framework.response import Response -from .conftest import InnerSchema -from .test_app.models import SampleModel +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v1.rest_framework") class SampleSerializer(serializers.Serializer): diff --git a/tests/v2/test_rest_framework.py b/tests/v2/test_rest_framework.py new file mode 100644 index 0000000..8223fc5 --- /dev/null +++ b/tests/v2/test_rest_framework.py @@ -0,0 +1,217 @@ +import io +import typing as t +from datetime import date + +import pytest +from rest_framework import exceptions, generics, serializers, views +from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema +from rest_framework.response import Response + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +class SampleSerializer(serializers.Serializer): + field = rest_framework.SchemaField(schema=t.List[InnerSchema]) + + +class SampleModelSerializer(serializers.ModelSerializer): + sample_field = rest_framework.SchemaField(schema=InnerSchema) + sample_list = rest_framework.SchemaField(schema=t.List[InnerSchema]) + sample_seq = rest_framework.SchemaField(schema=t.List[InnerSchema], default=list) + + class Meta: + model = SampleModel + fields = "sample_field", "sample_list", "sample_seq" + + +def test_schema_field(): + field = rest_framework.SchemaField(InnerSchema) + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = { + "stub_str": "abc", + "stub_int": 1, + "stub_list": [date(2022, 7, 1)], + } + + assert field.to_representation(existing_instance) == expected_encoded + assert field.to_internal_value(expected_encoded) == existing_instance + + with pytest.raises(serializers.ValidationError): + field.to_internal_value(None) + + with pytest.raises(serializers.ValidationError): + field.to_internal_value("null") + + +def test_field_schema_with_custom_config(): + field = rest_framework.SchemaField(InnerSchema, allow_null=True, exclude={"stub_int"}) + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + + assert field.to_representation(existing_instance) == expected_encoded + assert field.to_internal_value(expected_encoded) == existing_instance + assert field.to_internal_value(None) is None + assert field.to_internal_value("null") is None + + +def test_serializer_marshalling_with_schema_field(): + existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]} + expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}]} + + serializer = SampleSerializer(instance=existing_instance) + assert serializer.data == expected_data + + serializer = SampleSerializer(data=expected_data) + serializer.is_valid(raise_exception=True) + assert serializer.validated_data == existing_instance + + +def test_model_serializer_marshalling_with_schema_field(): + instance = SampleModel( + sample_field=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + sample_list=[InnerSchema(stub_str="abc", stub_int=2, stub_list=[date(2022, 7, 1)])] * 2, + sample_seq=[InnerSchema(stub_str="abc", stub_int=3, stub_list=[date(2022, 7, 1)])] * 3, + ) + serializer = SampleModelSerializer(instance) + + expected_data = { + "sample_field": {"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}, + "sample_list": [{"stub_str": "abc", "stub_int": 2, "stub_list": [date(2022, 7, 1)]}] * 2, + "sample_seq": [{"stub_str": "abc", "stub_int": 3, "stub_list": [date(2022, 7, 1)]}] * 3, + } + assert serializer.data == expected_data + + +@pytest.mark.parametrize( + "export_kwargs", + [ + {"include": {"stub_str", "stub_int"}}, + {"exclude": {"stub_list"}}, + {"exclude_unset": True}, + {"exclude_defaults": True}, + {"exclude_none": True}, + {"by_alias": True}, + ], +) +def test_field_export_kwargs(export_kwargs): + field = rest_framework.SchemaField(InnerSchema, **export_kwargs) + assert field.to_representation(InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])) + + +def test_invalid_data_serialization(): + invalid_data = {"field": [{"stub_int": "abc", "stub_list": ["abc"]}]} + serializer = SampleSerializer(data=invalid_data) + + with pytest.raises(exceptions.ValidationError) as e: + serializer.is_valid(raise_exception=True) + + assert e.match(r".*stub_str.*stub_int.*stub_list.*") + + +def test_schema_renderer(): + renderer = rest_framework.SchemaRenderer() + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + assert renderer.render(existing_instance) == expected_encoded + + +def test_typed_schema_renderer(): + renderer = rest_framework.SchemaRenderer[InnerSchema]() + existing_data = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + assert renderer.render(existing_data) == expected_encoded + + +def test_schema_parser(): + parser = rest_framework.SchemaParser[InnerSchema]() + existing_encoded = '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}' + expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + + assert parser.parse(io.StringIO(existing_encoded)) == expected_instance + + +@api_view(["POST"]) +@schema(rest_framework.AutoSchema()) +@parser_classes([rest_framework.SchemaParser[InnerSchema]]) +@renderer_classes([rest_framework.SchemaRenderer[t.List[InnerSchema]]]) +def sample_view(request): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + +class ClassBasedViewWithSerializer(generics.RetrieveAPIView): + serializer_class = SampleSerializer + schema = rest_framework.AutoSchema() + + +class ClassBasedViewWithModel(generics.ListCreateAPIView): + queryset = SampleModel.objects.all() + serializer_class = SampleModelSerializer + + +class ClassBasedView(views.APIView): + parser_classes = [rest_framework.SchemaParser[InnerSchema]] + renderer_classes = [rest_framework.SchemaRenderer[t.List[InnerSchema]]] + + def post(self, request, *args, **kwargs): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + +class ClassBasedViewWithSchemaContext(ClassBasedView): + parser_classes = [rest_framework.SchemaParser] + renderer_classes = [rest_framework.SchemaRenderer] + + def get_renderer_context(self): + ctx = super().get_renderer_context() + return dict(ctx, renderer_schema=t.List[InnerSchema]) + + def get_parser_context(self, http_request): + ctx = super().get_parser_context(http_request) + return dict(ctx, parser_schema=InnerSchema) + + +@pytest.mark.parametrize( + "view", + [ + sample_view, + ClassBasedView.as_view(), + ClassBasedViewWithSchemaContext.as_view(), + ], +) +def test_end_to_end_api_view(view, request_factory): + expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + existing_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + request = request_factory.post("/", existing_encoded, content_type="application/json") + response = view(request) + + assert response.data == [expected_instance] + assert response.data[0] is not expected_instance + + assert response.rendered_content == b"[%s]" % existing_encoded + + +@pytest.mark.django_db +def test_end_to_end_list_create_api_view(request_factory): + field_data = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]).json() + expected_result = { + "sample_field": {"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}, + "sample_list": [{"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}], + "sample_seq": [], + } + + payload = '{"sample_field": %s, "sample_list": [%s], "sample_seq": []}' % ((field_data,) * 2) + request = request_factory.post("/", payload.encode(), content_type="application/json") + response = ClassBasedViewWithModel.as_view()(request) + + assert response.data == expected_result + + request = request_factory.get("/", content_type="application/json") + response = ClassBasedViewWithModel.as_view()(request) + assert response.data == [expected_result] From da0dfb9970088ba69f1138748a3b655bf278b3a2 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Sat, 18 Nov 2023 01:23:32 +0400 Subject: [PATCH 15/34] Refine some typings and error handling in SchemaAdapter [ci skip] --- .editorconfig | 5 +++ Makefile | 14 ++---- django_pydantic_field/compat/django.py | 25 +++++------ django_pydantic_field/compat/functools.py | 4 ++ django_pydantic_field/compat/typing.py | 6 +++ django_pydantic_field/v1/rest_framework.py | 6 +-- django_pydantic_field/v2/fields.py | 2 +- django_pydantic_field/v2/forms.py | 6 +-- django_pydantic_field/v2/rest_framework.py | 5 ++- django_pydantic_field/v2/types.py | 50 +++++++++++++--------- django_pydantic_field/v2/utils.py | 10 ----- 11 files changed, 64 insertions(+), 69 deletions(-) create mode 100644 django_pydantic_field/compat/functools.py create mode 100644 django_pydantic_field/compat/typing.py diff --git a/.editorconfig b/.editorconfig index ccd4e7f..429639c 100644 --- a/.editorconfig +++ b/.editorconfig @@ -11,6 +11,11 @@ indent_size = 4 [*.py] max_line_length = 120 +[Makefile] +indent_style = tab +indent_size = 4 + + [{*.json,*.yml,*.yaml}] indent_size = 2 insert_final_newline = false diff --git a/Makefile b/Makefile index 2d52756..88d7574 100644 --- a/Makefile +++ b/Makefile @@ -1,36 +1,28 @@ +.PHONY: install build test lint upload upload-test clean -.PHONY: install install: python3 -m pip install build twine python3 -m pip install -e .[dev,test] - -.PHONY: build build: python3 -m build +migrations: + DJANGO_SETTINGS_MODULE="tests.settings.django_test_settings" python3 -m django makemigrations --noinput -.PHONY: test test: A= test: pytest $(A) -.PHONY: lint lint: A=. lint: mypy $(A) - -.PHONY: upload upload: python3 -m twine upload dist/* - -.PHONY: upload-test upload-test: python3 -m twine upload --repository testpypi dist/* - -.PHONY: clean clean: rm -rf dist/* diff --git a/django_pydantic_field/compat/django.py b/django_pydantic_field/compat/django.py index 57b99f3..304a64c 100644 --- a/django_pydantic_field/compat/django.py +++ b/django_pydantic_field/compat/django.py @@ -16,16 +16,13 @@ """ import sys import types -import typing as t - -try: - from typing import get_args, get_origin -except ImportError: - from typing_extensions import get_args, get_origin # type: ignore[no-redef] +import typing as ty from django.db.migrations.serializer import BaseSerializer, serializer_factory from django.db.migrations.writer import MigrationWriter +from .typing import get_args, get_origin + class GenericContainer: __slots__ = "origin", "args" @@ -107,21 +104,21 @@ def serialize(self): if sys.version_info >= (3, 9): GenericAlias = types.GenericAlias - GenericTypes: t.Tuple[t.Any, ...] = ( + GenericTypes: ty.Tuple[ty.Any, ...] = ( GenericAlias, - type(t.List[int]), - type(t.List), + type(ty.List[int]), + type(ty.List), ) else: # types.GenericAlias is missing, meaning python version < 3.9, # which has a different inheritance models for typed generics - GenericAlias = type(t.List[int]) # noqa - GenericTypes = GenericAlias, type(t.List) # noqa + GenericAlias = type(ty.List[int]) # noqa + GenericTypes = GenericAlias, type(ty.List) # noqa MigrationWriter.register_serializer(GenericContainer, GenericSerializer) -MigrationWriter.register_serializer(t.ForwardRef, TypingSerializer) -MigrationWriter.register_serializer(type(t.Union), TypingSerializer) # type: ignore +MigrationWriter.register_serializer(ty.ForwardRef, TypingSerializer) +MigrationWriter.register_serializer(type(ty.Union), TypingSerializer) # type: ignore if sys.version_info >= (3, 10): @@ -132,7 +129,7 @@ class UnionTypeSerializer(BaseSerializer): def serialize(self): imports = set() - if isinstance(self.value, type(t.Union)): # type: ignore + if isinstance(self.value, type(ty.Union)): # type: ignore imports.add("import typing") for arg in get_args(self.value): diff --git a/django_pydantic_field/compat/functools.py b/django_pydantic_field/compat/functools.py new file mode 100644 index 0000000..d89be7b --- /dev/null +++ b/django_pydantic_field/compat/functools.py @@ -0,0 +1,4 @@ +try: + from functools import cached_property as cached_property +except ImportError: + from django.utils.functional import cached_property as cached_property # type: ignore diff --git a/django_pydantic_field/compat/typing.py b/django_pydantic_field/compat/typing.py new file mode 100644 index 0000000..fd0865e --- /dev/null +++ b/django_pydantic_field/compat/typing.py @@ -0,0 +1,6 @@ +try: + from typing import get_args as get_args + from typing import get_origin as get_origin +except ImportError: + from typing_extensions import get_args as get_args # type: ignore + from typing_extensions import get_origin as get_origin # type: ignore diff --git a/django_pydantic_field/v1/rest_framework.py b/django_pydantic_field/v1/rest_framework.py index 53708fd..234f8b9 100644 --- a/django_pydantic_field/v1/rest_framework.py +++ b/django_pydantic_field/v1/rest_framework.py @@ -1,10 +1,5 @@ import typing as t -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args - from django.conf import settings from pydantic import BaseModel, ValidationError @@ -13,6 +8,7 @@ from rest_framework.schemas.utils import is_list_view from . import base +from ..compat.typing import get_args __all__ = ( "SchemaField", diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index c904b6a..beafb1a 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -62,7 +62,7 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: performed_checks = super().check(**kwargs) try: self.adapter.validate_schema() - except ValueError as exc: + except types.ImproperlyConfiguredSchema as exc: performed_checks.append(checks.Error(exc.args[0], obj=self)) return performed_checks diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py index 573d35b..7f5d656 100644 --- a/django_pydantic_field/v2/forms.py +++ b/django_pydantic_field/v2/forms.py @@ -1,17 +1,13 @@ from __future__ import annotations import typing as ty -from collections import ChainMap import pydantic from django.core.exceptions import ValidationError from django.forms.fields import InvalidJSONInput, JSONField, JSONString from django.utils.translation import gettext_lazy as _ -from . import types, utils - -if ty.TYPE_CHECKING: - from django.forms import BaseForm +from . import types class SchemaField(JSONField, ty.Generic[types.ST]): diff --git a/django_pydantic_field/v2/rest_framework.py b/django_pydantic_field/v2/rest_framework.py index 4979992..7684c71 100644 --- a/django_pydantic_field/v2/rest_framework.py +++ b/django_pydantic_field/v2/rest_framework.py @@ -5,7 +5,8 @@ import pydantic from rest_framework import exceptions, fields, parsers, renderers -from . import types, utils +from . import types +from ..compat.typing import get_args if ty.TYPE_CHECKING: from collections.abc import Mapping @@ -74,7 +75,7 @@ def _make_adapter_from_context(self, ctx: RequestResponseContext) -> types.Schem def _make_adapter_from_annotation(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: try: - schema = utils.get_args(self.__orig_class__)[0] # type: ignore + schema = get_args(self.__orig_class__)[0] # type: ignore except (AttributeError, IndexError): return None diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 55e74aa..657fdc9 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -1,34 +1,34 @@ from __future__ import annotations import typing as ty -import typing_extensions as te -from collections import ChainMap import pydantic +import typing_extensions as te -from . import utils from ..compat.django import GenericContainer - -ST = ty.TypeVar("ST", bound="SchemaT") +from ..compat.functools import cached_property +from . import utils if ty.TYPE_CHECKING: - from collections.abc import Mapping + from collections.abc import Mapping, Sequence - from pydantic.type_adapter import IncEx - from pydantic.dataclasses import DataclassClassOrWrapper from django.db.models import Model + from pydantic.dataclasses import DataclassClassOrWrapper + from pydantic.type_adapter import IncEx ModelType = ty.Type[pydantic.BaseModel] DjangoModelType = ty.Type[Model] SchemaT = ty.Union[ pydantic.BaseModel, DataclassClassOrWrapper, - ty.Sequence[ty.Any], - ty.Mapping[str, ty.Any], - ty.Set[ty.Any], - ty.FrozenSet[ty.Any], + Sequence[ty.Any], + Mapping[str, ty.Any], + set[ty.Any], + frozenset[ty.Any], ] +ST = ty.TypeVar("ST", bound="SchemaT") + class ExportKwargs(te.TypedDict, total=False): strict: bool @@ -44,6 +44,10 @@ class ExportKwargs(te.TypedDict, total=False): warnings: bool +class ImproperlyConfiguredSchema(ValueError): + """Raised when a schema is improperly configured.""" + + class SchemaAdapter(ty.Generic[ST]): def __init__( self, @@ -54,7 +58,7 @@ def __init__( allow_null: bool | None = None, **export_kwargs: ty.Unpack[ExportKwargs], ): - self.schema = schema + self.schema = GenericContainer.unwrap(schema) self.config = config self.parent_type = parent_type self.attname = attname @@ -67,7 +71,7 @@ def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: export_kwargs = {key: kwargs.pop(key) for key in common_keys} return ty.cast(ExportKwargs, export_kwargs) - @utils.cached_property + @cached_property def type_adapter(self) -> pydantic.TypeAdapter: return pydantic.TypeAdapter(self.prepared_schema, config=self.config) # type: ignore @@ -83,7 +87,12 @@ def bind(self, parent_type: type, attname: str) -> None: def validate_schema(self) -> None: """Validate the schema and raise an exception if it is invalid.""" - self.prepared_schema() + try: + self._prepare_schema() + except Exception as exc: + if not isinstance(exc, ImproperlyConfiguredSchema): + raise ImproperlyConfiguredSchema(*exc.args) from exc + raise def validate_python(self, value: ty.Any, *, strict: bool | None = None, from_attributes: bool | None = None) -> ST: """Validate the value and raise an exception if it is invalid.""" @@ -110,14 +119,11 @@ def json_schema(self) -> ty.Any: by_alias = self.export_kwargs.get("by_alias", True) return self.type_adapter.json_schema(by_alias=by_alias) - @utils.cached_property - def prepared_schema(self) -> type[ST]: + def _prepare_schema(self) -> type[ST]: schema = self.schema if schema is None and self.attname is not None: schema = self._guess_schema_from_annotations() - if isinstance(schema, GenericContainer): - schema = ty.cast(ty.Type[ST], GenericContainer.unwrap(schema)) if isinstance(schema, (str, ty.ForwardRef)): schema = self._resolve_schema_forward_ref(schema) @@ -126,13 +132,15 @@ def prepared_schema(self) -> type[ST]: error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" else: error_msg = "The adapter is accessed before it was bound" - raise ValueError(error_msg) + raise ImproperlyConfiguredSchema(error_msg) if self.allow_null: schema = ty.Optional[schema] return ty.cast(ty.Type[ST], schema) + prepared_schema = cached_property(_prepare_schema) + def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | None: return utils.get_annotated_type(self.parent_type, self.attname) @@ -143,7 +151,7 @@ def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any: globalns = utils.get_namespace(self.parent_type) return utils.evaluate_forward_ref(schema, globalns) - @utils.cached_property + @cached_property def _dump_python_kwargs(self) -> dict[str, ty.Any]: export_kwargs = self.export_kwargs.copy() export_kwargs.pop("strict", None) diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py index 1f3fd6a..24f1801 100644 --- a/django_pydantic_field/v2/utils.py +++ b/django_pydantic_field/v2/utils.py @@ -4,16 +4,6 @@ import typing as ty from collections import ChainMap -try: - from functools import cached_property as cached_property -except ImportError: - from django.utils.functional import cached_property as cached_property # type: ignore - -try: - from typing import get_args as get_args -except ImportError: - from typing_extensions import get_args as get_args # type: ignore - if ty.TYPE_CHECKING: from collections.abc import Mapping From 816b7ecf2823918ca2175444d7b0310015ad3443 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 21 Nov 2023 04:12:18 +0400 Subject: [PATCH 16/34] Finally make migrations perform well enough [ci skip] Also add runtime checks for possible data integrity violations --- django_pydantic_field/compat/django.py | 4 +- django_pydantic_field/v2/fields.py | 70 ++++++++---- django_pydantic_field/v2/types.py | 71 ++++++++++-- tests/sample_app/migrations/0001_initial.py | 116 +++++++++----------- tests/sample_app/models.py | 4 +- tests/test_app/migrations/0001_initial.py | 56 ++++------ tests/test_fields.py | 107 ++++++------------ 7 files changed, 226 insertions(+), 202 deletions(-) diff --git a/django_pydantic_field/compat/django.py b/django_pydantic_field/compat/django.py index 304a64c..af08042 100644 --- a/django_pydantic_field/compat/django.py +++ b/django_pydantic_field/compat/django.py @@ -40,7 +40,7 @@ def wrap(cls, typ_): @classmethod def unwrap(cls, type_): - if not isinstance(type_, GenericContainer): + if not isinstance(type_, cls): return type_ if not type_.args: @@ -129,7 +129,7 @@ class UnionTypeSerializer(BaseSerializer): def serialize(self): imports = set() - if isinstance(self.value, type(ty.Union)): # type: ignore + if isinstance(self.value, (type(ty.Union), types.UnionType)): # type: ignore imports.add("import typing") for arg in get_args(self.value): diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index beafb1a..a217cc2 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -7,7 +7,8 @@ from django.core import checks, exceptions from django.core.serializers.json import DjangoJSONEncoder -from django.db.models.expressions import BaseExpression, Col +from django.db.models.fields import NOT_PROVIDED +from django.db.models.expressions import BaseExpression, Col, Value from django.db.models.fields.json import JSONField from django.db.models.lookups import Transform from django.db.models.query_utils import DeferredAttribute @@ -19,6 +20,9 @@ class SchemaAttribute(DeferredAttribute): field: PydanticSchemaField + def __set_name__(self, owner, name): + self.field.adapter.bind(owner, name) + def __set__(self, obj, value): obj.__dict__[self.field.attname] = self.field.to_python(value) @@ -35,11 +39,10 @@ def __init__( **kwargs, ): kwargs.setdefault("encoder", DjangoJSONEncoder) - self.export_kwargs = export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) super().__init__(*args, **kwargs) - self.schema = schema + self.schema = GenericContainer.unwrap(schema) self.config = config self.adapter = types.SchemaAdapter(schema, config, None, self.get_attname(), self.null, **export_kwargs) @@ -51,7 +54,13 @@ def __copy__(self): def deconstruct(self) -> ty.Any: field_name, import_path, args, kwargs = super().deconstruct() - kwargs.update(schema=GenericContainer.wrap(self.schema), config=self.config, **self.export_kwargs) + + default = kwargs.get("default", NOT_PROVIDED) + if default is not NOT_PROVIDED and not callable(default): + kwargs["default"] = self.adapter.dump_python(default, include=None, exclude=None, round_trip=True) + + prep_schema = GenericContainer.wrap(self.adapter.prepared_schema) + kwargs.update(schema=prep_schema, config=self.config, **self.export_kwargs) return field_name, import_path, args, kwargs def contribute_to_class(self, cls: types.DjangoModelType, name: str, private_only: bool = False) -> None: @@ -59,11 +68,36 @@ def contribute_to_class(self, cls: types.DjangoModelType, name: str, private_onl super().contribute_to_class(cls, name, private_only) def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: - performed_checks = super().check(**kwargs) + # Remove checks of using mutable datastructure instances as `default` values, since they'll be adapted anyway. + performed_checks = [check for check in super().check(**kwargs) if check.id != "fields.E010"] try: self.adapter.validate_schema() except types.ImproperlyConfiguredSchema as exc: performed_checks.append(checks.Error(exc.args[0], obj=self)) + + if self.has_default(): + try: + self.get_prep_value(self.get_default()) + except pydantic.ValidationError as exc: + message = f"Default value cannot be adapted to the schema. Pydantic error: \n{str(exc)}" + performed_checks.append(checks.Error(message, obj=self, id="pydantic.E001")) + + if {"include", "exclude"} & self.export_kwargs.keys(): + schema_default = self.get_default() + if schema_default is None: + prep_value = self.adapter.type_adapter.get_default_value() + if prep_value is not None: + prep_value = prep_value.value + schema_default = prep_value + + if schema_default is not None: + try: + self.adapter.validate_python(self.get_prep_value(self.default)) + except pydantic.ValidationError as exc: + message = f"Export arguments may lead to data integrity problems. Pydantic error: \n{str(exc)}" + hint = "Please review `import` and `export` arguments." + performed_checks.append(checks.Warning(message, obj=self, hint=hint, id="pydantic.E002")) + return performed_checks def validate(self, value: ty.Any, model_instance: ty.Any) -> None: @@ -78,30 +112,25 @@ def to_python(self, value: ty.Any): raise exceptions.ValidationError(exc.json(), code="invalid", params=error_params) from exc def get_prep_value(self, value: ty.Any): - if isinstance(value, BaseExpression): - # We don't want to perform coercion on database query expressions. - return super().get_prep_value(value) + if isinstance(value, Value) and isinstance(value.output_field, self.__class__): + # Prepare inner value for `Value`-wrapped expressions. + value = Value(self.get_prep_value(value.value), value.output_field) + elif not isinstance(value, BaseExpression): + # Prepare the value if it is not a query expression. + prep_value = self.adapter.validate_python(value) + value = self.adapter.dump_python(prep_value) - prep_value = self.adapter.validate_python(value) - plain_value = self.adapter.dump_python(prep_value) - return super().get_prep_value(plain_value) + return super().get_prep_value(value) def get_transform(self, lookup_name: str): - transform: type[Transform] | SchemaKeyTransformAdapter | None - transform = super().get_transform(lookup_name) + transform: type[Transform] | SchemaKeyTransformAdapter | None = super().get_transform(lookup_name) if transform is not None: transform = SchemaKeyTransformAdapter(transform) return transform def get_default(self) -> types.ST: default_value = super().get_default() - try: - raw_value = dict(default_value) - prep_value = self.adapter.validate_python(raw_value, strict=True) - except (TypeError, ValueError): - prep_value = self.adapter.validate_python(default_value) - - return prep_value + return self.adapter.validate_python(default_value) def formfield(self, **kwargs): field_kwargs = dict( @@ -132,4 +161,5 @@ def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None: def SchemaField(schema=None, config=None, *args, **kwargs): # type: ignore + kwargs.pop("_adapter", None) return PydanticSchemaField(*args, schema=schema, config=config, **kwargs) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 657fdc9..d5e78f1 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as ty +from collections import ChainMap import pydantic import typing_extensions as te @@ -107,12 +108,14 @@ def validate_json(self, value: str | bytes, *, strict: bool | None = None) -> ST strict = self.export_kwargs.get("strict", None) return self.type_adapter.validate_json(value, strict=strict) - def dump_python(self, value: ty.Any) -> ty.Any: + def dump_python(self, value: ty.Any, **override_kwargs: ty.Unpack[ExportKwargs]) -> ty.Any: """Dump the value to a Python object.""" - return self.type_adapter.dump_python(value, **self._dump_python_kwargs) + union_kwargs = ChainMap(override_kwargs, self._dump_python_kwargs) # type: ignore + return self.type_adapter.dump_python(value, **union_kwargs) - def dump_json(self, value: ty.Any) -> bytes: - return self.type_adapter.dump_json(value, **self._dump_python_kwargs) + def dump_json(self, value: ty.Any, **override_kwargs: ty.Unpack[ExportKwargs]) -> bytes: + union_kwargs = ChainMap(override_kwargs, self._dump_python_kwargs) # type: ignore + return self.type_adapter.dump_json(value, **union_kwargs) def json_schema(self) -> ty.Any: """Return the JSON schema for the field.""" @@ -124,9 +127,10 @@ def _prepare_schema(self) -> type[ST]: if schema is None and self.attname is not None: schema = self._guess_schema_from_annotations() - if isinstance(schema, (str, ty.ForwardRef)): - schema = self._resolve_schema_forward_ref(schema) + if isinstance(schema, str): + schema = ty.ForwardRef(schema) + schema = self._resolve_schema_forward_ref(schema) if schema is None: if self.parent_type is not None and self.attname is not None: error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" @@ -135,21 +139,64 @@ def _prepare_schema(self) -> type[ST]: raise ImproperlyConfiguredSchema(error_msg) if self.allow_null: - schema = ty.Optional[schema] + schema = ty.Optional[schema] # type: ignore return ty.cast(ty.Type[ST], schema) prepared_schema = cached_property(_prepare_schema) + def __copy__(self): + instance = self.__class__( + self.schema, + self.config, + self.parent_type, + self.attname, + self.allow_null, + **self.export_kwargs, + ) + instance.__dict__.update(self.__dict__) + return instance + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(bound={self.is_bound}, schema={self.schema!r}, config={self.config!r})" + + def __eq__(self, other: ty.Any) -> bool: + if not isinstance(other, self.__class__): + return NotImplemented + + self_fields = [self.attname, self.export_kwargs] + other_fields = [other.attname, other.export_kwargs] + try: + self_fields.append(self.prepared_schema) + other_fields.append(other.prepared_schema) + except ImproperlyConfiguredSchema: + if self.is_bound and other.is_bound: + return False + else: + self_fields.extend((self.schema, self.config, self.allow_null)) + other_fields.extend((other.schema, other.config, other.allow_null)) + + return self_fields == other_fields + def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | None: return utils.get_annotated_type(self.parent_type, self.attname) - def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any: - if isinstance(schema, str): - schema = ty.ForwardRef(schema) + def _resolve_schema_forward_ref(self, schema: ty.Any) -> ty.Any: + if schema is None: + return None + + if isinstance(schema, ty.ForwardRef): + globalns = utils.get_namespace(self.parent_type) + return utils.evaluate_forward_ref(schema, globalns) + + wrapped_schema = GenericContainer.wrap(schema) + if not isinstance(wrapped_schema, GenericContainer): + return schema + + origin = self._resolve_schema_forward_ref(wrapped_schema.origin) + args = map(self._resolve_schema_forward_ref, wrapped_schema.args) + return GenericContainer.unwrap(GenericContainer(origin, tuple(args))) - globalns = utils.get_namespace(self.parent_type) - return utils.evaluate_forward_ref(schema, globalns) @cached_property def _dump_python_kwargs(self) -> dict[str, ty.Any]: diff --git a/tests/sample_app/migrations/0001_initial.py b/tests/sample_app/migrations/0001_initial.py index 6d72082..5edca8e 100644 --- a/tests/sample_app/migrations/0001_initial.py +++ b/tests/sample_app/migrations/0001_initial.py @@ -1,10 +1,11 @@ -# Generated by Django 4.2.2 on 2023-06-19 12:36 -import typing +# Generated by Django 4.2.7 on 2023-11-20 18:11 -import django_pydantic_field.fields -import tests.sample_app.models -import typing_extensions +import django.core.serializers.json from django.db import migrations, models +import django_pydantic_field.compat.django +import django_pydantic_field.v2.fields +import tests.sample_app.models +import typing class Migration(migrations.Migration): @@ -16,26 +17,19 @@ class Migration(migrations.Migration): migrations.CreateModel( name="Building", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "opt_meta", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, - default={"type": "frame"}, + default={"buildingType": "frame"}, + encoder=django.core.serializers.json.DjangoJSONEncoder, exclude={"type"}, null=True, - schema=django_pydantic_field._migration_serializers.GenericContainer( + schema=django_pydantic_field.compat.django.GenericContainer( typing.Union, ( - typing.ForwardRef("BuildingMeta"), + tests.sample_app.models.BuildingMeta, type(None), ), ), @@ -43,41 +37,47 @@ class Migration(migrations.Migration): ), ( "meta", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( + by_alias=True, config=None, - default={"type": "frame"}, + default={"buildingType": "frame"}, + encoder=django.core.serializers.json.DjangoJSONEncoder, include={"type"}, - schema="BuildingMeta", + schema=tests.sample_app.models.BuildingMeta, ), ), ( "meta_schema_list", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, default=list, - schema=typing.ForwardRef("t.List[BuildingMeta]"), + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + list, (tests.sample_app.models.BuildingMeta,) + ), ), ), ( "meta_typing_list", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, default=list, - schema=django_pydantic_field._migration_serializers.GenericContainer( - list, (typing.ForwardRef("BuildingMeta"),) + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + list, (tests.sample_app.models.BuildingMeta,) ), ), ), ( "meta_untyped_list", - django_pydantic_field.fields.PydanticSchemaField( - config=None, default=list, schema=list + django_pydantic_field.v2.fields.PydanticSchemaField( + config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "meta_untyped_builtin_list", - django_pydantic_field.fields.PydanticSchemaField( - config=None, default=list, schema=list + django_pydantic_field.v2.fields.PydanticSchemaField( + config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ], @@ -85,74 +85,64 @@ class Migration(migrations.Migration): migrations.CreateModel( name="PostponedBuilding", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "meta", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( by_alias=True, config=None, - default={"buildingType": None}, - schema="BuildingMeta", + default={"buildingType": tests.sample_app.models.BuildingTypes["FRAME"]}, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.sample_app.models.BuildingMeta, ), ), ( "meta_builtin_list", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, default=list, - schema=django_pydantic_field._migration_serializers.GenericContainer( + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( list, (tests.sample_app.models.BuildingMeta,) ), ), ), ( "meta_typing_list", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, default=list, - schema=django_pydantic_field._migration_serializers.GenericContainer( - typing.List, (typing.ForwardRef("BuildingMeta"),) + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + list, (tests.sample_app.models.BuildingMeta,) ), ), ), ( "meta_untyped_list", - django_pydantic_field.fields.PydanticSchemaField( - config=None, default=list, schema=list + django_pydantic_field.v2.fields.PydanticSchemaField( + config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "meta_untyped_builtin_list", - django_pydantic_field.fields.PydanticSchemaField( - config=None, default=list, schema=list + django_pydantic_field.v2.fields.PydanticSchemaField( + config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "nested_generics", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, - schema=django_pydantic_field._migration_serializers.GenericContainer( + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( typing.Union, ( - django_pydantic_field._migration_serializers.GenericContainer( - typing.List, - ( - django_pydantic_field._migration_serializers.GenericContainer( - typing_extensions.Literal, ("foo",) - ), - ), - ), - django_pydantic_field._migration_serializers.GenericContainer( - typing_extensions.Literal, ("bar",) + django_pydantic_field.compat.django.GenericContainer( + list, + (django_pydantic_field.compat.django.GenericContainer(typing.Literal, ("foo",)),), ), + django_pydantic_field.compat.django.GenericContainer(typing.Literal, ("bar",)), ), ), ), diff --git a/tests/sample_app/models.py b/tests/sample_app/models.py index d9e68c8..be3c465 100644 --- a/tests/sample_app/models.py +++ b/tests/sample_app/models.py @@ -14,8 +14,8 @@ class BuildingTypes(str, enum.Enum): class Building(models.Model): - opt_meta: t.Optional["BuildingMeta"] = SchemaField(default={"type": "frame"}, exclude={"type"}, null=True) - meta: "BuildingMeta" = SchemaField(default={"type": "frame"}, include={"type"}) + opt_meta: t.Optional["BuildingMeta"] = SchemaField(default={"buildingType": "frame"}, exclude={"type"}, null=True) + meta: "BuildingMeta" = SchemaField(default={"buildingType": "frame"}, include={"type"}, by_alias=True) meta_schema_list = SchemaField(schema=t.ForwardRef("t.List[BuildingMeta]"), default=list) meta_typing_list: t.List["BuildingMeta"] = SchemaField(default=list) diff --git a/tests/test_app/migrations/0001_initial.py b/tests/test_app/migrations/0001_initial.py index 00a6f1c..1a41b53 100644 --- a/tests/test_app/migrations/0001_initial.py +++ b/tests/test_app/migrations/0001_initial.py @@ -1,10 +1,11 @@ -# Generated by Django 4.1.7 on 2023-02-27 14:41 +# Generated by Django 4.2.7 on 2023-11-20 18:11 +import django.core.serializers.json from django.db import migrations, models -import django_pydantic_field._migration_serializers -import django_pydantic_field.fields +import django_pydantic_field.compat.django +import django_pydantic_field.v2.fields import tests.conftest -import typing +import tests.test_app.models class Migration(migrations.Migration): @@ -16,27 +17,23 @@ class Migration(migrations.Migration): migrations.CreateModel( name="SampleForwardRefModel", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "annotated_field", - django_pydantic_field.fields.PydanticSchemaField( - config=None, default=dict, schema="SampleSchema" + django_pydantic_field.v2.fields.PydanticSchemaField( + config=None, + default=dict, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.test_app.models.SampleSchema, ), ), ( "field", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, default=dict, - schema=typing.ForwardRef("SampleSchema"), + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.test_app.models.SampleSchema, ), ), ], @@ -44,37 +41,32 @@ class Migration(migrations.Migration): migrations.CreateModel( name="SampleModel", fields=[ - ( - "id", - models.BigAutoField( - auto_created=True, - primary_key=True, - serialize=False, - verbose_name="ID", - ), - ), + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "sample_field", - django_pydantic_field.fields.PydanticSchemaField( - config={"allow_mutation": False, "frozen": True}, + django_pydantic_field.v2.fields.PydanticSchemaField( + config=None, + encoder=django.core.serializers.json.DjangoJSONEncoder, schema=tests.conftest.InnerSchema, ), ), ( "sample_list", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, - schema=django_pydantic_field._migration_serializers.GenericContainer( + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( list, (tests.conftest.InnerSchema,) ), ), ), ( "sample_seq", - django_pydantic_field.fields.PydanticSchemaField( + django_pydantic_field.v2.fields.PydanticSchemaField( config=None, default=list, - schema=django_pydantic_field._migration_serializers.GenericContainer( + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( list, (tests.conftest.InnerSchema,) ), ), diff --git a/tests/test_fields.py b/tests/test_fields.py index 79fe2ae..a993ac4 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,6 +1,6 @@ import json import sys -import typing as t +import typing as ty from collections import abc from copy import copy from datetime import date @@ -43,7 +43,7 @@ def test_null_field(): assert field.to_python(None) is None assert field.get_prep_value(None) is None - field = fields.SchemaField(t.Optional[InnerSchema], null=True, default=None) + field = fields.SchemaField(ty.Optional[InnerSchema], null=True, default=None) assert field.get_prep_value(None) is None @@ -54,7 +54,12 @@ def test_forwardrefs_deferred_resolution(): @pytest.mark.parametrize( - "forward_ref", ["InnerSchema", t.ForwardRef("SampleDataclass"), t.List["int"]] + "forward_ref", + [ + "InnerSchema", + ty.ForwardRef("SampleDataclass"), + ty.List["int"], + ], ) def test_resolved_forwardrefs(forward_ref): class ModelWithForwardRefs(models.Model): @@ -84,55 +89,41 @@ class Meta: schema=SampleDataclass, default={"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}, ), - fields.PydanticSchemaField( - schema=t.Optional[InnerSchema], null=True, default=None - ), + fields.PydanticSchemaField(schema=ty.Optional[InnerSchema], null=True, default=None), ], ) def test_field_serialization(field): _test_field_serialization(field) -@pytest.mark.skipif( - sys.version_info < (3, 9), reason="Built-in type subscription supports only in 3.9+" -) +@pytest.mark.skipif(sys.version_info < (3, 9), reason="Built-in type subscription supports only in 3.9+") @pytest.mark.parametrize( "field_factory", [ lambda: fields.PydanticSchemaField(schema=list[InnerSchema], default=list), - lambda: fields.PydanticSchemaField(schema=dict[str, InnerSchema], default=list), - lambda: fields.PydanticSchemaField( - schema=abc.Sequence[InnerSchema], default=list - ), - lambda: fields.PydanticSchemaField( - schema=abc.Mapping[str, InnerSchema], default=dict - ), + lambda: fields.PydanticSchemaField(schema=dict[str, InnerSchema], default=dict), + lambda: fields.PydanticSchemaField(schema=abc.Sequence[InnerSchema], default=list), + lambda: fields.PydanticSchemaField(schema=abc.Mapping[str, InnerSchema], default=dict), ], ) def test_field_builtin_annotations_serialization(field_factory): _test_field_serialization(field_factory()) -@pytest.mark.skipif( - sys.version_info < (3, 10), reason="Union type syntax supported only in 3.10+" -) +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Union type syntax supported only in 3.10+") def test_field_union_type_serialization(): - field = fields.PydanticSchemaField( - schema=(InnerSchema | None), null=True, default=None - ) + field = fields.PydanticSchemaField(schema=(InnerSchema | None), null=True, default=None) _test_field_serialization(field) -@pytest.mark.skipif( - sys.version_info >= (3, 9), reason="Should test against builtin generic types" -) +@pytest.mark.skipif(sys.version_info >= (3, 9), reason="Should test against builtin generic types") @pytest.mark.parametrize( "field", [ - fields.PydanticSchemaField(schema=t.List[InnerSchema], default=list), - fields.PydanticSchemaField(schema=t.Dict[str, InnerSchema], default=list), - fields.PydanticSchemaField(schema=t.Sequence[InnerSchema], default=list), - fields.PydanticSchemaField(schema=t.Mapping[str, InnerSchema], default=dict), + fields.PydanticSchemaField(schema=ty.List[InnerSchema], default=list), + fields.PydanticSchemaField(schema=ty.Dict[str, InnerSchema], default=list), + fields.PydanticSchemaField(schema=ty.Sequence[InnerSchema], default=list), + fields.PydanticSchemaField(schema=ty.Mapping[str, InnerSchema], default=dict), ], ) def test_field_typing_annotations_serialization(field): @@ -147,42 +138,24 @@ def test_field_typing_annotations_serialization(field): "old_field, new_field", [ ( - lambda: fields.PydanticSchemaField( - schema=t.List[InnerSchema], default=list - ), + lambda: fields.PydanticSchemaField(schema=ty.List[InnerSchema], default=list), lambda: fields.PydanticSchemaField(schema=list[InnerSchema], default=list), ), ( - lambda: fields.PydanticSchemaField( - schema=t.Dict[str, InnerSchema], default=list - ), - lambda: fields.PydanticSchemaField( - schema=dict[str, InnerSchema], default=list - ), + lambda: fields.PydanticSchemaField(schema=ty.Dict[str, InnerSchema], default=dict), + lambda: fields.PydanticSchemaField(schema=dict[str, InnerSchema], default=dict), ), ( - lambda: fields.PydanticSchemaField( - schema=t.Sequence[InnerSchema], default=list - ), - lambda: fields.PydanticSchemaField( - schema=abc.Sequence[InnerSchema], default=list - ), + lambda: fields.PydanticSchemaField(schema=ty.Sequence[InnerSchema], default=list), + lambda: fields.PydanticSchemaField(schema=abc.Sequence[InnerSchema], default=list), ), ( - lambda: fields.PydanticSchemaField( - schema=t.Mapping[str, InnerSchema], default=dict - ), - lambda: fields.PydanticSchemaField( - schema=abc.Mapping[str, InnerSchema], default=dict - ), + lambda: fields.PydanticSchemaField(schema=ty.Mapping[str, InnerSchema], default=dict), + lambda: fields.PydanticSchemaField(schema=abc.Mapping[str, InnerSchema], default=dict), ), ( - lambda: fields.PydanticSchemaField( - schema=t.Mapping[str, InnerSchema], default=dict - ), - lambda: fields.PydanticSchemaField( - schema=abc.Mapping[str, InnerSchema], default=dict - ), + lambda: fields.PydanticSchemaField(schema=ty.Mapping[str, InnerSchema], default=dict), + lambda: fields.PydanticSchemaField(schema=abc.Mapping[str, InnerSchema], default=dict), ), ], ) @@ -192,19 +165,11 @@ def test_field_typing_to_builtin_serialization(old_field, new_field): _, _, args, kwargs = old_field.deconstruct() reconstructed_field = fields.PydanticSchemaField(*args, **kwargs) - assert ( - old_field.get_default() - == new_field.get_default() - == reconstructed_field.get_default() - ) + assert old_field.get_default() == new_field.get_default() == reconstructed_field.get_default() assert new_field.schema == reconstructed_field.schema deserialized_field = reconstruct_field(serialize_field(old_field)) - assert ( - old_field.get_default() - == deserialized_field.get_default() - == new_field.get_default() - ) + assert old_field.get_default() == deserialized_field.get_default() == new_field.get_default() assert new_field.schema == deserialized_field.schema @@ -212,8 +177,8 @@ def test_field_typing_to_builtin_serialization(old_field, new_field): "field, flawed_data", [ (fields.PydanticSchemaField(schema=InnerSchema), {}), - (fields.PydanticSchemaField(schema=t.List[InnerSchema]), [{}]), - (fields.PydanticSchemaField(schema=t.Dict[int, float]), {"1": "abc"}), + (fields.PydanticSchemaField(schema=ty.List[InnerSchema]), [{}]), + (fields.PydanticSchemaField(schema=ty.Dict[int, float]), {"1": "abc"}), ], ) def test_field_validation_exceptions(field, flawed_data): @@ -260,15 +225,15 @@ def test_export_kwargs_support(export_kwargs): def _test_field_serialization(field): - _, _, args, kwargs = field.deconstruct() + _, _, args, kwargs = field_data = field.deconstruct() reconstructed_field = fields.PydanticSchemaField(*args, **kwargs) assert field.get_default() == reconstructed_field.get_default() - assert field.schema == reconstructed_field.schema + assert reconstructed_field.deconstruct() == field_data deserialized_field = reconstruct_field(serialize_field(field)) assert deserialized_field.get_default() == field.get_default() - assert field.schema == deserialized_field.schema + assert deserialized_field.deconstruct() == field_data def serialize_field(field: fields.PydanticSchemaField) -> str: From dc5bf37177918c1e1e67367221903e391ec8833b Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 21 Nov 2023 04:20:31 +0400 Subject: [PATCH 17/34] Normalize import paths on field deconstruction for v1/v2 [ci skip] --- django_pydantic_field/v1/fields.py | 3 +++ django_pydantic_field/v2/fields.py | 3 +++ tests/sample_app/migrations/0001_initial.py | 28 ++++++++++----------- tests/test_app/migrations/0001_initial.py | 14 +++++------ 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/django_pydantic_field/v1/fields.py b/django_pydantic_field/v1/fields.py index b85041e..b9cc59b 100644 --- a/django_pydantic_field/v1/fields.py +++ b/django_pydantic_field/v1/fields.py @@ -84,6 +84,9 @@ def get_prep_value(self, value): def deconstruct(self): name, path, args, kwargs = super().deconstruct() + if path.startswith("django_pydantic_field.v1."): + path = path.replace("django_pydantic_field.v1", "django_pydantic_field", 1) + self._deconstruct_schema(kwargs) self._deconstruct_default(kwargs) self._deconstruct_config(kwargs) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index a217cc2..452ab7a 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -54,6 +54,8 @@ def __copy__(self): def deconstruct(self) -> ty.Any: field_name, import_path, args, kwargs = super().deconstruct() + if import_path.startswith("django_pydantic_field.v2."): + import_path = import_path.replace("django_pydantic_field.v2", "django_pydantic_field", 1) default = kwargs.get("default", NOT_PROVIDED) if default is not NOT_PROVIDED and not callable(default): @@ -61,6 +63,7 @@ def deconstruct(self) -> ty.Any: prep_schema = GenericContainer.wrap(self.adapter.prepared_schema) kwargs.update(schema=prep_schema, config=self.config, **self.export_kwargs) + return field_name, import_path, args, kwargs def contribute_to_class(self, cls: types.DjangoModelType, name: str, private_only: bool = False) -> None: diff --git a/tests/sample_app/migrations/0001_initial.py b/tests/sample_app/migrations/0001_initial.py index 5edca8e..1c5047c 100644 --- a/tests/sample_app/migrations/0001_initial.py +++ b/tests/sample_app/migrations/0001_initial.py @@ -1,9 +1,9 @@ -# Generated by Django 4.2.7 on 2023-11-20 18:11 +# Generated by Django 4.2.7 on 2023-11-20 18:19 import django.core.serializers.json from django.db import migrations, models import django_pydantic_field.compat.django -import django_pydantic_field.v2.fields +import django_pydantic_field.fields import tests.sample_app.models import typing @@ -20,7 +20,7 @@ class Migration(migrations.Migration): ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "opt_meta", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default={"buildingType": "frame"}, encoder=django.core.serializers.json.DjangoJSONEncoder, @@ -37,7 +37,7 @@ class Migration(migrations.Migration): ), ( "meta", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( by_alias=True, config=None, default={"buildingType": "frame"}, @@ -48,7 +48,7 @@ class Migration(migrations.Migration): ), ( "meta_schema_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, @@ -59,7 +59,7 @@ class Migration(migrations.Migration): ), ( "meta_typing_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, @@ -70,13 +70,13 @@ class Migration(migrations.Migration): ), ( "meta_untyped_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "meta_untyped_builtin_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), @@ -88,7 +88,7 @@ class Migration(migrations.Migration): ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "meta", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( by_alias=True, config=None, default={"buildingType": tests.sample_app.models.BuildingTypes["FRAME"]}, @@ -98,7 +98,7 @@ class Migration(migrations.Migration): ), ( "meta_builtin_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, @@ -109,7 +109,7 @@ class Migration(migrations.Migration): ), ( "meta_typing_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, @@ -120,19 +120,19 @@ class Migration(migrations.Migration): ), ( "meta_untyped_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "meta_untyped_builtin_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=list ), ), ( "nested_generics", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=django_pydantic_field.compat.django.GenericContainer( diff --git a/tests/test_app/migrations/0001_initial.py b/tests/test_app/migrations/0001_initial.py index 1a41b53..1659b62 100644 --- a/tests/test_app/migrations/0001_initial.py +++ b/tests/test_app/migrations/0001_initial.py @@ -1,9 +1,9 @@ -# Generated by Django 4.2.7 on 2023-11-20 18:11 +# Generated by Django 4.2.7 on 2023-11-20 18:19 import django.core.serializers.json from django.db import migrations, models import django_pydantic_field.compat.django -import django_pydantic_field.v2.fields +import django_pydantic_field.fields import tests.conftest import tests.test_app.models @@ -20,7 +20,7 @@ class Migration(migrations.Migration): ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "annotated_field", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=dict, encoder=django.core.serializers.json.DjangoJSONEncoder, @@ -29,7 +29,7 @@ class Migration(migrations.Migration): ), ( "field", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=dict, encoder=django.core.serializers.json.DjangoJSONEncoder, @@ -44,7 +44,7 @@ class Migration(migrations.Migration): ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), ( "sample_field", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=tests.conftest.InnerSchema, @@ -52,7 +52,7 @@ class Migration(migrations.Migration): ), ( "sample_list", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, encoder=django.core.serializers.json.DjangoJSONEncoder, schema=django_pydantic_field.compat.django.GenericContainer( @@ -62,7 +62,7 @@ class Migration(migrations.Migration): ), ( "sample_seq", - django_pydantic_field.v2.fields.PydanticSchemaField( + django_pydantic_field.fields.PydanticSchemaField( config=None, default=list, encoder=django.core.serializers.json.DjangoJSONEncoder, From a58bcdd11f423240ce1a15fb1243db9c71d3538c Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 22 Nov 2023 00:05:59 +0400 Subject: [PATCH 18/34] Make all tests pass on v1 & v2 --- django_pydantic_field/v1/forms.py | 8 +++-- django_pydantic_field/v2/fields.py | 2 +- django_pydantic_field/v2/types.py | 4 +-- tests/sample_app/migrations/0001_initial.py | 13 +++++-- tests/test_app/migrations/0001_initial.py | 2 +- tests/test_fields.py | 40 ++++++++++++++++----- 6 files changed, 51 insertions(+), 18 deletions(-) diff --git a/django_pydantic_field/v1/forms.py b/django_pydantic_field/v1/forms.py index 3185aad..6cfaad4 100644 --- a/django_pydantic_field/v1/forms.py +++ b/django_pydantic_field/v1/forms.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import typing as t from functools import partial @@ -15,6 +17,8 @@ class SchemaField(JSONField, t.Generic[base.ST]): default_error_messages = { "schema_error": _("Schema didn't match. Detail: %(detail)s"), } + decoder: partial[base.SchemaDecoder] + encoder: partial[base.SchemaEncoder] def __init__( self, @@ -30,8 +34,8 @@ def __init__( __module__=__module__, ) export_params = base.extract_export_kwargs(kwargs, dict.pop) - decoder = partial(base.SchemaDecoder, self.schema) - encoder = partial( + decoder: partial[base.SchemaDecoder] = partial(base.SchemaDecoder, self.schema) + encoder: partial[base.SchemaEncoder] = partial( base.SchemaEncoder, schema=self.schema, export=export_params, diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 452ab7a..7568cb6 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -126,7 +126,7 @@ def get_prep_value(self, value: ty.Any): return super().get_prep_value(value) def get_transform(self, lookup_name: str): - transform: type[Transform] | SchemaKeyTransformAdapter | None = super().get_transform(lookup_name) + transform: ty.Any = super().get_transform(lookup_name) if transform is not None: transform = SchemaKeyTransformAdapter(transform) return transform diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index d5e78f1..593ac00 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -164,8 +164,8 @@ def __eq__(self, other: ty.Any) -> bool: if not isinstance(other, self.__class__): return NotImplemented - self_fields = [self.attname, self.export_kwargs] - other_fields = [other.attname, other.export_kwargs] + self_fields: list[ty.Any] = [self.attname, self.export_kwargs] + other_fields: list[ty.Any] = [other.attname, other.export_kwargs] try: self_fields.append(self.prepared_schema) other_fields.append(other.prepared_schema) diff --git a/tests/sample_app/migrations/0001_initial.py b/tests/sample_app/migrations/0001_initial.py index 1c5047c..78db57b 100644 --- a/tests/sample_app/migrations/0001_initial.py +++ b/tests/sample_app/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.7 on 2023-11-20 18:19 +# Generated by Django 3.2.23 on 2023-11-21 14:37 import django.core.serializers.json from django.db import migrations, models @@ -6,6 +6,7 @@ import django_pydantic_field.fields import tests.sample_app.models import typing +import typing_extensions class Migration(migrations.Migration): @@ -140,9 +141,15 @@ class Migration(migrations.Migration): ( django_pydantic_field.compat.django.GenericContainer( list, - (django_pydantic_field.compat.django.GenericContainer(typing.Literal, ("foo",)),), + ( + django_pydantic_field.compat.django.GenericContainer( + typing_extensions.Literal, ("foo",) + ), + ), + ), + django_pydantic_field.compat.django.GenericContainer( + typing_extensions.Literal, ("bar",) ), - django_pydantic_field.compat.django.GenericContainer(typing.Literal, ("bar",)), ), ), ), diff --git a/tests/test_app/migrations/0001_initial.py b/tests/test_app/migrations/0001_initial.py index 1659b62..80a3436 100644 --- a/tests/test_app/migrations/0001_initial.py +++ b/tests/test_app/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2.7 on 2023-11-20 18:19 +# Generated by Django 3.2.23 on 2023-11-21 14:37 import django.core.serializers.json from django.db import migrations, models diff --git a/tests/test_fields.py b/tests/test_fields.py index a993ac4..621e9aa 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -5,13 +5,16 @@ from copy import copy from datetime import date +import pydantic import pytest from django.core.exceptions import ValidationError -from django.db import models, connection +from django.db import connection, models from django.db.migrations.writer import MigrationWriter + from django_pydantic_field import fields +from django_pydantic_field.compat.pydantic import PYDANTIC_V1, PYDANTIC_V2 -from .conftest import InnerSchema, SampleDataclass +from .conftest import InnerSchema, SampleDataclass # noqa from .sample_app.models import Building from .test_app.models import SampleForwardRefModel, SampleModel, SampleSchema @@ -76,10 +79,6 @@ class Meta: schema=InnerSchema, default=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), ), - fields.PydanticSchemaField( - schema=InnerSchema, - default=(("stub_str", "abc"), ("stub_list", [date(2022, 7, 1)])), - ), fields.PydanticSchemaField( schema=InnerSchema, default={"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}, @@ -90,6 +89,17 @@ class Meta: default={"stub_str": "abc", "stub_list": [date(2022, 7, 1)]}, ), fields.PydanticSchemaField(schema=ty.Optional[InnerSchema], null=True, default=None), + pytest.param( + fields.PydanticSchemaField( + schema=InnerSchema, + default=(("stub_str", "abc"), ("stub_list", [date(2022, 7, 1)])), + ), + marks=pytest.mark.xfail( + PYDANTIC_V2, + reason="Tuple-based default reconstruction is not supported with Pydantic 2", + raises=pydantic.ValidationError, + ), + ), ], ) def test_field_serialization(field): @@ -121,7 +131,7 @@ def test_field_union_type_serialization(): "field", [ fields.PydanticSchemaField(schema=ty.List[InnerSchema], default=list), - fields.PydanticSchemaField(schema=ty.Dict[str, InnerSchema], default=list), + fields.PydanticSchemaField(schema=ty.Dict[str, InnerSchema], default=dict), fields.PydanticSchemaField(schema=ty.Sequence[InnerSchema], default=list), fields.PydanticSchemaField(schema=ty.Mapping[str, InnerSchema], default=dict), ], @@ -229,11 +239,23 @@ def _test_field_serialization(field): reconstructed_field = fields.PydanticSchemaField(*args, **kwargs) assert field.get_default() == reconstructed_field.get_default() - assert reconstructed_field.deconstruct() == field_data + + if PYDANTIC_V2: + assert reconstructed_field.deconstruct() == field_data + elif PYDANTIC_V1: + assert reconstructed_field.schema == field.schema + else: + pytest.fail("Unsupported Pydantic version") deserialized_field = reconstruct_field(serialize_field(field)) assert deserialized_field.get_default() == field.get_default() - assert deserialized_field.deconstruct() == field_data + + if PYDANTIC_V2: + assert deserialized_field.deconstruct() == field_data + elif PYDANTIC_V1: + assert deserialized_field.schema == field.schema + else: + pytest.fail("Unsupported Pydantic version") def serialize_field(field: fields.PydanticSchemaField) -> str: From 48afa3490a98416336e3ff9102225fb26671fc5d Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 22 Nov 2023 01:09:42 +0400 Subject: [PATCH 19/34] Add warnings about deprecated `json.dumps(...)`-passed kwargs [ci skip] --- django_pydantic_field/compat/deprecation.py | 23 +++++++++++++++++++++ django_pydantic_field/v2/fields.py | 9 ++++---- django_pydantic_field/v2/forms.py | 3 +++ django_pydantic_field/v2/rest_framework.py | 5 ++++- 4 files changed, 34 insertions(+), 6 deletions(-) create mode 100644 django_pydantic_field/compat/deprecation.py diff --git a/django_pydantic_field/compat/deprecation.py b/django_pydantic_field/compat/deprecation.py new file mode 100644 index 0000000..fc156a0 --- /dev/null +++ b/django_pydantic_field/compat/deprecation.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import typing as ty +import warnings + +_NOT_PROVIDED = object() +_DEPRECATED_KWARGS = ( + "allow_nan", + "indent", + "separators", + "skipkeys", + "sort_keys", +) +_DEPRECATED_KWARGS_MESSAGE = ( + "The `%s=` argument is not supported by Pydantic v2 and will be removed in the future versions." +) + + +def truncate_deprecated_v1_export_kwargs(kwargs: dict[str, ty.Any]) -> None: + for kwarg in _DEPRECATED_KWARGS: + maybe_present_kwarg = kwargs.pop(kwarg, _NOT_PROVIDED) + if maybe_present_kwarg is not _NOT_PROVIDED: + warnings.warn(_DEPRECATED_KWARGS_MESSAGE % kwarg, DeprecationWarning, stacklevel=2) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 7568cb6..b5f6cc8 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -3,18 +3,17 @@ import typing as ty import pydantic - from django.core import checks, exceptions from django.core.serializers.json import DjangoJSONEncoder - -from django.db.models.fields import NOT_PROVIDED from django.db.models.expressions import BaseExpression, Col, Value +from django.db.models.fields import NOT_PROVIDED from django.db.models.fields.json import JSONField from django.db.models.lookups import Transform from django.db.models.query_utils import DeferredAttribute -from . import types, forms +from ..compat.deprecation import truncate_deprecated_v1_export_kwargs from ..compat.django import GenericContainer +from . import forms, types class SchemaAttribute(DeferredAttribute): @@ -164,5 +163,5 @@ def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None: def SchemaField(schema=None, config=None, *args, **kwargs): # type: ignore - kwargs.pop("_adapter", None) + truncate_deprecated_v1_export_kwargs(kwargs) return PydanticSchemaField(*args, schema=schema, config=config, **kwargs) diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py index 7f5d656..b162a1b 100644 --- a/django_pydantic_field/v2/forms.py +++ b/django_pydantic_field/v2/forms.py @@ -7,6 +7,7 @@ from django.forms.fields import InvalidJSONInput, JSONField, JSONString from django.utils.translation import gettext_lazy as _ +from ..compat.deprecation import truncate_deprecated_v1_export_kwargs from . import types @@ -24,6 +25,8 @@ def __init__( *args, **kwargs, ): + truncate_deprecated_v1_export_kwargs(kwargs) + self.schema = schema self.config = config self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) diff --git a/django_pydantic_field/v2/rest_framework.py b/django_pydantic_field/v2/rest_framework.py index 7684c71..8436590 100644 --- a/django_pydantic_field/v2/rest_framework.py +++ b/django_pydantic_field/v2/rest_framework.py @@ -5,11 +5,13 @@ import pydantic from rest_framework import exceptions, fields, parsers, renderers -from . import types +from ..compat.deprecation import truncate_deprecated_v1_export_kwargs from ..compat.typing import get_args +from . import types if ty.TYPE_CHECKING: from collections.abc import Mapping + from rest_framework.serializers import BaseSerializer RequestResponseContext = Mapping[str, ty.Any] @@ -26,6 +28,7 @@ def __init__( allow_null: bool = False, **kwargs, ): + truncate_deprecated_v1_export_kwargs(kwargs) self.schema = schema self.config = config self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) From 7fe3d781b503b37db5ced7a67e876b30db5b28b3 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 22 Nov 2023 01:22:24 +0400 Subject: [PATCH 20/34] Add some explanatory comments [ci skip] --- django_pydantic_field/v2/fields.py | 12 +++++++++--- django_pydantic_field/v2/types.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index b5f6cc8..43c1041 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -73,20 +73,25 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: # Remove checks of using mutable datastructure instances as `default` values, since they'll be adapted anyway. performed_checks = [check for check in super().check(**kwargs) if check.id != "fields.E010"] try: + # Test that the schema could be resolved in runtime, even if it contains forward references. self.adapter.validate_schema() except types.ImproperlyConfiguredSchema as exc: - performed_checks.append(checks.Error(exc.args[0], obj=self)) + message = f"Cannot resolve the schema. Original error: \n{exc.args[0]}" + performed_checks.append(checks.Error(message, obj=self, id="pydantic.E001")) if self.has_default(): try: + # Test that the default value conforms to the schema. self.get_prep_value(self.get_default()) except pydantic.ValidationError as exc: message = f"Default value cannot be adapted to the schema. Pydantic error: \n{str(exc)}" - performed_checks.append(checks.Error(message, obj=self, id="pydantic.E001")) + performed_checks.append(checks.Error(message, obj=self, id="pydantic.E002")) if {"include", "exclude"} & self.export_kwargs.keys(): + # Try to prepare the default value to test export ability against it. schema_default = self.get_default() if schema_default is None: + # If the default value is not set, try to get the default value from the schema. prep_value = self.adapter.type_adapter.get_default_value() if prep_value is not None: prep_value = prep_value.value @@ -94,11 +99,12 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: if schema_default is not None: try: + # Perform the full round-trip transformation to test the export ability. self.adapter.validate_python(self.get_prep_value(self.default)) except pydantic.ValidationError as exc: message = f"Export arguments may lead to data integrity problems. Pydantic error: \n{str(exc)}" hint = "Please review `import` and `export` arguments." - performed_checks.append(checks.Warning(message, obj=self, hint=hint, id="pydantic.E002")) + performed_checks.append(checks.Warning(message, obj=self, hint=hint, id="pydantic.E003")) return performed_checks diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 593ac00..6530948 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -46,7 +46,7 @@ class ExportKwargs(te.TypedDict, total=False): class ImproperlyConfiguredSchema(ValueError): - """Raised when a schema is improperly configured.""" + """Raised when the schema is improperly configured.""" class SchemaAdapter(ty.Generic[ST]): From 4c16036e20eaa8aabbe498a606ed95049503ab7f Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Sat, 25 Nov 2023 20:19:06 +0400 Subject: [PATCH 21/34] Add tests for v2 `SchemaAdapter`. --- django_pydantic_field/v2/types.py | 40 +++++++-- tests/v2/test_types.py | 141 ++++++++++++++++++++++++++++++ 2 files changed, 176 insertions(+), 5 deletions(-) create mode 100644 tests/v2/test_types.py diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 6530948..1aaa787 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -66,8 +66,31 @@ def __init__( self.allow_null = allow_null self.export_kwargs = export_kwargs + @classmethod + def from_type( + cls, + schema: ty.Any, + config: pydantic.ConfigDict | None = None, + **kwargs: ty.Unpack[ExportKwargs], + ) -> SchemaAdapter[ST]: + """Create an adapter from a type.""" + return cls(schema, config, None, None, **kwargs) + + @classmethod + def from_annotation( + cls, + parent_type: type, + attname: str, + config: pydantic.ConfigDict | None = None, + **kwargs: ty.Unpack[ExportKwargs], + ) -> SchemaAdapter[ST]: + """Create an adapter from a type annotation.""" + return cls(None, config, parent_type, attname, **kwargs) + @staticmethod def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: + """Extract the export kwargs from the kwargs passed to the field. + This method mutates passed kwargs by removing those that are used by the adapter.""" common_keys = kwargs.keys() & ExportKwargs.__annotations__.keys() export_kwargs = {key: kwargs.pop(key) for key in common_keys} return ty.cast(ExportKwargs, export_kwargs) @@ -78,9 +101,11 @@ def type_adapter(self) -> pydantic.TypeAdapter: @property def is_bound(self) -> bool: + """Return True if the adapter is bound to a specific attribute of a `parent_type`.""" return self.parent_type is not None and self.attname is not None def bind(self, parent_type: type, attname: str) -> None: + """Bind the adapter to specific attribute of a `parent_type`.""" self.parent_type = parent_type self.attname = attname self.__dict__.pop("prepared_schema", None) @@ -123,19 +148,24 @@ def json_schema(self) -> ty.Any: return self.type_adapter.json_schema(by_alias=by_alias) def _prepare_schema(self) -> type[ST]: + """Prepare the schema for the adapter. + + This method is called by `prepared_schema` property and should not be called directly. + The intent is to resolve the real schema from an annotations or a forward references. + """ schema = self.schema - if schema is None and self.attname is not None: + if schema is None and self.is_bound: schema = self._guess_schema_from_annotations() if isinstance(schema, str): schema = ty.ForwardRef(schema) schema = self._resolve_schema_forward_ref(schema) if schema is None: - if self.parent_type is not None and self.attname is not None: - error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" + if self.is_bound: + error_msg = f"Annotation is not provided for {self.parent_type.__name__}.{self.attname}" # type: ignore[union-attr] else: - error_msg = "The adapter is accessed before it was bound" + error_msg = "Cannot resolve the schema. The adapter is accessed before it was bound." raise ImproperlyConfiguredSchema(error_msg) if self.allow_null: @@ -197,9 +227,9 @@ def _resolve_schema_forward_ref(self, schema: ty.Any) -> ty.Any: args = map(self._resolve_schema_forward_ref, wrapped_schema.args) return GenericContainer.unwrap(GenericContainer(origin, tuple(args))) - @cached_property def _dump_python_kwargs(self) -> dict[str, ty.Any]: export_kwargs = self.export_kwargs.copy() export_kwargs.pop("strict", None) + export_kwargs.pop("from_attributes", None) return ty.cast(dict, export_kwargs) diff --git a/tests/v2/test_types.py b/tests/v2/test_types.py new file mode 100644 index 0000000..891ad24 --- /dev/null +++ b/tests/v2/test_types.py @@ -0,0 +1,141 @@ +import sys +import pydantic +import pytest +import typing as ty + +from ..conftest import InnerSchema, SampleDataclass + +types = pytest.importorskip("django_pydantic_field.v2.types") +skip_unsupported_builtin_subscription = pytest.mark.skipif(sys.version_info < (3, 9), reason="Built-in type subscription supports only in 3.9+") + + +@pytest.mark.parametrize( + "ctor, args, kwargs", + [ + pytest.param(types.SchemaAdapter, ["list[int]", None, None, None], {}, marks=skip_unsupported_builtin_subscription), + pytest.param(types.SchemaAdapter, ["list[int]", {"strict": True}, None, None], {}, marks=skip_unsupported_builtin_subscription), + (types.SchemaAdapter, [ty.List[int], None, None, None], {}), + (types.SchemaAdapter, [ty.List[int], {"strict": True}, None, None], {}), + (types.SchemaAdapter, [None, None, InnerSchema, "stub_int"], {}), + (types.SchemaAdapter, [None, None, SampleDataclass, "stub_int"], {}), + pytest.param(types.SchemaAdapter.from_type, ["list[int]"], {}, marks=skip_unsupported_builtin_subscription), + pytest.param(types.SchemaAdapter.from_type, ["list[int]", {"strict": True}], {}, marks=skip_unsupported_builtin_subscription), + (types.SchemaAdapter.from_type, [ty.List[int]], {}), + (types.SchemaAdapter.from_type, [ty.List[int], {"strict": True}], {}), + (types.SchemaAdapter.from_annotation, [InnerSchema, "stub_int"], {}), + (types.SchemaAdapter.from_annotation, [InnerSchema, "stub_int", {"strict": True}], {}), + (types.SchemaAdapter.from_annotation, [SampleDataclass, "stub_int"], {}), + (types.SchemaAdapter.from_annotation, [SampleDataclass, "stub_int", {"strict": True}], {}), + ] +) +def test_schema_adapter_constructors(ctor, args, kwargs): + adapter = ctor(*args, **kwargs) + adapter.validate_schema() + assert isinstance(adapter.type_adapter, pydantic.TypeAdapter) + + +def test_schema_adapter_is_bound(): + adapter = types.SchemaAdapter(None, None, None, None) + with pytest.raises(types.ImproperlyConfiguredSchema): + adapter.validate_schema() # Schema cannot be resolved for fully unbound adapter + + adapter = types.SchemaAdapter(ty.List[int], None, None, None) + assert not adapter.is_bound, "SchemaAdapter should not be bound" + adapter.validate_schema() # Schema should be resolved from direct argument + + adapter.bind(InnerSchema, "stub_int") + assert adapter.is_bound, "SchemaAdapter should be bound" + adapter.validate_schema() # Schema should be resolved from direct argument + + adapter = types.SchemaAdapter(None, None, InnerSchema, "stub_int") + assert adapter.is_bound, "SchemaAdapter should be bound" + adapter.validate_schema() # Schema should be resolved from bound attribute + + +@pytest.mark.parametrize( + "kwargs, expected_export_kwargs", + [ + ({}, {}), + ({"strict": True}, {"strict": True}), + ({"strict": True, "by_alias": False}, {"strict": True, "by_alias": False}), + ({"strict": True, "from_attributes": False, "on_delete": "CASCADE"}, {"strict": True, "from_attributes": False}), + ], +) +def test_schema_adapter_extract_export_kwargs(kwargs, expected_export_kwargs): + orig_kwargs = dict(kwargs) + assert types.SchemaAdapter.extract_export_kwargs(kwargs) == expected_export_kwargs + assert kwargs == {key: orig_kwargs[key] for key in orig_kwargs.keys() - expected_export_kwargs.keys()} + + +def test_schema_adapter_validate_python(): + adapter = types.SchemaAdapter.from_type(ty.List[int]) + assert adapter.validate_python([1, 2, 3]) == [1, 2, 3] + assert adapter.validate_python([1, 2, 3], strict=True) == [1, 2, 3] + assert adapter.validate_python([1, 2, 3], strict=False) == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {"strict": True}) + assert adapter.validate_python([1, 2, 3]) == [1, 2, 3] + assert adapter.validate_python(["1", "2", "3"], strict=False) == [1, 2, 3] + assert sorted(adapter.validate_python({1, 2, 3}, strict=False)) == [1, 2, 3] + with pytest.raises(pydantic.ValidationError): + assert adapter.validate_python(["1", "2", "3"]) == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {"strict": False}) + assert adapter.validate_python([1, 2, 3]) == [1, 2, 3] + assert adapter.validate_python([1, 2, 3], strict=False) == [1, 2, 3] + assert sorted(adapter.validate_python({1, 2, 3})) == [1, 2, 3] + with pytest.raises(pydantic.ValidationError): + assert adapter.validate_python({1, 2, 3}, strict=True) == [1, 2, 3] + + +def test_schema_adapter_validate_json(): + adapter = types.SchemaAdapter.from_type(ty.List[int]) + assert adapter.validate_json("[1, 2, 3]") == [1, 2, 3] + assert adapter.validate_json("[1, 2, 3]", strict=True) == [1, 2, 3] + assert adapter.validate_json("[1, 2, 3]", strict=False) == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {"strict": True}) + assert adapter.validate_json("[1, 2, 3]") == [1, 2, 3] + assert adapter.validate_json('["1", "2", "3"]', strict=False) == [1, 2, 3] + with pytest.raises(pydantic.ValidationError): + assert adapter.validate_json('["1", "2", "3"]') == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {"strict": False}) + assert adapter.validate_json("[1, 2, 3]") == [1, 2, 3] + assert adapter.validate_json("[1, 2, 3]", strict=False) == [1, 2, 3] + with pytest.raises(pydantic.ValidationError): + assert adapter.validate_json('["1", "2", "3"]', strict=True) == [1, 2, 3] + + +def test_schema_adapter_dump_python(): + adapter = types.SchemaAdapter.from_type(ty.List[int]) + assert adapter.dump_python([1, 2, 3]) == [1, 2, 3] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {}) + assert adapter.dump_python([1, 2, 3]) == [1, 2, 3] + assert sorted(adapter.dump_python({1, 2, 3})) == [1, 2, 3] + with pytest.warns(UserWarning): + assert adapter.dump_python(["1", "2", "3"]) == ["1", "2", "3"] + + adapter = types.SchemaAdapter.from_type(ty.List[int], {}) + assert adapter.dump_python([1, 2, 3]) == [1, 2, 3] + assert sorted(adapter.dump_python({1, 2, 3})) == [1, 2, 3] + with pytest.warns(UserWarning): + assert adapter.dump_python(["1", "2", "3"]) == ["1", "2", "3"] + + +def test_schema_adapter_dump_json(): + adapter = types.SchemaAdapter.from_type(ty.List[int]) + assert adapter.dump_json([1, 2, 3]) == b"[1,2,3]" + + adapter = types.SchemaAdapter.from_type(ty.List[int], {}) + assert adapter.dump_json([1, 2, 3]) == b"[1,2,3]" + assert adapter.dump_json({1, 2, 3}) == b"[1,2,3]" + with pytest.warns(UserWarning): + assert adapter.dump_json(["1", "2", "3"]) == b'["1","2","3"]' + + adapter = types.SchemaAdapter.from_type(ty.List[int], {}) + assert adapter.dump_json([1, 2, 3]) == b"[1,2,3]" + assert adapter.dump_json({1, 2, 3}) == b"[1,2,3]" + with pytest.warns(UserWarning): + assert adapter.dump_json(["1", "2", "3"]) == b'["1","2","3"]' From 6ec9e185e3dbb1bc5d66f0ae0b2951c617fe4fe3 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 28 Nov 2023 01:08:38 +0400 Subject: [PATCH 22/34] Add type annotations and stubs for v2 and compat layer. --- django_pydantic_field/compat/deprecation.py | 8 +- django_pydantic_field/fields.pyi | 93 +++++++++++++++++---- django_pydantic_field/forms.pyi | 63 ++++++++++++++ django_pydantic_field/rest_framework.pyi | 60 +++++++++++++ django_pydantic_field/v1/fields.py | 4 +- django_pydantic_field/v2/fields.py | 62 +++++++++++++- django_pydantic_field/v2/forms.py | 2 +- django_pydantic_field/v2/rest_framework.py | 10 +-- django_pydantic_field/v2/types.py | 3 +- pyproject.toml | 2 +- 10 files changed, 276 insertions(+), 31 deletions(-) create mode 100644 django_pydantic_field/forms.pyi create mode 100644 django_pydantic_field/rest_framework.pyi diff --git a/django_pydantic_field/compat/deprecation.py b/django_pydantic_field/compat/deprecation.py index fc156a0..3758164 100644 --- a/django_pydantic_field/compat/deprecation.py +++ b/django_pydantic_field/compat/deprecation.py @@ -3,7 +3,7 @@ import typing as ty import warnings -_NOT_PROVIDED = object() +_MISSING = object() _DEPRECATED_KWARGS = ( "allow_nan", "indent", @@ -18,6 +18,6 @@ def truncate_deprecated_v1_export_kwargs(kwargs: dict[str, ty.Any]) -> None: for kwarg in _DEPRECATED_KWARGS: - maybe_present_kwarg = kwargs.pop(kwarg, _NOT_PROVIDED) - if maybe_present_kwarg is not _NOT_PROVIDED: - warnings.warn(_DEPRECATED_KWARGS_MESSAGE % kwarg, DeprecationWarning, stacklevel=2) + maybe_present_kwarg = kwargs.pop(kwarg, _MISSING) + if maybe_present_kwarg is not _MISSING: + warnings.warn(_DEPRECATED_KWARGS_MESSAGE % (kwarg,), DeprecationWarning, stacklevel=2) diff --git a/django_pydantic_field/fields.pyi b/django_pydantic_field/fields.pyi index f0230b0..622427b 100644 --- a/django_pydantic_field/fields.pyi +++ b/django_pydantic_field/fields.pyi @@ -1,15 +1,22 @@ from __future__ import annotations +import json import typing as ty +import typing_extensions as te -from pydantic import BaseModel, ConfigDict -from pydantic.dataclasses import DataclassClassOrWrapper +import typing_extensions as te +from pydantic import BaseConfig, BaseModel, ConfigDict + +try: + from pydantic.dataclasses import DataclassClassOrWrapper as PydanticDataclass +except ImportError: + from pydantic._internal._dataclasses import PydanticDataclass as PydanticDataclass __all__ = ("SchemaField",) SchemaT: ty.TypeAlias = ty.Union[ BaseModel, - DataclassClassOrWrapper, + PydanticDataclass, ty.Sequence[ty.Any], ty.Mapping[str, ty.Any], ty.Set[ty.Any], @@ -17,27 +24,83 @@ SchemaT: ty.TypeAlias = ty.Union[ ] OptSchemaT: ty.TypeAlias = ty.Optional[SchemaT] ST = ty.TypeVar("ST", bound=SchemaT) +IncEx = ty.Union[ty.Set[int], ty.Set[str], ty.Dict[int, ty.Any], ty.Dict[str, ty.Any]] +ConfigType = ty.Union[ConfigDict, ty.Type[BaseConfig], type] + +class _FieldKwargs(te.TypedDict, total=False): + name: str | None + verbose_name: str | None + primary_key: bool + max_length: int | None + unique: bool + blank: bool + db_index: bool + rel: ty.Any + editable: bool + serialize: bool + unique_for_date: str | None + unique_for_month: str | None + unique_for_year: str | None + choices: ty.Sequence[ty.Tuple[str, str]] | None + help_text: str | None + db_column: str | None + db_tablespace: str | None + auto_created: bool + validators: ty.Sequence[ty.Callable] | None + error_messages: ty.Mapping[str, str] | None + db_comment: str | None + +class _JSONFieldKwargs(_FieldKwargs, total=False): + encoder: ty.Callable[[], json.JSONEncoder] + decoder: ty.Callable[[], json.JSONDecoder] +class _ExportKwargs(te.TypedDict, total=False): + strict: bool + from_attributes: bool + mode: te.Literal["json", "python"] + include: IncEx | None + exclude: IncEx | None + by_alias: bool + exclude_unset: bool + exclude_defaults: bool + exclude_none: bool + round_trip: bool + warnings: bool + +class _SchemaFieldKwargs(_JSONFieldKwargs, _ExportKwargs, total=False): ... + +class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs, total=False): + allow_nan: ty.Any + indent: ty.Any + separators: ty.Any + skipkeys: ty.Any + sort_keys: ty.Any @ty.overload def SchemaField( - schema: type[ST | None] | ty.ForwardRef = ..., - config: ConfigDict = ..., + schema: ty.Type[ST | None] | ty.ForwardRef = ..., + config: ConfigType = ..., default: OptSchemaT | ty.Callable[[], OptSchemaT] = ..., *args, null: ty.Literal[True], - **kwargs, -) -> ST | None: - ... - - + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> ST | None: ... @ty.overload def SchemaField( - schema: type[ST] | ty.ForwardRef = ..., - config: ConfigDict = ..., + schema: ty.Type[ST] | ty.ForwardRef = ..., + config: ConfigType = ..., default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., *args, null: ty.Literal[False] = ..., - **kwargs, -) -> ST: - ... + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> ST: ... +@ty.overload +@te.deprecated("Passing `json.dumps` kwargs to `SchemaField` is not supported by Pydantic v2 and will be removed in the future versions.") +def SchemaField( + schema: ty.Type[ST] | ty.ForwardRef = ..., + config: ConfigType = ..., + default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., + *args, + null: bool = ..., + **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs], +) -> ST | None: ... diff --git a/django_pydantic_field/forms.pyi b/django_pydantic_field/forms.pyi new file mode 100644 index 0000000..c1b2272 --- /dev/null +++ b/django_pydantic_field/forms.pyi @@ -0,0 +1,63 @@ +from __future__ import annotations + +import json +import typing as ty +import typing_extensions as te + +from django.forms.fields import JSONField +from django.forms.widgets import Widget +from django.utils.functional import _StrOrPromise + +from .fields import ST, ConfigType, _ExportKwargs + +__all__ = ("SchemaField",) + +class _FieldKwargs(ty.TypedDict, total=False): + required: bool + widget: Widget | type[Widget] | None + label: _StrOrPromise | None + initial: ty.Any | None + help_text: _StrOrPromise + error_messages: ty.Mapping[str, _StrOrPromise] | None + show_hidden_initial: bool + validators: ty.Sequence[ty.Callable[[ty.Any], None]] + localize: bool + disabled: bool + label_suffix: str | None + +class _CharFieldKwargs(_FieldKwargs, total=False): + max_length: int | None + min_length: int | None + strip: bool + empty_value: ty.Any + +class _JSONFieldKwargs(_CharFieldKwargs, total=False): + encoder: ty.Callable[[], json.JSONEncoder] | None + decoder: ty.Callable[[], json.JSONDecoder] | None + +class _SchemaFieldKwargs(_ExportKwargs, _JSONFieldKwargs, total=False): + allow_null: bool | None + +class SchemaField(JSONField, ty.Generic[ST]): + @ty.overload + def __init__( + self, + schema: ty.Type[ST] | ty.ForwardRef | str, + config: ConfigType | None = ..., + *args, + **kwargs: te.Unpack[_SchemaFieldKwargs], + ) -> None: ... + @ty.overload + @te.deprecated("Passing `json.dumps` kwargs to `SchemaField` is not supported by Pydantic v2 and will be removed in the future versions.") + def __init__( + self, + schema: ty.Type[ST] | ty.ForwardRef | str, + config: ConfigType | None = ..., + *args, + allow_nan: ty.Any = ..., + indent: ty.Any = ..., + separators: ty.Any = ..., + skipkeys: ty.Any = ..., + sort_keys: ty.Any = ..., + **kwargs: te.Unpack[_SchemaFieldKwargs], + ) -> None: ... diff --git a/django_pydantic_field/rest_framework.pyi b/django_pydantic_field/rest_framework.pyi new file mode 100644 index 0000000..dbf0d90 --- /dev/null +++ b/django_pydantic_field/rest_framework.pyi @@ -0,0 +1,60 @@ +import typing as ty +import typing_extensions as te + +from rest_framework import parsers, renderers +from rest_framework.fields import _DefaultInitial, Field +from rest_framework.validators import Validator + +from django.utils.functional import _StrOrPromise + +from .fields import ST, ConfigType, _ExportKwargs + +__all__ = ("SchemaField", "SchemaParser", "SchemaRenderer") + +class _FieldKwargs(te.TypedDict, ty.Generic[ST], total=False): + read_only: bool + write_only: bool + required: bool + default: _DefaultInitial[ST] + initial: _DefaultInitial[ST] + source: str + label: _StrOrPromise + help_text: _StrOrPromise + style: dict[str, ty.Any] + error_messages: dict[str, _StrOrPromise] + validators: ty.Sequence[Validator[ST]] + allow_null: bool + +class _SchemaFieldKwargs(_FieldKwargs[ST], _ExportKwargs, total=False): ... + +class SchemaField(Field, ty.Generic[ST]): + @ty.overload + def __init__( + self, + schema: ty.Type[ST] | ty.ForwardRef | str, + config: ConfigType | None = ..., + *args, + **kwargs: te.Unpack[_SchemaFieldKwargs[ST]], + ) -> None: ... + @ty.overload + @te.deprecated("Passing `json.dumps` kwargs to `SchemaField` is not supported by Pydantic v2 and will be removed in the future versions.") + def __init__( + self, + schema: ty.Type[ST] | ty.ForwardRef | str, + config: ConfigType | None = ..., + *args, + allow_nan: ty.Any = ..., + indent: ty.Any = ..., + separators: ty.Any = ..., + skipkeys: ty.Any = ..., + sort_keys: ty.Any = ..., + **kwargs: te.Unpack[_SchemaFieldKwargs], + ) -> None: ... + +class SchemaParser(parsers.JSONParser, ty.Generic[ST]): + schema_context_key: ty.ClassVar[str] + config_context_key: ty.ClassVar[str] + +class SchemaRenderer(renderers.JSONRenderer, ty.Generic[ST]): + schema_context_key: ty.ClassVar[str] + config_context_key: ty.ClassVar[str] diff --git a/django_pydantic_field/v1/fields.py b/django_pydantic_field/v1/fields.py index b9cc59b..5dcf21f 100644 --- a/django_pydantic_field/v1/fields.py +++ b/django_pydantic_field/v1/fields.py @@ -192,6 +192,6 @@ def SchemaField( ... -def SchemaField(schema=None, config=None, *args, **kwargs) -> t.Any: - kwargs.update(schema=schema, config=config) +def SchemaField(schema=None, config=None, default=None, *args, **kwargs) -> t.Any: + kwargs.update(schema=schema, config=config, default=default) return PydanticSchemaField(*args, **kwargs) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 43c1041..fca96f6 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -15,6 +15,40 @@ from ..compat.django import GenericContainer from . import forms, types +if ty.TYPE_CHECKING: + import json + import typing_extensions as te + + class _SchemaFieldKwargs(types.ExportKwargs, total=False): + # django.db.models.fields.Field kwargs + name: str | None + verbose_name: str | None + primary_key: bool + max_length: int | None + unique: bool + blank: bool + db_index: bool + rel: ty.Any + editable: bool + serialize: bool + unique_for_date: str | None + unique_for_month: str | None + unique_for_year: str | None + choices: ty.Sequence[ty.Tuple[str, str]] | None + help_text: str | None + db_column: str | None + db_tablespace: str | None + auto_created: bool + validators: ty.Sequence[ty.Callable] | None + error_messages: ty.Mapping[str, str] | None + db_comment: str | None + # django.db.models.fields.json.JSONField kwargs + encoder: ty.Callable[[], json.JSONEncoder] + decoder: ty.Callable[[], json.JSONDecoder] + + +__all__ = ("SchemaField",) + class SchemaAttribute(DeferredAttribute): field: PydanticSchemaField @@ -168,6 +202,30 @@ def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None: return self.transform(col, *args, **kwargs) -def SchemaField(schema=None, config=None, *args, **kwargs): # type: ignore +@ty.overload +def SchemaField( + schema: type[types.ST | None] | ty.ForwardRef = ..., + config: pydantic.ConfigDict = ..., + default: types.SchemaT | None | ty.Callable[[], types.SchemaT | None] = ..., + *args, + null: ty.Literal[True], + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> types.ST | None: + ... + + +@ty.overload +def SchemaField( + schema: type[types.ST] | ty.ForwardRef = ..., + config: pydantic.ConfigDict = ..., + default: ty.Union[types.SchemaT, ty.Callable[[], types.SchemaT]] = ..., + *args, + null: ty.Literal[False] = ..., + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> types.ST: + ... + + +def SchemaField(schema=None, config=None, default=None, *args, **kwargs): # type: ignore truncate_deprecated_v1_export_kwargs(kwargs) - return PydanticSchemaField(*args, schema=schema, config=config, **kwargs) + return PydanticSchemaField(*args, schema=schema, config=config, default=default, **kwargs) diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py index b162a1b..d2fb06c 100644 --- a/django_pydantic_field/v2/forms.py +++ b/django_pydantic_field/v2/forms.py @@ -19,7 +19,7 @@ class SchemaField(JSONField, ty.Generic[types.ST]): def __init__( self, - schema: types.ST, + schema: type[types.ST] | ty.ForwardRef | str, config: pydantic.ConfigDict | None = None, allow_null: bool | None = None, *args, diff --git a/django_pydantic_field/v2/rest_framework.py b/django_pydantic_field/v2/rest_framework.py index 8436590..e8087e4 100644 --- a/django_pydantic_field/v2/rest_framework.py +++ b/django_pydantic_field/v2/rest_framework.py @@ -4,6 +4,7 @@ import pydantic from rest_framework import exceptions, fields, parsers, renderers +from rest_framework.schemas import coreapi from ..compat.deprecation import truncate_deprecated_v1_export_kwargs from ..compat.typing import get_args @@ -89,7 +90,7 @@ def _make_adapter_from_annotation(self, ctx: RequestResponseContext) -> types.Sc class SchemaRenderer(_AnnotatedAdapterMixin[types.ST], renderers.JSONRenderer): schema_context_key = "renderer_schema" - config_context_key = "renderer_schema_config" + config_context_key = "renderer_config" def render(self, data: ty.Any, accepted_media_type=None, renderer_context=None): renderer_context = renderer_context or {} @@ -121,7 +122,7 @@ def render_pydantic_model(self, instance: pydantic.BaseModel, renderer_context: class SchemaParser(_AnnotatedAdapterMixin[types.ST], parsers.JSONParser): schema_context_key = "parser_schema" - config_context_key = "parser_schema_config" + config_context_key = "parser_config" renderer_class = SchemaRenderer def parse(self, stream: ty.IO[bytes], media_type=None, parser_context=None): @@ -136,6 +137,5 @@ def parse(self, stream: ty.IO[bytes], media_type=None, parser_context=None): raise exceptions.ParseError(exc.errors()) # type: ignore -class AutoSchema: - def __init__(*args, **kwargs): - ... +class AutoSchema(coreapi.AutoSchema): + """Not implemented yet.""" diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 1aaa787..c1bed91 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -104,12 +104,13 @@ def is_bound(self) -> bool: """Return True if the adapter is bound to a specific attribute of a `parent_type`.""" return self.parent_type is not None and self.attname is not None - def bind(self, parent_type: type, attname: str) -> None: + def bind(self, parent_type: type | None, attname: str | None) -> te.Self: """Bind the adapter to specific attribute of a `parent_type`.""" self.parent_type = parent_type self.attname = attname self.__dict__.pop("prepared_schema", None) self.__dict__.pop("type_adapter", None) + return self def validate_schema(self) -> None: """Validate the schema and raise an exception if it is invalid.""" diff --git a/pyproject.toml b/pyproject.toml index 30258ce..2e31739 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,7 @@ dev = [ "black", "isort", "mypy", - "pytest==7.0.*", + "pytest~=7.4", "djangorestframework>=3.11,<4", "django-stubs[compatible-mypy]~=4.2", "djangorestframework-stubs[compatible-mypy]~=3.14", From 78b0a01c1f13c7ee6e392d0ae9f6fd3df2752ce7 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 28 Nov 2023 22:46:55 +0400 Subject: [PATCH 23/34] Update project layout, add new packages to setuptools discovery. --- django_pydantic_field/v2/fields.py | 2 +- pyproject.toml | 10 ++++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index fca96f6..4888deb 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -226,6 +226,6 @@ def SchemaField( ... -def SchemaField(schema=None, config=None, default=None, *args, **kwargs): # type: ignore +def SchemaField(schema=None, config=None, default=NOT_PROVIDED, *args, **kwargs): # type: ignore truncate_deprecated_v1_export_kwargs(kwargs) return PydanticSchemaField(*args, schema=schema, config=config, default=default, **kwargs) diff --git a/pyproject.toml b/pyproject.toml index 2e31739..bfd89e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "django-pydantic-field" -version = "0.3.0-alpha1" +version = "0.3.0-alpha2" description = "Django JSONField with Pydantic models as a Schema" readme = "README.md" license = { file = "LICENSE" } @@ -47,6 +47,7 @@ dependencies = [ [project.optional-dependencies] openapi = ["uritemplate"] dev = [ + "build", "black", "isort", "mypy", @@ -75,7 +76,12 @@ Source = "https://github.com/surenkov/django-pydantic-field" Changelog = "https://github.com/surenkov/django-pydantic-field/releases" [tool.setuptools] -packages = ["django_pydantic_field"] +packages = [ + "django_pydantic_field", + "django_pydantic_field.compat", + "django_pydantic_field.v1", + "django_pydantic_field.v2", +] [tool.isort] py_version = 311 From 2ec096f6cf44cb937d42864b0af2e0302a405a49 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 28 Nov 2023 23:07:47 +0400 Subject: [PATCH 24/34] Update pyproject.toml to make sure of package contents. --- .github/workflows/python-publish.yml | 2 +- pyproject.toml | 10 +--------- 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index bf6b02e..ba8d668 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -23,7 +23,7 @@ jobs: - name: Build package run: python -m build - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 + uses: pypa/gh-action-pypi-publish@v1.8.10 with: user: __token__ password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/pyproject.toml b/pyproject.toml index bfd89e5..d3e5a5c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "django-pydantic-field" -version = "0.3.0-alpha2" +version = "0.3.0-alpha3" description = "Django JSONField with Pydantic models as a Schema" readme = "README.md" license = { file = "LICENSE" } @@ -75,14 +75,6 @@ Documentation = "https://github.com/surenkov/django-pydantic-field" Source = "https://github.com/surenkov/django-pydantic-field" Changelog = "https://github.com/surenkov/django-pydantic-field/releases" -[tool.setuptools] -packages = [ - "django_pydantic_field", - "django_pydantic_field.compat", - "django_pydantic_field.v1", - "django_pydantic_field.v2", -] - [tool.isort] py_version = 311 profile = "black" From 35b8ecd4415133550d303b2dc1c24836ea2a5649 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 29 Nov 2023 00:19:32 +0400 Subject: [PATCH 25/34] Tweak type stubs for deprecated arguments. --- django_pydantic_field/fields.pyi | 18 ++++++++++++++---- django_pydantic_field/forms.pyi | 18 +++++++++++------- django_pydantic_field/rest_framework.pyi | 19 +++++++++++-------- pyproject.toml | 2 +- 4 files changed, 37 insertions(+), 20 deletions(-) diff --git a/django_pydantic_field/fields.pyi b/django_pydantic_field/fields.pyi index 622427b..20aa9ad 100644 --- a/django_pydantic_field/fields.pyi +++ b/django_pydantic_field/fields.pyi @@ -78,7 +78,7 @@ class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs, total=False): @ty.overload def SchemaField( - schema: ty.Type[ST | None] | ty.ForwardRef = ..., + schema: ty.Type[ST] | None | ty.ForwardRef = ..., config: ConfigType = ..., default: OptSchemaT | ty.Callable[[], OptSchemaT] = ..., *args, @@ -95,12 +95,22 @@ def SchemaField( **kwargs: te.Unpack[_SchemaFieldKwargs], ) -> ST: ... @ty.overload -@te.deprecated("Passing `json.dumps` kwargs to `SchemaField` is not supported by Pydantic v2 and will be removed in the future versions.") +@te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") def SchemaField( - schema: ty.Type[ST] | ty.ForwardRef = ..., + schema: ty.Type[ST] | None | ty.ForwardRef = ..., config: ConfigType = ..., default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., *args, - null: bool = ..., + null: ty.Literal[True], **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs], ) -> ST | None: ... +@ty.overload +@te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") +def SchemaField( + schema: ty.Type[ST] | ty.ForwardRef = ..., + config: ConfigType = ..., + default: ty.Union[SchemaT, ty.Callable[[], SchemaT]] = ..., + *args, + null: ty.Literal[False] = ..., + **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs], +) -> ST: ... diff --git a/django_pydantic_field/forms.pyi b/django_pydantic_field/forms.pyi index c1b2272..646f3a7 100644 --- a/django_pydantic_field/forms.pyi +++ b/django_pydantic_field/forms.pyi @@ -38,6 +38,15 @@ class _JSONFieldKwargs(_CharFieldKwargs, total=False): class _SchemaFieldKwargs(_ExportKwargs, _JSONFieldKwargs, total=False): allow_null: bool | None + +class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs, total=False): + allow_nan: ty.Any + indent: ty.Any + separators: ty.Any + skipkeys: ty.Any + sort_keys: ty.Any + + class SchemaField(JSONField, ty.Generic[ST]): @ty.overload def __init__( @@ -48,16 +57,11 @@ class SchemaField(JSONField, ty.Generic[ST]): **kwargs: te.Unpack[_SchemaFieldKwargs], ) -> None: ... @ty.overload - @te.deprecated("Passing `json.dumps` kwargs to `SchemaField` is not supported by Pydantic v2 and will be removed in the future versions.") + @te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") def __init__( self, schema: ty.Type[ST] | ty.ForwardRef | str, config: ConfigType | None = ..., *args, - allow_nan: ty.Any = ..., - indent: ty.Any = ..., - separators: ty.Any = ..., - skipkeys: ty.Any = ..., - sort_keys: ty.Any = ..., - **kwargs: te.Unpack[_SchemaFieldKwargs], + **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs], ) -> None: ... diff --git a/django_pydantic_field/rest_framework.pyi b/django_pydantic_field/rest_framework.pyi index dbf0d90..e4a3b87 100644 --- a/django_pydantic_field/rest_framework.pyi +++ b/django_pydantic_field/rest_framework.pyi @@ -25,7 +25,15 @@ class _FieldKwargs(te.TypedDict, ty.Generic[ST], total=False): validators: ty.Sequence[Validator[ST]] allow_null: bool -class _SchemaFieldKwargs(_FieldKwargs[ST], _ExportKwargs, total=False): ... +class _SchemaFieldKwargs(_FieldKwargs[ST], _ExportKwargs, total=False): + pass + +class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs[ST], total=False): + allow_nan: ty.Any + indent: ty.Any + separators: ty.Any + skipkeys: ty.Any + sort_keys: ty.Any class SchemaField(Field, ty.Generic[ST]): @ty.overload @@ -37,18 +45,13 @@ class SchemaField(Field, ty.Generic[ST]): **kwargs: te.Unpack[_SchemaFieldKwargs[ST]], ) -> None: ... @ty.overload - @te.deprecated("Passing `json.dumps` kwargs to `SchemaField` is not supported by Pydantic v2 and will be removed in the future versions.") + @te.deprecated("Passing `json.dump` kwargs to `SchemaField` is not supported by Pydantic 2 and will be removed in the future versions.") def __init__( self, schema: ty.Type[ST] | ty.ForwardRef | str, config: ConfigType | None = ..., *args, - allow_nan: ty.Any = ..., - indent: ty.Any = ..., - separators: ty.Any = ..., - skipkeys: ty.Any = ..., - sort_keys: ty.Any = ..., - **kwargs: te.Unpack[_SchemaFieldKwargs], + **kwargs: te.Unpack[_DeprecatedSchemaFieldKwargs[ST]], ) -> None: ... class SchemaParser(parsers.JSONParser, ty.Generic[ST]): diff --git a/pyproject.toml b/pyproject.toml index d3e5a5c..1a73683 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,7 @@ skip_glob = [ ] [tool.black] -target-version = ['py312'] +target-version = ["py38", "py39", "py310", "py311", "py312"] line-length = 120 exclude = ''' /( From 435b94056b406382a610863fe7a0f1ab5aca2af7 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Sun, 3 Dec 2023 21:14:53 +0400 Subject: [PATCH 26/34] Split `.rest_framework` module into package. [ci skip] --- django_pydantic_field/v2/rest_framework.py | 141 ------------------ .../v2/rest_framework/__init__.py | 20 +++ .../v2/rest_framework/coreapi.py | 5 + .../v2/rest_framework/fields.py | 55 +++++++ .../v2/rest_framework/mixins.py | 42 ++++++ .../v2/rest_framework/openapi.py | 5 + .../v2/rest_framework/parsers.py | 26 ++++ .../v2/rest_framework/renderers.py | 48 ++++++ tests/v2/test_rest_framework.py | 5 +- 9 files changed, 204 insertions(+), 143 deletions(-) delete mode 100644 django_pydantic_field/v2/rest_framework.py create mode 100644 django_pydantic_field/v2/rest_framework/__init__.py create mode 100644 django_pydantic_field/v2/rest_framework/coreapi.py create mode 100644 django_pydantic_field/v2/rest_framework/fields.py create mode 100644 django_pydantic_field/v2/rest_framework/mixins.py create mode 100644 django_pydantic_field/v2/rest_framework/openapi.py create mode 100644 django_pydantic_field/v2/rest_framework/parsers.py create mode 100644 django_pydantic_field/v2/rest_framework/renderers.py diff --git a/django_pydantic_field/v2/rest_framework.py b/django_pydantic_field/v2/rest_framework.py deleted file mode 100644 index e8087e4..0000000 --- a/django_pydantic_field/v2/rest_framework.py +++ /dev/null @@ -1,141 +0,0 @@ -from __future__ import annotations - -import typing as ty - -import pydantic -from rest_framework import exceptions, fields, parsers, renderers -from rest_framework.schemas import coreapi - -from ..compat.deprecation import truncate_deprecated_v1_export_kwargs -from ..compat.typing import get_args -from . import types - -if ty.TYPE_CHECKING: - from collections.abc import Mapping - - from rest_framework.serializers import BaseSerializer - - RequestResponseContext = Mapping[str, ty.Any] - - -class SchemaField(fields.Field, ty.Generic[types.ST]): - adapter: types.SchemaAdapter - - def __init__( - self, - schema: type[types.ST], - config: pydantic.ConfigDict | None = None, - *args, - allow_null: bool = False, - **kwargs, - ): - truncate_deprecated_v1_export_kwargs(kwargs) - self.schema = schema - self.config = config - self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) - self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null, **self.export_kwargs) - super().__init__(*args, **kwargs) - - def bind(self, field_name: str, parent: BaseSerializer): - if not self.adapter.is_bound: - self.adapter.bind(type(parent), field_name) - super().bind(field_name, parent) - - def to_internal_value(self, data: ty.Any): - try: - if isinstance(data, (str, bytes)): - return self.adapter.validate_json(data) - return self.adapter.validate_python(data) - except pydantic.ValidationError as exc: - raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore - - def to_representation(self, value: ty.Optional[types.ST]): - try: - prep_value = self.adapter.validate_python(value) - return self.adapter.dump_python(prep_value) - except pydantic.ValidationError as exc: - raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore - - -class _AnnotatedAdapterMixin(ty.Generic[types.ST]): - schema_context_key: ty.ClassVar[str] = "response_schema" - config_context_key: ty.ClassVar[str] = "response_schema_config" - - def get_adapter(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: - adapter = self._make_adapter_from_context(ctx) - if adapter is None: - adapter = self._make_adapter_from_annotation(ctx) - - return adapter - - def _make_adapter_from_context(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: - schema = ctx.get(self.schema_context_key) - if schema is not None: - config = ctx.get(self.config_context_key) - export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(ctx)) - return types.SchemaAdapter(schema, config, type(ctx.get("view")), None, **export_kwargs) - - return schema - - def _make_adapter_from_annotation(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: - try: - schema = get_args(self.__orig_class__)[0] # type: ignore - except (AttributeError, IndexError): - return None - - config = ctx.get(self.config_context_key) - export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(ctx)) - return types.SchemaAdapter(schema, config, type(ctx.get("view")), None, **export_kwargs) - - -class SchemaRenderer(_AnnotatedAdapterMixin[types.ST], renderers.JSONRenderer): - schema_context_key = "renderer_schema" - config_context_key = "renderer_config" - - def render(self, data: ty.Any, accepted_media_type=None, renderer_context=None): - renderer_context = renderer_context or {} - response = renderer_context.get("response") - if response is not None and response.exception: - return super().render(data, accepted_media_type, renderer_context) - - adapter = self.get_adapter(renderer_context) - if adapter is None and isinstance(data, pydantic.BaseModel): - return self.render_pydantic_model(data, renderer_context) - if adapter is None: - raise RuntimeError("Schema should be either explicitly set with annotation or passed in the context") - - try: - prep_data = adapter.validate_python(data) - return adapter.dump_json(prep_data) - except pydantic.ValidationError as exc: - return exc.json(indent=True, include_input=True).encode() - - def render_pydantic_model(self, instance: pydantic.BaseModel, renderer_context: Mapping[str, ty.Any]): - export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(renderer_context)) - export_kwargs.pop("strict", None) - export_kwargs.pop("from_attributes", None) - export_kwargs.pop("mode", None) - - json_dump = instance.model_dump_json(**export_kwargs) # type: ignore - return json_dump.encode() - - -class SchemaParser(_AnnotatedAdapterMixin[types.ST], parsers.JSONParser): - schema_context_key = "parser_schema" - config_context_key = "parser_config" - renderer_class = SchemaRenderer - - def parse(self, stream: ty.IO[bytes], media_type=None, parser_context=None): - parser_context = parser_context or {} - adapter = self.get_adapter(parser_context) - if adapter is None: - raise RuntimeError("Schema should be either explicitly set with annotation or passed in the context") - - try: - return adapter.validate_json(stream.read()) - except pydantic.ValidationError as exc: - raise exceptions.ParseError(exc.errors()) # type: ignore - - -class AutoSchema(coreapi.AutoSchema): - """Not implemented yet.""" diff --git a/django_pydantic_field/v2/rest_framework/__init__.py b/django_pydantic_field/v2/rest_framework/__init__.py new file mode 100644 index 0000000..f151648 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/__init__.py @@ -0,0 +1,20 @@ +from .fields import SchemaField as SchemaField +from .parsers import SchemaParser as SchemaParser +from .renderers import SchemaRenderer as SchemaRenderer + +_DEPRECATED_MESSAGE = ( + "`django_pydantic_field.rest_framework.AutoSchema` is deprecated, " + "please use explicit imports for `django_pydantic_field.rest_framework.openapi.AutoSchema` " + "or `django_pydantic_field.rest_framework.coreapi.AutoSchema` instead." +) + +def __getattr__(key): + if key == "AutoSchema": + import warnings + + from .openapi import AutoSchema + + warnings.warn(_DEPRECATED_MESSAGE, DeprecationWarning) + return AutoSchema + + raise AttributeError(key) diff --git a/django_pydantic_field/v2/rest_framework/coreapi.py b/django_pydantic_field/v2/rest_framework/coreapi.py new file mode 100644 index 0000000..7daa1fd --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/coreapi.py @@ -0,0 +1,5 @@ +from rest_framework.schemas import coreapi + + +class AutoSchema(coreapi.AutoSchema): + """Not implemented yet.""" diff --git a/django_pydantic_field/v2/rest_framework/fields.py b/django_pydantic_field/v2/rest_framework/fields.py new file mode 100644 index 0000000..1303389 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/fields.py @@ -0,0 +1,55 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from rest_framework import exceptions, fields + +from ...compat.deprecation import truncate_deprecated_v1_export_kwargs +from .. import types + +if ty.TYPE_CHECKING: + from collections.abc import Mapping + + from rest_framework.serializers import BaseSerializer + + RequestResponseContext = Mapping[str, ty.Any] + + +class SchemaField(fields.Field, ty.Generic[types.ST]): + adapter: types.SchemaAdapter + + def __init__( + self, + schema: type[types.ST], + config: pydantic.ConfigDict | None = None, + *args, + allow_null: bool = False, + **kwargs, + ): + truncate_deprecated_v1_export_kwargs(kwargs) + self.schema = schema + self.config = config + self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) + self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null, **self.export_kwargs) + super().__init__(*args, **kwargs) + + def bind(self, field_name: str, parent: BaseSerializer): + if not self.adapter.is_bound: + self.adapter.bind(type(parent), field_name) + super().bind(field_name, parent) + + def to_internal_value(self, data: ty.Any): + try: + if isinstance(data, (str, bytes)): + return self.adapter.validate_json(data) + return self.adapter.validate_python(data) + except pydantic.ValidationError as exc: + raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore + + def to_representation(self, value: ty.Optional[types.ST]): + try: + prep_value = self.adapter.validate_python(value) + return self.adapter.dump_python(prep_value) + except pydantic.ValidationError as exc: + raise exceptions.ValidationError(exc.errors(), code="invalid") # type: ignore diff --git a/django_pydantic_field/v2/rest_framework/mixins.py b/django_pydantic_field/v2/rest_framework/mixins.py new file mode 100644 index 0000000..6f3d6ae --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/mixins.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import typing as ty + +from ...compat.typing import get_args +from .. import types + +if ty.TYPE_CHECKING: + from collections.abc import Mapping + + RequestResponseContext = Mapping[str, ty.Any] + + +class AnnotatedAdapterMixin(ty.Generic[types.ST]): + schema_context_key: ty.ClassVar[str] = "response_schema" + config_context_key: ty.ClassVar[str] = "response_schema_config" + + def get_adapter(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + adapter = self._make_adapter_from_context(ctx) + if adapter is None: + adapter = self._make_adapter_from_annotation(ctx) + + return adapter + + def _make_adapter_from_context(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + schema = ctx.get(self.schema_context_key) + if schema is not None: + config = ctx.get(self.config_context_key) + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(ctx)) + return types.SchemaAdapter(schema, config, type(ctx.get("view")), None, **export_kwargs) + + return schema + + def _make_adapter_from_annotation(self, ctx: RequestResponseContext) -> types.SchemaAdapter[types.ST] | None: + try: + schema = get_args(self.__orig_class__)[0] # type: ignore + except (AttributeError, IndexError): + return None + + config = ctx.get(self.config_context_key) + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(ctx)) + return types.SchemaAdapter(schema, config, type(ctx.get("view")), None, **export_kwargs) diff --git a/django_pydantic_field/v2/rest_framework/openapi.py b/django_pydantic_field/v2/rest_framework/openapi.py new file mode 100644 index 0000000..3299bd5 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/openapi.py @@ -0,0 +1,5 @@ +from rest_framework.schemas import openapi + + +class AutoSchema(openapi.AutoSchema): + """Not implemented yet.""" diff --git a/django_pydantic_field/v2/rest_framework/parsers.py b/django_pydantic_field/v2/rest_framework/parsers.py new file mode 100644 index 0000000..726ad85 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/parsers.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from rest_framework import exceptions, parsers + +from .. import types +from . import mixins, renderers + + +class SchemaParser(mixins.AnnotatedAdapterMixin[types.ST], parsers.JSONParser): + schema_context_key = "parser_schema" + config_context_key = "parser_config" + renderer_class = renderers.SchemaRenderer + + def parse(self, stream: ty.IO[bytes], media_type=None, parser_context=None): + parser_context = parser_context or {} + adapter = self.get_adapter(parser_context) + if adapter is None: + raise RuntimeError("Schema should be either explicitly set with annotation or passed in the context") + + try: + return adapter.validate_json(stream.read()) + except pydantic.ValidationError as exc: + raise exceptions.ParseError(exc.errors()) # type: ignore diff --git a/django_pydantic_field/v2/rest_framework/renderers.py b/django_pydantic_field/v2/rest_framework/renderers.py new file mode 100644 index 0000000..820aa54 --- /dev/null +++ b/django_pydantic_field/v2/rest_framework/renderers.py @@ -0,0 +1,48 @@ +from __future__ import annotations + +import typing as ty + +import pydantic +from rest_framework import renderers + +from .. import types +from . import mixins + +if ty.TYPE_CHECKING: + from collections.abc import Mapping + + RequestResponseContext = Mapping[str, ty.Any] + +__all__ = ("SchemaRenderer",) + + +class SchemaRenderer(mixins.AnnotatedAdapterMixin[types.ST], renderers.JSONRenderer): + schema_context_key = "renderer_schema" + config_context_key = "renderer_config" + + def render(self, data: ty.Any, accepted_media_type=None, renderer_context=None): + renderer_context = renderer_context or {} + response = renderer_context.get("response") + if response is not None and response.exception: + return super().render(data, accepted_media_type, renderer_context) + + adapter = self.get_adapter(renderer_context) + if adapter is None and isinstance(data, pydantic.BaseModel): + return self.render_pydantic_model(data, renderer_context) + if adapter is None: + raise RuntimeError("Schema should be either explicitly set with annotation or passed in the context") + + try: + prep_data = adapter.validate_python(data) + return adapter.dump_json(prep_data) + except pydantic.ValidationError as exc: + return exc.json(indent=True, include_input=True).encode() + + def render_pydantic_model(self, instance: pydantic.BaseModel, renderer_context: Mapping[str, ty.Any]): + export_kwargs = types.SchemaAdapter.extract_export_kwargs(dict(renderer_context)) + export_kwargs.pop("strict", None) + export_kwargs.pop("from_attributes", None) + export_kwargs.pop("mode", None) + + json_dump = instance.model_dump_json(**export_kwargs) # type: ignore + return json_dump.encode() diff --git a/tests/v2/test_rest_framework.py b/tests/v2/test_rest_framework.py index 8223fc5..65147b9 100644 --- a/tests/v2/test_rest_framework.py +++ b/tests/v2/test_rest_framework.py @@ -11,6 +11,7 @@ from tests.test_app.models import SampleModel rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") class SampleSerializer(serializers.Serializer): @@ -136,7 +137,7 @@ def test_schema_parser(): @api_view(["POST"]) -@schema(rest_framework.AutoSchema()) +@schema(coreapi.AutoSchema()) @parser_classes([rest_framework.SchemaParser[InnerSchema]]) @renderer_classes([rest_framework.SchemaRenderer[t.List[InnerSchema]]]) def sample_view(request): @@ -146,7 +147,7 @@ def sample_view(request): class ClassBasedViewWithSerializer(generics.RetrieveAPIView): serializer_class = SampleSerializer - schema = rest_framework.AutoSchema() + schema = coreapi.AutoSchema() class ClassBasedViewWithModel(generics.ListCreateAPIView): From 08d5f594640e31ba28e03b7b99d45535c1954dcf Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 5 Dec 2023 02:19:44 +0400 Subject: [PATCH 27/34] Bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 36a1b44..3c79a84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "django-pydantic-field" -version = "0.3.0-alpha3" +version = "0.3.0-alpha4" description = "Django JSONField with Pydantic models as a Schema" readme = "README.md" license = { file = "LICENSE" } From 8089b7e0945e862a8bfe79ec5703985a966d2a7e Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 5 Dec 2023 02:23:16 +0400 Subject: [PATCH 28/34] Add Pydantic :: 2 to classifiers list. [ci skip] --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 57f3fd4..97399d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ classifiers = [ "Framework :: Django :: 5.0", "Framework :: Pydantic", "Framework :: Pydantic :: 1", + "Framework :: Pydantic :: 2", "License :: OSI Approved :: MIT License", "Programming Language :: Python", "Programming Language :: Python :: 3", From 8f4dcc05938bebe7f7fcd237679f2b2ce938d7a0 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 20 Dec 2023 02:54:31 +0400 Subject: [PATCH 29/34] Implement CoreAPI schema generator [ci skip] --- .../v2/rest_framework/coreapi.py | 222 +++++++++++++++++- django_pydantic_field/v2/types.py | 2 +- tests/v2/test_rest_framework.py | 33 ++- 3 files changed, 251 insertions(+), 6 deletions(-) diff --git a/django_pydantic_field/v2/rest_framework/coreapi.py b/django_pydantic_field/v2/rest_framework/coreapi.py index 7daa1fd..5a1412f 100644 --- a/django_pydantic_field/v2/rest_framework/coreapi.py +++ b/django_pydantic_field/v2/rest_framework/coreapi.py @@ -1,5 +1,223 @@ -from rest_framework.schemas import coreapi +from __future__ import annotations +import typing as ty -class AutoSchema(coreapi.AutoSchema): +from rest_framework.compat import coreapi, coreschema +from rest_framework.schemas.coreapi import AutoSchema as _CoreAPIAutoSchema + +from .fields import SchemaField + +if ty.TYPE_CHECKING: + from coreschema.schemas import Schema as _CoreAPISchema + from rest_framework.serializers import Serializer + +__all__ = ("AutoSchema",) + + +class AutoSchema(_CoreAPIAutoSchema): """Not implemented yet.""" + + def get_serializer_fields(self, path: str, method: str) -> list[coreapi.Field]: + base_field_schemas = super().get_serializer_fields(path, method) + if not base_field_schemas: + return [] + + serializer: Serializer = self.view.get_serializer() + pydantic_schema_fields: dict[str, coreapi.Field] = {} + + for field_name, field in serializer.fields.items(): + if not field.read_only and isinstance(field, SchemaField): + pydantic_schema_fields[field_name] = self._prepare_schema_field(field) + + if not pydantic_schema_fields: + return base_field_schemas + + return [pydantic_schema_fields.get(field.name, field) for field in base_field_schemas] + + def _prepare_schema_field(self, field: SchemaField) -> coreapi.Field: + build_core_schema = SimpleCoreSchemaTransformer(field.adapter.json_schema()) + return coreapi.Field( + name=field.field_name, + location="form", + required=field.required, + schema=build_core_schema(), + description=field.help_text, + ) + + +class SimpleCoreSchemaTransformer: + def __init__(self, json_schema: dict[str, ty.Any]): + self.root_schema = json_schema + + def __call__(self) -> _CoreAPISchema: + definitions = self._populate_definitions() + root_schema = self._transform(self.root_schema) + + if definitions: + if isinstance(root_schema, coreschema.Ref): + schema_name = root_schema.ref_name + else: + schema_name = root_schema.title or "Schema" + definitions[schema_name] = root_schema + + root_schema = coreschema.RefSpace(definitions, schema_name) + + return root_schema + + def _populate_definitions(self): + schemas = self.root_schema.get("$defs", {}) + return {ref_name: self._transform(schema) for ref_name, schema in schemas.items()} + + def _transform(self, schema) -> _CoreAPISchema: + schemas = [ + *self._transform_type_schema(schema), + *self._transform_composite_types(schema), + *self._transform_ref(schema), + ] + if not schemas: + schema = self._transform_any(schema) + elif len(schemas) == 1: + schema = schemas[0] + else: + schema = coreschema.Intersection(schemas) + return schema + + def _transform_type_schema(self, schema): + schema_type = schema.get("type", None) + + if schema_type is not None: + schema_types = schema_type if isinstance(schema_type, list) else [schema_type] + + for schema_type in schema_types: + transformer = getattr(self, f"transform_{schema_type}") + yield transformer(schema) + + def _transform_composite_types(self, schema): + for operation, transform_name in self.COMBINATOR_TYPES.items(): + value = schema.get(operation, None) + + if value is not None: + transformer = getattr(self, transform_name) + yield transformer(schema) + + def _transform_ref(self, schema): + reference = schema.get("$ref", None) + if reference is not None: + yield coreschema.Ref(reference) + + def _transform_any(self, schema): + attrs = self._get_common_attributes(schema) + return coreschema.Anything(**attrs) + + # Simple types transformers + + def transform_object(self, schema) -> coreschema.Object: + properties = schema.get("properties", None) + if properties is not None: + properties = {prop: self._transform(prop_schema) for prop, prop_schema in properties.items()} + + pattern_props = schema.get("patternProperties", None) + if pattern_props is not None: + pattern_props = {pattern: self._transform(prop_schema) for pattern, prop_schema in pattern_props.items()} + + extra_props = schema.get("additionalProperties", None) + if extra_props is not None: + if extra_props not in (True, False): + extra_props = self._transform(schema) + + return coreschema.Object( + properties=properties, + pattern_properties=pattern_props, + additional_properties=extra_props, # type: ignore + min_properties=schema.get("minProperties"), + max_properties=schema.get("maxProperties"), + required=schema.get("required", []), + **self._get_common_attributes(schema), + ) + + def transform_array(self, schema) -> coreschema.Array: + items = schema.get("items", None) + if items is not None: + if isinstance(items, list): + items = list(map(self._transform, items)) + elif items not in (True, False): + items = self._transform(items) + + extra_items = schema.get("additionalItems") + if extra_items is not None: + if isinstance(items, list): + items = list(map(self._transform, items)) + elif items not in (True, False): + items = self._transform(items) + + return coreschema.Array( + items=items, + additional_items=extra_items, + min_items=schema.get("minItems"), + max_items=schema.get("maxItems"), + unique_items=schema.get("uniqueItems"), + **self._get_common_attributes(schema), + ) + + def transform_boolean(self, schema) -> coreschema.Boolean: + attrs = self._get_common_attributes(schema) + return coreschema.Boolean(**attrs) + + def transform_integer(self, schema) -> coreschema.Integer: + return self._transform_numeric(schema, cls=coreschema.Integer) + + def transform_null(self, schema) -> coreschema.Null: + attrs = self._get_common_attributes(schema) + return coreschema.Null(**attrs) + + def transform_number(self, schema) -> coreschema.Number: + return self._transform_numeric(schema, cls=coreschema.Number) + + def transform_string(self, schema) -> coreschema.String: + return coreschema.String( + min_length=schema.get("minLength"), + max_length=schema.get("maxLength"), + pattern=schema.get("pattern"), + format=schema.get("format"), + **self._get_common_attributes(schema), + ) + + # Composite types transformers + + COMBINATOR_TYPES = { + "anyOf": "transform_union", + "oneOf": "transform_exclusive_union", + "allOf": "transform_intersection", + "not": "transform_not", + } + + def transform_union(self, schema): + return coreschema.Union([self._transform(option) for option in schema["anyOf"]]) + + def transform_exclusive_union(self, schema): + return coreschema.ExclusiveUnion([self._transform(option) for option in schema["oneOf"]]) + + def transform_intersection(self, schema): + return coreschema.Intersection([self._transform(option) for option in schema["allOf"]]) + + def transform_not(self, schema): + return coreschema.Not(self._transform(schema["not"])) + + # Common schema transformations + + def _get_common_attributes(self, schema): + return dict( + title=schema.get("title"), + description=schema.get("description"), + default=schema.get("default"), + ) + + def _transform_numeric(self, schema, cls): + return cls( + minimum=schema.get("minimum"), + maximum=schema.get("maximum"), + exclusive_minimum=schema.get("exclusiveMinimum"), + exclusive_maximum=schema.get("exclusiveMaximum"), + multiple_of=schema.get("multipleOf"), + **self._get_common_attributes(schema), + ) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index c1bed91..30663af 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -143,7 +143,7 @@ def dump_json(self, value: ty.Any, **override_kwargs: ty.Unpack[ExportKwargs]) - union_kwargs = ChainMap(override_kwargs, self._dump_python_kwargs) # type: ignore return self.type_adapter.dump_json(value, **union_kwargs) - def json_schema(self) -> ty.Any: + def json_schema(self) -> dict[str, ty.Any]: """Return the JSON schema for the field.""" by_alias = self.export_kwargs.get("by_alias", True) return self.type_adapter.json_schema(by_alias=by_alias) diff --git a/tests/v2/test_rest_framework.py b/tests/v2/test_rest_framework.py index 65147b9..e7ad27a 100644 --- a/tests/v2/test_rest_framework.py +++ b/tests/v2/test_rest_framework.py @@ -1,10 +1,13 @@ import io import typing as t from datetime import date +from types import SimpleNamespace import pytest -from rest_framework import exceptions, generics, serializers, views +from django.urls import path +from rest_framework import exceptions, generics, schemas, serializers, views from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema +from rest_framework.request import Request from rest_framework.response import Response from tests.conftest import InnerSchema @@ -136,7 +139,7 @@ def test_schema_parser(): assert parser.parse(io.StringIO(existing_encoded)) == expected_instance -@api_view(["POST"]) +@api_view(["GET", "POST"]) @schema(coreapi.AutoSchema()) @parser_classes([rest_framework.SchemaParser[InnerSchema]]) @renderer_classes([rest_framework.SchemaRenderer[t.List[InnerSchema]]]) @@ -145,7 +148,7 @@ def sample_view(request): return Response([request.data]) -class ClassBasedViewWithSerializer(generics.RetrieveAPIView): +class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): serializer_class = SampleSerializer schema = coreapi.AutoSchema() @@ -216,3 +219,27 @@ def test_end_to_end_list_create_api_view(request_factory): request = request_factory.get("/", content_type="application/json") response = ClassBasedViewWithModel.as_view()(request) assert response.data == [expected_result] + + +urlconf = SimpleNamespace( + urlpatterns=[ + path("/func", sample_view), + path("/class", ClassBasedViewWithSerializer.as_view()), + ], +) + + +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/func"), + ("POST", "/func"), + ("GET", "/class"), + ("PUT", "/class"), + ], +) +def test_coreapi_schema_generators(request_factory, method, path): + generator = schemas.SchemaGenerator(urlconf=urlconf) + request = Request(request_factory.generic(method, path)) + coreapi_schema = generator.get_schema(request) + assert coreapi_schema From b4ff98ad61ffbc6b3bf6623e14e118c4eef7a3a1 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Thu, 28 Dec 2023 02:25:54 +0400 Subject: [PATCH 30/34] Update out-of-module compat imports --- django_pydantic_field/v1/__init__.py | 2 +- django_pydantic_field/v1/fields.py | 3 ++- django_pydantic_field/v1/rest_framework.py | 2 +- django_pydantic_field/v2/__init__.py | 2 +- django_pydantic_field/v2/fields.py | 7 ++++--- django_pydantic_field/v2/forms.py | 4 ++-- django_pydantic_field/v2/rest_framework/fields.py | 7 ++++--- django_pydantic_field/v2/rest_framework/mixins.py | 4 ++-- django_pydantic_field/v2/types.py | 4 ++-- 9 files changed, 19 insertions(+), 16 deletions(-) diff --git a/django_pydantic_field/v1/__init__.py b/django_pydantic_field/v1/__init__.py index 91323b9..00c87c8 100644 --- a/django_pydantic_field/v1/__init__.py +++ b/django_pydantic_field/v1/__init__.py @@ -1,4 +1,4 @@ -from ..compat.pydantic import PYDANTIC_V1 +from django_pydantic_field.compat.pydantic import PYDANTIC_V1 if not PYDANTIC_V1: raise ImportError("django_pydantic_field.v1 package is only compatible with Pydantic v1") diff --git a/django_pydantic_field/v1/fields.py b/django_pydantic_field/v1/fields.py index 5dcf21f..188ecf4 100644 --- a/django_pydantic_field/v1/fields.py +++ b/django_pydantic_field/v1/fields.py @@ -9,8 +9,9 @@ from django.db.models.fields.json import JSONField from django.db.models.query_utils import DeferredAttribute +from django_pydantic_field.compat.django import GenericContainer, GenericTypes + from . import base, forms, utils -from ..compat.django import GenericContainer, GenericTypes __all__ = ("SchemaField",) diff --git a/django_pydantic_field/v1/rest_framework.py b/django_pydantic_field/v1/rest_framework.py index 234f8b9..465bfe4 100644 --- a/django_pydantic_field/v1/rest_framework.py +++ b/django_pydantic_field/v1/rest_framework.py @@ -8,7 +8,7 @@ from rest_framework.schemas.utils import is_list_view from . import base -from ..compat.typing import get_args +from django_pydantic_field.compat.typing import get_args __all__ = ( "SchemaField", diff --git a/django_pydantic_field/v2/__init__.py b/django_pydantic_field/v2/__init__.py index 91b48f0..c905b8b 100644 --- a/django_pydantic_field/v2/__init__.py +++ b/django_pydantic_field/v2/__init__.py @@ -1,4 +1,4 @@ -from ..compat.pydantic import PYDANTIC_V2 +from django_pydantic_field.compat.pydantic import PYDANTIC_V2 if not PYDANTIC_V2: raise ImportError("django_pydantic_field.v2 package is only compatible with Pydantic v2") diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 4888deb..d8e6a8e 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -11,8 +11,9 @@ from django.db.models.lookups import Transform from django.db.models.query_utils import DeferredAttribute -from ..compat.deprecation import truncate_deprecated_v1_export_kwargs -from ..compat.django import GenericContainer +from django_pydantic_field.compat import deprecation +from django_pydantic_field.compat.django import GenericContainer + from . import forms, types if ty.TYPE_CHECKING: @@ -227,5 +228,5 @@ def SchemaField( def SchemaField(schema=None, config=None, default=NOT_PROVIDED, *args, **kwargs): # type: ignore - truncate_deprecated_v1_export_kwargs(kwargs) + deprecation.truncate_deprecated_v1_export_kwargs(kwargs) return PydanticSchemaField(*args, schema=schema, config=config, default=default, **kwargs) diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py index d2fb06c..b2c49d2 100644 --- a/django_pydantic_field/v2/forms.py +++ b/django_pydantic_field/v2/forms.py @@ -7,7 +7,7 @@ from django.forms.fields import InvalidJSONInput, JSONField, JSONString from django.utils.translation import gettext_lazy as _ -from ..compat.deprecation import truncate_deprecated_v1_export_kwargs +from django_pydantic_field.compat import deprecation from . import types @@ -25,7 +25,7 @@ def __init__( *args, **kwargs, ): - truncate_deprecated_v1_export_kwargs(kwargs) + deprecation.truncate_deprecated_v1_export_kwargs(kwargs) self.schema = schema self.config = config diff --git a/django_pydantic_field/v2/rest_framework/fields.py b/django_pydantic_field/v2/rest_framework/fields.py index 1303389..b15aa43 100644 --- a/django_pydantic_field/v2/rest_framework/fields.py +++ b/django_pydantic_field/v2/rest_framework/fields.py @@ -5,8 +5,8 @@ import pydantic from rest_framework import exceptions, fields -from ...compat.deprecation import truncate_deprecated_v1_export_kwargs -from .. import types +from django_pydantic_field.compat import deprecation +from django_pydantic_field.v2 import types if ty.TYPE_CHECKING: from collections.abc import Mapping @@ -27,7 +27,8 @@ def __init__( allow_null: bool = False, **kwargs, ): - truncate_deprecated_v1_export_kwargs(kwargs) + deprecation.truncate_deprecated_v1_export_kwargs(kwargs) + self.schema = schema self.config = config self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) diff --git a/django_pydantic_field/v2/rest_framework/mixins.py b/django_pydantic_field/v2/rest_framework/mixins.py index 6f3d6ae..47bf75d 100644 --- a/django_pydantic_field/v2/rest_framework/mixins.py +++ b/django_pydantic_field/v2/rest_framework/mixins.py @@ -2,8 +2,8 @@ import typing as ty -from ...compat.typing import get_args -from .. import types +from django_pydantic_field.compat.typing import get_args +from django_pydantic_field.v2 import types if ty.TYPE_CHECKING: from collections.abc import Mapping diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 30663af..7cfa372 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -6,8 +6,8 @@ import pydantic import typing_extensions as te -from ..compat.django import GenericContainer -from ..compat.functools import cached_property +from django_pydantic_field.compat.django import GenericContainer +from django_pydantic_field.compat.functools import cached_property from . import utils if ty.TYPE_CHECKING: From 1dbbb08b71e4577f68f118e0d842bd0e6f4b8967 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Thu, 28 Dec 2023 02:26:16 +0400 Subject: [PATCH 31/34] Add initial OpenAPI schema generator --- .../v2/rest_framework/openapi.py | 74 ++++++++++++++++++- pyproject.toml | 3 +- 2 files changed, 75 insertions(+), 2 deletions(-) diff --git a/django_pydantic_field/v2/rest_framework/openapi.py b/django_pydantic_field/v2/rest_framework/openapi.py index 3299bd5..597cce3 100644 --- a/django_pydantic_field/v2/rest_framework/openapi.py +++ b/django_pydantic_field/v2/rest_framework/openapi.py @@ -1,5 +1,77 @@ +from __future__ import annotations + +import typing as ty + +import pydantic + +from rest_framework import serializers from rest_framework.schemas import openapi +from .fields import SchemaField + +if ty.TYPE_CHECKING: + from pydantic.json_schema import JsonSchemaMode + class AutoSchema(openapi.AutoSchema): - """Not implemented yet.""" + _SCHEMA_REF_TEMPLATE_PREFIX = "#/components/schemas/{model}" + + def __init__(self, tags=None, operation_id_base=None, component_name=None) -> None: + super().__init__(tags, operation_id_base, component_name) + self.collected_schema_defs = {} + + def get_components(self, path: str, method: str) -> dict[str, ty.Any]: + if method.lower() == "delete": + return {} + + request_serializer = self.get_request_serializer(path, method) + response_serializer = self.get_response_serializer(path, method) + components = {} + + if isinstance(request_serializer, serializers.Serializer): + component_name = self.get_component_name(request_serializer) + content = self.map_serializer(request_serializer, "validation") + components.setdefault(component_name, content) + + if isinstance(response_serializer, serializers.Serializer): + component_name = self.get_component_name(response_serializer) + content = self.map_serializer(response_serializer, "serialization") + components.setdefault(component_name, content) + + if self.collected_schema_defs: + components.update(self.collected_schema_defs) + self.collected_schema_defs = {} + + return components + + def map_serializer( + self, + serializer: serializers.Serializer, + mode: JsonSchemaMode = "validation", + ) -> dict[str, ty.Any]: + component_content = super().map_serializer(serializer) + schema_fields_adapters = [] + + for field in serializer.fields.values(): + if isinstance(field, SchemaField): + schema_fields_adapters.append((field.field_name, mode, field.adapter.type_adapter)) + + if schema_fields_adapters: + field_schemas, common_schemas = pydantic.TypeAdapter.json_schemas( + schema_fields_adapters, + ref_template=self._SCHEMA_REF_TEMPLATE_PREFIX, + ) + for (field_name, _), field_schema in field_schemas.items(): + component_content["properties"][field_name] = field_schema + + self.collected_schema_defs.update(common_schemas.get("$defs", {})) + + return component_content + + def map_parsers(self, path: str, method: str) -> list[str]: + # TODO: Implmenent SchemaParser + return super().map_parsers(path, method) + + def map_renderers(self, path: str, method: str) -> list[str]: + # TODO: Implement SchemaRenderer + return super().map_renderers(path, method) diff --git a/pyproject.toml b/pyproject.toml index 97399d9..6e729e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ dependencies = [ [project.optional-dependencies] openapi = ["uritemplate"] +coreapi = ["coreapi"] dev = [ "build", "black", @@ -61,7 +62,7 @@ dev = [ "pytest-django>=4.5,<5", ] test = [ - "django_pydantic_field[openapi]", + "django_pydantic_field[openapi,coreapi]", "dj-database-url~=2.0", "djangorestframework>=3,<4", "pyyaml", From ad15ee16ef1e19f837f3617d4f9eba818972e0b5 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 3 Jan 2024 00:35:04 +0400 Subject: [PATCH 32/34] OpenAPI schema generation for parsers/renderers. --- .../v2/rest_framework/coreapi.py | 2 +- .../v2/rest_framework/mixins.py | 1 + .../v2/rest_framework/openapi.py | 193 ++++++++++++++---- tests/v2/test_rest_framework.py | 2 + 4 files changed, 157 insertions(+), 41 deletions(-) diff --git a/django_pydantic_field/v2/rest_framework/coreapi.py b/django_pydantic_field/v2/rest_framework/coreapi.py index 5a1412f..bc1fe79 100644 --- a/django_pydantic_field/v2/rest_framework/coreapi.py +++ b/django_pydantic_field/v2/rest_framework/coreapi.py @@ -8,7 +8,7 @@ from .fields import SchemaField if ty.TYPE_CHECKING: - from coreschema.schemas import Schema as _CoreAPISchema + from coreschema.schemas import Schema as _CoreAPISchema # type: ignore[import-untyped] from rest_framework.serializers import Serializer __all__ = ("AutoSchema",) diff --git a/django_pydantic_field/v2/rest_framework/mixins.py b/django_pydantic_field/v2/rest_framework/mixins.py index 47bf75d..a8a8c59 100644 --- a/django_pydantic_field/v2/rest_framework/mixins.py +++ b/django_pydantic_field/v2/rest_framework/mixins.py @@ -12,6 +12,7 @@ class AnnotatedAdapterMixin(ty.Generic[types.ST]): + media_type: ty.ClassVar[str] schema_context_key: ty.ClassVar[str] = "response_schema" config_context_key: ty.ClassVar[str] = "response_schema_config" diff --git a/django_pydantic_field/v2/rest_framework/openapi.py b/django_pydantic_field/v2/rest_framework/openapi.py index 597cce3..3c741f5 100644 --- a/django_pydantic_field/v2/rest_framework/openapi.py +++ b/django_pydantic_field/v2/rest_framework/openapi.py @@ -3,75 +3,188 @@ import typing as ty import pydantic - +import weakref from rest_framework import serializers -from rest_framework.schemas import openapi +from rest_framework.schemas import openapi, utils as drf_schema_utils +from rest_framework.test import APIRequestFactory -from .fields import SchemaField +from . import fields, parsers, renderers if ty.TYPE_CHECKING: + from collections.abc import Iterable + from pydantic.json_schema import JsonSchemaMode + from . import mixins + class AutoSchema(openapi.AutoSchema): - _SCHEMA_REF_TEMPLATE_PREFIX = "#/components/schemas/{model}" + REF_TEMPLATE_PREFIX = "#/components/schemas/{model}" def __init__(self, tags=None, operation_id_base=None, component_name=None) -> None: super().__init__(tags, operation_id_base, component_name) - self.collected_schema_defs = {} + self.collected_schema_defs: dict[str, ty.Any] = {} + self.adapter_type_to_schema_refs = weakref.WeakKeyDictionary[type, str]() + self.adapter_mode: JsonSchemaMode = "validation" + self.rf = APIRequestFactory() def get_components(self, path: str, method: str) -> dict[str, ty.Any]: if method.lower() == "delete": return {} - request_serializer = self.get_request_serializer(path, method) - response_serializer = self.get_response_serializer(path, method) - components = {} - - if isinstance(request_serializer, serializers.Serializer): - component_name = self.get_component_name(request_serializer) - content = self.map_serializer(request_serializer, "validation") - components.setdefault(component_name, content) + super().get_components - if isinstance(response_serializer, serializers.Serializer): - component_name = self.get_component_name(response_serializer) - content = self.map_serializer(response_serializer, "serialization") - components.setdefault(component_name, content) + request_serializer = self.get_request_serializer(path, method) # type: ignore[attr-defined] + response_serializer = self.get_response_serializer(path, method) # type: ignore[attr-defined] + components = { + **self._collect_serializer_component(response_serializer, "serialization"), + **self._collect_serializer_component(request_serializer, "validation"), + } if self.collected_schema_defs: components.update(self.collected_schema_defs) self.collected_schema_defs = {} return components - def map_serializer( - self, - serializer: serializers.Serializer, - mode: JsonSchemaMode = "validation", - ) -> dict[str, ty.Any]: + def get_request_body(self, path, method): + if method not in ("PUT", "PATCH", "POST"): + return {} + + self.request_media_types = self.map_parsers(path, method) + + request_schema = {} + serializer = self.get_request_serializer(path, method) + if isinstance(serializer, serializers.Serializer): + request_schema = self.get_reference(serializer) + + schema_content = {} + + for parser, ct in zip(self.view.parser_classes, self.request_media_types): + if issubclass(parser, parsers.SchemaParser): + ref_path = self._get_component_ref(self.adapter_type_to_schema_refs[parser]) + schema_content[ct] = {"schema": {"$ref": ref_path}} + else: + schema_content[ct] = request_schema + + return {"content": schema_content} + + def get_responses(self, path, method): + if method == "DELETE": + return {"204": {"description": ""}} + + self.response_media_types = self.map_renderers(path, method) + serializer = self.get_response_serializer(path, method) + + item_schema = {} + if isinstance(serializer, serializers.Serializer): + item_schema = self.get_reference(serializer) + + if drf_schema_utils.is_list_view(path, method, self.view): + response_schema = {"type": "array", "items": item_schema} + paginator = self.get_paginator() + if paginator: + response_schema = paginator.get_paginated_response_schema(response_schema) + else: + response_schema = item_schema + + schema_content = {} + for renderer, ct in zip(self.view.renderer_classes, self.response_media_types): + if issubclass(renderer, renderers.SchemaRenderer): + ref_path = self._get_component_ref(self.adapter_type_to_schema_refs[renderer]) + schema_content[ct] = {"schema": {"$ref": ref_path}} + else: + schema_content[ct] = response_schema + + status_code = "201" if method == "POST" else "200" + return { + status_code: { + "content": schema_content, + "description": "", + } + } + + def map_parsers(self, path: str, method: str) -> list[str]: + schema_parsers = [] + media_types = [] + + for parser in self.view.parser_classes: + media_types.append(parser.media_type) + if issubclass(parser, parsers.SchemaParser): + schema_parsers.append(parser()) + + if schema_parsers: + self.adapter_mode = "validation" + request = self.rf.generic(method, path) + schemas = self._collect_adapter_components(schema_parsers, self.view.get_parser_context(request)) + self.collected_schema_defs.update(schemas) + + return media_types + + def map_renderers(self, path: str, method: str) -> list[str]: + schema_renderers = [] + media_types = [] + + for renderer in self.view.renderer_classes: + media_types.append(renderer.media_type) + if issubclass(renderer, renderers.SchemaRenderer): + schema_renderers.append(renderer()) + + if schema_renderers: + self.adapter_mode = "serialization" + schemas = self._collect_adapter_components(schema_renderers, self.view.get_renderer_context()) + self.collected_schema_defs.update(schemas) + + return media_types + + def map_serializer(self, serializer): component_content = super().map_serializer(serializer) - schema_fields_adapters = [] + field_adapters = [] for field in serializer.fields.values(): - if isinstance(field, SchemaField): - schema_fields_adapters.append((field.field_name, mode, field.adapter.type_adapter)) - - if schema_fields_adapters: - field_schemas, common_schemas = pydantic.TypeAdapter.json_schemas( - schema_fields_adapters, - ref_template=self._SCHEMA_REF_TEMPLATE_PREFIX, - ) - for (field_name, _), field_schema in field_schemas.items(): - component_content["properties"][field_name] = field_schema + if isinstance(field, fields.SchemaField): + field_adapters.append((field.field_name, self.adapter_mode, field.adapter.type_adapter)) - self.collected_schema_defs.update(common_schemas.get("$defs", {})) + if field_adapters: + field_schemas = self._collect_type_adapter_schemas(field_adapters) + for field_name, field_schema in field_schemas.items(): + component_content["properties"][field_name] = field_schema return component_content - def map_parsers(self, path: str, method: str) -> list[str]: - # TODO: Implmenent SchemaParser - return super().map_parsers(path, method) + def _collect_serializer_component(self, serializer: serializers.BaseSerializer | None, mode: JsonSchemaMode): + schema_definition = {} + if isinstance(serializer, serializers.Serializer): + self.adapter_mode = mode + component_name = self.get_component_name(serializer) + schema_definition[component_name] = self.map_serializer(serializer) + return schema_definition - def map_renderers(self, path: str, method: str) -> list[str]: - # TODO: Implement SchemaRenderer - return super().map_renderers(path, method) + def _collect_adapter_components(self, components: Iterable[mixins.AnnotatedAdapterMixin], context: dict): + type_adapters = [] + + for component in components: + schema_adapter = component.get_adapter(context) + if schema_adapter is not None: + schema_name = schema_adapter.prepared_schema.__class__.__name__ + self.adapter_type_to_schema_refs[type(component)] = schema_name + + type_adapters.append((schema_name, self.adapter_mode, schema_adapter.type_adapter)) + + if type_adapters: + return self._collect_type_adapter_schemas(type_adapters) + + return {} + + def _collect_type_adapter_schemas(self, adapters: Iterable[tuple[str, JsonSchemaMode, pydantic.TypeAdapter]]): + inner_schemas = {} + + schemas, common_schemas = pydantic.TypeAdapter.json_schemas(adapters, ref_template=self.REF_TEMPLATE_PREFIX) + for (field_name, _), field_schema in schemas.items(): + inner_schemas[field_name] = field_schema + + self.collected_schema_defs.update(common_schemas.get("$defs", {})) + return inner_schemas + + def _get_component_ref(self, model: str): + return self.REF_TEMPLATE_PREFIX.format(model=model) diff --git a/tests/v2/test_rest_framework.py b/tests/v2/test_rest_framework.py index e7ad27a..e2837a7 100644 --- a/tests/v2/test_rest_framework.py +++ b/tests/v2/test_rest_framework.py @@ -1,4 +1,5 @@ import io +import sys import typing as t from datetime import date from types import SimpleNamespace @@ -229,6 +230,7 @@ def test_end_to_end_list_create_api_view(request_factory): ) +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="CoreAPI is not compatible with 3.12") @pytest.mark.parametrize( "method, path", [ From 6a674cbd454a69fdcdfa2aac1e5d56563dc492a9 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 3 Jan 2024 01:04:02 +0400 Subject: [PATCH 33/34] Split v2.rest_framework package tests. --- .../v2/rest_framework/openapi.py | 63 +++-- tests/v2/rest_framework/__init__.py | 0 tests/v2/rest_framework/test_coreapi.py | 26 ++ tests/v2/rest_framework/test_e2e_views.py | 56 ++++ tests/v2/rest_framework/test_fields.py | 108 ++++++++ tests/v2/rest_framework/test_openapi.py | 23 ++ tests/v2/rest_framework/test_parsers.py | 23 ++ tests/v2/rest_framework/test_renderers.py | 23 ++ tests/v2/rest_framework/view_fixtures.py | 88 +++++++ tests/v2/test_rest_framework.py | 247 ------------------ 10 files changed, 378 insertions(+), 279 deletions(-) create mode 100644 tests/v2/rest_framework/__init__.py create mode 100644 tests/v2/rest_framework/test_coreapi.py create mode 100644 tests/v2/rest_framework/test_e2e_views.py create mode 100644 tests/v2/rest_framework/test_fields.py create mode 100644 tests/v2/rest_framework/test_openapi.py create mode 100644 tests/v2/rest_framework/test_parsers.py create mode 100644 tests/v2/rest_framework/test_renderers.py create mode 100644 tests/v2/rest_framework/view_fixtures.py delete mode 100644 tests/v2/test_rest_framework.py diff --git a/django_pydantic_field/v2/rest_framework/openapi.py b/django_pydantic_field/v2/rest_framework/openapi.py index 3c741f5..2965e41 100644 --- a/django_pydantic_field/v2/rest_framework/openapi.py +++ b/django_pydantic_field/v2/rest_framework/openapi.py @@ -24,7 +24,7 @@ class AutoSchema(openapi.AutoSchema): def __init__(self, tags=None, operation_id_base=None, component_name=None) -> None: super().__init__(tags, operation_id_base, component_name) self.collected_schema_defs: dict[str, ty.Any] = {} - self.adapter_type_to_schema_refs = weakref.WeakKeyDictionary[type, str]() + self.collected_adapter_schema_refs: dict[str, ty.Any] = {} self.adapter_mode: JsonSchemaMode = "validation" self.rf = APIRequestFactory() @@ -32,8 +32,6 @@ def get_components(self, path: str, method: str) -> dict[str, ty.Any]: if method.lower() == "delete": return {} - super().get_components - request_serializer = self.get_request_serializer(path, method) # type: ignore[attr-defined] response_serializer = self.get_response_serializer(path, method) # type: ignore[attr-defined] @@ -61,9 +59,9 @@ def get_request_body(self, path, method): schema_content = {} for parser, ct in zip(self.view.parser_classes, self.request_media_types): - if issubclass(parser, parsers.SchemaParser): - ref_path = self._get_component_ref(self.adapter_type_to_schema_refs[parser]) - schema_content[ct] = {"schema": {"$ref": ref_path}} + if isinstance(parser(), parsers.SchemaParser): + parser_schema = self.collected_adapter_schema_refs[repr(parser)] + schema_content[ct] = {"schema": parser_schema} else: schema_content[ct] = request_schema @@ -76,23 +74,21 @@ def get_responses(self, path, method): self.response_media_types = self.map_renderers(path, method) serializer = self.get_response_serializer(path, method) - item_schema = {} + response_schema = {} if isinstance(serializer, serializers.Serializer): - item_schema = self.get_reference(serializer) + response_schema = self.get_reference(serializer) - if drf_schema_utils.is_list_view(path, method, self.view): - response_schema = {"type": "array", "items": item_schema} - paginator = self.get_paginator() - if paginator: - response_schema = paginator.get_paginated_response_schema(response_schema) - else: - response_schema = item_schema + is_list_view = drf_schema_utils.is_list_view(path, method, self.view) + if is_list_view: + response_schema = self._get_paginated_schema(response_schema) schema_content = {} for renderer, ct in zip(self.view.renderer_classes, self.response_media_types): - if issubclass(renderer, renderers.SchemaRenderer): - ref_path = self._get_component_ref(self.adapter_type_to_schema_refs[renderer]) - schema_content[ct] = {"schema": {"$ref": ref_path}} + if isinstance(renderer(), renderers.SchemaRenderer): + renderer_schema = {"schema": self.collected_adapter_schema_refs[repr(renderer)]} + if is_list_view: + renderer_schema = self._get_paginated_schema(renderer_schema) + schema_content[ct] = renderer_schema else: schema_content[ct] = response_schema @@ -110,14 +106,15 @@ def map_parsers(self, path: str, method: str) -> list[str]: for parser in self.view.parser_classes: media_types.append(parser.media_type) - if issubclass(parser, parsers.SchemaParser): - schema_parsers.append(parser()) + instance = parser() + if isinstance(instance, parsers.SchemaParser): + schema_parsers.append(parser) if schema_parsers: self.adapter_mode = "validation" request = self.rf.generic(method, path) schemas = self._collect_adapter_components(schema_parsers, self.view.get_parser_context(request)) - self.collected_schema_defs.update(schemas) + self.collected_adapter_schema_refs.update(schemas) return media_types @@ -127,13 +124,14 @@ def map_renderers(self, path: str, method: str) -> list[str]: for renderer in self.view.renderer_classes: media_types.append(renderer.media_type) - if issubclass(renderer, renderers.SchemaRenderer): - schema_renderers.append(renderer()) + instance = renderer() + if isinstance(instance, renderers.SchemaRenderer): + schema_renderers.append(renderer) if schema_renderers: self.adapter_mode = "serialization" schemas = self._collect_adapter_components(schema_renderers, self.view.get_renderer_context()) - self.collected_schema_defs.update(schemas) + self.collected_adapter_schema_refs.update(schemas) return media_types @@ -160,16 +158,13 @@ def _collect_serializer_component(self, serializer: serializers.BaseSerializer | schema_definition[component_name] = self.map_serializer(serializer) return schema_definition - def _collect_adapter_components(self, components: Iterable[mixins.AnnotatedAdapterMixin], context: dict): + def _collect_adapter_components(self, components: Iterable[type[mixins.AnnotatedAdapterMixin]], context: dict): type_adapters = [] for component in components: - schema_adapter = component.get_adapter(context) + schema_adapter = component().get_adapter(context) if schema_adapter is not None: - schema_name = schema_adapter.prepared_schema.__class__.__name__ - self.adapter_type_to_schema_refs[type(component)] = schema_name - - type_adapters.append((schema_name, self.adapter_mode, schema_adapter.type_adapter)) + type_adapters.append((repr(component), self.adapter_mode, schema_adapter.type_adapter)) if type_adapters: return self._collect_type_adapter_schemas(type_adapters) @@ -186,5 +181,9 @@ def _collect_type_adapter_schemas(self, adapters: Iterable[tuple[str, JsonSchema self.collected_schema_defs.update(common_schemas.get("$defs", {})) return inner_schemas - def _get_component_ref(self, model: str): - return self.REF_TEMPLATE_PREFIX.format(model=model) + def _get_paginated_schema(self, schema) -> ty.Any: + response_schema = {"type": "array", "items": schema} + paginator = self.get_paginator() + if paginator: + response_schema = paginator.get_paginated_response_schema(response_schema) # type: ignore + return response_schema diff --git a/tests/v2/rest_framework/__init__.py b/tests/v2/rest_framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/v2/rest_framework/test_coreapi.py b/tests/v2/rest_framework/test_coreapi.py new file mode 100644 index 0000000..c3e2669 --- /dev/null +++ b/tests/v2/rest_framework/test_coreapi.py @@ -0,0 +1,26 @@ +import sys + +import pytest +from rest_framework import schemas +from rest_framework.request import Request + +from .view_fixtures import create_views_urlconf + +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="CoreAPI is not compatible with 3.12") +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/func"), + ("POST", "/func"), + ("GET", "/class"), + ("PUT", "/class"), + ], +) +def test_coreapi_schema_generators(request_factory, method, path): + urlconf = create_views_urlconf(coreapi.AutoSchema) + generator = schemas.SchemaGenerator(urlconf=urlconf) + request = Request(request_factory.generic(method, path)) + coreapi_schema = generator.get_schema(request) + assert coreapi_schema diff --git a/tests/v2/rest_framework/test_e2e_views.py b/tests/v2/rest_framework/test_e2e_views.py new file mode 100644 index 0000000..4790263 --- /dev/null +++ b/tests/v2/rest_framework/test_e2e_views.py @@ -0,0 +1,56 @@ +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +from .view_fixtures import ( + ClassBasedView, + ClassBasedViewWithModel, + ClassBasedViewWithSchemaContext, + sample_view, +) + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + + +@pytest.mark.parametrize( + "view", + [ + sample_view, + ClassBasedView.as_view(), + ClassBasedViewWithSchemaContext.as_view(), + ], +) +def test_end_to_end_api_view(view, request_factory): + expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + existing_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + request = request_factory.post("/", existing_encoded, content_type="application/json") + response = view(request) + + assert response.data == [expected_instance] + assert response.data[0] is not expected_instance + + assert response.rendered_content == b"[%s]" % existing_encoded + + +@pytest.mark.django_db +def test_end_to_end_list_create_api_view(request_factory): + field_data = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]).json() + expected_result = { + "sample_field": {"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}, + "sample_list": [{"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}], + "sample_seq": [], + } + + payload = '{"sample_field": %s, "sample_list": [%s], "sample_seq": []}' % ((field_data,) * 2) + request = request_factory.post("/", payload.encode(), content_type="application/json") + response = ClassBasedViewWithModel.as_view()(request) + + assert response.data == expected_result + + request = request_factory.get("/", content_type="application/json") + response = ClassBasedViewWithModel.as_view()(request) + assert response.data == [expected_result] diff --git a/tests/v2/rest_framework/test_fields.py b/tests/v2/rest_framework/test_fields.py new file mode 100644 index 0000000..6098da6 --- /dev/null +++ b/tests/v2/rest_framework/test_fields.py @@ -0,0 +1,108 @@ +import typing as ty +from datetime import date + +import pytest +from rest_framework import exceptions, serializers + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +class SampleSerializer(serializers.Serializer): + field = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + + +class SampleModelSerializer(serializers.ModelSerializer): + sample_field = rest_framework.SchemaField(schema=InnerSchema) + sample_list = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + sample_seq = rest_framework.SchemaField(schema=ty.List[InnerSchema], default=list) + + class Meta: + model = SampleModel + fields = "sample_field", "sample_list", "sample_seq" + + +def test_schema_field(): + field = rest_framework.SchemaField(InnerSchema) + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = { + "stub_str": "abc", + "stub_int": 1, + "stub_list": [date(2022, 7, 1)], + } + + assert field.to_representation(existing_instance) == expected_encoded + assert field.to_internal_value(expected_encoded) == existing_instance + + with pytest.raises(serializers.ValidationError): + field.to_internal_value(None) + + with pytest.raises(serializers.ValidationError): + field.to_internal_value("null") + + +def test_field_schema_with_custom_config(): + field = rest_framework.SchemaField(InnerSchema, allow_null=True, exclude={"stub_int"}) + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + + assert field.to_representation(existing_instance) == expected_encoded + assert field.to_internal_value(expected_encoded) == existing_instance + assert field.to_internal_value(None) is None + assert field.to_internal_value("null") is None + + +def test_serializer_marshalling_with_schema_field(): + existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]} + expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}]} + + serializer = SampleSerializer(instance=existing_instance) + assert serializer.data == expected_data + + serializer = SampleSerializer(data=expected_data) + serializer.is_valid(raise_exception=True) + assert serializer.validated_data == existing_instance + + +def test_model_serializer_marshalling_with_schema_field(): + instance = SampleModel( + sample_field=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + sample_list=[InnerSchema(stub_str="abc", stub_int=2, stub_list=[date(2022, 7, 1)])] * 2, + sample_seq=[InnerSchema(stub_str="abc", stub_int=3, stub_list=[date(2022, 7, 1)])] * 3, + ) + serializer = SampleModelSerializer(instance) + + expected_data = { + "sample_field": {"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}, + "sample_list": [{"stub_str": "abc", "stub_int": 2, "stub_list": [date(2022, 7, 1)]}] * 2, + "sample_seq": [{"stub_str": "abc", "stub_int": 3, "stub_list": [date(2022, 7, 1)]}] * 3, + } + assert serializer.data == expected_data + + +@pytest.mark.parametrize( + "export_kwargs", + [ + {"include": {"stub_str", "stub_int"}}, + {"exclude": {"stub_list"}}, + {"exclude_unset": True}, + {"exclude_defaults": True}, + {"exclude_none": True}, + {"by_alias": True}, + ], +) +def test_field_export_kwargs(export_kwargs): + field = rest_framework.SchemaField(InnerSchema, **export_kwargs) + assert field.to_representation(InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])) + + +def test_invalid_data_serialization(): + invalid_data = {"field": [{"stub_int": "abc", "stub_list": ["abc"]}]} + serializer = SampleSerializer(data=invalid_data) + + with pytest.raises(exceptions.ValidationError) as e: + serializer.is_valid(raise_exception=True) + + assert e.match(r".*stub_str.*stub_int.*stub_list.*") diff --git a/tests/v2/rest_framework/test_openapi.py b/tests/v2/rest_framework/test_openapi.py new file mode 100644 index 0000000..a0d8bc6 --- /dev/null +++ b/tests/v2/rest_framework/test_openapi.py @@ -0,0 +1,23 @@ +import pytest +from rest_framework.schemas.openapi import SchemaGenerator +from rest_framework.request import Request + +from .view_fixtures import create_views_urlconf + +openapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.openapi") + +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/func"), + ("POST", "/func"), + ("GET", "/class"), + ("PUT", "/class"), + ], +) +def test_coreapi_schema_generators(request_factory, method, path): + urlconf = create_views_urlconf(openapi.AutoSchema) + generator = SchemaGenerator(urlconf=urlconf) + request = Request(request_factory.generic(method, path)) + openapi_schema = generator.get_schema(request) + assert openapi_schema diff --git a/tests/v2/rest_framework/test_parsers.py b/tests/v2/rest_framework/test_parsers.py new file mode 100644 index 0000000..db64a56 --- /dev/null +++ b/tests/v2/rest_framework/test_parsers.py @@ -0,0 +1,23 @@ +import io +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +@pytest.mark.parametrize( + "schema_type, existing_encoded, expected_decoded", + [ + ( + InnerSchema, + '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}', + InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + ) + ], +) +def test_schema_parser(schema_type, existing_encoded, expected_decoded): + parser = rest_framework.SchemaParser[schema_type]() + assert parser.parse(io.StringIO(existing_encoded)) == expected_decoded diff --git a/tests/v2/rest_framework/test_renderers.py b/tests/v2/rest_framework/test_renderers.py new file mode 100644 index 0000000..59d4aed --- /dev/null +++ b/tests/v2/rest_framework/test_renderers.py @@ -0,0 +1,23 @@ +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +def test_schema_renderer(): + renderer = rest_framework.SchemaRenderer() + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + assert renderer.render(existing_instance) == expected_encoded + + +def test_typed_schema_renderer(): + renderer = rest_framework.SchemaRenderer[InnerSchema]() + existing_data = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + assert renderer.render(existing_data) == expected_encoded diff --git a/tests/v2/rest_framework/view_fixtures.py b/tests/v2/rest_framework/view_fixtures.py new file mode 100644 index 0000000..a55ddd6 --- /dev/null +++ b/tests/v2/rest_framework/view_fixtures.py @@ -0,0 +1,88 @@ +import typing as ty +from types import SimpleNamespace + +import pytest +from django.urls import path +from rest_framework import generics, serializers, views +from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema +from rest_framework.response import Response + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + + +class SampleSerializer(serializers.Serializer): + field = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + + +class SampleModelSerializer(serializers.ModelSerializer): + sample_field = rest_framework.SchemaField(schema=InnerSchema) + sample_list = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + sample_seq = rest_framework.SchemaField(schema=ty.List[InnerSchema], default=list) + + class Meta: + model = SampleModel + fields = "sample_field", "sample_list", "sample_seq" + + +class ClassBasedView(views.APIView): + parser_classes = [rest_framework.SchemaParser[InnerSchema]] + renderer_classes = [rest_framework.SchemaRenderer[ty.List[InnerSchema]]] + + def post(self, request, *args, **kwargs): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + +class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): + serializer_class = SampleSerializer + + +class ClassBasedViewWithModel(generics.ListCreateAPIView): + queryset = SampleModel.objects.all() + serializer_class = SampleModelSerializer + + +class ClassBasedViewWithSchemaContext(ClassBasedView): + parser_classes = [rest_framework.SchemaParser] + renderer_classes = [rest_framework.SchemaRenderer] + + def get_renderer_context(self): + ctx = super().get_renderer_context() + return dict(ctx, renderer_schema=ty.List[InnerSchema]) + + def get_parser_context(self, http_request): + ctx = super().get_parser_context(http_request) + return dict(ctx, parser_schema=InnerSchema) + + +@api_view(["GET", "POST"]) +@parser_classes([rest_framework.SchemaParser[InnerSchema]]) +@renderer_classes([rest_framework.SchemaRenderer[ty.List[InnerSchema]]]) +def sample_view(request): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + +def create_views_urlconf(schema_view_inspector): + @api_view(["GET", "POST"]) + @schema(schema_view_inspector()) + @parser_classes([rest_framework.SchemaParser[InnerSchema]]) + @renderer_classes([rest_framework.SchemaRenderer[ty.List[InnerSchema]]]) + def sample_view(request): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): + serializer_class = SampleSerializer + schema = schema_view_inspector() + + return SimpleNamespace( + urlpatterns=[ + path("/func", sample_view), + path("/class", ClassBasedViewWithSerializer.as_view()), + ], + ) diff --git a/tests/v2/test_rest_framework.py b/tests/v2/test_rest_framework.py deleted file mode 100644 index e2837a7..0000000 --- a/tests/v2/test_rest_framework.py +++ /dev/null @@ -1,247 +0,0 @@ -import io -import sys -import typing as t -from datetime import date -from types import SimpleNamespace - -import pytest -from django.urls import path -from rest_framework import exceptions, generics, schemas, serializers, views -from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema -from rest_framework.request import Request -from rest_framework.response import Response - -from tests.conftest import InnerSchema -from tests.test_app.models import SampleModel - -rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") -coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") - - -class SampleSerializer(serializers.Serializer): - field = rest_framework.SchemaField(schema=t.List[InnerSchema]) - - -class SampleModelSerializer(serializers.ModelSerializer): - sample_field = rest_framework.SchemaField(schema=InnerSchema) - sample_list = rest_framework.SchemaField(schema=t.List[InnerSchema]) - sample_seq = rest_framework.SchemaField(schema=t.List[InnerSchema], default=list) - - class Meta: - model = SampleModel - fields = "sample_field", "sample_list", "sample_seq" - - -def test_schema_field(): - field = rest_framework.SchemaField(InnerSchema) - existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - expected_encoded = { - "stub_str": "abc", - "stub_int": 1, - "stub_list": [date(2022, 7, 1)], - } - - assert field.to_representation(existing_instance) == expected_encoded - assert field.to_internal_value(expected_encoded) == existing_instance - - with pytest.raises(serializers.ValidationError): - field.to_internal_value(None) - - with pytest.raises(serializers.ValidationError): - field.to_internal_value("null") - - -def test_field_schema_with_custom_config(): - field = rest_framework.SchemaField(InnerSchema, allow_null=True, exclude={"stub_int"}) - existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - expected_encoded = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} - - assert field.to_representation(existing_instance) == expected_encoded - assert field.to_internal_value(expected_encoded) == existing_instance - assert field.to_internal_value(None) is None - assert field.to_internal_value("null") is None - - -def test_serializer_marshalling_with_schema_field(): - existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]} - expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}]} - - serializer = SampleSerializer(instance=existing_instance) - assert serializer.data == expected_data - - serializer = SampleSerializer(data=expected_data) - serializer.is_valid(raise_exception=True) - assert serializer.validated_data == existing_instance - - -def test_model_serializer_marshalling_with_schema_field(): - instance = SampleModel( - sample_field=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), - sample_list=[InnerSchema(stub_str="abc", stub_int=2, stub_list=[date(2022, 7, 1)])] * 2, - sample_seq=[InnerSchema(stub_str="abc", stub_int=3, stub_list=[date(2022, 7, 1)])] * 3, - ) - serializer = SampleModelSerializer(instance) - - expected_data = { - "sample_field": {"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}, - "sample_list": [{"stub_str": "abc", "stub_int": 2, "stub_list": [date(2022, 7, 1)]}] * 2, - "sample_seq": [{"stub_str": "abc", "stub_int": 3, "stub_list": [date(2022, 7, 1)]}] * 3, - } - assert serializer.data == expected_data - - -@pytest.mark.parametrize( - "export_kwargs", - [ - {"include": {"stub_str", "stub_int"}}, - {"exclude": {"stub_list"}}, - {"exclude_unset": True}, - {"exclude_defaults": True}, - {"exclude_none": True}, - {"by_alias": True}, - ], -) -def test_field_export_kwargs(export_kwargs): - field = rest_framework.SchemaField(InnerSchema, **export_kwargs) - assert field.to_representation(InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])) - - -def test_invalid_data_serialization(): - invalid_data = {"field": [{"stub_int": "abc", "stub_list": ["abc"]}]} - serializer = SampleSerializer(data=invalid_data) - - with pytest.raises(exceptions.ValidationError) as e: - serializer.is_valid(raise_exception=True) - - assert e.match(r".*stub_str.*stub_int.*stub_list.*") - - -def test_schema_renderer(): - renderer = rest_framework.SchemaRenderer() - existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' - - assert renderer.render(existing_instance) == expected_encoded - - -def test_typed_schema_renderer(): - renderer = rest_framework.SchemaRenderer[InnerSchema]() - existing_data = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} - expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' - - assert renderer.render(existing_data) == expected_encoded - - -def test_schema_parser(): - parser = rest_framework.SchemaParser[InnerSchema]() - existing_encoded = '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}' - expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - - assert parser.parse(io.StringIO(existing_encoded)) == expected_instance - - -@api_view(["GET", "POST"]) -@schema(coreapi.AutoSchema()) -@parser_classes([rest_framework.SchemaParser[InnerSchema]]) -@renderer_classes([rest_framework.SchemaRenderer[t.List[InnerSchema]]]) -def sample_view(request): - assert isinstance(request.data, InnerSchema) - return Response([request.data]) - - -class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): - serializer_class = SampleSerializer - schema = coreapi.AutoSchema() - - -class ClassBasedViewWithModel(generics.ListCreateAPIView): - queryset = SampleModel.objects.all() - serializer_class = SampleModelSerializer - - -class ClassBasedView(views.APIView): - parser_classes = [rest_framework.SchemaParser[InnerSchema]] - renderer_classes = [rest_framework.SchemaRenderer[t.List[InnerSchema]]] - - def post(self, request, *args, **kwargs): - assert isinstance(request.data, InnerSchema) - return Response([request.data]) - - -class ClassBasedViewWithSchemaContext(ClassBasedView): - parser_classes = [rest_framework.SchemaParser] - renderer_classes = [rest_framework.SchemaRenderer] - - def get_renderer_context(self): - ctx = super().get_renderer_context() - return dict(ctx, renderer_schema=t.List[InnerSchema]) - - def get_parser_context(self, http_request): - ctx = super().get_parser_context(http_request) - return dict(ctx, parser_schema=InnerSchema) - - -@pytest.mark.parametrize( - "view", - [ - sample_view, - ClassBasedView.as_view(), - ClassBasedViewWithSchemaContext.as_view(), - ], -) -def test_end_to_end_api_view(view, request_factory): - expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - existing_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' - - request = request_factory.post("/", existing_encoded, content_type="application/json") - response = view(request) - - assert response.data == [expected_instance] - assert response.data[0] is not expected_instance - - assert response.rendered_content == b"[%s]" % existing_encoded - - -@pytest.mark.django_db -def test_end_to_end_list_create_api_view(request_factory): - field_data = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]).json() - expected_result = { - "sample_field": {"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}, - "sample_list": [{"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}], - "sample_seq": [], - } - - payload = '{"sample_field": %s, "sample_list": [%s], "sample_seq": []}' % ((field_data,) * 2) - request = request_factory.post("/", payload.encode(), content_type="application/json") - response = ClassBasedViewWithModel.as_view()(request) - - assert response.data == expected_result - - request = request_factory.get("/", content_type="application/json") - response = ClassBasedViewWithModel.as_view()(request) - assert response.data == [expected_result] - - -urlconf = SimpleNamespace( - urlpatterns=[ - path("/func", sample_view), - path("/class", ClassBasedViewWithSerializer.as_view()), - ], -) - - -@pytest.mark.skipif(sys.version_info >= (3, 12), reason="CoreAPI is not compatible with 3.12") -@pytest.mark.parametrize( - "method, path", - [ - ("GET", "/func"), - ("POST", "/func"), - ("GET", "/class"), - ("PUT", "/class"), - ], -) -def test_coreapi_schema_generators(request_factory, method, path): - generator = schemas.SchemaGenerator(urlconf=urlconf) - request = Request(request_factory.generic(method, path)) - coreapi_schema = generator.get_schema(request) - assert coreapi_schema From d76d23c6351d1b2d5ca37b9679328b6699a8aedf Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Thu, 4 Jan 2024 23:31:11 +0400 Subject: [PATCH 34/34] Add snapshot tests for OpenAPI schema generators. --- .../v2/rest_framework/openapi.py | 20 +- django_pydantic_field/v2/utils.py | 9 + pyproject.toml | 3 +- tests/conftest.py | 6 + ..._openapi_schema_generators[GET-class].json | 217 ++++++++++++++++++ ...t_openapi_schema_generators[GET-func].json | 217 ++++++++++++++++++ ..._openapi_schema_generators[POST-func].json | 217 ++++++++++++++++++ ..._openapi_schema_generators[PUT-class].json | 217 ++++++++++++++++++ tests/v2/rest_framework/test_openapi.py | 5 +- tests/v2/test_types.py | 11 +- 10 files changed, 906 insertions(+), 16 deletions(-) create mode 100644 tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-class].json create mode 100644 tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-func].json create mode 100644 tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[POST-func].json create mode 100644 tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[PUT-class].json diff --git a/django_pydantic_field/v2/rest_framework/openapi.py b/django_pydantic_field/v2/rest_framework/openapi.py index 2965e41..4ccb265 100644 --- a/django_pydantic_field/v2/rest_framework/openapi.py +++ b/django_pydantic_field/v2/rest_framework/openapi.py @@ -3,12 +3,13 @@ import typing as ty import pydantic -import weakref from rest_framework import serializers -from rest_framework.schemas import openapi, utils as drf_schema_utils +from rest_framework.schemas import openapi +from rest_framework.schemas import utils as drf_schema_utils from rest_framework.test import APIRequestFactory from . import fields, parsers, renderers +from ..utils import get_origin_type if ty.TYPE_CHECKING: from collections.abc import Iterable @@ -59,11 +60,12 @@ def get_request_body(self, path, method): schema_content = {} for parser, ct in zip(self.view.parser_classes, self.request_media_types): - if isinstance(parser(), parsers.SchemaParser): + if issubclass(get_origin_type(parser), parsers.SchemaParser): parser_schema = self.collected_adapter_schema_refs[repr(parser)] - schema_content[ct] = {"schema": parser_schema} else: - schema_content[ct] = request_schema + parser_schema = request_schema + + schema_content[ct] = {"schema": parser_schema} return {"content": schema_content} @@ -84,7 +86,7 @@ def get_responses(self, path, method): schema_content = {} for renderer, ct in zip(self.view.renderer_classes, self.response_media_types): - if isinstance(renderer(), renderers.SchemaRenderer): + if issubclass(get_origin_type(renderer), renderers.SchemaRenderer): renderer_schema = {"schema": self.collected_adapter_schema_refs[repr(renderer)]} if is_list_view: renderer_schema = self._get_paginated_schema(renderer_schema) @@ -106,8 +108,7 @@ def map_parsers(self, path: str, method: str) -> list[str]: for parser in self.view.parser_classes: media_types.append(parser.media_type) - instance = parser() - if isinstance(instance, parsers.SchemaParser): + if issubclass(get_origin_type(parser), parsers.SchemaParser): schema_parsers.append(parser) if schema_parsers: @@ -124,8 +125,7 @@ def map_renderers(self, path: str, method: str) -> list[str]: for renderer in self.view.renderer_classes: media_types.append(renderer.media_type) - instance = renderer() - if isinstance(instance, renderers.SchemaRenderer): + if issubclass(get_origin_type(renderer), renderers.SchemaRenderer): schema_renderers.append(renderer) if schema_renderers: diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py index 24f1801..0a724b6 100644 --- a/django_pydantic_field/v2/utils.py +++ b/django_pydantic_field/v2/utils.py @@ -4,6 +4,8 @@ import typing as ty from collections import ChainMap +from django_pydantic_field.compat import typing + if ty.TYPE_CHECKING: from collections.abc import Mapping @@ -39,6 +41,13 @@ def get_local_namespace(cls) -> dict[str, ty.Any]: return {} +def get_origin_type(cls: type): + origin_tp = typing.get_origin(cls) + if origin_tp is not None: + return origin_tp + return cls + + if sys.version_info >= (3, 9): def evaluate_forward_ref(ref: ty.ForwardRef, ns: Mapping[str, ty.Any]) -> ty.Any: diff --git a/pyproject.toml b/pyproject.toml index 401caba..123cf2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "django-pydantic-field" -version = "0.3.0-alpha5" +version = "0.3.0-beta1" description = "Django JSONField with Pydantic models as a Schema" readme = "README.md" license = { file = "LICENSE" } @@ -66,6 +66,7 @@ test = [ "dj-database-url~=2.0", "djangorestframework>=3,<4", "pyyaml", + "syrupy>=3,<5", ] ci = [ 'psycopg[binary]>=3.1,<4; python_version>="3.9"', diff --git a/tests/conftest.py b/tests/conftest.py index 03c84b5..be81a87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import typing as t from datetime import date +from syrupy.extensions.json import JSONSnapshotExtension import pydantic import pytest @@ -70,3 +71,8 @@ def mysql_backend(settings): ) def available_database_backends(request, settings): yield request.param(settings) + + +@pytest.fixture +def snapshot_json(snapshot): + return snapshot.use_extension(JSONSnapshotExtension) diff --git a/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-class].json b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-class].json new file mode 100644 index 0000000..a0d5192 --- /dev/null +++ b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-class].json @@ -0,0 +1,217 @@ +{ + "components": { + "schemas": { + "InnerSchema": { + "properties": { + "stub_int": { + "default": 1, + "title": "Stub Int", + "type": "integer" + }, + "stub_list": { + "items": { + "format": "date", + "type": "string" + }, + "title": "Stub List", + "type": "array" + }, + "stub_str": { + "title": "Stub Str", + "type": "string" + } + }, + "required": [ + "stub_str", + "stub_list" + ], + "title": "InnerSchema", + "type": "object" + }, + "Sample": { + "properties": { + "field": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "required": [ + "field" + ], + "type": "object" + } + } + }, + "info": { + "title": "", + "version": "" + }, + "openapi": "3.0.2", + "paths": { + "/class": { + "get": { + "description": "", + "operationId": "retrieveSample", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "patch": { + "description": "", + "operationId": "partialUpdateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "put": { + "description": "", + "operationId": "updateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + } + }, + "/func": { + "get": { + "description": "", + "operationId": "listsample_views", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "items": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "type": "array" + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + }, + "post": { + "description": "", + "operationId": "createsample_view", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InnerSchema" + } + } + } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + } + } + } +} diff --git a/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-func].json b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-func].json new file mode 100644 index 0000000..a0d5192 --- /dev/null +++ b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[GET-func].json @@ -0,0 +1,217 @@ +{ + "components": { + "schemas": { + "InnerSchema": { + "properties": { + "stub_int": { + "default": 1, + "title": "Stub Int", + "type": "integer" + }, + "stub_list": { + "items": { + "format": "date", + "type": "string" + }, + "title": "Stub List", + "type": "array" + }, + "stub_str": { + "title": "Stub Str", + "type": "string" + } + }, + "required": [ + "stub_str", + "stub_list" + ], + "title": "InnerSchema", + "type": "object" + }, + "Sample": { + "properties": { + "field": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "required": [ + "field" + ], + "type": "object" + } + } + }, + "info": { + "title": "", + "version": "" + }, + "openapi": "3.0.2", + "paths": { + "/class": { + "get": { + "description": "", + "operationId": "retrieveSample", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "patch": { + "description": "", + "operationId": "partialUpdateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "put": { + "description": "", + "operationId": "updateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + } + }, + "/func": { + "get": { + "description": "", + "operationId": "listsample_views", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "items": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "type": "array" + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + }, + "post": { + "description": "", + "operationId": "createsample_view", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InnerSchema" + } + } + } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + } + } + } +} diff --git a/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[POST-func].json b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[POST-func].json new file mode 100644 index 0000000..a0d5192 --- /dev/null +++ b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[POST-func].json @@ -0,0 +1,217 @@ +{ + "components": { + "schemas": { + "InnerSchema": { + "properties": { + "stub_int": { + "default": 1, + "title": "Stub Int", + "type": "integer" + }, + "stub_list": { + "items": { + "format": "date", + "type": "string" + }, + "title": "Stub List", + "type": "array" + }, + "stub_str": { + "title": "Stub Str", + "type": "string" + } + }, + "required": [ + "stub_str", + "stub_list" + ], + "title": "InnerSchema", + "type": "object" + }, + "Sample": { + "properties": { + "field": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "required": [ + "field" + ], + "type": "object" + } + } + }, + "info": { + "title": "", + "version": "" + }, + "openapi": "3.0.2", + "paths": { + "/class": { + "get": { + "description": "", + "operationId": "retrieveSample", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "patch": { + "description": "", + "operationId": "partialUpdateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "put": { + "description": "", + "operationId": "updateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + } + }, + "/func": { + "get": { + "description": "", + "operationId": "listsample_views", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "items": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "type": "array" + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + }, + "post": { + "description": "", + "operationId": "createsample_view", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InnerSchema" + } + } + } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + } + } + } +} diff --git a/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[PUT-class].json b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[PUT-class].json new file mode 100644 index 0000000..a0d5192 --- /dev/null +++ b/tests/v2/rest_framework/__snapshots__/test_openapi/test_openapi_schema_generators[PUT-class].json @@ -0,0 +1,217 @@ +{ + "components": { + "schemas": { + "InnerSchema": { + "properties": { + "stub_int": { + "default": 1, + "title": "Stub Int", + "type": "integer" + }, + "stub_list": { + "items": { + "format": "date", + "type": "string" + }, + "title": "Stub List", + "type": "array" + }, + "stub_str": { + "title": "Stub Str", + "type": "string" + } + }, + "required": [ + "stub_str", + "stub_list" + ], + "title": "InnerSchema", + "type": "object" + }, + "Sample": { + "properties": { + "field": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "required": [ + "field" + ], + "type": "object" + } + } + }, + "info": { + "title": "", + "version": "" + }, + "openapi": "3.0.2", + "paths": { + "/class": { + "get": { + "description": "", + "operationId": "retrieveSample", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "patch": { + "description": "", + "operationId": "partialUpdateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + }, + "put": { + "description": "", + "operationId": "updateSample", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "application/x-www-form-urlencoded": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + }, + "multipart/form-data": { + "schema": { + "$ref": "#/components/schemas/Sample" + } + } + } + }, + "responses": { + "200": { + "content": { + "application/json": { + "$ref": "#/components/schemas/Sample" + }, + "text/html": { + "$ref": "#/components/schemas/Sample" + } + }, + "description": "" + } + }, + "tags": [ + "class" + ] + } + }, + "/func": { + "get": { + "description": "", + "operationId": "listsample_views", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "items": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + }, + "type": "array" + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + }, + "post": { + "description": "", + "operationId": "createsample_view", + "parameters": [], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/InnerSchema" + } + } + } + }, + "responses": { + "201": { + "content": { + "application/json": { + "schema": { + "items": { + "$ref": "#/components/schemas/InnerSchema" + }, + "type": "array" + } + } + }, + "description": "" + } + }, + "tags": [ + "func" + ] + } + } + } +} diff --git a/tests/v2/rest_framework/test_openapi.py b/tests/v2/rest_framework/test_openapi.py index a0d8bc6..0ce6ca6 100644 --- a/tests/v2/rest_framework/test_openapi.py +++ b/tests/v2/rest_framework/test_openapi.py @@ -15,9 +15,8 @@ ("PUT", "/class"), ], ) -def test_coreapi_schema_generators(request_factory, method, path): +def test_openapi_schema_generators(request_factory, method, path, snapshot_json): urlconf = create_views_urlconf(openapi.AutoSchema) generator = SchemaGenerator(urlconf=urlconf) request = Request(request_factory.generic(method, path)) - openapi_schema = generator.get_schema(request) - assert openapi_schema + assert snapshot_json() == generator.get_schema(request) diff --git a/tests/v2/test_types.py b/tests/v2/test_types.py index 891ad24..7050b51 100644 --- a/tests/v2/test_types.py +++ b/tests/v2/test_types.py @@ -6,9 +6,13 @@ from ..conftest import InnerSchema, SampleDataclass types = pytest.importorskip("django_pydantic_field.v2.types") -skip_unsupported_builtin_subscription = pytest.mark.skipif(sys.version_info < (3, 9), reason="Built-in type subscription supports only in 3.9+") +skip_unsupported_builtin_subscription = pytest.mark.skipif( + sys.version_info < (3, 9), + reason="Built-in type subscription supports only in 3.9+", +) +# fmt: off @pytest.mark.parametrize( "ctor, args, kwargs", [ @@ -26,8 +30,9 @@ (types.SchemaAdapter.from_annotation, [InnerSchema, "stub_int", {"strict": True}], {}), (types.SchemaAdapter.from_annotation, [SampleDataclass, "stub_int"], {}), (types.SchemaAdapter.from_annotation, [SampleDataclass, "stub_int", {"strict": True}], {}), - ] + ], ) +# fmt: on def test_schema_adapter_constructors(ctor, args, kwargs): adapter = ctor(*args, **kwargs) adapter.validate_schema() @@ -52,6 +57,7 @@ def test_schema_adapter_is_bound(): adapter.validate_schema() # Schema should be resolved from bound attribute +# fmt: off @pytest.mark.parametrize( "kwargs, expected_export_kwargs", [ @@ -61,6 +67,7 @@ def test_schema_adapter_is_bound(): ({"strict": True, "from_attributes": False, "on_delete": "CASCADE"}, {"strict": True, "from_attributes": False}), ], ) +# fmt: on def test_schema_adapter_extract_export_kwargs(kwargs, expected_export_kwargs): orig_kwargs = dict(kwargs) assert types.SchemaAdapter.extract_export_kwargs(kwargs) == expected_export_kwargs