From 6a674cbd454a69fdcdfa2aac1e5d56563dc492a9 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 3 Jan 2024 01:04:02 +0400 Subject: [PATCH] Split v2.rest_framework package tests. --- .../v2/rest_framework/openapi.py | 63 +++-- tests/v2/rest_framework/__init__.py | 0 tests/v2/rest_framework/test_coreapi.py | 26 ++ tests/v2/rest_framework/test_e2e_views.py | 56 ++++ tests/v2/rest_framework/test_fields.py | 108 ++++++++ tests/v2/rest_framework/test_openapi.py | 23 ++ tests/v2/rest_framework/test_parsers.py | 23 ++ tests/v2/rest_framework/test_renderers.py | 23 ++ tests/v2/rest_framework/view_fixtures.py | 88 +++++++ tests/v2/test_rest_framework.py | 247 ------------------ 10 files changed, 378 insertions(+), 279 deletions(-) create mode 100644 tests/v2/rest_framework/__init__.py create mode 100644 tests/v2/rest_framework/test_coreapi.py create mode 100644 tests/v2/rest_framework/test_e2e_views.py create mode 100644 tests/v2/rest_framework/test_fields.py create mode 100644 tests/v2/rest_framework/test_openapi.py create mode 100644 tests/v2/rest_framework/test_parsers.py create mode 100644 tests/v2/rest_framework/test_renderers.py create mode 100644 tests/v2/rest_framework/view_fixtures.py delete mode 100644 tests/v2/test_rest_framework.py diff --git a/django_pydantic_field/v2/rest_framework/openapi.py b/django_pydantic_field/v2/rest_framework/openapi.py index 3c741f5..2965e41 100644 --- a/django_pydantic_field/v2/rest_framework/openapi.py +++ b/django_pydantic_field/v2/rest_framework/openapi.py @@ -24,7 +24,7 @@ class AutoSchema(openapi.AutoSchema): def __init__(self, tags=None, operation_id_base=None, component_name=None) -> None: super().__init__(tags, operation_id_base, component_name) self.collected_schema_defs: dict[str, ty.Any] = {} - self.adapter_type_to_schema_refs = weakref.WeakKeyDictionary[type, str]() + self.collected_adapter_schema_refs: dict[str, ty.Any] = {} self.adapter_mode: JsonSchemaMode = "validation" self.rf = APIRequestFactory() @@ -32,8 +32,6 @@ def get_components(self, path: str, method: str) -> dict[str, ty.Any]: if method.lower() == "delete": return {} - super().get_components - request_serializer = self.get_request_serializer(path, method) # type: ignore[attr-defined] response_serializer = self.get_response_serializer(path, method) # type: ignore[attr-defined] @@ -61,9 +59,9 @@ def get_request_body(self, path, method): schema_content = {} for parser, ct in zip(self.view.parser_classes, self.request_media_types): - if issubclass(parser, parsers.SchemaParser): - ref_path = self._get_component_ref(self.adapter_type_to_schema_refs[parser]) - schema_content[ct] = {"schema": {"$ref": ref_path}} + if isinstance(parser(), parsers.SchemaParser): + parser_schema = self.collected_adapter_schema_refs[repr(parser)] + schema_content[ct] = {"schema": parser_schema} else: schema_content[ct] = request_schema @@ -76,23 +74,21 @@ def get_responses(self, path, method): self.response_media_types = self.map_renderers(path, method) serializer = self.get_response_serializer(path, method) - item_schema = {} + response_schema = {} if isinstance(serializer, serializers.Serializer): - item_schema = self.get_reference(serializer) + response_schema = self.get_reference(serializer) - if drf_schema_utils.is_list_view(path, method, self.view): - response_schema = {"type": "array", "items": item_schema} - paginator = self.get_paginator() - if paginator: - response_schema = paginator.get_paginated_response_schema(response_schema) - else: - response_schema = item_schema + is_list_view = drf_schema_utils.is_list_view(path, method, self.view) + if is_list_view: + response_schema = self._get_paginated_schema(response_schema) schema_content = {} for renderer, ct in zip(self.view.renderer_classes, self.response_media_types): - if issubclass(renderer, renderers.SchemaRenderer): - ref_path = self._get_component_ref(self.adapter_type_to_schema_refs[renderer]) - schema_content[ct] = {"schema": {"$ref": ref_path}} + if isinstance(renderer(), renderers.SchemaRenderer): + renderer_schema = {"schema": self.collected_adapter_schema_refs[repr(renderer)]} + if is_list_view: + renderer_schema = self._get_paginated_schema(renderer_schema) + schema_content[ct] = renderer_schema else: schema_content[ct] = response_schema @@ -110,14 +106,15 @@ def map_parsers(self, path: str, method: str) -> list[str]: for parser in self.view.parser_classes: media_types.append(parser.media_type) - if issubclass(parser, parsers.SchemaParser): - schema_parsers.append(parser()) + instance = parser() + if isinstance(instance, parsers.SchemaParser): + schema_parsers.append(parser) if schema_parsers: self.adapter_mode = "validation" request = self.rf.generic(method, path) schemas = self._collect_adapter_components(schema_parsers, self.view.get_parser_context(request)) - self.collected_schema_defs.update(schemas) + self.collected_adapter_schema_refs.update(schemas) return media_types @@ -127,13 +124,14 @@ def map_renderers(self, path: str, method: str) -> list[str]: for renderer in self.view.renderer_classes: media_types.append(renderer.media_type) - if issubclass(renderer, renderers.SchemaRenderer): - schema_renderers.append(renderer()) + instance = renderer() + if isinstance(instance, renderers.SchemaRenderer): + schema_renderers.append(renderer) if schema_renderers: self.adapter_mode = "serialization" schemas = self._collect_adapter_components(schema_renderers, self.view.get_renderer_context()) - self.collected_schema_defs.update(schemas) + self.collected_adapter_schema_refs.update(schemas) return media_types @@ -160,16 +158,13 @@ def _collect_serializer_component(self, serializer: serializers.BaseSerializer | schema_definition[component_name] = self.map_serializer(serializer) return schema_definition - def _collect_adapter_components(self, components: Iterable[mixins.AnnotatedAdapterMixin], context: dict): + def _collect_adapter_components(self, components: Iterable[type[mixins.AnnotatedAdapterMixin]], context: dict): type_adapters = [] for component in components: - schema_adapter = component.get_adapter(context) + schema_adapter = component().get_adapter(context) if schema_adapter is not None: - schema_name = schema_adapter.prepared_schema.__class__.__name__ - self.adapter_type_to_schema_refs[type(component)] = schema_name - - type_adapters.append((schema_name, self.adapter_mode, schema_adapter.type_adapter)) + type_adapters.append((repr(component), self.adapter_mode, schema_adapter.type_adapter)) if type_adapters: return self._collect_type_adapter_schemas(type_adapters) @@ -186,5 +181,9 @@ def _collect_type_adapter_schemas(self, adapters: Iterable[tuple[str, JsonSchema self.collected_schema_defs.update(common_schemas.get("$defs", {})) return inner_schemas - def _get_component_ref(self, model: str): - return self.REF_TEMPLATE_PREFIX.format(model=model) + def _get_paginated_schema(self, schema) -> ty.Any: + response_schema = {"type": "array", "items": schema} + paginator = self.get_paginator() + if paginator: + response_schema = paginator.get_paginated_response_schema(response_schema) # type: ignore + return response_schema diff --git a/tests/v2/rest_framework/__init__.py b/tests/v2/rest_framework/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/v2/rest_framework/test_coreapi.py b/tests/v2/rest_framework/test_coreapi.py new file mode 100644 index 0000000..c3e2669 --- /dev/null +++ b/tests/v2/rest_framework/test_coreapi.py @@ -0,0 +1,26 @@ +import sys + +import pytest +from rest_framework import schemas +from rest_framework.request import Request + +from .view_fixtures import create_views_urlconf + +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + +@pytest.mark.skipif(sys.version_info >= (3, 12), reason="CoreAPI is not compatible with 3.12") +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/func"), + ("POST", "/func"), + ("GET", "/class"), + ("PUT", "/class"), + ], +) +def test_coreapi_schema_generators(request_factory, method, path): + urlconf = create_views_urlconf(coreapi.AutoSchema) + generator = schemas.SchemaGenerator(urlconf=urlconf) + request = Request(request_factory.generic(method, path)) + coreapi_schema = generator.get_schema(request) + assert coreapi_schema diff --git a/tests/v2/rest_framework/test_e2e_views.py b/tests/v2/rest_framework/test_e2e_views.py new file mode 100644 index 0000000..4790263 --- /dev/null +++ b/tests/v2/rest_framework/test_e2e_views.py @@ -0,0 +1,56 @@ +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +from .view_fixtures import ( + ClassBasedView, + ClassBasedViewWithModel, + ClassBasedViewWithSchemaContext, + sample_view, +) + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + + +@pytest.mark.parametrize( + "view", + [ + sample_view, + ClassBasedView.as_view(), + ClassBasedViewWithSchemaContext.as_view(), + ], +) +def test_end_to_end_api_view(view, request_factory): + expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + existing_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + request = request_factory.post("/", existing_encoded, content_type="application/json") + response = view(request) + + assert response.data == [expected_instance] + assert response.data[0] is not expected_instance + + assert response.rendered_content == b"[%s]" % existing_encoded + + +@pytest.mark.django_db +def test_end_to_end_list_create_api_view(request_factory): + field_data = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]).json() + expected_result = { + "sample_field": {"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}, + "sample_list": [{"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}], + "sample_seq": [], + } + + payload = '{"sample_field": %s, "sample_list": [%s], "sample_seq": []}' % ((field_data,) * 2) + request = request_factory.post("/", payload.encode(), content_type="application/json") + response = ClassBasedViewWithModel.as_view()(request) + + assert response.data == expected_result + + request = request_factory.get("/", content_type="application/json") + response = ClassBasedViewWithModel.as_view()(request) + assert response.data == [expected_result] diff --git a/tests/v2/rest_framework/test_fields.py b/tests/v2/rest_framework/test_fields.py new file mode 100644 index 0000000..6098da6 --- /dev/null +++ b/tests/v2/rest_framework/test_fields.py @@ -0,0 +1,108 @@ +import typing as ty +from datetime import date + +import pytest +from rest_framework import exceptions, serializers + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +class SampleSerializer(serializers.Serializer): + field = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + + +class SampleModelSerializer(serializers.ModelSerializer): + sample_field = rest_framework.SchemaField(schema=InnerSchema) + sample_list = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + sample_seq = rest_framework.SchemaField(schema=ty.List[InnerSchema], default=list) + + class Meta: + model = SampleModel + fields = "sample_field", "sample_list", "sample_seq" + + +def test_schema_field(): + field = rest_framework.SchemaField(InnerSchema) + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = { + "stub_str": "abc", + "stub_int": 1, + "stub_list": [date(2022, 7, 1)], + } + + assert field.to_representation(existing_instance) == expected_encoded + assert field.to_internal_value(expected_encoded) == existing_instance + + with pytest.raises(serializers.ValidationError): + field.to_internal_value(None) + + with pytest.raises(serializers.ValidationError): + field.to_internal_value("null") + + +def test_field_schema_with_custom_config(): + field = rest_framework.SchemaField(InnerSchema, allow_null=True, exclude={"stub_int"}) + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + + assert field.to_representation(existing_instance) == expected_encoded + assert field.to_internal_value(expected_encoded) == existing_instance + assert field.to_internal_value(None) is None + assert field.to_internal_value("null") is None + + +def test_serializer_marshalling_with_schema_field(): + existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]} + expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}]} + + serializer = SampleSerializer(instance=existing_instance) + assert serializer.data == expected_data + + serializer = SampleSerializer(data=expected_data) + serializer.is_valid(raise_exception=True) + assert serializer.validated_data == existing_instance + + +def test_model_serializer_marshalling_with_schema_field(): + instance = SampleModel( + sample_field=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + sample_list=[InnerSchema(stub_str="abc", stub_int=2, stub_list=[date(2022, 7, 1)])] * 2, + sample_seq=[InnerSchema(stub_str="abc", stub_int=3, stub_list=[date(2022, 7, 1)])] * 3, + ) + serializer = SampleModelSerializer(instance) + + expected_data = { + "sample_field": {"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}, + "sample_list": [{"stub_str": "abc", "stub_int": 2, "stub_list": [date(2022, 7, 1)]}] * 2, + "sample_seq": [{"stub_str": "abc", "stub_int": 3, "stub_list": [date(2022, 7, 1)]}] * 3, + } + assert serializer.data == expected_data + + +@pytest.mark.parametrize( + "export_kwargs", + [ + {"include": {"stub_str", "stub_int"}}, + {"exclude": {"stub_list"}}, + {"exclude_unset": True}, + {"exclude_defaults": True}, + {"exclude_none": True}, + {"by_alias": True}, + ], +) +def test_field_export_kwargs(export_kwargs): + field = rest_framework.SchemaField(InnerSchema, **export_kwargs) + assert field.to_representation(InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])) + + +def test_invalid_data_serialization(): + invalid_data = {"field": [{"stub_int": "abc", "stub_list": ["abc"]}]} + serializer = SampleSerializer(data=invalid_data) + + with pytest.raises(exceptions.ValidationError) as e: + serializer.is_valid(raise_exception=True) + + assert e.match(r".*stub_str.*stub_int.*stub_list.*") diff --git a/tests/v2/rest_framework/test_openapi.py b/tests/v2/rest_framework/test_openapi.py new file mode 100644 index 0000000..a0d8bc6 --- /dev/null +++ b/tests/v2/rest_framework/test_openapi.py @@ -0,0 +1,23 @@ +import pytest +from rest_framework.schemas.openapi import SchemaGenerator +from rest_framework.request import Request + +from .view_fixtures import create_views_urlconf + +openapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.openapi") + +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/func"), + ("POST", "/func"), + ("GET", "/class"), + ("PUT", "/class"), + ], +) +def test_coreapi_schema_generators(request_factory, method, path): + urlconf = create_views_urlconf(openapi.AutoSchema) + generator = SchemaGenerator(urlconf=urlconf) + request = Request(request_factory.generic(method, path)) + openapi_schema = generator.get_schema(request) + assert openapi_schema diff --git a/tests/v2/rest_framework/test_parsers.py b/tests/v2/rest_framework/test_parsers.py new file mode 100644 index 0000000..db64a56 --- /dev/null +++ b/tests/v2/rest_framework/test_parsers.py @@ -0,0 +1,23 @@ +import io +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +@pytest.mark.parametrize( + "schema_type, existing_encoded, expected_decoded", + [ + ( + InnerSchema, + '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}', + InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), + ) + ], +) +def test_schema_parser(schema_type, existing_encoded, expected_decoded): + parser = rest_framework.SchemaParser[schema_type]() + assert parser.parse(io.StringIO(existing_encoded)) == expected_decoded diff --git a/tests/v2/rest_framework/test_renderers.py b/tests/v2/rest_framework/test_renderers.py new file mode 100644 index 0000000..59d4aed --- /dev/null +++ b/tests/v2/rest_framework/test_renderers.py @@ -0,0 +1,23 @@ +from datetime import date + +import pytest + +from tests.conftest import InnerSchema + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") + + +def test_schema_renderer(): + renderer = rest_framework.SchemaRenderer() + existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) + expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + assert renderer.render(existing_instance) == expected_encoded + + +def test_typed_schema_renderer(): + renderer = rest_framework.SchemaRenderer[InnerSchema]() + existing_data = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} + expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' + + assert renderer.render(existing_data) == expected_encoded diff --git a/tests/v2/rest_framework/view_fixtures.py b/tests/v2/rest_framework/view_fixtures.py new file mode 100644 index 0000000..a55ddd6 --- /dev/null +++ b/tests/v2/rest_framework/view_fixtures.py @@ -0,0 +1,88 @@ +import typing as ty +from types import SimpleNamespace + +import pytest +from django.urls import path +from rest_framework import generics, serializers, views +from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema +from rest_framework.response import Response + +from tests.conftest import InnerSchema +from tests.test_app.models import SampleModel + +rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") +coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") + + +class SampleSerializer(serializers.Serializer): + field = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + + +class SampleModelSerializer(serializers.ModelSerializer): + sample_field = rest_framework.SchemaField(schema=InnerSchema) + sample_list = rest_framework.SchemaField(schema=ty.List[InnerSchema]) + sample_seq = rest_framework.SchemaField(schema=ty.List[InnerSchema], default=list) + + class Meta: + model = SampleModel + fields = "sample_field", "sample_list", "sample_seq" + + +class ClassBasedView(views.APIView): + parser_classes = [rest_framework.SchemaParser[InnerSchema]] + renderer_classes = [rest_framework.SchemaRenderer[ty.List[InnerSchema]]] + + def post(self, request, *args, **kwargs): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + +class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): + serializer_class = SampleSerializer + + +class ClassBasedViewWithModel(generics.ListCreateAPIView): + queryset = SampleModel.objects.all() + serializer_class = SampleModelSerializer + + +class ClassBasedViewWithSchemaContext(ClassBasedView): + parser_classes = [rest_framework.SchemaParser] + renderer_classes = [rest_framework.SchemaRenderer] + + def get_renderer_context(self): + ctx = super().get_renderer_context() + return dict(ctx, renderer_schema=ty.List[InnerSchema]) + + def get_parser_context(self, http_request): + ctx = super().get_parser_context(http_request) + return dict(ctx, parser_schema=InnerSchema) + + +@api_view(["GET", "POST"]) +@parser_classes([rest_framework.SchemaParser[InnerSchema]]) +@renderer_classes([rest_framework.SchemaRenderer[ty.List[InnerSchema]]]) +def sample_view(request): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + +def create_views_urlconf(schema_view_inspector): + @api_view(["GET", "POST"]) + @schema(schema_view_inspector()) + @parser_classes([rest_framework.SchemaParser[InnerSchema]]) + @renderer_classes([rest_framework.SchemaRenderer[ty.List[InnerSchema]]]) + def sample_view(request): + assert isinstance(request.data, InnerSchema) + return Response([request.data]) + + class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): + serializer_class = SampleSerializer + schema = schema_view_inspector() + + return SimpleNamespace( + urlpatterns=[ + path("/func", sample_view), + path("/class", ClassBasedViewWithSerializer.as_view()), + ], + ) diff --git a/tests/v2/test_rest_framework.py b/tests/v2/test_rest_framework.py deleted file mode 100644 index e2837a7..0000000 --- a/tests/v2/test_rest_framework.py +++ /dev/null @@ -1,247 +0,0 @@ -import io -import sys -import typing as t -from datetime import date -from types import SimpleNamespace - -import pytest -from django.urls import path -from rest_framework import exceptions, generics, schemas, serializers, views -from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema -from rest_framework.request import Request -from rest_framework.response import Response - -from tests.conftest import InnerSchema -from tests.test_app.models import SampleModel - -rest_framework = pytest.importorskip("django_pydantic_field.v2.rest_framework") -coreapi = pytest.importorskip("django_pydantic_field.v2.rest_framework.coreapi") - - -class SampleSerializer(serializers.Serializer): - field = rest_framework.SchemaField(schema=t.List[InnerSchema]) - - -class SampleModelSerializer(serializers.ModelSerializer): - sample_field = rest_framework.SchemaField(schema=InnerSchema) - sample_list = rest_framework.SchemaField(schema=t.List[InnerSchema]) - sample_seq = rest_framework.SchemaField(schema=t.List[InnerSchema], default=list) - - class Meta: - model = SampleModel - fields = "sample_field", "sample_list", "sample_seq" - - -def test_schema_field(): - field = rest_framework.SchemaField(InnerSchema) - existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - expected_encoded = { - "stub_str": "abc", - "stub_int": 1, - "stub_list": [date(2022, 7, 1)], - } - - assert field.to_representation(existing_instance) == expected_encoded - assert field.to_internal_value(expected_encoded) == existing_instance - - with pytest.raises(serializers.ValidationError): - field.to_internal_value(None) - - with pytest.raises(serializers.ValidationError): - field.to_internal_value("null") - - -def test_field_schema_with_custom_config(): - field = rest_framework.SchemaField(InnerSchema, allow_null=True, exclude={"stub_int"}) - existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - expected_encoded = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} - - assert field.to_representation(existing_instance) == expected_encoded - assert field.to_internal_value(expected_encoded) == existing_instance - assert field.to_internal_value(None) is None - assert field.to_internal_value("null") is None - - -def test_serializer_marshalling_with_schema_field(): - existing_instance = {"field": [InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])]} - expected_data = {"field": [{"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}]} - - serializer = SampleSerializer(instance=existing_instance) - assert serializer.data == expected_data - - serializer = SampleSerializer(data=expected_data) - serializer.is_valid(raise_exception=True) - assert serializer.validated_data == existing_instance - - -def test_model_serializer_marshalling_with_schema_field(): - instance = SampleModel( - sample_field=InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]), - sample_list=[InnerSchema(stub_str="abc", stub_int=2, stub_list=[date(2022, 7, 1)])] * 2, - sample_seq=[InnerSchema(stub_str="abc", stub_int=3, stub_list=[date(2022, 7, 1)])] * 3, - ) - serializer = SampleModelSerializer(instance) - - expected_data = { - "sample_field": {"stub_str": "abc", "stub_int": 1, "stub_list": [date(2022, 7, 1)]}, - "sample_list": [{"stub_str": "abc", "stub_int": 2, "stub_list": [date(2022, 7, 1)]}] * 2, - "sample_seq": [{"stub_str": "abc", "stub_int": 3, "stub_list": [date(2022, 7, 1)]}] * 3, - } - assert serializer.data == expected_data - - -@pytest.mark.parametrize( - "export_kwargs", - [ - {"include": {"stub_str", "stub_int"}}, - {"exclude": {"stub_list"}}, - {"exclude_unset": True}, - {"exclude_defaults": True}, - {"exclude_none": True}, - {"by_alias": True}, - ], -) -def test_field_export_kwargs(export_kwargs): - field = rest_framework.SchemaField(InnerSchema, **export_kwargs) - assert field.to_representation(InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)])) - - -def test_invalid_data_serialization(): - invalid_data = {"field": [{"stub_int": "abc", "stub_list": ["abc"]}]} - serializer = SampleSerializer(data=invalid_data) - - with pytest.raises(exceptions.ValidationError) as e: - serializer.is_valid(raise_exception=True) - - assert e.match(r".*stub_str.*stub_int.*stub_list.*") - - -def test_schema_renderer(): - renderer = rest_framework.SchemaRenderer() - existing_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' - - assert renderer.render(existing_instance) == expected_encoded - - -def test_typed_schema_renderer(): - renderer = rest_framework.SchemaRenderer[InnerSchema]() - existing_data = {"stub_str": "abc", "stub_list": [date(2022, 7, 1)]} - expected_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' - - assert renderer.render(existing_data) == expected_encoded - - -def test_schema_parser(): - parser = rest_framework.SchemaParser[InnerSchema]() - existing_encoded = '{"stub_str": "abc", "stub_int": 1, "stub_list": ["2022-07-01"]}' - expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - - assert parser.parse(io.StringIO(existing_encoded)) == expected_instance - - -@api_view(["GET", "POST"]) -@schema(coreapi.AutoSchema()) -@parser_classes([rest_framework.SchemaParser[InnerSchema]]) -@renderer_classes([rest_framework.SchemaRenderer[t.List[InnerSchema]]]) -def sample_view(request): - assert isinstance(request.data, InnerSchema) - return Response([request.data]) - - -class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): - serializer_class = SampleSerializer - schema = coreapi.AutoSchema() - - -class ClassBasedViewWithModel(generics.ListCreateAPIView): - queryset = SampleModel.objects.all() - serializer_class = SampleModelSerializer - - -class ClassBasedView(views.APIView): - parser_classes = [rest_framework.SchemaParser[InnerSchema]] - renderer_classes = [rest_framework.SchemaRenderer[t.List[InnerSchema]]] - - def post(self, request, *args, **kwargs): - assert isinstance(request.data, InnerSchema) - return Response([request.data]) - - -class ClassBasedViewWithSchemaContext(ClassBasedView): - parser_classes = [rest_framework.SchemaParser] - renderer_classes = [rest_framework.SchemaRenderer] - - def get_renderer_context(self): - ctx = super().get_renderer_context() - return dict(ctx, renderer_schema=t.List[InnerSchema]) - - def get_parser_context(self, http_request): - ctx = super().get_parser_context(http_request) - return dict(ctx, parser_schema=InnerSchema) - - -@pytest.mark.parametrize( - "view", - [ - sample_view, - ClassBasedView.as_view(), - ClassBasedViewWithSchemaContext.as_view(), - ], -) -def test_end_to_end_api_view(view, request_factory): - expected_instance = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]) - existing_encoded = b'{"stub_str":"abc","stub_int":1,"stub_list":["2022-07-01"]}' - - request = request_factory.post("/", existing_encoded, content_type="application/json") - response = view(request) - - assert response.data == [expected_instance] - assert response.data[0] is not expected_instance - - assert response.rendered_content == b"[%s]" % existing_encoded - - -@pytest.mark.django_db -def test_end_to_end_list_create_api_view(request_factory): - field_data = InnerSchema(stub_str="abc", stub_list=[date(2022, 7, 1)]).json() - expected_result = { - "sample_field": {"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}, - "sample_list": [{"stub_str": "abc", "stub_list": [date(2022, 7, 1)], "stub_int": 1}], - "sample_seq": [], - } - - payload = '{"sample_field": %s, "sample_list": [%s], "sample_seq": []}' % ((field_data,) * 2) - request = request_factory.post("/", payload.encode(), content_type="application/json") - response = ClassBasedViewWithModel.as_view()(request) - - assert response.data == expected_result - - request = request_factory.get("/", content_type="application/json") - response = ClassBasedViewWithModel.as_view()(request) - assert response.data == [expected_result] - - -urlconf = SimpleNamespace( - urlpatterns=[ - path("/func", sample_view), - path("/class", ClassBasedViewWithSerializer.as_view()), - ], -) - - -@pytest.mark.skipif(sys.version_info >= (3, 12), reason="CoreAPI is not compatible with 3.12") -@pytest.mark.parametrize( - "method, path", - [ - ("GET", "/func"), - ("POST", "/func"), - ("GET", "/class"), - ("PUT", "/class"), - ], -) -def test_coreapi_schema_generators(request_factory, method, path): - generator = schemas.SchemaGenerator(urlconf=urlconf) - request = Request(request_factory.generic(method, path)) - coreapi_schema = generator.get_schema(request) - assert coreapi_schema