Skip to content

Commit

Permalink
fix mypy issues
Browse files Browse the repository at this point in the history
  • Loading branch information
bellini666 committed Dec 21, 2024
1 parent f5d77c8 commit 79db7ef
Show file tree
Hide file tree
Showing 38 changed files with 210 additions and 1,519 deletions.
6 changes: 4 additions & 2 deletions strawberry/experimental/pydantic/error_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
if TYPE_CHECKING:
from collections.abc import Sequence

from strawberry.types.base import WithStrawberryObjectDefinition


def get_type_for_field(field: CompatModelField) -> Union[type[Union[None, list]], Any]:
type_ = field.outer_type_
Expand Down Expand Up @@ -113,7 +115,7 @@ def wrap(cls: type) -> type:
if name in fields_set
]

wrapped = _wrap_dataclass(cls)
wrapped: type[WithStrawberryObjectDefinition] = _wrap_dataclass(cls)
extra_fields = cast(list[dataclasses.Field], _get_fields(wrapped, {}))
private_fields = get_private_fields(wrapped)

Expand Down Expand Up @@ -146,7 +148,7 @@ def wrap(cls: type) -> type:
)

model._strawberry_type = cls # type: ignore[attr-defined]
cls._pydantic_type = model
cls._pydantic_type = model # type: ignore[attr-defined]
return cls

return wrap
9 changes: 1 addition & 8 deletions strawberry/experimental/pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,7 @@
else:
raise

try:
from typing import GenericAlias as TypingGenericAlias # type: ignore
except ImportError:
import sys

# python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
# we do this under a conditional to avoid a mypy :)
raise
from typing import GenericAlias as TypingGenericAlias # type: ignore


def replace_pydantic_types(type_: Any, is_input: bool) -> Any:
Expand Down
10 changes: 5 additions & 5 deletions strawberry/ext/mypy_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def __str__(self) -> str:
return self.message


def lazy_type_analyze_callback(ctx: AnalyzeTypeContext) -> type:
def lazy_type_analyze_callback(ctx: AnalyzeTypeContext) -> Type:
if len(ctx.type.args) == 0:
# TODO: maybe this should throw an error

Expand All @@ -123,7 +123,7 @@ def _get_named_type(name: str, api: SemanticAnalyzerPluginInterface) -> Any:
return api.named_type(name)


def _get_type_for_expr(expr: Expression, api: SemanticAnalyzerPluginInterface) -> type:
def _get_type_for_expr(expr: Expression, api: SemanticAnalyzerPluginInterface) -> Type:
if isinstance(expr, NameExpr):
# guarding against invalid nodes, still have to figure out why this happens
# but sometimes mypy crashes because the internal node of the named type
Expand Down Expand Up @@ -247,7 +247,7 @@ def enum_hook(ctx: DynamicClassDefContext) -> None:
)
return

enum_type: Optional[type]
enum_type: Optional[Type]

try:
enum_type = _get_type_for_expr(first_argument, ctx.api)
Expand Down Expand Up @@ -295,7 +295,7 @@ def scalar_hook(ctx: DynamicClassDefContext) -> None:
)
return

scalar_type: Optional[type]
scalar_type: Optional[Type]

# TODO: add proper support for NewType

Expand Down Expand Up @@ -620,7 +620,7 @@ def _is_strawberry_pydantic_decorator(self, fullname: str) -> bool:
)


def plugin(version: str) -> typing.type[StrawberryPlugin]:
def plugin(version: str) -> typing.Type[StrawberryPlugin]:
match = VERSION_RE.match(version)
if match:
MypyVersion.VERSION = Decimal(".".join(match.groups()))
Expand Down
3 changes: 2 additions & 1 deletion strawberry/federation/object_type.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import builtins
from collections.abc import Iterable, Sequence
from typing import (
TYPE_CHECKING,
Expand All @@ -20,7 +21,7 @@
from .schema_directives import Key


T = TypeVar("T", bound=type)
T = TypeVar("T", bound=builtins.type)


def _impl_type(
Expand Down
4 changes: 2 additions & 2 deletions strawberry/federation/schema_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ def schema_directive(
print_definition: bool = True,
compose: bool = False,
import_url: Optional[str] = None,
) -> Callable[..., T]:
) -> Callable[[T], T]:
def _wrap(cls: T) -> T:
cls = _wrap_dataclass(cls) # type: ignore
fields = _get_fields(cls, {})

cls.__strawberry_directive__ = StrawberryFederationSchemaDirective(
cls.__strawberry_directive__ = StrawberryFederationSchemaDirective( # type: ignore[attr-defined]
python_name=cls.__name__,
graphql_name=name,
locations=locations,
Expand Down
61 changes: 43 additions & 18 deletions strawberry/printer/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
overload,
)

from graphql import is_union_type
from graphql import GraphQLObjectType, GraphQLSchema, is_union_type
from graphql.language.printer import print_ast
from graphql.type import (
is_enum_type,
Expand All @@ -33,7 +33,11 @@
from graphql.utilities.print_schema import print_type as original_print_type

from strawberry.schema_directive import Location, StrawberrySchemaDirective
from strawberry.types.base import StrawberryContainer, has_object_definition
from strawberry.types.base import (
StrawberryContainer,
StrawberryObjectDefinition,
has_object_definition,
)
from strawberry.types.enum import EnumDefinition
from strawberry.types.scalar import ScalarWrapper
from strawberry.types.unset import UNSET
Expand Down Expand Up @@ -220,7 +224,12 @@ def print_args(
)


def print_fields(type_: type, schema: BaseSchema, *, extras: PrintExtras) -> str:
def print_fields(
type_: GraphQLObjectType,
schema: BaseSchema,
*,
extras: PrintExtras,
) -> str:
from strawberry.schema.schema_converter import GraphQLCoreConverter

fields = []
Expand Down Expand Up @@ -315,11 +324,13 @@ def print_enum(
)


def print_extends(type_: type, schema: BaseSchema) -> str:
def print_extends(type_: GraphQLObjectType, schema: BaseSchema) -> str:
from strawberry.schema.schema_converter import GraphQLCoreConverter

strawberry_type = type_.extensions and type_.extensions.get(
GraphQLCoreConverter.DEFINITION_BACKREF
strawberry_type = cast(
Optional[StrawberryObjectDefinition],
type_.extensions
and type_.extensions.get(GraphQLCoreConverter.DEFINITION_BACKREF),
)

if strawberry_type and strawberry_type.extend:
Expand All @@ -329,12 +340,14 @@ def print_extends(type_: type, schema: BaseSchema) -> str:


def print_type_directives(
type_: type, schema: BaseSchema, *, extras: PrintExtras
type_: GraphQLObjectType, schema: BaseSchema, *, extras: PrintExtras
) -> str:
from strawberry.schema.schema_converter import GraphQLCoreConverter

strawberry_type = type_.extensions and type_.extensions.get(
GraphQLCoreConverter.DEFINITION_BACKREF
strawberry_type = cast(
Optional[StrawberryObjectDefinition],
type_.extensions
and type_.extensions.get(GraphQLCoreConverter.DEFINITION_BACKREF),
)

if not strawberry_type:
Expand All @@ -349,7 +362,7 @@ def print_type_directives(
for directive in strawberry_type.directives or []
if any(
location in allowed_locations
for location in directive.__strawberry_directive__.locations
for location in directive.__strawberry_directive__.locations # type: ignore[attr-defined]
)
)

Expand Down Expand Up @@ -545,21 +558,33 @@ def is_builtin_directive(directive: GraphQLDirective) -> bool:


def print_schema(schema: BaseSchema) -> str:
graphql_core_schema = schema._schema # type: ignore
graphql_core_schema = cast(
GraphQLSchema,
schema._schema, # type: ignore
)
extras = PrintExtras()

directives = filter(
lambda n: not is_builtin_directive(n), graphql_core_schema.directives
)
filtered_directives = [
directive
for directive in graphql_core_schema.directives
if not is_builtin_directive(directive)
]

type_map = graphql_core_schema.type_map
types = filter(is_defined_type, map(type_map.get, sorted(type_map)))
types = [
type_
for type_name in sorted(type_map)
if is_defined_type(type_ := type_map[type_name])
]

types_printed = [_print_type(type_, schema, extras=extras) for type_ in types]
schema_definition = print_schema_definition(schema, extras=extras)

directives = filter(
None, [print_directive(directive, schema=schema) for directive in directives]
)
directives = [
printed_directive
for directive in filtered_directives
if (printed_directive := print_directive(directive, schema=schema)) is not None
]

def _name_getter(type_: Any) -> str:
if hasattr(type_, "name"):
Expand Down
11 changes: 7 additions & 4 deletions strawberry/schema/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@
ExecutionContext,
ExecutionResult,
)
from strawberry.types.base import StrawberryObjectDefinition
from strawberry.types.base import (
StrawberryObjectDefinition,
WithStrawberryObjectDefinition,
)
from strawberry.types.enum import EnumDefinition
from strawberry.types.graphql import OperationType
from strawberry.types.scalar import ScalarDefinition
Expand All @@ -31,9 +34,9 @@
class BaseSchema(Protocol):
config: StrawberryConfig
schema_converter: GraphQLCoreConverter
query: type
mutation: Optional[type]
subscription: Optional[type]
query: type[WithStrawberryObjectDefinition]
mutation: Optional[type[WithStrawberryObjectDefinition]]
subscription: Optional[type[WithStrawberryObjectDefinition]]
schema_directives: list[object]

@abstractmethod
Expand Down
22 changes: 18 additions & 4 deletions strawberry/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,11 @@
from strawberry.schema.schema_converter import GraphQLCoreConverter
from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY
from strawberry.types import ExecutionContext
from strawberry.types.base import StrawberryObjectDefinition, has_object_definition
from strawberry.types.base import (
StrawberryObjectDefinition,
WithStrawberryObjectDefinition,
has_object_definition,
)
from strawberry.types.graphql import OperationType

from ..printer import print_schema
Expand Down Expand Up @@ -140,14 +144,24 @@ class Query:
self.directives = directives
self.schema_directives = list(schema_directives)

query_type = self.schema_converter.from_object(query.__strawberry_definition__)
query_type = self.schema_converter.from_object(
cast(type[WithStrawberryObjectDefinition], query).__strawberry_definition__
)
mutation_type = (
self.schema_converter.from_object(mutation.__strawberry_definition__)
self.schema_converter.from_object(
cast(
type[WithStrawberryObjectDefinition], mutation
).__strawberry_definition__
)
if mutation
else None
)
subscription_type = (
self.schema_converter.from_object(subscription.__strawberry_definition__)
self.schema_converter.from_object(
cast(
type[WithStrawberryObjectDefinition], subscription
).__strawberry_definition__
)
if subscription
else None
)
Expand Down
5 changes: 3 additions & 2 deletions strawberry/schema/schema_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,8 @@ def from_directive(self, directive: StrawberryDirective) -> GraphQLDirective:

def from_schema_directive(self, cls: type) -> GraphQLDirective:
strawberry_directive = cast(
"StrawberrySchemaDirective", cls.__strawberry_directive__
"StrawberrySchemaDirective",
cls.__strawberry_directive__, # type: ignore[attr-defined]
)
module = sys.modules[cls.__module__]

Expand Down Expand Up @@ -770,7 +771,7 @@ def from_scalar(self, scalar: type) -> GraphQLScalarType:
else:
scalar_definition = _scalar_definition
else:
scalar_definition = scalar._scalar_definition
scalar_definition = scalar._scalar_definition # type: ignore[attr-defined]

scalar_name = self.config.name_converter.from_type(scalar_definition)

Expand Down
2 changes: 1 addition & 1 deletion strawberry/schema/types/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def _make_scalar_definition(scalar_type: GraphQLScalarType) -> ScalarDefinition:


def _get_scalar_definition(scalar: type) -> ScalarDefinition:
return scalar._scalar_definition
return scalar._scalar_definition # type: ignore[attr-defined]


DEFAULT_SCALAR_REGISTRY: dict[object, ScalarDefinition] = {
Expand Down
4 changes: 2 additions & 2 deletions strawberry/schema_directive.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def schema_directive(
name: Optional[str] = None,
repeatable: bool = False,
print_definition: bool = True,
) -> Callable[..., T]:
) -> Callable[[T], T]:
def _wrap(cls: T) -> T:
cls = _wrap_dataclass(cls) # type: ignore
fields = _get_fields(cls, {})

cls.__strawberry_directive__ = StrawberrySchemaDirective(
cls.__strawberry_directive__ = StrawberrySchemaDirective( # type: ignore[attr-defined]
python_name=cls.__name__,
graphql_name=name,
locations=locations,
Expand Down
Loading

0 comments on commit 79db7ef

Please sign in to comment.