From 8f4dcc05938bebe7f7fcd237679f2b2ce938d7a0 Mon Sep 17 00:00:00 2001 From: Savva Surenkov Date: Wed, 20 Dec 2023 02:54:31 +0400 Subject: [PATCH] Implement CoreAPI schema generator [ci skip] --- .../v2/rest_framework/coreapi.py | 222 +++++++++++++++++- django_pydantic_field/v2/types.py | 2 +- tests/v2/test_rest_framework.py | 33 ++- 3 files changed, 251 insertions(+), 6 deletions(-) diff --git a/django_pydantic_field/v2/rest_framework/coreapi.py b/django_pydantic_field/v2/rest_framework/coreapi.py index 7daa1fd..5a1412f 100644 --- a/django_pydantic_field/v2/rest_framework/coreapi.py +++ b/django_pydantic_field/v2/rest_framework/coreapi.py @@ -1,5 +1,223 @@ -from rest_framework.schemas import coreapi +from __future__ import annotations +import typing as ty -class AutoSchema(coreapi.AutoSchema): +from rest_framework.compat import coreapi, coreschema +from rest_framework.schemas.coreapi import AutoSchema as _CoreAPIAutoSchema + +from .fields import SchemaField + +if ty.TYPE_CHECKING: + from coreschema.schemas import Schema as _CoreAPISchema + from rest_framework.serializers import Serializer + +__all__ = ("AutoSchema",) + + +class AutoSchema(_CoreAPIAutoSchema): """Not implemented yet.""" + + def get_serializer_fields(self, path: str, method: str) -> list[coreapi.Field]: + base_field_schemas = super().get_serializer_fields(path, method) + if not base_field_schemas: + return [] + + serializer: Serializer = self.view.get_serializer() + pydantic_schema_fields: dict[str, coreapi.Field] = {} + + for field_name, field in serializer.fields.items(): + if not field.read_only and isinstance(field, SchemaField): + pydantic_schema_fields[field_name] = self._prepare_schema_field(field) + + if not pydantic_schema_fields: + return base_field_schemas + + return [pydantic_schema_fields.get(field.name, field) for field in base_field_schemas] + + def _prepare_schema_field(self, field: SchemaField) -> coreapi.Field: + build_core_schema = SimpleCoreSchemaTransformer(field.adapter.json_schema()) + return coreapi.Field( + name=field.field_name, + location="form", + required=field.required, + schema=build_core_schema(), + description=field.help_text, + ) + + +class SimpleCoreSchemaTransformer: + def __init__(self, json_schema: dict[str, ty.Any]): + self.root_schema = json_schema + + def __call__(self) -> _CoreAPISchema: + definitions = self._populate_definitions() + root_schema = self._transform(self.root_schema) + + if definitions: + if isinstance(root_schema, coreschema.Ref): + schema_name = root_schema.ref_name + else: + schema_name = root_schema.title or "Schema" + definitions[schema_name] = root_schema + + root_schema = coreschema.RefSpace(definitions, schema_name) + + return root_schema + + def _populate_definitions(self): + schemas = self.root_schema.get("$defs", {}) + return {ref_name: self._transform(schema) for ref_name, schema in schemas.items()} + + def _transform(self, schema) -> _CoreAPISchema: + schemas = [ + *self._transform_type_schema(schema), + *self._transform_composite_types(schema), + *self._transform_ref(schema), + ] + if not schemas: + schema = self._transform_any(schema) + elif len(schemas) == 1: + schema = schemas[0] + else: + schema = coreschema.Intersection(schemas) + return schema + + def _transform_type_schema(self, schema): + schema_type = schema.get("type", None) + + if schema_type is not None: + schema_types = schema_type if isinstance(schema_type, list) else [schema_type] + + for schema_type in schema_types: + transformer = getattr(self, f"transform_{schema_type}") + yield transformer(schema) + + def _transform_composite_types(self, schema): + for operation, transform_name in self.COMBINATOR_TYPES.items(): + value = schema.get(operation, None) + + if value is not None: + transformer = getattr(self, transform_name) + yield transformer(schema) + + def _transform_ref(self, schema): + reference = schema.get("$ref", None) + if reference is not None: + yield coreschema.Ref(reference) + + def _transform_any(self, schema): + attrs = self._get_common_attributes(schema) + return coreschema.Anything(**attrs) + + # Simple types transformers + + def transform_object(self, schema) -> coreschema.Object: + properties = schema.get("properties", None) + if properties is not None: + properties = {prop: self._transform(prop_schema) for prop, prop_schema in properties.items()} + + pattern_props = schema.get("patternProperties", None) + if pattern_props is not None: + pattern_props = {pattern: self._transform(prop_schema) for pattern, prop_schema in pattern_props.items()} + + extra_props = schema.get("additionalProperties", None) + if extra_props is not None: + if extra_props not in (True, False): + extra_props = self._transform(schema) + + return coreschema.Object( + properties=properties, + pattern_properties=pattern_props, + additional_properties=extra_props, # type: ignore + min_properties=schema.get("minProperties"), + max_properties=schema.get("maxProperties"), + required=schema.get("required", []), + **self._get_common_attributes(schema), + ) + + def transform_array(self, schema) -> coreschema.Array: + items = schema.get("items", None) + if items is not None: + if isinstance(items, list): + items = list(map(self._transform, items)) + elif items not in (True, False): + items = self._transform(items) + + extra_items = schema.get("additionalItems") + if extra_items is not None: + if isinstance(items, list): + items = list(map(self._transform, items)) + elif items not in (True, False): + items = self._transform(items) + + return coreschema.Array( + items=items, + additional_items=extra_items, + min_items=schema.get("minItems"), + max_items=schema.get("maxItems"), + unique_items=schema.get("uniqueItems"), + **self._get_common_attributes(schema), + ) + + def transform_boolean(self, schema) -> coreschema.Boolean: + attrs = self._get_common_attributes(schema) + return coreschema.Boolean(**attrs) + + def transform_integer(self, schema) -> coreschema.Integer: + return self._transform_numeric(schema, cls=coreschema.Integer) + + def transform_null(self, schema) -> coreschema.Null: + attrs = self._get_common_attributes(schema) + return coreschema.Null(**attrs) + + def transform_number(self, schema) -> coreschema.Number: + return self._transform_numeric(schema, cls=coreschema.Number) + + def transform_string(self, schema) -> coreschema.String: + return coreschema.String( + min_length=schema.get("minLength"), + max_length=schema.get("maxLength"), + pattern=schema.get("pattern"), + format=schema.get("format"), + **self._get_common_attributes(schema), + ) + + # Composite types transformers + + COMBINATOR_TYPES = { + "anyOf": "transform_union", + "oneOf": "transform_exclusive_union", + "allOf": "transform_intersection", + "not": "transform_not", + } + + def transform_union(self, schema): + return coreschema.Union([self._transform(option) for option in schema["anyOf"]]) + + def transform_exclusive_union(self, schema): + return coreschema.ExclusiveUnion([self._transform(option) for option in schema["oneOf"]]) + + def transform_intersection(self, schema): + return coreschema.Intersection([self._transform(option) for option in schema["allOf"]]) + + def transform_not(self, schema): + return coreschema.Not(self._transform(schema["not"])) + + # Common schema transformations + + def _get_common_attributes(self, schema): + return dict( + title=schema.get("title"), + description=schema.get("description"), + default=schema.get("default"), + ) + + def _transform_numeric(self, schema, cls): + return cls( + minimum=schema.get("minimum"), + maximum=schema.get("maximum"), + exclusive_minimum=schema.get("exclusiveMinimum"), + exclusive_maximum=schema.get("exclusiveMaximum"), + multiple_of=schema.get("multipleOf"), + **self._get_common_attributes(schema), + ) diff --git a/django_pydantic_field/v2/types.py b/django_pydantic_field/v2/types.py index c1bed91..30663af 100644 --- a/django_pydantic_field/v2/types.py +++ b/django_pydantic_field/v2/types.py @@ -143,7 +143,7 @@ def dump_json(self, value: ty.Any, **override_kwargs: ty.Unpack[ExportKwargs]) - union_kwargs = ChainMap(override_kwargs, self._dump_python_kwargs) # type: ignore return self.type_adapter.dump_json(value, **union_kwargs) - def json_schema(self) -> ty.Any: + def json_schema(self) -> dict[str, ty.Any]: """Return the JSON schema for the field.""" by_alias = self.export_kwargs.get("by_alias", True) return self.type_adapter.json_schema(by_alias=by_alias) diff --git a/tests/v2/test_rest_framework.py b/tests/v2/test_rest_framework.py index 65147b9..e7ad27a 100644 --- a/tests/v2/test_rest_framework.py +++ b/tests/v2/test_rest_framework.py @@ -1,10 +1,13 @@ import io import typing as t from datetime import date +from types import SimpleNamespace import pytest -from rest_framework import exceptions, generics, serializers, views +from django.urls import path +from rest_framework import exceptions, generics, schemas, serializers, views from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema +from rest_framework.request import Request from rest_framework.response import Response from tests.conftest import InnerSchema @@ -136,7 +139,7 @@ def test_schema_parser(): assert parser.parse(io.StringIO(existing_encoded)) == expected_instance -@api_view(["POST"]) +@api_view(["GET", "POST"]) @schema(coreapi.AutoSchema()) @parser_classes([rest_framework.SchemaParser[InnerSchema]]) @renderer_classes([rest_framework.SchemaRenderer[t.List[InnerSchema]]]) @@ -145,7 +148,7 @@ def sample_view(request): return Response([request.data]) -class ClassBasedViewWithSerializer(generics.RetrieveAPIView): +class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView): serializer_class = SampleSerializer schema = coreapi.AutoSchema() @@ -216,3 +219,27 @@ def test_end_to_end_list_create_api_view(request_factory): request = request_factory.get("/", content_type="application/json") response = ClassBasedViewWithModel.as_view()(request) assert response.data == [expected_result] + + +urlconf = SimpleNamespace( + urlpatterns=[ + path("/func", sample_view), + path("/class", ClassBasedViewWithSerializer.as_view()), + ], +) + + +@pytest.mark.parametrize( + "method, path", + [ + ("GET", "/func"), + ("POST", "/func"), + ("GET", "/class"), + ("PUT", "/class"), + ], +) +def test_coreapi_schema_generators(request_factory, method, path): + generator = schemas.SchemaGenerator(urlconf=urlconf) + request = Request(request_factory.generic(method, path)) + coreapi_schema = generator.get_schema(request) + assert coreapi_schema