From eb43d5eae281302770764ec35e6b8570e918c761 Mon Sep 17 00:00:00 2001 From: Patrick Arminio Date: Wed, 22 Jan 2025 11:22:50 +0000 Subject: [PATCH] Super wip --- strawberry/http/async_base_view.py | 19 +++++++-- strawberry/schema/execute.py | 58 ++++++++++++---------------- strawberry/schema/schema.py | 44 ++++++++++++++++++--- strawberry/schema/subscribe.py | 10 ++++- tests/http/incremental/test_defer.py | 9 ++++- 5 files changed, 96 insertions(+), 44 deletions(-) diff --git a/strawberry/http/async_base_view.py b/strawberry/http/async_base_view.py index aa10b5400f..0722767d0b 100644 --- a/strawberry/http/async_base_view.py +++ b/strawberry/http/async_base_view.py @@ -358,9 +358,7 @@ async def run( except MissingQueryError as e: raise HTTPException(400, "No GraphQL query found in the request") from e - if HAS_INCREMENTAL_EXECUTION and isinstance( - result, ExperimentalIncrementalExecutionResults - ): + if isinstance(result, ExperimentalIncrementalExecutionResults): async def stream(): yield "---" @@ -524,9 +522,16 @@ async def parse_http_body( protocol=protocol, ) - def process_incremental_result( + async def process_incremental_result( self, request: Request, result: IncrementalResult ) -> GraphQLHTTPResponse: + result = await self.schema._handle_execution_result( + context=self.schema.execution_context, + result=result, + extensions_runner=self.schema.extensions_runner, + process_errors=self.schema.process_errors, + ) + if isinstance(result, IncrementalDeferResult): return { "data": result.data, @@ -568,6 +573,12 @@ async def process_result( result: Union[ExecutionResult, InitialIncrementalExecutionResult], ) -> GraphQLHTTPResponse: if not isinstance(result, InitialIncrementalExecutionResult): + result = await self.schema._handle_execution_result( + context=self.schema.execution_context, + result=result, + extensions_runner=self.schema.extensions_runner, + process_errors=self.schema.process_errors, + ) return process_result(result) return { diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index ac0eca1046..77fbd2ee6c 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -25,6 +25,26 @@ from .exceptions import InvalidOperationTypeError +try: + from graphql.execution.execute import ( + ExperimentalIncrementalExecutionResults, + InitialIncrementalExecutionResult, + ) + from graphql.execution.incremental_publisher import ( + IncrementalDeferResult, + IncrementalResult, + IncrementalStreamResult, + SubsequentIncrementalExecutionResult, + ) + +except ImportError: + from types import NoneType + + InitialIncrementalExecutionResult = NoneType + IncrementalResult = NoneType + IncrementalStreamResult = NoneType + SubsequentIncrementalExecutionResult = NoneType + if TYPE_CHECKING: from typing_extensions import NotRequired, TypeAlias, Unpack @@ -115,26 +135,6 @@ async def _parse_and_validate_async( return None -async def _handle_execution_result( - context: ExecutionContext, - result: Union[GraphQLExecutionResult, ExecutionResult], - extensions_runner: SchemaExtensionsRunner, - process_errors: ProcessErrors | None, -) -> ExecutionResult: - # TODO: deal with this later - # # Set errors on the context so that it's easier - # # to access in extensions - # if result.errors: - # context.errors = result.errors - # if process_errors: - # process_errors(result.errors, context) - # if isinstance(result, GraphQLExecutionResult): - # result = ExecutionResult(data=result.data, errors=result.errors) - # result.extensions = await extensions_runner.get_extensions_results(context) - # context.result = result # type: ignore # mypy failed to deduce correct type. - return result - - def _coerce_error(error: Union[GraphQLError, Exception]) -> GraphQLError: if isinstance(error, GraphQLError): return error @@ -157,9 +157,8 @@ async def execute( if errors := await _parse_and_validate_async( execution_context, extensions_runner ): - return await _handle_execution_result( - execution_context, errors, extensions_runner, process_errors - ) + # TODO: ... + return errors assert execution_context.graphql_document async with extensions_runner.executing(): @@ -195,16 +194,9 @@ async def execute( except (MissingQueryError, InvalidOperationTypeError): raise except Exception as exc: # noqa: BLE001 - return await _handle_execution_result( - execution_context, - PreExecutionError(data=None, errors=[_coerce_error(exc)]), - extensions_runner, - process_errors, - ) - # return results after all the operation completed. - return await _handle_execution_result( - execution_context, result, extensions_runner, None - ) + return PreExecutionError(data=None, errors=[_coerce_error(exc)]) + + return result def execute_sync( diff --git a/strawberry/schema/schema.py b/strawberry/schema/schema.py index 6ee9a1b21c..e81c68f9d4 100644 --- a/strawberry/schema/schema.py +++ b/strawberry/schema/schema.py @@ -5,13 +5,16 @@ from typing import ( TYPE_CHECKING, Any, + Callable, Optional, Union, cast, ) +from graphql import ExecutionResult as GraphQLExecutionResult from graphql import ( GraphQLBoolean, + GraphQLError, GraphQLField, GraphQLNamedType, GraphQLNonNull, @@ -37,7 +40,7 @@ from strawberry.printer import print_schema from strawberry.schema.schema_converter import GraphQLCoreConverter from strawberry.schema.types.scalar import DEFAULT_SCALAR_REGISTRY -from strawberry.types import ExecutionContext +from strawberry.types import ExecutionContext, ExecutionResult from strawberry.types.base import ( StrawberryObjectDefinition, WithStrawberryObjectDefinition, @@ -53,17 +56,21 @@ if TYPE_CHECKING: from collections.abc import Iterable + from typing_extensions import TypeAlias from graphql import ExecutionContext as GraphQLExecutionContext from strawberry.directive import StrawberryDirective - from strawberry.types import ExecutionResult from strawberry.types.base import StrawberryType from strawberry.types.enum import EnumDefinition from strawberry.types.field import StrawberryField from strawberry.types.scalar import ScalarDefinition, ScalarWrapper from strawberry.types.union import StrawberryUnion +ProcessErrors: TypeAlias = ( + "Callable[[list[GraphQLError], Optional[ExecutionContext]], None]" +) + DEFAULT_ALLOWED_OPERATION_TYPES = { OperationType.QUERY, OperationType.MUTATION, @@ -293,6 +300,30 @@ def _create_execution_context( provided_operation_name=operation_name, ) + # TODO: is this the right place to do this? + async def _handle_execution_result( + self, + context: ExecutionContext, + result: Union[GraphQLExecutionResult, ExecutionResult], + extensions_runner: SchemaExtensionsRunner, + process_errors: ProcessErrors | None, + ) -> ExecutionResult: + # Set errors on the context so that it's easier + # to access in extensions + if result.errors: + context.errors = result.errors + + if process_errors: + process_errors(result.errors, context) + + if isinstance(result, GraphQLExecutionResult): + result = ExecutionResult(data=result.data, errors=result.errors) + + result.extensions = await extensions_runner.get_extensions_results(context) + + context.result = result # type: ignore # mypy failed to deduce correct type. + return result + @lru_cache def get_type_by_name( self, name: str @@ -369,12 +400,15 @@ async def execute( # TODO (#3571): remove this when we implement execution context as parameter. for extension in extensions: extension.execution_context = execution_context + # TODO: fix (race conditions, ugly code) + self.execution_context = execution_context + self.extensions_runner = self.create_extensions_runner( + execution_context, extensions + ) return await execute( self._schema, execution_context=execution_context, - extensions_runner=self.create_extensions_runner( - execution_context, extensions - ), + extensions_runner=self.extensions_runner, process_errors=self._process_errors, middleware_manager=self._get_middleware_manager(extensions), execution_context_class=self.execution_context_class, diff --git a/strawberry/schema/subscribe.py b/strawberry/schema/subscribe.py index 8bda51a4d1..c9e320e187 100644 --- a/strawberry/schema/subscribe.py +++ b/strawberry/schema/subscribe.py @@ -17,7 +17,6 @@ from .execute import ( ProcessErrors, _coerce_error, - _handle_execution_result, _parse_and_validate_async, ) @@ -39,6 +38,15 @@ ] +def _handle_execution_result( + context: ExecutionContext, + result: Union[GraphQLExecutionResult, ExecutionResult], + extensions_runner: SchemaExtensionsRunner, + process_errors: ProcessErrors | None, +) -> ExecutionResult: + pass + + async def _subscribe( schema: GraphQLSchema, execution_context: ExecutionContext, diff --git a/tests/http/incremental/test_defer.py b/tests/http/incremental/test_defer.py index f21a4e30f8..54ac4bb94a 100644 --- a/tests/http/incremental/test_defer.py +++ b/tests/http/incremental/test_defer.py @@ -26,13 +26,20 @@ async def test_basic_defer(method: Literal["get", "post"], http_client: HttpClie "data": {"hello": "Hello world"}, "incremental": [], "hasNext": True, + # TODO: why is this None? "extensions": None, } subsequent = await stream.__anext__() assert subsequent == { - "incremental": [{"data": {"asyncHello": "Hello world"}}], + "incremental": [ + { + "data": {"asyncHello": "Hello world"}, + "extensions": {"example": "example"}, + } + ], "hasNext": False, + # TODO: how do we fill these? "extensions": None, }