Skip to content

Commit

Permalink
Implement basic abstractions around pydantic 2.
Browse files Browse the repository at this point in the history
  • Loading branch information
surenkov committed Nov 5, 2023
1 parent 0df6db1 commit 79477de
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 0 deletions.
1 change: 1 addition & 0 deletions django_pydantic_field/v2/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .fields import SchemaField as SchemaField
78 changes: 78 additions & 0 deletions django_pydantic_field/v2/fields.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from __future__ import annotations

import typing as ty

import pydantic

from django.core import checks

from django.db.models.expressions import BaseExpression
from django.db.models.fields.json import JSONField
from django.db.models.query_utils import DeferredAttribute

from . import types


class SchemaAttribute(DeferredAttribute):
field: PydanticSchemaField

def __set__(self, obj, value):
obj.__dict__[self.field.attname] = self.field.to_python(value)


class PydanticSchemaField(JSONField, ty.Generic[types.ST]):
def __init__(
self,
*args,
schema: type[types.ST] | str | None = None,
config: pydantic.ConfigDict | None = None,
**kwargs,
):
super().__init__(*args, **kwargs)
self.schema = schema
self.config = config
self.adapter = types.SchemaAdapter(schema, config, None, self.attname, self.null)

def __copy__(self):
_, _, args, kwargs = self.deconstruct()
copied = self.__class__(*args, **kwargs)
copied.set_attributes_from_name(self.name)
return copied

def contribute_to_class(self, cls, name, private_only):
self.adapter.bind(cls, name)
return super().contribute_to_class(cls, name, private_only)

def check(self, **kwargs: ty.Any) -> list[checks.CheckMessage]:
performed_checks = super().check(**kwargs)
try:
self.adapter.validate_schema()
except ValueError as exc:
performed_checks.append(checks.Error(exc.args[0], obj=self))
return performed_checks

def to_python(self, value: ty.Any):
return self.adapter.type_adapter.validate_python(value)

def get_prep_value(self, value: ty.Any):
if isinstance(value, BaseExpression):
# We don't want to perform coercion on database query expressions.
return super().get_prep_value(value)
return self.adapter.type_adapter.dump_python(value)

def validate(self, value: ty.Any, model_instance: ty.Any) -> None:
value = self.adapter.type_adapter.validate_python(value)
return super().validate(value, model_instance)


@ty.overload
def SchemaField(schema: None = None) -> ty.Any:
...

@ty.overload
def SchemaField(schema: type[types.ST]) -> ty.Any:
...


def SchemaField(schema=None): # type: ignore
return PydanticSchemaField(schema=schema)
68 changes: 68 additions & 0 deletions django_pydantic_field/v2/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import functools
import typing as ty

import pydantic
from django.core import exceptions

from . import utils

ST = ty.TypeVar("ST", bound="SchemaT")

if ty.TYPE_CHECKING:
from pydantic.dataclasses import DataclassClassOrWrapper

ModelType = ty.Type[pydantic.BaseModel]
SchemaT = ty.Union[
pydantic.BaseModel,
DataclassClassOrWrapper,
ty.Sequence[ty.Any],
ty.Mapping[str, ty.Any],
ty.Set[ty.Any],
ty.FrozenSet[ty.Any],
]


class SchemaAdapter(ty.Generic[ST]):
def __init__(self, schema, config, parent_type, attname, allow_null, *, parent_depth = 4):
self.schema = schema
self.config = config
self.parent_type = parent_type
self.attname = attname
self.allow_null = allow_null
self.parent_depth = parent_depth

@functools.cached_property
def type_adapter(self) -> pydantic.TypeAdapter:
schema = self._get_prepared_schema()
return pydantic.TypeAdapter(schema, config=self.config, _parent_depth=3) # type: ignore

def validate_schema(self) -> None:
"""Validate the schema and raise an exception if it is invalid."""
self._get_prepared_schema()

def bind(self, parent_type, attname):
self.parent_type = parent_type
self.attname = attname
del self.type_adapter

def _get_prepared_schema(self) -> type[ST]:
schema = self.schema
if schema is None:
schema = self._guess_schema_from_annotations()
if isinstance(schema, (str, ty.ForwardRef)):
schema = self._resolve_schema_forward_ref(schema)
if schema is None:
error_msg = f"Schema not provided for {self.parent_type.__name__}.{self.attname}"
raise ValueError(error_msg)
if self.allow_null:
schema = ty.Optional[schema]
return ty.cast(type[ST], schema)

def _guess_schema_from_annotations(self) -> type[ST] | str | ty.ForwardRef | None:
return utils.get_annotated_type(self.parent_type, self.attname)

def _resolve_schema_forward_ref(self, schema: str | ty.ForwardRef) -> ty.Any:
if isinstance(schema, str):
schema = ty.ForwardRef(schema)
namespace = utils.get_local_namespace(self.parent_type)
return schema._evaluate(namespace, vars(self.parent_type), frozenset()) # type: ignore
24 changes: 24 additions & 0 deletions django_pydantic_field/v2/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import sys
import typing as ty


def get_annotated_type(obj, field, default=None) -> ty.Any:
try:
if isinstance(obj, type):
annotations = obj.__dict__["__annotations__"]
else:
annotations = obj.__annotations__

return annotations[field]
except (AttributeError, KeyError):
return default


def get_local_namespace(cls) -> dict[str, ty.Any]:
try:
module = cls.__module__
return vars(sys.modules[module])
except (KeyError, AttributeError):
return {}

0 comments on commit 79477de

Please sign in to comment.