Skip to content

Commit

Permalink
Small refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Tinche committed Dec 18, 2023
1 parent b8a219e commit e3e0c8b
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 82 deletions.
23 changes: 11 additions & 12 deletions src/uapi/_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from cattrs._compat import is_union_type
from incant import is_subclass

from .attrschema import build_attrs_schema
from .openapi import (
AnySchema,
ApiKeySecurityScheme,
ArraySchema,
MediaType,
MediaTypeName,
OneOfSchema,
Expand Down Expand Up @@ -137,14 +136,14 @@ def build_operation(
req_type, loader = type_and_loader
if has(req_type):
request_bodies[loader.content_type or "*/*"] = MediaType(
builder.reference_for_type(req_type)
builder.get_schema_for_type(req_type)
)
else:
# It's a dict.
v_type = req_type.__args__[1] # type: ignore[attr-defined]

add_prop: Reference | Schema = (
builder.reference_for_type(v_type)
builder.get_schema_for_type(v_type)
if has(v_type)
else builder.PYTHON_PRIMITIVES_TO_OPENAPI[v_type]
)
Expand All @@ -159,7 +158,7 @@ def build_operation(
):
# A body form.
request_bodies["application/x-www-form-urlencoded"] = MediaType(
builder.reference_for_type(form_type)
builder.get_schema_for_type(form_type)
)
else:
if is_union_type(arg_type):
Expand All @@ -168,7 +167,7 @@ def build_operation(
if union_member is NoneType:
refs.append(Schema(Schema.Type.NULL))
elif union_member in builder.PYTHON_PRIMITIVES_TO_OPENAPI:
refs.append(builder.PYTHON_PRIMITIVES_TO_OPENAPI[union_member])
refs.append(builder.get_schema_for_type(union_member))
param_schema: OneOfSchema | Schema = OneOfSchema(refs)
else:
param_schema = builder.PYTHON_PRIMITIVES_TO_OPENAPI.get(
Expand Down Expand Up @@ -231,7 +230,7 @@ def build_operation(

def _coalesce_responses(rs: Sequence[Response]) -> Response:
first_resp = rs[0]
content: dict[MediaTypeName, list[Schema | ArraySchema | Reference]] = {}
content: dict[MediaTypeName, list[AnySchema | Reference]] = {}
for r in rs:
for mtn, mt in r.content.items():
if isinstance(mt.schema, OneOfSchema):
Expand Down Expand Up @@ -389,16 +388,16 @@ def gather_endpoint_components(handler: Callable, builder: SchemaBuilder) -> Non
and (arg_type := type_and_loader[0]) not in builder.names
):
if has(arg_type):
builder.build_schema_with(arg_type, build_attrs_schema)
builder.get_schema_for_type(arg_type)
else:
# It's a dict.
val_arg = arg_type.__args__[1] # type: ignore[attr-defined]
if has(val_arg):
builder.build_schema_with(val_arg, build_attrs_schema)
builder.get_schema_for_type(val_arg)
elif arg.annotation is not InspectParameter.empty and (
form_type := maybe_form_type(arg)
):
builder.build_schema_with(form_type, build_attrs_schema)
builder.get_schema_for_type(form_type)


def components_to_openapi(
Expand All @@ -415,7 +414,7 @@ def components_to_openapi(
gather_endpoint_components(handler, builder)

for component in builder._build_queue:
builder.build_schema_with(component, build_attrs_schema)
builder.get_schema_for_type(component)

return OpenAPI.Components(builder.components, security_schemes)

Expand Down Expand Up @@ -456,7 +455,7 @@ def make_openapi_spec(
)
while schema_builder._build_queue:
for component in list(schema_builder._build_queue):
schema_builder.build_schema_with(component, build_attrs_schema)
schema_builder.build_schema_from_rules(component)
return res


Expand Down
12 changes: 5 additions & 7 deletions src/uapi/attrschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,21 @@ def build_attrs_schema(type: Any, builder: SchemaBuilder) -> Schema:
if a_type in builder.PYTHON_PRIMITIVES_TO_OPENAPI:
schema: AnySchema | Reference = builder.PYTHON_PRIMITIVES_TO_OPENAPI[a_type]
elif has(a_type):
schema = builder.reference_for_type(a_type)
schema = builder.get_schema_for_type(a_type)
elif getattr(a_type, "__origin__", None) is list:
arg = a_type.__args__[0]
if arg in mapping:
arg = mapping[arg]
if has(arg):
ref = builder.reference_for_type(arg)
ref = builder.get_schema_for_type(arg)
schema = ArraySchema(ref)
elif arg in builder.PYTHON_PRIMITIVES_TO_OPENAPI:
schema = ArraySchema(
Schema(builder.PYTHON_PRIMITIVES_TO_OPENAPI[arg].type)
)
schema = ArraySchema(builder.PYTHON_PRIMITIVES_TO_OPENAPI[arg])
elif getattr(a_type, "__origin__", None) is dict:
val_arg = a_type.__args__[1]

if has(val_arg):
add_prop: Reference | Schema = builder.reference_for_type(val_arg)
add_prop: Reference | Schema = builder.get_schema_for_type(val_arg)
else:
add_prop = builder.PYTHON_PRIMITIVES_TO_OPENAPI[val_arg]

Expand All @@ -68,7 +66,7 @@ def build_attrs_schema(type: Any, builder: SchemaBuilder) -> Schema:
refs: list[Reference | Schema] = []
for arg in a_type.__args__:
if has(arg):
refs.append(builder.reference_for_type(arg))
refs.append(builder.get_schema_for_type(arg))
elif arg is NoneType:
refs.append(Schema(Schema.Type.NULL))
elif arg in builder.PYTHON_PRIMITIVES_TO_OPENAPI:
Expand Down
118 changes: 92 additions & 26 deletions src/uapi/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
from enum import Enum, unique
from typing import Any, ClassVar, Literal, TypeAlias

from attrs import Factory, define, field, frozen
from attrs import Factory, define, field, frozen, has
from cattrs import override
from cattrs._compat import is_generic
from cattrs._compat import get_args, is_generic, is_sequence
from cattrs.gen import make_dict_structure_fn, make_dict_unstructure_fn
from cattrs.preconf.json import make_converter

converter = make_converter(omit_if_default=True)

# MediaTypeNames are like `application/json`.
MediaTypeName = str
MediaTypeName: TypeAlias = str
# HTTP status codes
StatusCodeType: TypeAlias = str

Expand Down Expand Up @@ -45,20 +45,35 @@ class Type(Enum):
required: list[str] = Factory(list)


@frozen
class IntegerSchema:
format: Literal[None, "int32", "int64"] = None
minimum: int | None = None
maximum: int | None = None
exclusiveMinimum: bool = False
exclusiveMaximum: bool = False
multipleOf: int | None = None
enum: list[int] = Factory(list)
type: Literal[Schema.Type.INTEGER] = Schema.Type.INTEGER


@frozen
class ArraySchema:
items: Schema | Reference
items: Schema | IntegerSchema | Reference
type: Literal[Schema.Type.ARRAY] = Schema.Type.ARRAY


@frozen
class OneOfSchema:
oneOf: Sequence[Reference | Schema | ArraySchema]
oneOf: Sequence[AnySchema | Reference]


AnySchema = Schema | IntegerSchema | ArraySchema | OneOfSchema


@frozen
class MediaType:
schema: Schema | OneOfSchema | ArraySchema | Reference
schema: AnySchema | Reference


@frozen
Expand All @@ -79,10 +94,7 @@ class Kind(str, Enum):
name: str
kind: Kind
required: bool = False
schema: Schema | Reference | OneOfSchema | None = None


AnySchema = Schema | ArraySchema | OneOfSchema
schema: AnySchema | Reference | None = None


@frozen
Expand Down Expand Up @@ -144,13 +156,17 @@ class Path:
components: Components


Predicate: TypeAlias = Callable[[Any], bool]
BuildHook: TypeAlias = Callable[[Any, "SchemaBuilder"], AnySchema]


@define
class SchemaBuilder:
"""A helper builder for defining OpenAPI/JSON schemas."""

PYTHON_PRIMITIVES_TO_OPENAPI: ClassVar = {
str: Schema(Schema.Type.STRING),
int: Schema(Schema.Type.INTEGER),
int: IntegerSchema(),
bool: Schema(Schema.Type.BOOLEAN),
float: Schema(Schema.Type.NUMBER, format="double"),
bytes: Schema(Schema.Type.STRING, format="binary"),
Expand All @@ -160,18 +176,37 @@ class SchemaBuilder:

names: dict[type, str] = Factory(dict)
components: dict[str, AnySchema | Reference] = Factory(dict)
build_rules: list[tuple[Predicate, BuildHook]] = Factory(
lambda self: self.default_build_rules(), takes_self=True
)
_build_queue: list[type] = field(factory=list, init=False)

def build_schema_with(
self, type: Any, hook: Callable[[Any, SchemaBuilder], Schema]
) -> Schema:
def build_schema_with(self, type: Any, hook: BuildHook) -> AnySchema:
"""Build the schema for `type` using the provided hook, bypassing rules."""
name = self._name_for(type)
self.components[name] = (r := hook(type, self))
if type in self._build_queue:
self._build_queue.remove(type)
return r

def reference_for_type(self, type: Any) -> Reference | Schema:
def build_schema_from_rules(self, type: Any) -> AnySchema:
for pred, hook in self.build_rules: # noqa: B007
if pred(type):
break
else:
raise Exception(f"Can't handle {type}")

name = self._name_for(type)
self.components[name] = (r := hook(type, self))
if type in self._build_queue:
self._build_queue.remove(type)
return r

def get_schema_for_type(self, type: Any) -> Reference | Schema:
# First check inline types.
if type in self.PYTHON_PRIMITIVES_TO_OPENAPI:
return self.PYTHON_PRIMITIVES_TO_OPENAPI[type]

name = self._name_for(type)
if name not in self.components and type not in self._build_queue:
self._build_queue.append(type)
Expand All @@ -187,6 +222,24 @@ def _name_for(self, type: Any) -> str:
self.names[type] = name
return self.names[type]

@classmethod
def default_build_rules(cls) -> list[tuple[Predicate, BuildHook]]:
"""Set up the default build rules."""
from .attrschema import build_attrs_schema

def build_sequence_schema(type: Any, builder: SchemaBuilder) -> AnySchema:
arg = get_args(type)[0]
return ArraySchema(builder.get_schema_for_type(arg))

return [
(
cls.PYTHON_PRIMITIVES_TO_OPENAPI.__contains__,
lambda t, _: cls.PYTHON_PRIMITIVES_TO_OPENAPI[t],
),
(is_sequence, build_sequence_schema),
(has, build_attrs_schema),
]


def _make_generic_name(type: type) -> str:
"""Used for generic attrs classes (Generic[int] instead of just Generic)."""
Expand All @@ -202,22 +255,23 @@ def _structure_schemas(val, _):
type = Schema.Type(val["type"])
if type is Schema.Type.ARRAY:
return converter.structure(val, ArraySchema)
if type is Schema.Type.INTEGER:
return converter.structure(val, IntegerSchema)
return converter.structure(val, Schema)


def _structure_inlinetype_ref(val, _):
return converter.structure(val, Schema if "type" in val else Reference)
def _structure_schema_or_ref(val, _) -> Schema | IntegerSchema | Reference:
if "$ref" in val:
return converter.structure(val, Reference)
type = Schema.Type(val["type"])
if type is Schema.Type.INTEGER:
return converter.structure(val, IntegerSchema)
return converter.structure(val, Schema)


converter.register_structure_hook(AnySchema | Reference, _structure_schemas)
converter.register_structure_hook(
Schema | OneOfSchema | ArraySchema | Reference, _structure_schemas
)
converter.register_structure_hook(Schema | Reference, _structure_inlinetype_ref)
converter.register_structure_hook(
Parameter, make_dict_structure_fn(Parameter, converter, kind=override(rename="in"))
)
converter.register_structure_hook(
Reference, make_dict_structure_fn(Reference, converter, ref=override(rename="$ref"))
Schema | IntegerSchema | Reference, _structure_schema_or_ref
)
converter.register_structure_hook(
Schema | ArraySchema | Reference,
Expand All @@ -239,6 +293,13 @@ def _structure_inlinetype_ref(val, _):
else converter.structure(v, Schema)
),
)
converter.register_structure_hook(
Parameter, make_dict_structure_fn(Parameter, converter, kind=override(rename="in"))
)
converter.register_structure_hook(
Reference, make_dict_structure_fn(Reference, converter, ref=override(rename="$ref"))
)

converter.register_unstructure_hook(
ApiKeySecurityScheme,
make_dict_unstructure_fn(
Expand All @@ -248,7 +309,6 @@ def _structure_inlinetype_ref(val, _):
type=override(omit_if_default=False),
),
)

converter.register_unstructure_hook(
Reference,
make_dict_unstructure_fn(Reference, converter, ref=override(rename="$ref")),
Expand All @@ -265,3 +325,9 @@ def _structure_inlinetype_ref(val, _):
ArraySchema, converter, type=override(omit_if_default=False)
),
)
converter.register_unstructure_hook(
IntegerSchema,
make_dict_unstructure_fn(
IntegerSchema, converter, type=override(omit_if_default=False)
),
)
4 changes: 1 addition & 3 deletions src/uapi/shorthands.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from incant import is_subclass
from orjson import dumps

from .attrschema import build_attrs_schema
from .openapi import MediaType, Response, SchemaBuilder
from .status import BaseResponse, NoContent, Ok

Expand Down Expand Up @@ -160,9 +159,8 @@ def is_union_member(value: Any) -> bool:

@staticmethod
def make_openapi_response(type: Any, builder: SchemaBuilder) -> Response | None:
builder.build_schema_with(type, build_attrs_schema)
return Response(
"OK", {"application/json": MediaType(builder.reference_for_type(type))}
"OK", {"application/json": MediaType(builder.get_schema_for_type(type))}
)

@staticmethod
Expand Down
Loading

0 comments on commit e3e0c8b

Please sign in to comment.