Skip to content

Commit

Permalink
Django form field impl for V2.
Browse files Browse the repository at this point in the history
  • Loading branch information
surenkov committed Nov 7, 2023
1 parent 417b669 commit a3f25a8
Show file tree
Hide file tree
Showing 9 changed files with 257 additions and 30 deletions.
5 changes: 5 additions & 0 deletions django_pydantic_field/v1/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from ..compat.pydantic import PYDANTIC_V1

if not PYDANTIC_V1:
raise ImportError("django_pydantic_field.v1 package is only compatible with Pydantic v1")

from .fields import *
5 changes: 5 additions & 0 deletions django_pydantic_field/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
from ..compat.pydantic import PYDANTIC_V2

if not PYDANTIC_V2:
raise ImportError("django_pydantic_field.v2 package is only compatible with Pydantic v2")

from .fields import SchemaField as SchemaField
24 changes: 17 additions & 7 deletions django_pydantic_field/v2/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
from django.db.models.lookups import Transform
from django.db.models.query_utils import DeferredAttribute

from . import types
from . import types, forms, utils
from ..compat.django import GenericContainer


class SchemaAttribute(DeferredAttribute):
field: PydanticSchemaField

def __set_name__(self, owner, name):
self.field.adapter.bind(owner, name)

def __set__(self, obj, value):
obj.__dict__[self.field.attname] = self.field.to_python(value)

Expand Down Expand Up @@ -71,14 +68,14 @@ def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]:

def validate(self, value: ty.Any, model_instance: ty.Any) -> None:
value = self.adapter.validate_python(value)
return super().validate(value, model_instance)
return super(JSONField, self).validate(value, model_instance)

def to_python(self, value: ty.Any):
try:
return self.adapter.validate_python(value)
except pydantic.ValidationError as exc:
error_params = {"errors": exc.errors(), "field": self}
raise exceptions.ValidationError(exc.title, code="invalid", params=error_params) from exc
raise exceptions.ValidationError(exc.json(), code="invalid", params=error_params) from exc

def get_prep_value(self, value: ty.Any):
if isinstance(value, BaseExpression):
Expand All @@ -87,7 +84,6 @@ def get_prep_value(self, value: ty.Any):

prep_value = self.adapter.validate_python(value)
plain_value = self.adapter.dump_python(prep_value)

return super().get_prep_value(plain_value)

def get_transform(self, lookup_name: str):
Expand All @@ -107,6 +103,20 @@ def get_default(self) -> types.ST:

return prep_value

def formfield(self, **kwargs):
schema = self.schema
if schema is None:
schema = utils.get_annotated_type(self.model, self.attname)

field_kwargs = dict(
form_class=forms.SchemaField,
schema=schema,
config=self.config,
**self.export_kwargs,
)
field_kwargs.update(kwargs)
return super().formfield(**field_kwargs) # type: ignore


class SchemaKeyTransformAdapter:
"""An adapter for creating key transforms for schema field lookups."""
Expand Down
80 changes: 77 additions & 3 deletions django_pydantic_field/v2/forms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,78 @@
import typing as ty
from collections import ChainMap
from django.forms import BaseForm, ModelForm

class SchemaField:
def __init__(*args, **kwargs):
...
import pydantic
from django.core.exceptions import ValidationError
from django.forms.fields import JSONField, JSONString, InvalidJSONInput
from django.utils.translation import gettext_lazy as _

from . import types, utils


class SchemaField(JSONField, ty.Generic[types.ST]):
adapter: types.SchemaAdapter
default_error_messages = {
"schema_error": _("Schema didn't match for %(title)s."),
}

def __init__(self, schema: types.ST | ty.ForwardRef, config: pydantic.ConfigDict | None = None, *args, **kwargs):
self.schema = schema
self.export_kwargs = types.SchemaAdapter.extract_export_kwargs(kwargs)
allow_null = None in self.empty_values
self.adapter = types.SchemaAdapter(schema, config, None, None, allow_null=allow_null, **self.export_kwargs)
super().__init__(*args, **kwargs)

def get_bound_field(self, form: ty.Any, field_name: str):
if not self.adapter.is_bound:
self._bind_schema_adapter(form, field_name)
return super().get_bound_field(form, field_name)

def bound_data(self, data: ty.Any, initial: ty.Any):
if self.disabled:
return initial
if data is None:
return None
try:
return self.adapter.validate_json(data)
except pydantic.ValidationError:
return InvalidJSONInput(data)

def to_python(self, value: ty.Any) -> ty.Any:
if self.disabled:
return value
if value in self.empty_values:
return None
elif isinstance(value, JSONString):
return value
try:
converted = self.adapter.validate_json(value)
except pydantic.ValidationError as exc:
error_params = {"value": value, "title": exc.title, "detail": exc.json(), "errors": exc.errors()}
raise ValidationError(self.error_messages["schema_error"], code="invalid", params=error_params) from exc

if isinstance(converted, str):
return JSONString(converted)

return converted

def prepare_value(self, value):
if isinstance(value, InvalidJSONInput):
return value

value = self.adapter.validate_python(value)
return self.adapter.dump_json(value).decode()

def has_changed(self, initial: ty.Any | None, data: ty.Any | None) -> bool:
if super(JSONField, self).has_changed(initial, data):
return True
return self.adapter.dump_json(initial) != self.adapter.dump_json(data)

def _bind_schema_adapter(self, form: BaseForm, field_name: str):
modelns = None
if isinstance(form, ModelForm):
modelns = ChainMap(
utils.get_local_namespace(form._meta.model),
utils.get_global_namespace(form._meta.model),
)
self.adapter.bind(form, field_name, modelns)
58 changes: 42 additions & 16 deletions django_pydantic_field/v2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,36 @@

import functools
import typing as ty
import typing_extensions as te
from collections import ChainMap

from pydantic.type_adapter import TypeAdapter
import pydantic

from . import utils
from ..compat.django import GenericContainer

ST = ty.TypeVar("ST", bound="SchemaT")

if ty.TYPE_CHECKING:
from pydantic import BaseModel
from collections.abc import MutableMapping

from pydantic.type_adapter import IncEx
from pydantic.dataclasses import DataclassClassOrWrapper
from django.db.models import Model

ModelType = ty.Type[BaseModel]
ModelType = ty.Type[pydantic.BaseModel]
DjangoModelType = ty.Type[Model]
SchemaT = ty.Union[
BaseModel,
pydantic.BaseModel,
DataclassClassOrWrapper,
ty.Sequence[ty.Any],
ty.Mapping[str, ty.Any],
ty.Set[ty.Any],
ty.FrozenSet[ty.Any],
]

class ExportKwargs(ty.TypedDict, total=False):

class ExportKwargs(te.TypedDict, total=False):
strict: bool
from_attributes: bool
mode: ty.Literal["json", "python"]
Expand All @@ -44,11 +48,11 @@ class ExportKwargs(ty.TypedDict, total=False):
class SchemaAdapter(ty.Generic[ST]):
def __init__(
self,
schema,
config,
parent_type,
attname,
allow_null,
schema: ty.Any,
config: pydantic.ConfigDict | None,
parent_type: type | None,
attname: str | None,
allow_null: bool | None = None,
*,
parent_depth=4,
**export_kwargs: ty.Unpack[ExportKwargs],
Expand All @@ -60,6 +64,7 @@ def __init__(
self.allow_null = allow_null
self.parent_depth = parent_depth
self.export_kwargs = export_kwargs
self.__namespace: MutableMapping[str, ty.Any] = {}

@staticmethod
def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs:
Expand All @@ -68,13 +73,18 @@ def extract_export_kwargs(kwargs: dict[str, ty.Any]) -> ExportKwargs:
return ty.cast(ExportKwargs, export_kwargs)

@functools.cached_property
def type_adapter(self) -> TypeAdapter:
def type_adapter(self) -> pydantic.TypeAdapter:
schema = self._get_prepared_schema()
return TypeAdapter(schema, config=self.config, _parent_depth=4) # type: ignore
return pydantic.TypeAdapter(schema, config=self.config, _parent_depth=4) # type: ignore

def bind(self, parent_type, attname):
@property
def is_bound(self) -> bool:
return self.parent_type is not None and self.attname is not None

def bind(self, parent_type, attname, __namespace: MutableMapping[str, ty.Any] | None = None):
self.parent_type = parent_type
self.attname = attname
self.__namespace = __namespace if __namespace is not None else {}
self.__dict__.pop("type_adapter", None)

def validate_schema(self) -> None:
Expand All @@ -89,10 +99,18 @@ def validate_python(self, value: ty.Any, *, strict: bool | None = None, from_att
from_attributes = self.export_kwargs.get("from_attributes", None)
return self.type_adapter.validate_python(value, strict=strict, from_attributes=from_attributes)

def validate_json(self, value: str | bytes, *, strict: bool | None = None) -> ST:
if strict is None:
strict = self.export_kwargs.get("strict", None)
return self.type_adapter.validate_json(value, strict=strict)

def dump_python(self, value: ty.Any) -> ty.Any:
"""Dump the value to a Python object."""
return self.type_adapter.dump_python(value, **self._dump_python_kwargs)

def dump_json(self, value: ty.Any) -> bytes:
return self.type_adapter.dump_json(value, **self._dump_python_kwargs)

def json_schema(self) -> ty.Any:
"""Return the JSON schema for the field."""
by_alias = self.export_kwargs.get("by_alias", True)
Expand All @@ -109,7 +127,10 @@ def _get_prepared_schema(self) -> type[ST]:
schema = self._resolve_schema_forward_ref(schema)

if schema is None:
error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}"
if self.parent_type is not None:
error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}"
else:
error_msg = "The adapter is accessed before it was bound to a field"
raise ValueError(error_msg)

if self.allow_null:
Expand All @@ -122,8 +143,13 @@ def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | Non
def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any:
if isinstance(schema, str):
schema = ty.ForwardRef(schema)
namespace = utils.get_local_namespace(self.parent_type)
return schema._evaluate(namespace, vars(self.parent_type), frozenset()) # type: ignore

globalns = ChainMap(
self.__namespace,
utils.get_local_namespace(self.parent_type),
utils.get_global_namespace(self.parent_type),
)
return schema._evaluate(dict(globalns), {}, frozenset()) # type: ignore

@functools.cached_property
def _dump_python_kwargs(self) -> dict[str, ty.Any]:
Expand Down
6 changes: 5 additions & 1 deletion django_pydantic_field/v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,13 @@ def get_annotated_type(obj, field, default=None) -> ty.Any:
return default


def get_local_namespace(cls) -> dict[str, ty.Any]:
def get_global_namespace(cls) -> dict[str, ty.Any]:
try:
module = cls.__module__
return vars(sys.modules[module])
except (KeyError, AttributeError):
return {}


def get_local_namespace(cls) -> dict[str, ty.Any]:
return dict(vars(cls))
8 changes: 5 additions & 3 deletions tests/test_form_field.py → tests/v1/test_forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import pytest
from django.core.exceptions import ValidationError
from django.forms import Form, modelform_factory
from django_pydantic_field import fields, forms

from .conftest import InnerSchema
from .test_app.models import SampleForwardRefModel, SampleSchema
from tests.conftest import InnerSchema
from tests.test_app.models import SampleForwardRefModel, SampleSchema

fields = pytest.importorskip("django_pydantic_field.v1.fields")
forms = pytest.importorskip("django_pydantic_field.v1.forms")


class SampleForm(Form):
Expand Down
Empty file added tests/v2/__init__.py
Empty file.
Loading

0 comments on commit a3f25a8

Please sign in to comment.