From 4c9e87181b9c38d6bc418c91fc6b2bda6966c626 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 26 Mar 2024 00:24:51 +0400 Subject: [PATCH] Implement containers for field info and dataclass serialization --- django_pydantic_field/compat/django.py | 260 +++++++++++++++--- django_pydantic_field/v1/fields.py | 6 +- django_pydantic_field/v2/fields.py | 12 +- django_pydantic_field/v2/types.py | 10 +- tests/test_app/migrations/0001_initial.py | 60 +++- .../test_app/migrations/0002_examplemodel.py | 30 -- .../migrations/0003_samplemodelwithroot.py | 31 --- tests/test_app/models.py | 3 +- 8 files changed, 297 insertions(+), 115 deletions(-) delete mode 100644 tests/test_app/migrations/0002_examplemodel.py delete mode 100644 tests/test_app/migrations/0003_samplemodelwithroot.py diff --git a/django_pydantic_field/compat/django.py b/django_pydantic_field/compat/django.py index ec88e87..d2819d5 100644 --- a/django_pydantic_field/compat/django.py +++ b/django_pydantic_field/compat/django.py @@ -15,17 +15,58 @@ only with `X | Y` syntax. Both cases require a dedicated serializer for migration writes. """ +from __future__ import annotations + +import abc +import dataclasses import sys import types import typing as ty +import typing_extensions as te from django.db.migrations.serializer import BaseSerializer, serializer_factory from django.db.migrations.writer import MigrationWriter +from pydantic.fields import FieldInfo +from pydantic_core import PydanticUndefined +from .pydantic import PYDANTIC_V1 from .typing import get_args, get_origin +try: + from pydantic._internal._repr import Representation + from pydantic.fields import _DefaultValues as FieldInfoDefaultValues + from pydantic_core import PydanticUndefined +except ImportError: + # Assuming this is a Pydantic v1 + from pydantic.fields import Undefined as PydanticUndefined # type: ignore[attr-defined, no-redef] + from pydantic.utils import Representation # type: ignore[no-redef] + + FieldInfoDefaultValues = FieldInfo.__field_constraints__ # type: ignore[attr-defined] + + +class BaseContainer(abc.ABC): + __slot__ = () + + @classmethod + def unwrap(cls, value): + if isinstance(value, BaseContainer) and type(value) is not BaseContainer: + return value.unwrap(value) + return value + + def __eq__(self, other): + if isinstance(other, self.__class__): + return all(getattr(self, attr) == getattr(other, attr) for attr in self.__slots__) + return NotImplemented + + def __str__(self): + return repr(self.unwrap(self)) + + def __repr__(self): + attrs = tuple(getattr(self, attr) for attr in self.__slots__) + return f"{self.__class__.__name__}{attrs}" -class GenericContainer: + +class GenericContainer(BaseContainer): __slots__ = "origin", "args" def __init__(self, origin, args: tuple = ()): @@ -33,67 +74,169 @@ def __init__(self, origin, args: tuple = ()): self.args = args @classmethod - def wrap(cls, typ_): - if isinstance(typ_, GenericTypes): - wrapped_args = tuple(map(cls.wrap, get_args(typ_))) - return cls(get_origin(typ_), wrapped_args) - return typ_ + def wrap(cls, value): + if isinstance(value, GenericTypes): + wrapped_args = tuple(map(cls.wrap, get_args(value))) + return cls(get_origin(value), wrapped_args) + if isinstance(value, FieldInfo): + return FieldInfoContainer.wrap(value) + return value @classmethod - def unwrap(cls, type_): - if not isinstance(type_, cls): - return type_ + def unwrap(cls, value): + if not isinstance(value, cls): + return value - if not type_.args: - return type_.origin + if not value.args: + return value.origin - unwrapped_args = tuple(map(cls.unwrap, type_.args)) + unwrapped_args = tuple(map(BaseContainer.unwrap, value.args)) try: # This is a fallback for Python < 3.8, please be careful with that - return type_.origin[unwrapped_args] + return value.origin[unwrapped_args] except TypeError: - return GenericAlias(type_.origin, unwrapped_args) + return GenericAlias(value.origin, unwrapped_args) - def __repr__(self): - return repr(self.unwrap(self)) + def __eq__(self, other): + if isinstance(other, GenericTypes): + return self == self.wrap(other) + return super().__eq__(other) + + +class DataclassContainer(BaseContainer): + __slots__ = "datacls", "kwargs" + + def __init__(self, datacls: type, kwargs: ty.Dict[str, ty.Any]): + self.datacls = datacls + self.kwargs = kwargs + + @classmethod + def wrap(cls, value): + if cls._is_dataclass_instance(value): + return cls(type(value), dataclasses.asdict(value)) + if isinstance(value, GenericTypes): + return GenericContainer.wrap(value) + return value + + @classmethod + def unwrap(cls, value): + if isinstance(value, cls): + return value.datacls(**value.kwargs) + return value - __str__ = __repr__ + @staticmethod + def _is_dataclass_instance(obj: ty.Any): + return dataclasses.is_dataclass(obj) and not isinstance(obj, type) def __eq__(self, other): - if isinstance(other, self.__class__): - return self.origin == other.origin and self.args == other.args - if isinstance(other, GenericTypes): + if self._is_dataclass_instance(other): return self == self.wrap(other) - return NotImplemented + return super().__eq__(other) + +class FieldInfoContainer(BaseContainer): + __slots__ = "origin", "metadata", "kwargs" + + def __init__(self, origin, metadata, kwargs): + self.origin = origin + self.metadata = metadata + self.kwargs = kwargs + + @classmethod + def wrap(cls, field: FieldInfo): + if not isinstance(field, FieldInfo): + return field -class GenericSerializer(BaseSerializer): - value: GenericContainer + # `getattr` is important to preserve compatibility with Pydantic v1 + origin = GenericContainer.wrap(getattr(field, "annotation", None)) + metadata = getattr(field, "metadata", ()) + metadata = tuple(map(DataclassContainer.wrap, metadata)) + + kwargs = dict(cls._iter_field_attrs(field)) + if PYDANTIC_V1: + kwargs.update(kwargs.pop("extra", {})) + + return cls(origin, metadata, kwargs) + + @classmethod + def unwrap(cls, value): + if not isinstance(value, cls): + return value + if PYDANTIC_V1: + return FieldInfo(**value.kwargs) + + origin = GenericContainer.unwrap(value.origin) + metadata = tuple(map(BaseContainer.unwrap, value.metadata)) + try: + annotated_args = (origin, *metadata) + annotation = te.Annotated[annotated_args] + except AttributeError: + annotation = None + + return FieldInfo(annotation=annotation, **value.kwargs) + + def __eq__(self, other): + if isinstance(other, FieldInfo): + return self == self.wrap(other) + return super().__eq__(other) + + @staticmethod + def _iter_field_attrs(field: FieldInfo): + available_attrs = set(field.__slots__) - {"annotation", "metadata", "_attributes_set"} + + for attr in available_attrs: + attr_value = getattr(field, attr) + if attr_value is not PydanticUndefined and attr_value != FieldInfoDefaultValues.get(attr): + yield attr, getattr(field, attr) + + @staticmethod + def _wrap_metadata_object(obj): + return DataclassContainer.wrap(obj) + + +class BaseContainerSerializer(BaseSerializer): + value: BaseContainer def serialize(self): - value = self.value + tp_repr, imports = serializer_factory(type(self.value)).serialize() + attrs = [] + + for attr in self._iter_container_attrs(): + attr_repr, attr_imports = serializer_factory(attr).serialize() + attrs.append(attr_repr) + imports.update(attr_imports) + + attrs_repr = ", ".join(attrs) + return f"{tp_repr}({attrs_repr})", imports + + def _iter_container_attrs(self): + container = self.value + for attr in container.__slots__: + yield getattr(container, attr) - tp_repr, imports = serializer_factory(type(value)).serialize() - orig_repr, orig_imports = serializer_factory(value.origin).serialize() - imports.update(orig_imports) - args = [] - for arg in value.args: - arg_repr, arg_imports = serializer_factory(arg).serialize() - args.append(arg_repr) - imports.update(arg_imports) +class DataclassContainerSerializer(BaseSerializer): + value: DataclassContainer - if args: - args_repr = ", ".join(args) - generic_repr = "%s(%s, (%s,))" % (tp_repr, orig_repr, args_repr) - else: - generic_repr = "%s(%s)" % (tp_repr, orig_repr) + def serialize(self): + tp_repr, imports = serializer_factory(self.value.datacls).serialize() + + kwarg_pairs = [] + for arg, value in self.value.kwargs.items(): + value_repr, value_imports = serializer_factory(value).serialize() + kwarg_pairs.append(f"{arg}={value_repr}") + imports.update(value_imports) - return generic_repr, imports + kwargs_repr = ", ".join(kwarg_pairs) + return f"{tp_repr}({kwargs_repr})", imports class TypingSerializer(BaseSerializer): def serialize(self): + value = GenericContainer.wrap(self.value) + if isinstance(value, GenericContainer): + return serializer_factory(value).serialize() + orig_module = self.value.__module__ orig_repr = repr(self.value) @@ -103,6 +246,34 @@ def serialize(self): return orig_repr, {f"import {orig_module}"} +class FieldInfoSerializer(BaseSerializer): + value: FieldInfo + + def serialize(self): + container = FieldInfoContainer.wrap(self.value) + return serializer_factory(container).serialize() + + +class RepresentationSerializer(BaseSerializer): + value: Representation + + def serialize(self): + tp_repr, imports = serializer_factory(type(self.value)).serialize() + repr_args = [] + + for arg_name, arg_value in self.value.__repr_args__(): + arg_value_repr, arg_value_imports = serializer_factory(arg_value).serialize() + imports.update(arg_value_imports) + + if arg_name is None: + repr_args.append(arg_value_repr) + else: + repr_args.append(f"{arg_name}={arg_value_repr}") + + final_args_repr = ", ".join(repr_args) + return f"{tp_repr}({final_args_repr})" + + if sys.version_info >= (3, 9): GenericAlias = types.GenericAlias GenericTypes: ty.Tuple[ty.Any, ...] = ( @@ -117,7 +288,18 @@ def serialize(self): GenericTypes = GenericAlias, type(ty.List) # noqa -MigrationWriter.register_serializer(GenericContainer, GenericSerializer) +# BaseContainerSerializer *must be* registered after all specialized container serializers +MigrationWriter.register_serializer(DataclassContainer, DataclassContainerSerializer) +MigrationWriter.register_serializer(BaseContainer, BaseContainerSerializer) + +# Pydantic-specific datastructures serializers +MigrationWriter.register_serializer(FieldInfo, FieldInfoSerializer) +MigrationWriter.register_serializer(Representation, RepresentationSerializer) + +# Typing serializers +for type_ in GenericTypes: + MigrationWriter.register_serializer(type_, TypingSerializer) + MigrationWriter.register_serializer(ty.ForwardRef, TypingSerializer) MigrationWriter.register_serializer(type(ty.Union), TypingSerializer) # type: ignore diff --git a/django_pydantic_field/v1/fields.py b/django_pydantic_field/v1/fields.py index df6ec70..24a7af0 100644 --- a/django_pydantic_field/v1/fields.py +++ b/django_pydantic_field/v1/fields.py @@ -10,7 +10,7 @@ from django.db.models.fields.json import JSONField from django.db.models.query_utils import DeferredAttribute -from django_pydantic_field.compat.django import GenericContainer +from django_pydantic_field.compat.django import BaseContainer, GenericContainer from . import base, forms, utils @@ -44,7 +44,7 @@ class PydanticSchemaField(JSONField, t.Generic[base.ST]): def __init__( self, *args, - schema: t.Union[t.Type["base.ST"], "GenericContainer", "t.ForwardRef", str, None] = None, + schema: t.Union[t.Type["base.ST"], "BaseContainer", "t.ForwardRef", str, None] = None, config: t.Optional["base.ConfigType"] = None, **kwargs, ): @@ -137,7 +137,7 @@ def value_to_string(self, obj): return self.get_prep_value(value) def _resolve_schema(self, schema): - schema = t.cast(t.Type["base.ST"], GenericContainer.unwrap(schema)) + schema = t.cast(t.Type["base.ST"], BaseContainer.unwrap(schema)) self.schema = schema if schema is not None: diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 5c0fe98..ee9d615 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -1,9 +1,9 @@ from __future__ import annotations import typing as ty -import typing_extensions as te import pydantic +import typing_extensions as te from django.core import checks, exceptions from django.core.serializers.json import DjangoJSONEncoder from django.db.models.expressions import BaseExpression, Col, Value @@ -13,7 +13,7 @@ from django.db.models.query_utils import DeferredAttribute from django_pydantic_field.compat import deprecation -from django_pydantic_field.compat.django import GenericContainer +from django_pydantic_field.compat.django import BaseContainer, GenericContainer from . import forms, types @@ -70,7 +70,7 @@ class PydanticSchemaField(JSONField, ty.Generic[types.ST]): def __init__( self, *args, - schema: type[types.ST] | GenericContainer | ty.ForwardRef | str | None = None, + schema: type[types.ST] | BaseContainer | ty.ForwardRef | str | None = None, config: pydantic.ConfigDict | None = None, **kwargs, ): @@ -78,7 +78,7 @@ def __init__( self.export_kwargs = export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) super().__init__(*args, **kwargs) - self.schema = GenericContainer.unwrap(schema) + self.schema = BaseContainer.unwrap(schema) self.config = config self.adapter = types.SchemaAdapter(schema, config, None, self.get_attname(), self.null, **export_kwargs) @@ -116,7 +116,7 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: f"Please consider using field annotation syntax, e.g. `{annot_hint} = SchemaField(...)`; " "or a fallback to `pydantic.RootModel` with annotation instead." ) - performed_checks.append(checks.Warning(message, obj=self, hint=hint, id="pydantic.W004")) + performed_checks.append(checks.Error(message, obj=self, hint=hint, id="pydantic.E004")) try: # Test that the schema could be resolved in runtime, even if it contains forward references. @@ -138,7 +138,7 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: schema_default = self.get_default() if schema_default is None: # If the default value is not set, try to get the default value from the schema. - prep_value = self.adapter.type_adapter.get_default_value() + prep_value = self.adapter.get_default_value() if prep_value is not None: prep_value = prep_value.value schema_default = prep_value diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 72dcb2b..f237570 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -6,7 +6,7 @@ import pydantic import typing_extensions as te -from django_pydantic_field.compat.django import GenericContainer +from django_pydantic_field.compat.django import BaseContainer, GenericContainer from django_pydantic_field.compat.functools import cached_property from . import utils @@ -60,7 +60,7 @@ def __init__( allow_null: bool | None = None, **export_kwargs: ty.Unpack[ExportKwargs], ): - self.schema = GenericContainer.unwrap(schema) + self.schema = BaseContainer.unwrap(schema) self.config = config self.parent_type = parent_type self.attname = attname @@ -149,6 +149,12 @@ def json_schema(self) -> dict[str, ty.Any]: by_alias = self.export_kwargs.get("by_alias", True) return self.type_adapter.json_schema(by_alias=by_alias) + def get_default_value(self) -> ST | None: + wrapped = self.type_adapter.get_default_value() + if wrapped is not None: + return wrapped.value + return None + def _prepare_schema(self) -> type[ST]: """Prepare the schema for the adapter. diff --git a/tests/test_app/migrations/0001_initial.py b/tests/test_app/migrations/0001_initial.py index 80a3436..e83735a 100644 --- a/tests/test_app/migrations/0001_initial.py +++ b/tests/test_app/migrations/0001_initial.py @@ -1,19 +1,38 @@ -# Generated by Django 3.2.23 on 2023-11-21 14:37 +# Generated by Django 5.0.3 on 2024-03-25 21:17 +import annotated_types import django.core.serializers.json -from django.db import migrations, models import django_pydantic_field.compat.django import django_pydantic_field.fields import tests.conftest import tests.test_app.models +import typing +import typing_extensions +from django.db import migrations, models class Migration(migrations.Migration): + initial = True dependencies = [] operations = [ + migrations.CreateModel( + name="ExampleModel", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ( + "example_field", + django_pydantic_field.fields.PydanticSchemaField( + config=None, + default={"count": 1}, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.test_app.models.ExampleSchema, + ), + ), + ], + ), migrations.CreateModel( name="SampleForwardRefModel", fields=[ @@ -73,4 +92,41 @@ class Migration(migrations.Migration): ), ], ), + migrations.CreateModel( + name="SampleModelAnnotated", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ( + "annotated_field", + django_pydantic_field.fields.PydanticSchemaField( + config=None, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=django_pydantic_field.compat.django.GenericContainer( + typing_extensions.Annotated, + ( + django_pydantic_field.compat.django.GenericContainer(typing.Union, (int, float)), + django_pydantic_field.compat.django.FieldInfoContainer( + None, (annotated_types.Gt(gt=0),), {"title": "Annotated Field"} + ), + ), + ), + ), + ), + ], + ), + migrations.CreateModel( + name="SampleModelWithRoot", + fields=[ + ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), + ( + "root_field", + django_pydantic_field.fields.PydanticSchemaField( + config=None, + default=list, + encoder=django.core.serializers.json.DjangoJSONEncoder, + schema=tests.test_app.models.RootSchema, + ), + ), + ], + ), ] diff --git a/tests/test_app/migrations/0002_examplemodel.py b/tests/test_app/migrations/0002_examplemodel.py deleted file mode 100644 index c39b5b4..0000000 --- a/tests/test_app/migrations/0002_examplemodel.py +++ /dev/null @@ -1,30 +0,0 @@ -# Generated by Django 5.0.1 on 2024-01-27 18:30 - -import django.core.serializers.json -import django_pydantic_field.fields -import tests.test_app.models -from django.db import migrations, models - - -class Migration(migrations.Migration): - dependencies = [ - ("test_app", "0001_initial"), - ] - - operations = [ - migrations.CreateModel( - name="ExampleModel", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ( - "example_field", - django_pydantic_field.fields.PydanticSchemaField( - config=None, - default={"count": 1}, - encoder=django.core.serializers.json.DjangoJSONEncoder, - schema=tests.test_app.models.ExampleSchema, - ), - ), - ], - ), - ] diff --git a/tests/test_app/migrations/0003_samplemodelwithroot.py b/tests/test_app/migrations/0003_samplemodelwithroot.py deleted file mode 100644 index 004f6db..0000000 --- a/tests/test_app/migrations/0003_samplemodelwithroot.py +++ /dev/null @@ -1,31 +0,0 @@ -# Generated by Django 5.0.1 on 2024-03-11 22:29 - -import django.core.serializers.json -import django_pydantic_field.fields -import tests.test_app.models -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ("test_app", "0002_examplemodel"), - ] - - operations = [ - migrations.CreateModel( - name="SampleModelWithRoot", - fields=[ - ("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")), - ( - "root_field", - django_pydantic_field.fields.PydanticSchemaField( - config=None, - default=list, - encoder=django.core.serializers.json.DjangoJSONEncoder, - schema=tests.test_app.models.RootSchema, - ), - ), - ], - ), - ] diff --git a/tests/test_app/models.py b/tests/test_app/models.py index 85464ec..610eb9b 100644 --- a/tests/test_app/models.py +++ b/tests/test_app/models.py @@ -55,5 +55,4 @@ class SampleModelWithRoot(models.Model): class SampleModelAnnotated(models.Model): - annotated_field: te.Annotated[t.Union[int, float], pydantic.Field(gt=0)] = SchemaField() - arg_field = SchemaField(schema=te.Annotated[t.Union[int, float], pydantic.Field(lt=0)]) + annotated_field: te.Annotated[t.Union[int, float], pydantic.Field(gt=0, title="Annotated Field")] = SchemaField()