From a3f25a8e67f9dc0bc24a0c9430f3b6092819eb0a Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Tue, 7 Nov 2023 23:04:28 +0400 Subject: [PATCH] Django form field impl for V2. --- django_pydantic_field/v1/__init__.py | 5 + django_pydantic_field/v2/__init__.py | 5 + django_pydantic_field/v2/fields.py | 24 +++-- django_pydantic_field/v2/forms.py | 80 +++++++++++++- django_pydantic_field/v2/types.py | 58 +++++++--- django_pydantic_field/v2/utils.py | 6 +- .../{test_form_field.py => v1/test_forms.py} | 8 +- tests/v2/__init__.py | 0 tests/v2/test_forms.py | 101 ++++++++++++++++++ 9 files changed, 257 insertions(+), 30 deletions(-) rename tests/{test_form_field.py => v1/test_forms.py} (93%) create mode 100644 tests/v2/__init__.py create mode 100644 tests/v2/test_forms.py diff --git a/django_pydantic_field/v1/__init__.py b/django_pydantic_field/v1/__init__.py index 7746f2c..91323b9 100644 --- a/django_pydantic_field/v1/__init__.py +++ b/django_pydantic_field/v1/__init__.py @@ -1 +1,6 @@ +from ..compat.pydantic import PYDANTIC_V1 + +if not PYDANTIC_V1: + raise ImportError("django_pydantic_field.v1 package is only compatible with Pydantic v1") + from .fields import * diff --git a/django_pydantic_field/v2/__init__.py b/django_pydantic_field/v2/__init__.py index 9e72cfd..91b48f0 100644 --- a/django_pydantic_field/v2/__init__.py +++ b/django_pydantic_field/v2/__init__.py @@ -1 +1,6 @@ +from ..compat.pydantic import PYDANTIC_V2 + +if not PYDANTIC_V2: + raise ImportError("django_pydantic_field.v2 package is only compatible with Pydantic v2") + from .fields import SchemaField as SchemaField diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 3b59439..96bf621 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -12,16 +12,13 @@ from django.db.models.lookups import Transform from django.db.models.query_utils import DeferredAttribute -from . import types +from . import types, forms, utils from ..compat.django import GenericContainer class SchemaAttribute(DeferredAttribute): field: PydanticSchemaField - def __set_name__(self, owner, name): - self.field.adapter.bind(owner, name) - def __set__(self, obj, value): obj.__dict__[self.field.attname] = self.field.to_python(value) @@ -71,14 +68,14 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]: def validate(self, value: ty.Any, model_instance: ty.Any) -> None: value = self.adapter.validate_python(value) - return super().validate(value, model_instance) + return super(JSONField, self).validate(value, model_instance) def to_python(self, value: ty.Any): try: return self.adapter.validate_python(value) except pydantic.ValidationError as exc: error_params = {"errors": exc.errors(), "field": self} - raise exceptions.ValidationError(exc.title, code="invalid", params=error_params) from exc + raise exceptions.ValidationError(exc.json(), code="invalid", params=error_params) from exc def get_prep_value(self, value: ty.Any): if isinstance(value, BaseExpression): @@ -87,7 +84,6 @@ def get_prep_value(self, value: ty.Any): prep_value = self.adapter.validate_python(value) plain_value = self.adapter.dump_python(prep_value) - return super().get_prep_value(plain_value) def get_transform(self, lookup_name: str): @@ -107,6 +103,20 @@ def get_default(self) -> types.ST: return prep_value + def formfield(self, **kwargs): + schema = self.schema + if schema is None: + schema = utils.get_annotated_type(self.model, self.attname) + + field_kwargs = dict( + form_class=forms.SchemaField, + schema=schema, + config=self.config, + **self.export_kwargs, + ) + field_kwargs.update(kwargs) + return super().formfield(**field_kwargs) # type: ignore + class SchemaKeyTransformAdapter: """An adapter for creating key transforms for schema field lookups.""" diff --git a/django_pydantic_field/v2/forms.py b/django_pydantic_field/v2/forms.py index 5e4ea0c..73b35f1 100644 --- a/django_pydantic_field/v2/forms.py +++ b/django_pydantic_field/v2/forms.py @@ -1,4 +1,78 @@ +import typing as ty +from collections import ChainMap +from django.forms import BaseForm, ModelForm -class SchemaField: - def __init__(*args, **kwargs): - ... +import pydantic +from django.core.exceptions import ValidationError +from django.forms.fields import JSONField, JSONString, InvalidJSONInput +from django.utils.translation import gettext_lazy as _ + +from . import types, utils + + +class SchemaField(JSONField, ty.Generic[types.ST]): + adapter: types.SchemaAdapter + default_error_messages = { + "schema_error": _("Schema didn't match for %(title)s."), + } + + def __init__(self, schema: types.ST | ty.ForwardRef, config: pydantic.ConfigDict | None = None, *args, **kwargs): + self.schema = schema + self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs) + allow_null = None in self.empty_values + self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null=allow_null, **self.export_kwargs) + super().__init__(*args, **kwargs) + + def get_bound_field(self, form: ty.Any, field_name: str): + if not self.adapter.is_bound: + self._bind_schema_adapter(form, field_name) + return super().get_bound_field(form, field_name) + + def bound_data(self, data: ty.Any, initial: ty.Any): + if self.disabled: + return initial + if data is None: + return None + try: + return self.adapter.validate_json(data) + except pydantic.ValidationError: + return InvalidJSONInput(data) + + def to_python(self, value: ty.Any) -> ty.Any: + if self.disabled: + return value + if value in self.empty_values: + return None + elif isinstance(value, JSONString): + return value + try: + converted = self.adapter.validate_json(value) + except pydantic.ValidationError as exc: + error_params = {"value": value, "title": exc.title, "detail": exc.json(), "errors": exc.errors()} + raise ValidationError(self.error_messages["schema_error"], code="invalid", params=error_params) from exc + + if isinstance(converted, str): + return JSONString(converted) + + return converted + + def prepare_value(self, value): + if isinstance(value, InvalidJSONInput): + return value + + value = self.adapter.validate_python(value) + return self.adapter.dump_json(value).decode() + + def has_changed(self, initial: ty.Any | None, data: ty.Any | None) -> bool: + if super(JSONField, self).has_changed(initial, data): + return True + return self.adapter.dump_json(initial) != self.adapter.dump_json(data) + + def _bind_schema_adapter(self, form: BaseForm, field_name: str): + modelns = None + if isinstance(form, ModelForm): + modelns = ChainMap( + utils.get_local_namespace(form._meta.model), + utils.get_global_namespace(form._meta.model), + ) + self.adapter.bind(form, field_name, modelns) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index 9747ec7..15b0970 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -2,8 +2,10 @@ import functools import typing as ty +import typing_extensions as te +from collections import ChainMap -from pydantic.type_adapter import TypeAdapter +import pydantic from . import utils from ..compat.django import GenericContainer @@ -11,15 +13,16 @@ ST = ty.TypeVar("ST", bound="SchemaT") if ty.TYPE_CHECKING: - from pydantic import BaseModel + from collections.abc import MutableMapping + from pydantic.type_adapter import IncEx from pydantic.dataclasses import DataclassClassOrWrapper from django.db.models import Model - ModelType = ty.Type[BaseModel] + ModelType = ty.Type[pydantic.BaseModel] DjangoModelType = ty.Type[Model] SchemaT = ty.Union[ - BaseModel, + pydantic.BaseModel, DataclassClassOrWrapper, ty.Sequence[ty.Any], ty.Mapping[str, ty.Any], @@ -27,7 +30,8 @@ ty.FrozenSet[ty.Any], ] -class ExportKwargs(ty.TypedDict, total=False): + +class ExportKwargs(te.TypedDict, total=False): strict: bool from_attributes: bool mode: ty.Literal["json", "python"] @@ -44,11 +48,11 @@ class ExportKwargs(ty.TypedDict, total=False): class SchemaAdapter(ty.Generic[ST]): def __init__( self, - schema, - config, - parent_type, - attname, - allow_null, + schema: ty.Any, + config: pydantic.ConfigDict | None, + parent_type: type | None, + attname: str | None, + allow_null: bool | None = None, *, parent_depth=4, **export_kwargs: ty.Unpack[ExportKwargs], @@ -60,6 +64,7 @@ def __init__( self.allow_null = allow_null self.parent_depth = parent_depth self.export_kwargs = export_kwargs + self.__namespace: MutableMapping[str, ty.Any] = {} @staticmethod def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: @@ -68,13 +73,18 @@ def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs: return ty.cast(ExportKwargs, export_kwargs) @functools.cached_property - def type_adapter(self) -> TypeAdapter: + def type_adapter(self) -> pydantic.TypeAdapter: schema = self._get_prepared_schema() - return TypeAdapter(schema, config=self.config, _parent_depth=4) # type: ignore + return pydantic.TypeAdapter(schema, config=self.config, _parent_depth=4) # type: ignore - def bind(self, parent_type, attname): + @property + def is_bound(self) -> bool: + return self.parent_type is not None and self.attname is not None + + def bind(self, parent_type, attname, __namespace: MutableMapping[str, ty.Any] | None = None): self.parent_type = parent_type self.attname = attname + self.__namespace = __namespace if __namespace is not None else {} self.__dict__.pop("type_adapter", None) def validate_schema(self) -> None: @@ -89,10 +99,18 @@ def validate_python(self, value: ty.Any, *, strict: bool | None = None, from_att from_attributes = self.export_kwargs.get("from_attributes", None) return self.type_adapter.validate_python(value, strict=strict, from_attributes=from_attributes) + def validate_json(self, value: str | bytes, *, strict: bool | None = None) -> ST: + if strict is None: + strict = self.export_kwargs.get("strict", None) + return self.type_adapter.validate_json(value, strict=strict) + def dump_python(self, value: ty.Any) -> ty.Any: """Dump the value to a Python object.""" return self.type_adapter.dump_python(value, **self._dump_python_kwargs) + def dump_json(self, value: ty.Any) -> bytes: + return self.type_adapter.dump_json(value, **self._dump_python_kwargs) + def json_schema(self) -> ty.Any: """Return the JSON schema for the field.""" by_alias = self.export_kwargs.get("by_alias", True) @@ -109,7 +127,10 @@ def _get_prepared_schema(self) -> type[ST]: schema = self._resolve_schema_forward_ref(schema) if schema is None: - error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" + if self.parent_type is not None: + error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}" + else: + error_msg = "The adapter is accessed before it was bound to a field" raise ValueError(error_msg) if self.allow_null: @@ -122,8 +143,13 @@ def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | Non def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any: if isinstance(schema, str): schema = ty.ForwardRef(schema) - namespace = utils.get_local_namespace(self.parent_type) - return schema._evaluate(namespace, vars(self.parent_type), frozenset()) # type: ignore + + globalns = ChainMap( + self.__namespace, + utils.get_local_namespace(self.parent_type), + utils.get_global_namespace(self.parent_type), + ) + return schema._evaluate(dict(globalns), {}, frozenset()) # type: ignore @functools.cached_property def _dump_python_kwargs(self) -> dict[str, ty.Any]: diff --git a/django_pydantic_field/v2/utils.py b/django_pydantic_field/v2/utils.py index 39b625a..98868c1 100644 --- a/django_pydantic_field/v2/utils.py +++ b/django_pydantic_field/v2/utils.py @@ -16,9 +16,13 @@ def get_annotated_type(obj, field, default=None) -> ty.Any: return default -def get_local_namespace(cls) -> dict[str, ty.Any]: +def get_global_namespace(cls) -> dict[str, ty.Any]: try: module = cls.__module__ return vars(sys.modules[module]) except (KeyError, AttributeError): return {} + + +def get_local_namespace(cls) -> dict[str, ty.Any]: + return dict(vars(cls)) diff --git a/tests/test_form_field.py b/tests/v1/test_forms.py similarity index 93% rename from tests/test_form_field.py rename to tests/v1/test_forms.py index cadcc6a..a816d25 100644 --- a/tests/test_form_field.py +++ b/tests/v1/test_forms.py @@ -4,10 +4,12 @@ import pytest from django.core.exceptions import ValidationError from django.forms import Form, modelform_factory -from django_pydantic_field import fields, forms -from .conftest import InnerSchema -from .test_app.models import SampleForwardRefModel, SampleSchema +from tests.conftest import InnerSchema +from tests.test_app.models import SampleForwardRefModel, SampleSchema + +fields = pytest.importorskip("django_pydantic_field.v1.fields") +forms = pytest.importorskip("django_pydantic_field.v1.forms") class SampleForm(Form): diff --git a/tests/v2/__init__.py b/tests/v2/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/v2/test_forms.py b/tests/v2/test_forms.py new file mode 100644 index 0000000..4a1081e --- /dev/null +++ b/tests/v2/test_forms.py @@ -0,0 +1,101 @@ +import typing as ty + +import django +import pytest +from django.core.exceptions import ValidationError +from django.forms import Form, modelform_factory + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleForwardRefModel, SampleSchema + +fields = pytest.importorskip("django_pydantic_field.v2.fields") +forms = pytest.importorskip("django_pydantic_field.v2.forms") + + +class SampleForm(Form): + field = forms.SchemaField(ty.ForwardRef("SampleSchema")) + + +def test_form_schema_field(): + field = forms.SchemaField(InnerSchema) + + cleaned_data = field.clean('{"stub_str": "abc", "stub_list": ["1970-01-01"]}') + assert cleaned_data == InnerSchema.model_validate({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + + +def test_empty_form_values(): + field = forms.SchemaField(InnerSchema, required=False) + assert field.clean("") is None + assert field.clean(None) is None + + +def test_prepare_value(): + field = forms.SchemaField(InnerSchema, required=False) + expected = '{"stub_str":"abc","stub_int":1,"stub_list":["1970-01-01"]}' + assert expected == field.prepare_value({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + + +def test_empty_required_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean("") + + assert e.match("This field is required") + + +def test_invalid_schema_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean('{"stub_list": "abc"}') + + assert e.match("Schema didn't match for") + assert "stub_list" in e.value.params["detail"] # type: ignore + assert "stub_str" in e.value.params["detail"] # type: ignore + + +def test_invalid_json_raises(): + field = forms.SchemaField(InnerSchema) + with pytest.raises(ValidationError) as e: + field.clean('{"stub_list": "abc}') + + assert e.match("Schema didn't match for") + assert '"type":"json_invalid"' in e.value.params["detail"] # type: ignore + + +@pytest.mark.xfail( + django.VERSION[:2] < (4, 0), + reason="Django < 4 has it's own feeling on bound fields resolution", +) +def test_forwardref_field(): + form = SampleForm(data={"field": '{"field": "2"}'}) + assert form.is_valid() + + +def test_model_formfield(): + field = fields.PydanticSchemaField(schema=InnerSchema) + assert isinstance(field.formfield(), forms.SchemaField) + + +def test_forwardref_model_formfield(): + form_cls = modelform_factory(SampleForwardRefModel, exclude=("field",)) + form = form_cls(data={"annotated_field": '{"field": "2"}'}) + + assert form.is_valid(), form.errors + cleaned_data = form.cleaned_data + + assert cleaned_data is not None + assert cleaned_data["annotated_field"] == SampleSchema(field=2) + + +@pytest.mark.parametrize("export_kwargs", [ + {"include": {"stub_str", "stub_int"}}, + {"exclude": {"stub_list"}}, + {"exclude_unset": True}, + {"exclude_defaults": True}, + {"exclude_none": True}, + {"by_alias": True}, +]) +def test_form_field_export_kwargs(export_kwargs): + field = forms.SchemaField(InnerSchema, required=False, **export_kwargs) + value = InnerSchema.model_validate({"stub_str": "abc", "stub_list": ["1970-01-01"]}) + assert field.prepare_value(value)