From 9db8faac850aa16a89d82187bcf52485c2451266 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Mon, 6 Nov 2023 22:05:55 +0400 Subject: [PATCH] Adapt SchemaField's transformations for json lookups. - Fixes all e2e tests --- django_pydantic_field/v2/fields.py | 25 +++++++++++++++++++++++-- tests/test_e2e_models.py | 16 +++++----------- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/django_pydantic_field/v2/fields.py b/django_pydantic_field/v2/fields.py index 443fd7a..652d58f 100644 --- a/django_pydantic_field/v2/fields.py +++ b/django_pydantic_field/v2/fields.py @@ -7,8 +7,9 @@ from django.core import checks, exceptions from django.core.serializers.json import DjangoJSONEncoder -from django.db.models.expressions import BaseExpression +from django.db.models.expressions import BaseExpression, Col from django.db.models.fields.json import JSONField +from django.db.models.lookups import Transform from django.db.models.query_utils import DeferredAttribute from . import types @@ -84,13 +85,19 @@ def get_prep_value(self, value: ty.Any): try: prep_value = self.adapter.validate_python(value, strict=True) - except TypeError: + except pydantic.ValidationError: prep_value = self.adapter.dump_python(value) prep_value = self.adapter.validate_python(prep_value) plain_value = self.adapter.dump_python(prep_value) return super().get_prep_value(plain_value) + def get_transform(self, lookup_name: str): + transform = super().get_transform(lookup_name) + if transform is not None: + transform = SchemaKeyTransformAdapter(transform) + return transform + def get_default(self) -> types.ST: default_value = super().get_default() try: @@ -102,6 +109,20 @@ def get_default(self) -> types.ST: return prep_value +class SchemaKeyTransformAdapter: + """An adapter for creating key transforms for schema field lookups.""" + + def __init__(self, transform: type[Transform]): + self.transform = transform + + def __call__(self, col: Col | None = None, *args, **kwargs) -> Transform | None: + """All transforms should bypass the SchemaField's adaptaion with `get_prep_value`, + and routed to JSONField's `get_prep_value` for further processing.""" + if isinstance(col, BaseExpression): + col = col.copy() + col.output_field = super(PydanticSchemaField, col.output_field) # type: ignore + return self.transform(col, *args, **kwargs) + @ty.overload def SchemaField(schema: None = None) -> ty.Any: diff --git a/tests/test_e2e_models.py b/tests/test_e2e_models.py index ce5c5e7..c48cf73 100644 --- a/tests/test_e2e_models.py +++ b/tests/test_e2e_models.py @@ -3,8 +3,8 @@ import pytest from django.db.models import F, Q, JSONField, Value -from .conftest import InnerSchema -from .test_app.models import SampleModel +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel pytestmark = [ pytest.mark.usefixtures("available_database_backends"), @@ -17,15 +17,11 @@ [ ( { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, ), @@ -35,9 +31,7 @@ "sample_list": [{"stub_str": "abc", "stub_list": []}], }, { - "sample_field": InnerSchema( - stub_str="abc", stub_list=[date(2023, 6, 1)] - ), + "sample_field": InnerSchema(stub_str="abc", stub_list=[date(2023, 6, 1)]), "sample_list": [InnerSchema(stub_str="abc", stub_list=[])], }, ),