Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Run strawberry codegen on top of graphql schema. #3221

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: minor

Allow strawberry's code generator to build clients from graphql schema files.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is an awesome feature, we should extend this release note :)

11 changes: 9 additions & 2 deletions strawberry/cli/commands/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import functools
import importlib
import inspect
from pathlib import Path # noqa: TCH003
from pathlib import Path
from typing import List, Optional, Type

import rich
import typer
from graphql.utilities import build_schema

from strawberry.cli.app import app
from strawberry.cli.utils import load_schema
from strawberry.codegen import ConsolePlugin, QueryCodegen, QueryCodegenPlugin
from strawberry.codegen.schema_adapter import GraphQLSchemaWrapper, SchemaLike


def _is_codegen_plugin(obj: object) -> bool:
Expand Down Expand Up @@ -123,7 +125,12 @@
if not query:
return

schema_symbol = load_schema(schema, app_dir)
schema_symbol: SchemaLike
if schema.endswith(".graphql"):
with Path(schema).open() as input_schema:
schema_symbol = GraphQLSchemaWrapper(build_schema(input_schema.read()))

Check warning on line 131 in strawberry/cli/commands/codegen.py

View check run for this annotation

Codecov / codecov/patch

strawberry/cli/commands/codegen.py#L131

Added line #L131 was not covered by tests
Comment on lines +129 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should use a different flag for this :)

Also I prefer to remove one nesting level using read_text

Suggested change
if schema.endswith(".graphql"):
with Path(schema).open() as input_schema:
schema_symbol = GraphQLSchemaWrapper(build_schema(input_schema.read()))
if schema.endswith(".graphql"):
schema = Path(schema).read_text()
schema_symbol = GraphQLSchemaWrapper(build_schema(schema))

else:
schema_symbol = load_schema(schema, app_dir)

console_plugin_type = _load_plugin(cli_plugin) if cli_plugin else ConsolePlugin
console_plugin = console_plugin_type(output_dir)
Expand Down
23 changes: 10 additions & 13 deletions strawberry/codegen/query_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@
VariableDefinitionNode,
)

from strawberry.schema import Schema

from .schema_adapter import SchemaLike
from .types import GraphQLArgumentValue, GraphQLSelection, GraphQLType


Expand Down Expand Up @@ -302,7 +301,7 @@ def on_end(self, result: CodegenResult) -> None:
class QueryCodegen:
def __init__(
self,
schema: Schema,
schema: SchemaLike,
plugins: List[QueryCodegenPlugin],
console_plugin: Optional[ConsolePlugin] = None,
):
Expand Down Expand Up @@ -545,7 +544,7 @@ def _get_field_type(
not isinstance(field_type, StrawberryType)
and field_type in self.schema.schema_converter.scalar_registry
):
field_type = self.schema.schema_converter.scalar_registry[field_type] # type: ignore
field_type = self.schema.schema_converter.scalar_registry[field_type]

if isinstance(field_type, ScalarWrapper):
python_type = field_type.wrap
Expand Down Expand Up @@ -637,9 +636,10 @@ def _field_from_selection(
assert field, f"{parent_type.name},{selection.name.value}"

field_type = self._get_field_type(field.type)

return GraphQLField(
field.name, selection.alias.value if selection.alias else None, field_type
field.name,
selection.alias.value if selection.alias else None,
field_type,
)

def _unwrap_type(
Expand All @@ -660,8 +660,8 @@ def _unwrap_type(
elif isinstance(type_, StrawberryList):
type_, wrapper = self._unwrap_type(type_.of_type)
wrapper = (
GraphQLList if wrapper is None else lambda t: GraphQLList(wrapper(t)) # type: ignore[misc]
)
GraphQLList if wrapper is None else lambda t: GraphQLList(wrapper(t))
) # type: ignore[misc]

elif isinstance(type_, LazyType):
return self._unwrap_type(type_.resolve_type())
Expand All @@ -686,12 +686,9 @@ def _field_from_selection_set(
# but insertion order is maintained in python3.6+ (for CPython) and
# guaranteed for all python implementations in python3.7+, so that
# should be pretty safe.
if parent_type.type_var_map:
if getattr(parent_type, "type_var_map", None):
parent_type_name = (
"".join(
c.__name__ # type: ignore[union-attr]
for c in parent_type.type_var_map.values()
)
"".join(c.__name__ for c in parent_type.type_var_map.values()) # type: ignore[union-attr]
+ parent_type.name
)

Expand Down
251 changes: 251 additions & 0 deletions strawberry/codegen/schema_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
from __future__ import annotations

import functools
from enum import Enum
from typing import Any, Dict, Hashable, Optional, Tuple
from typing_extensions import Protocol

from graphql.type import (
GraphQLEnumType,
GraphQLField,
GraphQLInputObjectType,
GraphQLInterfaceType,
GraphQLList,
GraphQLNonNull,
GraphQLObjectType,
GraphQLOutputType,
GraphQLScalarType,
GraphQLSchema,
GraphQLType,
GraphQLUnionType,
GraphQLWrappingType,
)

from strawberry.custom_scalar import ScalarDefinition
from strawberry.enum import EnumDefinition, EnumValue
from strawberry.field import StrawberryField
from strawberry.type import StrawberryList, StrawberryType
from strawberry.types.types import StrawberryObjectDefinition
from strawberry.union import StrawberryUnion


class _ScalarRegistry:
"""A simple type registry for the GraphQLScalars that we encounter."""

def __init__(self) -> None:
self._cache: Dict[Any, Tuple[bool, Optional[ScalarDefinition]]] = {}

def _check_populate_cache(
self, obj: Hashable
) -> Tuple[bool, Optional[ScalarDefinition]]:
if obj in self._cache:
return self._cache[obj]

Check warning on line 42 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L42

Added line #L42 was not covered by tests

is_scalar = False
if isinstance(obj, GraphQLNonNull) and isinstance(
obj.of_type, GraphQLScalarType
):
is_scalar = True

Check warning on line 48 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L48

Added line #L48 was not covered by tests
elif isinstance(obj, GraphQLScalarType):
is_scalar = True
scalar_def = ScalarDefinition(
name=obj.name,
description=obj.description,
specified_by_url=obj.specified_by_url,
serialize=obj.serialize,
parse_value=obj.parse_value,
parse_literal=obj.parse_literal,
)

else:
scalar_def = None

Check warning on line 61 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L61

Added line #L61 was not covered by tests
if not is_scalar:
self._cache[obj] = (False, None)

Check warning on line 63 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L63

Added line #L63 was not covered by tests
self._cache[obj] = (is_scalar, scalar_def)
return self._cache[obj]

def __contains__(self, obj: Hashable) -> bool:
return self._check_populate_cache(obj)[0]

Check warning on line 68 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L68

Added line #L68 was not covered by tests

def __getitem__(self, obj: Hashable) -> ScalarDefinition:
_, result = self._check_populate_cache(obj)
if result is None:
raise KeyError(obj)

Check warning on line 73 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L73

Added line #L73 was not covered by tests
return result


class DeferredTypeStrawberryField(StrawberryField):
"""A basic strawberry field subclass for deferred resolution of the type property."""

def __init__(
self,
graphql_field_type: GraphQLOutputType,
schema_wrapper: GraphQLSchemaWrapper,
**kwargs: Any,
):
self.graphql_field_type = graphql_field_type
self.schema_wrapper = schema_wrapper
super().__init__(**kwargs)

@property
def type(self) -> Any:
inner_type = self.graphql_field_type
while isinstance(inner_type, GraphQLWrappingType):
inner_type = inner_type.of_type

name = getattr(inner_type, "name", None)
if name is not None:
field_type = self.schema_wrapper.get_type_by_name(name)
else:
raise ValueError(f"Unable to find type for {self.graphql_field_type}")

Check warning on line 100 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L100

Added line #L100 was not covered by tests
return field_type

@type.setter
def type(self, val: Any) -> None:
...


class GraphQLSchemaWrapper:
def __init__(self, schema: GraphQLSchema) -> None:
self.schema = schema
self.scalar_registry = _ScalarRegistry()
self._types_by_name: dict[str, Optional[StrawberryType]] = {}

def get_field_for_type(
field_name: str, type_name: str
) -> Optional[StrawberryField]:
return self._get_field_for_type(field_name, type_name)

self.get_field_for_type = functools.lru_cache(maxsize=None)(get_field_for_type)

def get_type_by_name(self, name: str) -> Optional[StrawberryType]:
if name not in self._types_by_name:
self._types_by_name[name] = self._get_type_by_name(name)

return self._types_by_name[name]

def _get_type_by_name(self, name: str) -> Optional[StrawberryType]:
schema_type = self.schema.get_type(name)
if schema_type is None:
return None

Check warning on line 130 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L130

Added line #L130 was not covered by tests
return self._strawberry_type_from_graphql_type(schema_type)

def _strawberry_type_from_graphql_type(
self, graphql_type: GraphQLType
) -> StrawberryType:
if isinstance(graphql_type, GraphQLNonNull):
graphql_type = graphql_type.of_type

Check warning on line 137 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L137

Added line #L137 was not covered by tests
if isinstance(graphql_type, GraphQLEnumType):
wrapped_cls = Enum("name", list(graphql_type.values)) # type: ignore[misc]
return EnumDefinition(
wrapped_cls=wrapped_cls,
name=graphql_type.name,
values=[
EnumValue(name=name, value=i)
for i, name in enumerate(graphql_type.values)
],
description=None,
)
if isinstance(
graphql_type,
(GraphQLObjectType, GraphQLInputObjectType, GraphQLInterfaceType),
):
obj_def = StrawberryObjectDefinition(
name=graphql_type.name,
is_input=False,
is_interface=False,
interfaces=[],
description=graphql_type.description,
origin=type(graphql_type.name, (), {}),
extend=False,
directives=[],
is_type_of=None,
resolve_type=None,
fields=[],
)
for graphql_field in graphql_type.fields.values():
obj_def.fields.append(
self._strawberry_field_from_graphql_field(graphql_field)
)
# This is just monkey-patching the strawberry-definition with itself.
obj_def.__strawberry_definition__ = obj_def # type:ignore[attr-defined]
return obj_def
if isinstance(graphql_type, GraphQLScalarType):
return self.scalar_registry[graphql_type]
if isinstance(graphql_type, GraphQLList):
return StrawberryList(

Check warning on line 176 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L176

Added line #L176 was not covered by tests
of_type=self._strawberry_type_from_graphql_type(graphql_type.of_type)
)
if isinstance(graphql_type, GraphQLUnionType):
types = [self.get_type_by_name(type_.name) for type_ in graphql_type.types]
return StrawberryUnion(
name=graphql_type.name, type_annotations=tuple(types)
)
raise ValueError(graphql_type)

Check warning on line 184 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L184

Added line #L184 was not covered by tests

def _strawberry_field_from_graphql_field(
self, graphql_field: GraphQLField
) -> StrawberryField:
ast_node = graphql_field.ast_node
if ast_node is None:
raise ValueError("GraphQLField must have an AST node to get it's name.")

Check warning on line 191 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L191

Added line #L191 was not covered by tests
name = ast_node.name.value
return DeferredTypeStrawberryField(
graphql_field_type=graphql_field.type,
schema_wrapper=self,
python_name=name,
graphql_name=name,
)

@property
def schema_converter(self) -> GraphQLSchemaWrapper:
return self

Check warning on line 202 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L202

Added line #L202 was not covered by tests

def _get_field_for_type(
self, field_name: str, type_name: str
) -> Optional[StrawberryField]:
type_ = self.get_type_by_name(type_name)
if type_ is None:
return None

Check warning on line 209 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L209

Added line #L209 was not covered by tests
if not isinstance(type_, StrawberryObjectDefinition):
raise TypeError(f"{type_name!r} does not correspond to an object type!")

Check warning on line 211 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L211

Added line #L211 was not covered by tests
return self.get_field(type_, field_name)

def get_field(
self, parent_type: StrawberryObjectDefinition, field_name: str
) -> Optional[StrawberryField]:
"""Get field of a given type with the given name."""
if field_name == "__typename":
field = StrawberryField(python_name=field_name, graphql_name=field_name)
field.type = self.get_type_by_name("String")
return field

Check warning on line 221 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L219-L221

Added lines #L219 - L221 were not covered by tests

return next(fld for fld in parent_type.fields if fld.name == field_name)


class Registry(Protocol):
def __contains__(self, key: Hashable) -> bool:
...

Check warning on line 228 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L228

Added line #L228 was not covered by tests

def __getitem__(self, key: Hashable) -> Any:
...

Check warning on line 231 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L231

Added line #L231 was not covered by tests


class SchemaConverterLike(Protocol):
@property
def scalar_registry(self) -> Registry:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some pragma no cover here? 😊

...

Check warning on line 237 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L237

Added line #L237 was not covered by tests


class SchemaLike(Protocol):
@property
def schema_converter(self) -> SchemaConverterLike:
...

Check warning on line 243 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L243

Added line #L243 was not covered by tests

def get_type_by_name(self, name: str) -> Optional[StrawberryType]:
...

Check warning on line 246 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L246

Added line #L246 was not covered by tests

def get_field_for_type(
self, field_name: str, type_name: str
) -> Optional[StrawberryField]:
...

Check warning on line 251 in strawberry/codegen/schema_adapter.py

View check run for this annotation

Codecov / codecov/patch

strawberry/codegen/schema_adapter.py#L251

Added line #L251 was not covered by tests
11 changes: 11 additions & 0 deletions tests/codegen/snapshots/from_graphql_schema/python/alias.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class OperationNameResultLazy:
# alias for something
lazy: bool

class OperationNameResult:
id: str
# alias for id
second_id: str
# alias for float
a_float: float
lazy: OperationNameResultLazy
18 changes: 18 additions & 0 deletions tests/codegen/snapshots/from_graphql_schema/python/basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from uuid import UUID
from datetime import date, datetime, time
from decimal import Decimal

class OperationNameResultLazy:
something: bool

class OperationNameResult:
id: str
integer: int
float: float
boolean: bool
uuid: UUID
date: date
datetime: datetime
time: time
decimal: Decimal
lazy: OperationNameResultLazy
9 changes: 9 additions & 0 deletions tests/codegen/snapshots/from_graphql_schema/python/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from enum import Enum

class Color(Enum):
RED = "RED"
GREEN = "GREEN"
BLUE = "BLUE"

class OperationNameResult:
enum: Color
Loading
Loading