diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0be55cd..6a48b0c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,11 +11,11 @@ repos: - id: isort args: [ "--settings-path", "./pyproject.toml", "--filter-files" ] files: "^django_pydantic_field/" - exclude: ^.*\b(\.pytest_cache|\.venv|venv|tests)\b.*$ + exclude: ^.*\b(\.pytest_cache|\.venv|venv).*\b.*$ - repo: https://github.com/psf/black rev: 24.3.0 hooks: - id: black args: [ "--config", "./pyproject.toml" ] files: "^django_pydantic_field/" - exclude: ^.*\b(\.pytest_cache|\.venv|venv|tests)\b.*$ + exclude: ^.*\b(\.pytest_cache|\.venv|venv).*\b.*$ diff --git a/README.md b/README.md index c05c9db..4f92714 100644 --- a/README.md +++ b/README.md @@ -82,8 +82,21 @@ class Bar(pydantic.BaseModel): slug: str = "foo_bar" ``` -In this case, exact type resolution will be postponed until initial access to the field. -Usually this happens on the first instantiation of the model. +**Pydantic v2 specific**: this behaviour is achieved by the fact that the exact type resolution will be postponed the until initial access to the field. Usually this happens on the first instantiation of the model. + +To reduce the number of runtime errors related to the postponed resolution, the field itself performs a few checks against the passed schema during `./manage.py check` command invocation, and consequently, in `runserver` and `makemigrations` commands. + +Here's the list of currently implemented checks: +- `pydantic.E001`: The passed schema could not be resolved. Most likely it does not exist in the scope of the defined field. +- `pydantic.E002`: `default=` value could not be serialized to the schema. +- `pydantic.W003`: The default value could not be reconstructed to the schema due to `include`/`exclude` configuration. + + +### `typing.Annotated` support +As of `v0.3.5`, SchemaField also supports `typing.Annotated[...]` expressions, both through `schema=` attribute or field annotation syntax; though I find the `schema=typing.Annotated[...]` variant highly discouraged. + +**The current limitation** is not in the field itself, but in possible `Annotated` metadata -- practically it can contain anything, and Django migrations serializers could refuse to write it to migrations. +For most relevant types in context of Pydantic, I wrote the specific serializers (particularly for `pydantic.FieldInfo`, `pydantic.Representation` and raw dataclasses), thus it should cover the majority of `Annotated` use cases. ## Django Forms support diff --git a/django_pydantic_field/compat/django.py b/django_pydantic_field/compat/django.py index ec88e87..d27a49e 100644 --- a/django_pydantic_field/compat/django.py +++ b/django_pydantic_field/compat/django.py @@ -13,19 +13,68 @@ `typing.Union` and its special forms, like `typing.Optional`, have its own inheritance chain. Moreover, `types.UnionType`, introduced in 3.10, do not allow explicit type construction, only with `X | Y` syntax. Both cases require a dedicated serializer for migration writes. + +[typing.Annotated](https://peps.python.org/pep-0593/) + `typing.Annotated` syntax is supported for direct field annotations, though I find it highly discouraged + while using in `schema=` attribute. + The limitation with `Annotated` types is that supplied metadata could be actually of any type. + In case of Pydantic, it is a `FieldInfo` objects, which are not compatible with Django Migrations serializers. + This module provides a few containers (`FieldInfoContainer` and `DataclassContainer`), + which allow Model serializers to work. """ +from __future__ import annotations + +import abc +import dataclasses import sys import types import typing as ty +import typing_extensions as te from django.db.migrations.serializer import BaseSerializer, serializer_factory from django.db.migrations.writer import MigrationWriter +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined +from .pydantic import PYDANTIC_V1 from .typing import get_args, get_origin +try: + from pydantic._internal._repr import Representation + from pydantic.fields import _DefaultValues as FieldInfoDefaultValues + from pydantic_core import PydanticUndefined +except ImportError: + # Assuming this is a Pydantic v1 + from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined, no-redef] + from pydantic.utils import Representation # type: ignore[no-redef] + + FieldInfoDefaultValues = FieldInfo.__field_constraints__ # type: ignore[attr-defined] + + +class BaseContainer(abc.ABC): + __slot__ = () -class GenericContainer: + @classmethod + def unwrap(cls, value): + if isinstance(value, BaseContainer) and type(value) is not BaseContainer: + return value.unwrap(value) + return value + + def __eq__(self, other): + if isinstance(other, self.__class__): + return all(getattr(self, attr) == getattr(other, attr) for attr in self.__slots__) + return NotImplemented + + def __str__(self): + return repr(self.unwrap(self)) + + def __repr__(self): + attrs = tuple(getattr(self, attr) for attr in self.__slots__) + return f"{self.__class__.__name__}{attrs}" + + +class GenericContainer(BaseContainer): __slots__ = "origin", "args" def __init__(self, origin, args: tuple = ()): @@ -33,67 +82,183 @@ def __init__(self, origin, args: tuple = ()): self.args = args @classmethod - def wrap(cls, typ_): - if isinstance(typ_, GenericTypes): - wrapped_args = tuple(map(cls.wrap, get_args(typ_))) - return cls(get_origin(typ_), wrapped_args) - return typ_ + def wrap(cls, value): + # NOTE: due to a bug in typing_extensions for `3.8`, Annotated aliases are handled explicitly + if isinstance(value, AnnotatedAlias): + args = (value.__origin__, *value.__metadata__) + wrapped_args = tuple(map(cls.wrap, args)) + return cls(te.Annotated, wrapped_args) + if isinstance(value, GenericTypes): + wrapped_args = tuple(map(cls.wrap, get_args(value))) + return cls(get_origin(value), wrapped_args) + if isinstance(value, FieldInfo): + return FieldInfoContainer.wrap(value) + return value @classmethod - def unwrap(cls, type_): - if not isinstance(type_, cls): - return type_ + def unwrap(cls, value): + if not isinstance(value, cls): + return value + + if PYDANTIC_V1: + origin = get_origin(BaseContainer.unwrap(value.origin)) or value.origin + else: + origin = value.origin - if not type_.args: - return type_.origin + if not value.args: + return origin - unwrapped_args = tuple(map(cls.unwrap, type_.args)) + unwrapped_args = tuple(map(BaseContainer.unwrap, value.args)) try: # This is a fallback for Python < 3.8, please be careful with that - return type_.origin[unwrapped_args] + return origin[unwrapped_args] except TypeError: - return GenericAlias(type_.origin, unwrapped_args) + return GenericAlias(origin, unwrapped_args) + + def __eq__(self, other): + if isinstance(other, GenericTypes): + return self == self.wrap(other) + return super().__eq__(other) - def __repr__(self): - return repr(self.unwrap(self)) - __str__ = __repr__ +class DataclassContainer(BaseContainer): + __slots__ = "datacls", "kwargs" + + def __init__(self, datacls: type, kwargs: ty.Dict[str, ty.Any]): + self.datacls = datacls + self.kwargs = kwargs + + @classmethod + def wrap(cls, value): + if cls._is_dataclass_instance(value): + return cls(type(value), dataclasses.asdict(value)) + if isinstance(value, GenericTypes): + return GenericContainer.wrap(value) + return value + + @classmethod + def unwrap(cls, value): + if isinstance(value, cls): + return value.datacls(**value.kwargs) + return value + + @staticmethod + def _is_dataclass_instance(obj: ty.Any): + return dataclasses.is_dataclass(obj) and not isinstance(obj, type) def __eq__(self, other): - if isinstance(other, self.__class__): - return self.origin == other.origin and self.args == other.args - if isinstance(other, GenericTypes): + if self._is_dataclass_instance(other): return self == self.wrap(other) - return NotImplemented + return super().__eq__(other) + + +class FieldInfoContainer(BaseContainer): + __slots__ = "origin", "metadata", "kwargs" + + def __init__(self, origin, metadata, kwargs): + self.origin = origin + self.metadata = metadata + self.kwargs = kwargs + + @classmethod + def wrap(cls, field: FieldInfo): + if not isinstance(field, FieldInfo): + return field + + # `getattr` is important to preserve compatibility with Pydantic v1 + metadata = getattr(field, "metadata", ()) + origin = getattr(field, "annotation", None) + if origin is type(None): + origin = None + + origin = GenericContainer.wrap(origin) + metadata = tuple(map(DataclassContainer.wrap, metadata)) + + kwargs = dict(cls._iter_field_attrs(field)) + if PYDANTIC_V1: + kwargs.update(kwargs.pop("extra", {})) + + return cls(origin, metadata, kwargs) + @classmethod + def unwrap(cls, value): + if not isinstance(value, cls): + return value + if PYDANTIC_V1: + return FieldInfo(**value.kwargs) + + origin = GenericContainer.unwrap(value.origin) + metadata = tuple(map(BaseContainer.unwrap, value.metadata)) + try: + annotated_args = (origin, *metadata) + annotation = te.Annotated[annotated_args] + except TypeError: + annotation = None -class GenericSerializer(BaseSerializer): - value: GenericContainer + return FieldInfo(annotation=annotation, **value.kwargs) + + def __eq__(self, other): + if isinstance(other, FieldInfo): + return self == self.wrap(other) + return super().__eq__(other) + + @staticmethod + def _iter_field_attrs(field: FieldInfo): + available_attrs = set(field.__slots__) - {"annotation", "metadata", "_attributes_set"} + + for attr in available_attrs: + attr_value = getattr(field, attr) + if attr_value is not PydanticUndefined and attr_value != FieldInfoDefaultValues.get(attr): + yield attr, getattr(field, attr) + + @staticmethod + def _wrap_metadata_object(obj): + return DataclassContainer.wrap(obj) + + +class BaseContainerSerializer(BaseSerializer): + value: BaseContainer def serialize(self): - value = self.value + tp_repr, imports = serializer_factory(type(self.value)).serialize() + attrs = [] - tp_repr, imports = serializer_factory(type(value)).serialize() - orig_repr, orig_imports = serializer_factory(value.origin).serialize() - imports.update(orig_imports) + for attr in self._iter_container_attrs(): + attr_repr, attr_imports = serializer_factory(attr).serialize() + attrs.append(attr_repr) + imports.update(attr_imports) - args = [] - for arg in value.args: - arg_repr, arg_imports = serializer_factory(arg).serialize() - args.append(arg_repr) - imports.update(arg_imports) + attrs_repr = ", ".join(attrs) + return f"{tp_repr}({attrs_repr})", imports - if args: - args_repr = ", ".join(args) - generic_repr = "%s(%s, (%s,))" % (tp_repr, orig_repr, args_repr) - else: - generic_repr = "%s(%s)" % (tp_repr, orig_repr) + def _iter_container_attrs(self): + container = self.value + for attr in container.__slots__: + yield getattr(container, attr) + + +class DataclassContainerSerializer(BaseSerializer): + value: DataclassContainer - return generic_repr, imports + def serialize(self): + tp_repr, imports = serializer_factory(self.value.datacls).serialize() + + kwarg_pairs = [] + for arg, value in self.value.kwargs.items(): + value_repr, value_imports = serializer_factory(value).serialize() + kwarg_pairs.append(f"{arg}={value_repr}") + imports.update(value_imports) + + kwargs_repr = ", ".join(kwarg_pairs) + return f"{tp_repr}({kwargs_repr})", imports class TypingSerializer(BaseSerializer): def serialize(self): + value = GenericContainer.wrap(self.value) + if isinstance(value, GenericContainer): + return serializer_factory(value).serialize() + orig_module = self.value.__module__ orig_repr = repr(self.value) @@ -103,6 +268,36 @@ def serialize(self): return orig_repr, {f"import {orig_module}"} +class FieldInfoSerializer(BaseSerializer): + value: FieldInfo + + def serialize(self): + container = FieldInfoContainer.wrap(self.value) + return serializer_factory(container).serialize() + + +class RepresentationSerializer(BaseSerializer): + value: Representation + + def serialize(self): + tp_repr, imports = serializer_factory(type(self.value)).serialize() + repr_args = [] + + for arg_name, arg_value in self.value.__repr_args__(): + arg_value_repr, arg_value_imports = serializer_factory(arg_value).serialize() + imports.update(arg_value_imports) + + if arg_name is None: + repr_args.append(arg_value_repr) + else: + repr_args.append(f"{arg_name}={arg_value_repr}") + + final_args_repr = ", ".join(repr_args) + return f"{tp_repr}({final_args_repr})" + + +AnnotatedAlias = te._AnnotatedAlias + if sys.version_info >= (3, 9): GenericAlias = types.GenericAlias GenericTypes: ty.Tuple[ty.Any, ...] = ( @@ -117,7 +312,18 @@ def serialize(self): GenericTypes = GenericAlias, type(ty.List) # noqa -MigrationWriter.register_serializer(GenericContainer, GenericSerializer) +# BaseContainerSerializer *must be* registered after all specialized container serializers +MigrationWriter.register_serializer(DataclassContainer, DataclassContainerSerializer) +MigrationWriter.register_serializer(BaseContainer, BaseContainerSerializer) + +# Pydantic-specific datastructures serializers +MigrationWriter.register_serializer(FieldInfo, FieldInfoSerializer) +MigrationWriter.register_serializer(Representation, RepresentationSerializer) + +# Typing serializers +for type_ in GenericTypes: + MigrationWriter.register_serializer(type_, TypingSerializer) + MigrationWriter.register_serializer(ty.ForwardRef, TypingSerializer) MigrationWriter.register_serializer(type(ty.Union), TypingSerializer) # type: ignore diff --git a/django_pydantic_field/fields.pyi b/django_pydantic_field/fields.pyi index a484932..886e520 100644 --- a/django_pydantic_field/fields.pyi +++ b/django_pydantic_field/fields.pyi @@ -78,7 +78,16 @@ class _DeprecatedSchemaFieldKwargs(_SchemaFieldKwargs, total=False): @ty.overload def SchemaField( - schema: ty.Type[ST] | None | ty.ForwardRef = ..., + schema: ty.Type[ST | None] | ty.ForwardRef = ..., + config: ConfigType = ..., + default: OptSchemaT | ty.Callable[[], OptSchemaT] | BaseExpression = ..., + *args, + null: ty.Literal[True], + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> ST | None: ... +@ty.overload +def SchemaField( + schema: te.Annotated[ty.Type[ST | None], ...] = ..., config: ConfigType = ..., default: OptSchemaT | ty.Callable[[], OptSchemaT] | BaseExpression = ..., *args, @@ -95,12 +104,21 @@ def SchemaField( **kwargs: te.Unpack[_SchemaFieldKwargs], ) -> ST: ... @ty.overload +def SchemaField( + schema: te.Annotated[ty.Type[ST], ...] = ..., + config: ConfigType = ..., + default: SchemaT | ty.Callable[[], SchemaT] | BaseExpression = ..., + *args, + null: ty.Literal[False] = ..., + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> ST: ... +@ty.overload @te.deprecated( "Passing `json.dump` kwargs to `SchemaField` is not supported by " "Pydantic 2 and will be removed in the future versions." ) def SchemaField( - schema: ty.Type[ST] | None | ty.ForwardRef = ..., + schema: ty.Type[ST | None] | ty.ForwardRef = ..., config: ConfigType = ..., default: SchemaT | ty.Callable[[], SchemaT] | BaseExpression = ..., *args, diff --git a/django_pydantic_field/v1/fields.py b/django_pydantic_field/v1/fields.py index df6ec70..24a7af0 100644 --- a/django_pydantic_field/v1/fields.py +++ b/django_pydantic_field/v1/fields.py @@ -10,7 +10,7 @@ from django.db.models.fields.json import JSONField from django.db.models.query_utils import DeferredAttribute -from django_pydantic_field.compat.django import GenericContainer +from django_pydantic_field.compat.django import BaseContainer, GenericContainer from . import base, forms, utils @@ -44,7 +44,7 @@ class PydanticSchemaField(JSONField, t.Generic[base.ST]): def __init__( self, *args, - schema: t.Union[t.Type["base.ST"], "GenericContainer", "t.ForwardRef", str, None] = None, + schema: t.Union[t.Type["base.ST"], "BaseContainer", "t.ForwardRef", str, None] = None, config: t.Optional["base.ConfigType"] = None, **kwargs, ): @@ -137,7 +137,7 @@ def value_to_string(self, obj): return self.get_prep_value(value) def _resolve_schema(self, schema): - schema = t.cast(t.Type["base.ST"], GenericContainer.unwrap(schema)) + schema = t.cast(t.Type["base.ST"], BaseContainer.unwrap(schema)) self.schema = schema if schema is not None: diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 5cac0f6..9ca1f0e 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -3,6 +3,7 @@ import typing as ty import pydantic +import typing_extensions as te from django.core import checks, exceptions from django.core.serializers.json import DjangoJSONEncoder from django.db.models.expressions import BaseExpression, Col, Value @@ -12,14 +13,13 @@ from django.db.models.query_utils import DeferredAttribute from django_pydantic_field.compat import deprecation -from django_pydantic_field.compat.django import GenericContainer +from django_pydantic_field.compat.django import BaseContainer, GenericContainer from . import forms, types if ty.TYPE_CHECKING: import json - import typing_extensions as te from django.db.models import Model class _SchemaFieldKwargs(types.ExportKwargs, total=False): @@ -70,7 +70,7 @@ class PydanticSchemaField(JSONField, ty.Generic[types.ST]): def __init__( self, *args, - schema: type[types.ST] | GenericContainer | ty.ForwardRef | str | None = None, + schema: type[types.ST] | BaseContainer | ty.ForwardRef | str | None = None, config: pydantic.ConfigDict | None = None, **kwargs, ): @@ -78,7 +78,7 @@ def __init__( self.export_kwargs = export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) super().__init__(*args, **kwargs) - self.schema = GenericContainer.unwrap(schema) + self.schema = BaseContainer.unwrap(schema) self.config = config self.adapter = types.SchemaAdapter(schema, config, None, self.get_attname(), self.null, **export_kwargs) @@ -129,7 +129,7 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: schema_default = self.get_default() if schema_default is None: # If the default value is not set, try to get the default value from the schema. - prep_value = self.adapter.type_adapter.get_default_value() + prep_value = self.adapter.get_default_value() if prep_value is not None: prep_value = prep_value.value schema_default = prep_value @@ -137,11 +137,11 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: if schema_default is not None: try: # Perform the full round-trip transformation to test the export ability. - self.adapter.validate_python(self.get_prep_value(self.default)) + self.adapter.validate_python(self.get_prep_value(schema_default)) except pydantic.ValidationError as exc: message = f"Export arguments may lead to data integrity problems. Pydantic error: \n{str(exc)}" hint = "Please review `import` and `export` arguments." - performed_checks.append(checks.Warning(message, obj=self, hint=hint, id="pydantic.E003")) + performed_checks.append(checks.Warning(message, obj=self, hint=hint, id="pydantic.W003")) return performed_checks @@ -223,6 +223,28 @@ def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None: return self.transform(col, *args, **kwargs) +@ty.overload +def SchemaField( + schema: ty.Annotated[type[types.ST | None], ...] = ..., + config: pydantic.ConfigDict = ..., + default: types.SchemaT | ty.Callable[[], types.SchemaT | None] | BaseExpression | None = ..., + *args, + null: ty.Literal[True], + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> types.ST | None: ... + + +@ty.overload +def SchemaField( + schema: ty.Annotated[type[types.ST], ...] = ..., + config: pydantic.ConfigDict = ..., + default: types.SchemaT | ty.Callable[[], types.SchemaT] | BaseExpression = ..., + *args, + null: ty.Literal[False] = ..., + **kwargs: te.Unpack[_SchemaFieldKwargs], +) -> types.ST: ... + + @ty.overload def SchemaField( schema: type[types.ST | None] | ty.ForwardRef = ..., diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 72dcb2b..f237570 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -6,7 +6,7 @@ import pydantic import typing_extensions as te -from django_pydantic_field.compat.django import GenericContainer +from django_pydantic_field.compat.django import BaseContainer, GenericContainer from django_pydantic_field.compat.functools import cached_property from . import utils @@ -60,7 +60,7 @@ def __init__( allow_null: bool | None = None, **export_kwargs: ty.Unpack[ExportKwargs], ): - self.schema = GenericContainer.unwrap(schema) + self.schema = BaseContainer.unwrap(schema) self.config = config self.parent_type = parent_type self.attname = attname @@ -149,6 +149,12 @@ def json_schema(self) -> dict[str, ty.Any]: by_alias = self.export_kwargs.get("by_alias", True) return self.type_adapter.json_schema(by_alias=by_alias) + def get_default_value(self) -> ST | None: + wrapped = self.type_adapter.get_default_value() + if wrapped is not None: + return wrapped.value + return None + def _prepare_schema(self) -> type[ST]: """Prepare the schema for the adapter. diff --git a/pyproject.toml b/pyproject.toml index 400d91e..a8bf701 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "django-pydantic-field" -version = "0.3.4" +version = "0.3.5" description = "Django JSONField with Pydantic models as a Schema" readme = "README.md" license = { file = "LICENSE" } @@ -90,9 +90,6 @@ include_trailing_comma = true force_alphabetical_sort_within_sections = true force_grid_wrap = 0 use_parentheses = true -skip_glob = [ - "*/migrations/*", -] [tool.black] target-version = ["py38", "py39", "py310", "py311", "py312"] @@ -102,7 +99,6 @@ exclude = ''' \.pytest_cache | \.venv | venv - | migrations )/ ''' @@ -111,7 +107,7 @@ plugins = [ "mypy_django_plugin.main", "mypy_drf_plugin.main" ] -exclude = [".env", "tests"] +exclude = [".env", ".venv", "tests"] enable_incomplete_feature = ["Unpack"] [tool.django-stubs] diff --git a/tests/test_app/migrations/0001_initial.py b/tests/test_app/migrations/0001_initial.py index 80a3436..9e71d9e 100644 --- a/tests/test_app/migrations/0001_initial.py +++ b/tests/test_app/migrations/0001_initial.py @@ -1,19 +1,39 @@ -# Generated by Django 3.2.23 on 2023-11-21 14:37 +# Generated by Django 5.0.3 on 2024-03-25 22:22 +import typing_extensions +import annotated_types import django.core.serializers.json -from django.db import migrations, models import django_pydantic_field.compat.django import django_pydantic_field.fields import tests.conftest import tests.test_app.models +import types +import typing +from django.db import migrations, models class Migration(migrations.Migration): + initial = True dependencies = [] operations = [ + migrations.CreateModel( + name="ExampleModel", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ( + "example_field", + django_pydantic_field.fields.PydanticSchemaField( + config=None, + default={"count": 1}, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.test_app.models.ExampleSchema, + ), + ), + ], + ), migrations.CreateModel( name="SampleForwardRefModel", fields=[ @@ -73,4 +93,57 @@ class Migration(migrations.Migration): ), ], ), + migrations.CreateModel( + name="SampleModelAnnotated", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ( + "annotated_field", + django_pydantic_field.fields.PydanticSchemaField( + config=None, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + typing_extensions.Annotated, + ( + django_pydantic_field.compat.django.GenericContainer(typing.Union, (int, float)), + django_pydantic_field.compat.django.FieldInfoContainer( + None, (annotated_types.Gt(gt=0),), {"title": "Annotated Field"} + ), + ), + ), + ), + ), + ( + "annotated_schema", + django_pydantic_field.fields.PydanticSchemaField( + config=None, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + typing_extensions.Annotated, + ( + django_pydantic_field.compat.django.GenericContainer(typing.Union, (int, float)), + django_pydantic_field.compat.django.FieldInfoContainer( + None, (annotated_types.Gt(gt=0),), {} + ), + ), + ), + ), + ), + ], + ), + migrations.CreateModel( + name="SampleModelWithRoot", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ( + "root_field", + django_pydantic_field.fields.PydanticSchemaField( + config=None, + default=list, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.test_app.models.RootSchema, + ), + ), + ], + ), ] diff --git a/tests/test_app/migrations/0002_examplemodel.py b/tests/test_app/migrations/0002_examplemodel.py deleted file mode 100644 index c39b5b4..0000000 --- a/tests/test_app/migrations/0002_examplemodel.py +++ /dev/null @@ -1,30 +0,0 @@ -# Generated by Django 5.0.1 on 2024-01-27 18:30 - -import django.core.serializers.json -import django_pydantic_field.fields -import tests.test_app.models -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("test_app", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="ExampleModel", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ( - "example_field", - django_pydantic_field.fields.PydanticSchemaField( - config=None, - default={"count": 1}, - encoder=django.core.serializers.json.DjangoJSONEncoder, - schema=tests.test_app.models.ExampleSchema, - ), - ), - ], - ), - ] diff --git a/tests/test_app/migrations/0003_samplemodelwithroot.py b/tests/test_app/migrations/0003_samplemodelwithroot.py deleted file mode 100644 index 004f6db..0000000 --- a/tests/test_app/migrations/0003_samplemodelwithroot.py +++ /dev/null @@ -1,31 +0,0 @@ -# Generated by Django 5.0.1 on 2024-03-11 22:29 - -import django.core.serializers.json -import django_pydantic_field.fields -import tests.test_app.models -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("test_app", "0002_examplemodel"), - ] - - operations = [ - migrations.CreateModel( - name="SampleModelWithRoot", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ( - "root_field", - django_pydantic_field.fields.PydanticSchemaField( - config=None, - default=list, - encoder=django.core.serializers.json.DjangoJSONEncoder, - schema=tests.test_app.models.RootSchema, - ), - ), - ], - ), - ] diff --git a/tests/test_app/models.py b/tests/test_app/models.py index 56fb317..d87ec50 100644 --- a/tests/test_app/models.py +++ b/tests/test_app/models.py @@ -1,4 +1,5 @@ import typing as t +import typing_extensions as te import pydantic from django.db import models @@ -51,3 +52,8 @@ class RootSchema(pydantic.BaseModel): class SampleModelWithRoot(models.Model): root_field = SchemaField(schema=RootSchema, default=list) + + +class SampleModelAnnotated(models.Model): + annotated_field: te.Annotated[t.Union[int, float], pydantic.Field(gt=0, title="Annotated Field")] = SchemaField() + annotated_schema = SchemaField(schema=te.Annotated[t.Union[int, float], pydantic.Field(gt=0)]) diff --git a/tests/v2/rest_framework/test_fields.py b/tests/v2/rest_framework/test_fields.py index 6098da6..2186ad6 100644 --- a/tests/v2/rest_framework/test_fields.py +++ b/tests/v2/rest_framework/test_fields.py @@ -1,7 +1,9 @@ import typing as ty from datetime import date +import pydantic import pytest +import typing_extensions as te from rest_framework import exceptions, serializers from tests.conftest import InnerSchema @@ -12,6 +14,11 @@ class SampleSerializer(serializers.Serializer): field = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + annotated = rest_framework.SchemaField( + schema=te.Annotated[ty.List[InnerSchema], pydantic.Field(alias="annotated_field")], + default=list, + by_alias=True, + ) class SampleModelSerializer(serializers.ModelSerializer): @@ -55,15 +62,16 @@ def test_field_schema_with_custom_config(): def test_serializer_marshalling_with_schema_field(): - existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]} - expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}]} + existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])], "annotated_field": []} + expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}], "annotated": []} + expected_validated_data = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])], "annotated": []} serializer = SampleSerializer(instance=existing_instance) assert serializer.data == expected_data serializer = SampleSerializer(data=expected_data) serializer.is_valid(raise_exception=True) - assert serializer.validated_data == existing_instance + assert serializer.validated_data == expected_validated_data def test_model_serializer_marshalling_with_schema_field(): diff --git a/tests/v2/test_forms.py b/tests/v2/test_forms.py index f7fa3d1..64b1b33 100644 --- a/tests/v2/test_forms.py +++ b/tests/v2/test_forms.py @@ -4,6 +4,7 @@ import django import pydantic import pytest +import typing_extensions as te from django.core.exceptions import ValidationError from django.forms import Form, modelform_factory @@ -146,3 +147,9 @@ def test_form_field_export_kwargs(export_kwargs): field = forms.SchemaField(InnerSchema, required=False, **export_kwargs) value = InnerSchema.model_validate({"stub_str": "abc", "stub_list": ["1970-01-01"]}) assert field.prepare_value(value) + + +def test_annotated_acceptance(): + field = forms.SchemaField(te.Annotated[InnerSchema, pydantic.Field(title="Inner Schema")]) + value = InnerSchema.model_validate({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + assert field.prepare_value(value)