Skip to content

Commit

Permalink
Implement CoreAPI schema generator [ci skip]
Browse files Browse the repository at this point in the history
  • Loading branch information
surenkov committed Dec 19, 2023
1 parent 9d5bb61 commit 8f4dcc0
Show file tree
Hide file tree
Showing 3 changed files with 251 additions and 6 deletions.
222 changes: 220 additions & 2 deletions django_pydantic_field/v2/rest_framework/coreapi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,223 @@
from rest_framework.schemas import coreapi
from __future__ import annotations

import typing as ty

class AutoSchema(coreapi.AutoSchema):
from rest_framework.compat import coreapi, coreschema
from rest_framework.schemas.coreapi import AutoSchema as _CoreAPIAutoSchema

from .fields import SchemaField

if ty.TYPE_CHECKING:
from coreschema.schemas import Schema as _CoreAPISchema
from rest_framework.serializers import Serializer

__all__ = ("AutoSchema",)


class AutoSchema(_CoreAPIAutoSchema):
"""Not implemented yet."""

def get_serializer_fields(self, path: str, method: str) -> list[coreapi.Field]:
base_field_schemas = super().get_serializer_fields(path, method)
if not base_field_schemas:
return []

serializer: Serializer = self.view.get_serializer()
pydantic_schema_fields: dict[str, coreapi.Field] = {}

for field_name, field in serializer.fields.items():
if not field.read_only and isinstance(field, SchemaField):
pydantic_schema_fields[field_name] = self._prepare_schema_field(field)

if not pydantic_schema_fields:
return base_field_schemas

return [pydantic_schema_fields.get(field.name, field) for field in base_field_schemas]

def _prepare_schema_field(self, field: SchemaField) -> coreapi.Field:
build_core_schema = SimpleCoreSchemaTransformer(field.adapter.json_schema())
return coreapi.Field(
name=field.field_name,
location="form",
required=field.required,
schema=build_core_schema(),
description=field.help_text,
)


class SimpleCoreSchemaTransformer:
def __init__(self, json_schema: dict[str, ty.Any]):
self.root_schema = json_schema

def __call__(self) -> _CoreAPISchema:
definitions = self._populate_definitions()
root_schema = self._transform(self.root_schema)

if definitions:
if isinstance(root_schema, coreschema.Ref):
schema_name = root_schema.ref_name
else:
schema_name = root_schema.title or "Schema"
definitions[schema_name] = root_schema

root_schema = coreschema.RefSpace(definitions, schema_name)

return root_schema

def _populate_definitions(self):
schemas = self.root_schema.get("$defs", {})
return {ref_name: self._transform(schema) for ref_name, schema in schemas.items()}

def _transform(self, schema) -> _CoreAPISchema:
schemas = [
*self._transform_type_schema(schema),
*self._transform_composite_types(schema),
*self._transform_ref(schema),
]
if not schemas:
schema = self._transform_any(schema)
elif len(schemas) == 1:
schema = schemas[0]
else:
schema = coreschema.Intersection(schemas)
return schema

def _transform_type_schema(self, schema):
schema_type = schema.get("type", None)

if schema_type is not None:
schema_types = schema_type if isinstance(schema_type, list) else [schema_type]

for schema_type in schema_types:
transformer = getattr(self, f"transform_{schema_type}")
yield transformer(schema)

def _transform_composite_types(self, schema):
for operation, transform_name in self.COMBINATOR_TYPES.items():
value = schema.get(operation, None)

if value is not None:
transformer = getattr(self, transform_name)
yield transformer(schema)

def _transform_ref(self, schema):
reference = schema.get("$ref", None)
if reference is not None:
yield coreschema.Ref(reference)

def _transform_any(self, schema):
attrs = self._get_common_attributes(schema)
return coreschema.Anything(**attrs)

# Simple types transformers

def transform_object(self, schema) -> coreschema.Object:
properties = schema.get("properties", None)
if properties is not None:
properties = {prop: self._transform(prop_schema) for prop, prop_schema in properties.items()}

pattern_props = schema.get("patternProperties", None)
if pattern_props is not None:
pattern_props = {pattern: self._transform(prop_schema) for pattern, prop_schema in pattern_props.items()}

extra_props = schema.get("additionalProperties", None)
if extra_props is not None:
if extra_props not in (True, False):
extra_props = self._transform(schema)

return coreschema.Object(
properties=properties,
pattern_properties=pattern_props,
additional_properties=extra_props, # type: ignore
min_properties=schema.get("minProperties"),
max_properties=schema.get("maxProperties"),
required=schema.get("required", []),
**self._get_common_attributes(schema),
)

def transform_array(self, schema) -> coreschema.Array:
items = schema.get("items", None)
if items is not None:
if isinstance(items, list):
items = list(map(self._transform, items))
elif items not in (True, False):
items = self._transform(items)

extra_items = schema.get("additionalItems")
if extra_items is not None:
if isinstance(items, list):
items = list(map(self._transform, items))
elif items not in (True, False):
items = self._transform(items)

return coreschema.Array(
items=items,
additional_items=extra_items,
min_items=schema.get("minItems"),
max_items=schema.get("maxItems"),
unique_items=schema.get("uniqueItems"),
**self._get_common_attributes(schema),
)

def transform_boolean(self, schema) -> coreschema.Boolean:
attrs = self._get_common_attributes(schema)
return coreschema.Boolean(**attrs)

def transform_integer(self, schema) -> coreschema.Integer:
return self._transform_numeric(schema, cls=coreschema.Integer)

def transform_null(self, schema) -> coreschema.Null:
attrs = self._get_common_attributes(schema)
return coreschema.Null(**attrs)

def transform_number(self, schema) -> coreschema.Number:
return self._transform_numeric(schema, cls=coreschema.Number)

def transform_string(self, schema) -> coreschema.String:
return coreschema.String(
min_length=schema.get("minLength"),
max_length=schema.get("maxLength"),
pattern=schema.get("pattern"),
format=schema.get("format"),
**self._get_common_attributes(schema),
)

# Composite types transformers

COMBINATOR_TYPES = {
"anyOf": "transform_union",
"oneOf": "transform_exclusive_union",
"allOf": "transform_intersection",
"not": "transform_not",
}

def transform_union(self, schema):
return coreschema.Union([self._transform(option) for option in schema["anyOf"]])

def transform_exclusive_union(self, schema):
return coreschema.ExclusiveUnion([self._transform(option) for option in schema["oneOf"]])

def transform_intersection(self, schema):
return coreschema.Intersection([self._transform(option) for option in schema["allOf"]])

def transform_not(self, schema):
return coreschema.Not(self._transform(schema["not"]))

# Common schema transformations

def _get_common_attributes(self, schema):
return dict(
title=schema.get("title"),
description=schema.get("description"),
default=schema.get("default"),
)

def _transform_numeric(self, schema, cls):
return cls(
minimum=schema.get("minimum"),
maximum=schema.get("maximum"),
exclusive_minimum=schema.get("exclusiveMinimum"),
exclusive_maximum=schema.get("exclusiveMaximum"),
multiple_of=schema.get("multipleOf"),
**self._get_common_attributes(schema),
)
2 changes: 1 addition & 1 deletion django_pydantic_field/v2/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def dump_json(self, value: ty.Any, **override_kwargs: ty.Unpack[ExportKwargs]) -
union_kwargs = ChainMap(override_kwargs, self._dump_python_kwargs) # type: ignore
return self.type_adapter.dump_json(value, **union_kwargs)

def json_schema(self) -> ty.Any:
def json_schema(self) -> dict[str, ty.Any]:
"""Return the JSON schema for the field."""
by_alias = self.export_kwargs.get("by_alias", True)
return self.type_adapter.json_schema(by_alias=by_alias)
Expand Down
33 changes: 30 additions & 3 deletions tests/v2/test_rest_framework.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import io
import typing as t
from datetime import date
from types import SimpleNamespace

import pytest
from rest_framework import exceptions, generics, serializers, views
from django.urls import path
from rest_framework import exceptions, generics, schemas, serializers, views
from rest_framework.decorators import api_view, parser_classes, renderer_classes, schema
from rest_framework.request import Request
from rest_framework.response import Response

from tests.conftest import InnerSchema
Expand Down Expand Up @@ -136,7 +139,7 @@ def test_schema_parser():
assert parser.parse(io.StringIO(existing_encoded)) == expected_instance


@api_view(["POST"])
@api_view(["GET", "POST"])
@schema(coreapi.AutoSchema())
@parser_classes([rest_framework.SchemaParser[InnerSchema]])
@renderer_classes([rest_framework.SchemaRenderer[t.List[InnerSchema]]])
Expand All @@ -145,7 +148,7 @@ def sample_view(request):
return Response([request.data])


class ClassBasedViewWithSerializer(generics.RetrieveAPIView):
class ClassBasedViewWithSerializer(generics.RetrieveUpdateAPIView):
serializer_class = SampleSerializer
schema = coreapi.AutoSchema()

Expand Down Expand Up @@ -216,3 +219,27 @@ def test_end_to_end_list_create_api_view(request_factory):
request = request_factory.get("/", content_type="application/json")
response = ClassBasedViewWithModel.as_view()(request)
assert response.data == [expected_result]


urlconf = SimpleNamespace(
urlpatterns=[
path("/func", sample_view),
path("/class", ClassBasedViewWithSerializer.as_view()),
],
)


@pytest.mark.parametrize(
"method, path",
[
("GET", "/func"),
("POST", "/func"),
("GET", "/class"),
("PUT", "/class"),
],
)
def test_coreapi_schema_generators(request_factory, method, path):
generator = schemas.SchemaGenerator(urlconf=urlconf)
request = Request(request_factory.generic(method, path))
coreapi_schema = generator.get_schema(request)
assert coreapi_schema

0 comments on commit 8f4dcc0

Please sign in to comment.