Skip to content

Commit

Permalink
Adapt SchemaField's transformations for json lookups.
Browse files Browse the repository at this point in the history
- Fixes all e2e tests
  • Loading branch information
surenkov committed Nov 6, 2023
1 parent 74a5d93 commit 9db8faa
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 13 deletions.
25 changes: 23 additions & 2 deletions django_pydantic_field/v2/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
from django.core import checks, exceptions
from django.core.serializers.json import DjangoJSONEncoder

from django.db.models.expressions import BaseExpression
from django.db.models.expressions import BaseExpression, Col
from django.db.models.fields.json import JSONField
from django.db.models.lookups import Transform
from django.db.models.query_utils import DeferredAttribute

from . import types
Expand Down Expand Up @@ -84,13 +85,19 @@ def get_prep_value(self, value: ty.Any):

try:
prep_value = self.adapter.validate_python(value, strict=True)
except TypeError:
except pydantic.ValidationError:
prep_value = self.adapter.dump_python(value)
prep_value = self.adapter.validate_python(prep_value)

plain_value = self.adapter.dump_python(prep_value)
return super().get_prep_value(plain_value)

def get_transform(self, lookup_name: str):
transform = super().get_transform(lookup_name)
if transform is not None:
transform = SchemaKeyTransformAdapter(transform)
return transform

def get_default(self) -> types.ST:
default_value = super().get_default()
try:
Expand All @@ -102,6 +109,20 @@ def get_default(self) -> types.ST:
return prep_value


class SchemaKeyTransformAdapter:
"""An adapter for creating key transforms for schema field lookups."""

def __init__(self, transform: type[Transform]):
self.transform = transform

def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None:
"""All transforms should bypass the SchemaField's adaptaion with `get_prep_value`,
and routed to JSONField's `get_prep_value` for further processing."""
if isinstance(col, BaseExpression):
col = col.copy()
col.output_field = super(PydanticSchemaField, col.output_field) # type: ignore
return self.transform(col, *args, **kwargs)


@ty.overload
def SchemaField(schema: None = None) -> ty.Any:
Expand Down
16 changes: 5 additions & 11 deletions tests/test_e2e_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import pytest
from django.db.models import F, Q, JSONField, Value

from .conftest import InnerSchema
from .test_app.models import SampleModel
from tests.conftest import InnerSchema
from tests.test_app.models import SampleModel

pytestmark = [
pytest.mark.usefixtures("available_database_backends"),
Expand All @@ -17,15 +17,11 @@
[
(
{
"sample_field": InnerSchema(
stub_str="abc", stub_list=[date(2023, 6, 1)]
),
"sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]),
"sample_list": [InnerSchema(stub_str="abc", stub_list=[])],
},
{
"sample_field": InnerSchema(
stub_str="abc", stub_list=[date(2023, 6, 1)]
),
"sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]),
"sample_list": [InnerSchema(stub_str="abc", stub_list=[])],
},
),
Expand All @@ -35,9 +31,7 @@
"sample_list": [{"stub_str": "abc", "stub_list": []}],
},
{
"sample_field": InnerSchema(
stub_str="abc", stub_list=[date(2023, 6, 1)]
),
"sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]),
"sample_list": [InnerSchema(stub_str="abc", stub_list=[])],
},
),
Expand Down

0 comments on commit 9db8faa

Please sign in to comment.