diff --git a/strawberry/schema/execute.py b/strawberry/schema/execute.py index 1ecdb4c1c2..08c4c8df82 100644 --- a/strawberry/schema/execute.py +++ b/strawberry/schema/execute.py @@ -88,65 +88,76 @@ async def execute( extensions=list(extensions), ) - async with extensions_runner.operation(): - # Note: In graphql-core the schema would be validated here but in - # Strawberry we are validating it at initialisation time instead - if not execution_context.query: - raise MissingQueryError() - - async with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options + try: + async with extensions_runner.operation(): + # Note: In graphql-core the schema would be validated here but in + # Strawberry we are validating it at initialisation time instead + if not execution_context.query: + raise MissingQueryError() + + async with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + execution_context.graphql_document = parse_document( + execution_context.query, **execution_context.parse_options + ) + + except GraphQLError as error: + execution_context.errors = [error] + process_errors([error], execution_context) + return ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), ) - except GraphQLError as error: - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=await extensions_runner.get_extensions_results(), - ) - - if execution_context.operation_type not in allowed_operation_types: - raise InvalidOperationTypeError(execution_context.operation_type) - - async with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) - - async with extensions_runner.executing(): - if not execution_context.result: - result = original_execute( - schema, - execution_context.graphql_document, - root_value=execution_context.root_value, - middleware=extensions_runner.as_middleware_manager(), - variable_values=execution_context.variables, - operation_name=execution_context.operation_name, - context_value=execution_context.context, - execution_context_class=execution_context_class, - ) - - if isawaitable(result): - result = await cast(Awaitable["GraphQLExecutionResult"], result) - - result = cast("GraphQLExecutionResult", result) - execution_context.result = result - # Also set errors on the execution_context so that it's easier - # to access in extensions - if result.errors: - execution_context.errors = result.errors - - # Run the `Schema.process_errors` function here before - # extensions have a chance to modify them (see the MaskErrors - # extension). That way we can log the original errors but - # only return a sanitised version to the client. - process_errors(result.errors, execution_context) + if execution_context.operation_type not in allowed_operation_types: + raise InvalidOperationTypeError(execution_context.operation_type) + + async with extensions_runner.validation(): + _run_validation(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + return ExecutionResult(data=None, errors=execution_context.errors) + + async with extensions_runner.executing(): + if not execution_context.result: + result = original_execute( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + middleware=extensions_runner.as_middleware_manager(), + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + context_value=execution_context.context, + execution_context_class=execution_context_class, + ) + + if isawaitable(result): + result = await cast(Awaitable["GraphQLExecutionResult"], result) + + result = cast("GraphQLExecutionResult", result) + execution_context.result = result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + + except Exception as exc: + error = GraphQLError(str(exc), original_error=exc) + execution_context.errors = [error] + process_errors([error], execution_context) + return ExecutionResult( + data=None, + errors=[error], + extensions=await extensions_runner.get_extensions_results(), + ) return ExecutionResult( data=execution_context.result.data, @@ -169,69 +180,80 @@ def execute_sync( extensions=list(extensions), ) - with extensions_runner.operation(): - # Note: In graphql-core the schema would be validated here but in - # Strawberry we are validating it at initialisation time instead - if not execution_context.query: - raise MissingQueryError() - - with extensions_runner.parsing(): - try: - if not execution_context.graphql_document: - execution_context.graphql_document = parse_document( - execution_context.query, **execution_context.parse_options + try: + with extensions_runner.operation(): + # Note: In graphql-core the schema would be validated here but in + # Strawberry we are validating it at initialisation time instead + if not execution_context.query: + raise MissingQueryError() + + with extensions_runner.parsing(): + try: + if not execution_context.graphql_document: + execution_context.graphql_document = parse_document( + execution_context.query, **execution_context.parse_options + ) + + except GraphQLError as error: + execution_context.errors = [error] + process_errors([error], execution_context) + return ExecutionResult( + data=None, + errors=[error], + extensions=extensions_runner.get_extensions_results_sync(), ) - except GraphQLError as error: - execution_context.errors = [error] - process_errors([error], execution_context) - return ExecutionResult( - data=None, - errors=[error], - extensions=extensions_runner.get_extensions_results_sync(), - ) - - if execution_context.operation_type not in allowed_operation_types: - raise InvalidOperationTypeError(execution_context.operation_type) - - with extensions_runner.validation(): - _run_validation(execution_context) - if execution_context.errors: - process_errors(execution_context.errors, execution_context) - return ExecutionResult(data=None, errors=execution_context.errors) - - with extensions_runner.executing(): - if not execution_context.result: - result = original_execute( - schema, - execution_context.graphql_document, - root_value=execution_context.root_value, - middleware=extensions_runner.as_middleware_manager(), - variable_values=execution_context.variables, - operation_name=execution_context.operation_name, - context_value=execution_context.context, - execution_context_class=execution_context_class, - ) - - if isawaitable(result): - result = cast(Awaitable["GraphQLExecutionResult"], result) - ensure_future(result).cancel() - raise RuntimeError( - "GraphQL execution failed to complete synchronously." + if execution_context.operation_type not in allowed_operation_types: + raise InvalidOperationTypeError(execution_context.operation_type) + + with extensions_runner.validation(): + _run_validation(execution_context) + if execution_context.errors: + process_errors(execution_context.errors, execution_context) + return ExecutionResult(data=None, errors=execution_context.errors) + + with extensions_runner.executing(): + if not execution_context.result: + result = original_execute( + schema, + execution_context.graphql_document, + root_value=execution_context.root_value, + middleware=extensions_runner.as_middleware_manager(), + variable_values=execution_context.variables, + operation_name=execution_context.operation_name, + context_value=execution_context.context, + execution_context_class=execution_context_class, ) - result = cast("GraphQLExecutionResult", result) - execution_context.result = result - # Also set errors on the execution_context so that it's easier - # to access in extensions - if result.errors: - execution_context.errors = result.errors - - # Run the `Schema.process_errors` function here before - # extensions have a chance to modify them (see the MaskErrors - # extension). That way we can log the original errors but - # only return a sanitised version to the client. - process_errors(result.errors, execution_context) + if isawaitable(result): + result = cast(Awaitable["GraphQLExecutionResult"], result) + ensure_future(result).cancel() + raise RuntimeError( + "GraphQL execution failed to complete synchronously." + ) + + result = cast("GraphQLExecutionResult", result) + execution_context.result = result + # Also set errors on the execution_context so that it's easier + # to access in extensions + if result.errors: + execution_context.errors = result.errors + + # Run the `Schema.process_errors` function here before + # extensions have a chance to modify them (see the MaskErrors + # extension). That way we can log the original errors but + # only return a sanitised version to the client. + process_errors(result.errors, execution_context) + + except Exception as exc: + error = GraphQLError(str(exc), original_error=exc) + execution_context.errors = [error] + process_errors([error], execution_context) + return ExecutionResult( + data=None, + errors=[error], + extensions=extensions_runner.get_extensions_results_sync(), + ) return ExecutionResult( data=execution_context.result.data, diff --git a/tests/schema/extensions/test_extensions.py b/tests/schema/extensions/test_extensions.py index f2b55ed615..e2c3e97014 100644 --- a/tests/schema/extensions/test_extensions.py +++ b/tests/schema/extensions/test_extensions.py @@ -467,8 +467,10 @@ def on_executing_start(self): schema = strawberry.Schema( query=default_query_types_and_query.query_type, extensions=[WrongUsageExtension] ) - with pytest.raises(ValueError): - schema.execute_sync(default_query_types_and_query.query) + + result = schema.execute_sync(default_query_types_and_query.query) + assert len(result.errors) == 1 + assert isinstance(result.errors[0].original_error, ValueError) async def test_legacy_extension_supported(): @@ -628,6 +630,128 @@ def string(self) -> str: schema.execute_sync(query) +class ExceptionTestingExtension(SchemaExtension): + def __init__(self, failing_hook: str): + self.failing_hook = failing_hook + self.called_hooks = set() + + def on_operation(self): + if self.failing_hook == "on_operation_start": + raise Exception(self.failing_hook) + self.called_hooks.add(1) + + yield + + if self.failing_hook == "on_operation_end": + raise Exception(self.failing_hook) + self.called_hooks.add(8) + + def on_parse(self): + if self.failing_hook == "on_parse_start": + raise Exception(self.failing_hook) + self.called_hooks.add(2) + + yield + + if self.failing_hook == "on_parse_end": + raise Exception(self.failing_hook) + self.called_hooks.add(3) + + def on_validate(self): + if self.failing_hook == "on_validate_start": + raise Exception(self.failing_hook) + self.called_hooks.add(4) + + yield + + if self.failing_hook == "on_validate_end": + raise Exception(self.failing_hook) + self.called_hooks.add(5) + + def on_execute(self): + if self.failing_hook == "on_execute_start": + raise Exception(self.failing_hook) + self.called_hooks.add(6) + + yield + + if self.failing_hook == "on_execute_end": + raise Exception(self.failing_hook) + self.called_hooks.add(7) + + +@pytest.mark.parametrize( + "failing_hook", + [ + "on_operation_start", + "on_operation_end", + "on_parse_start", + "on_parse_end", + "on_validate_start", + "on_validate_end", + "on_execute_start", + "on_execute_end", + ], +) +@pytest.mark.asyncio +async def test_exceptions_are_included_in_the_execution_result(failing_hook): + @strawberry.type + class Query: + @strawberry.field + def ping(self) -> str: + return "pong" + + schema = strawberry.Schema( + query=Query, + extensions=[ExceptionTestingExtension(failing_hook)], + ) + document = "query { ping }" + + sync_result = schema.execute_sync(document) + assert sync_result.errors is not None + assert len(sync_result.errors) == 1 + assert sync_result.errors[0].message == failing_hook + + async_result = await schema.execute(document) + assert async_result.errors is not None + assert len(async_result.errors) == 1 + assert sync_result.errors[0].message == failing_hook + + +@pytest.mark.parametrize( + "failing_hook,expected_hooks", + [ + ["on_operation_start", set()], + ["on_parse_start", {1, 8}], + ["on_parse_end", {1, 2, 8}], + ["on_validate_start", {1, 2, 3, 8}], + ["on_validate_end", {1, 2, 3, 4, 8}], + ["on_execute_start", {1, 2, 3, 4, 5, 8}], + ["on_execute_end", {1, 2, 3, 4, 5, 6, 8}], + ["on_operation_end", {1, 2, 3, 4, 5, 6, 7}], + ], +) +@pytest.mark.asyncio +async def test_exceptions_abort_evaluation(failing_hook, expected_hooks): + @strawberry.type + class Query: + @strawberry.field + def ping(self) -> str: + return "pong" + + extension = ExceptionTestingExtension(failing_hook) + schema = strawberry.Schema(query=Query, extensions=[extension]) + document = "query { ping }" + + extension.called_hooks = set() + schema.execute_sync(document) + assert extension.called_hooks == expected_hooks + + extension.called_hooks = set() + await schema.execute(document) + assert extension.called_hooks == expected_hooks + + @pytest.mark.asyncio async def test_non_parsing_errors_are_not_swallowed_by_parsing_hooks(): class MyExtension(SchemaExtension): @@ -643,11 +767,13 @@ def ping(self) -> str: schema = strawberry.Schema(query=Query, extensions=[MyExtension]) query = "query { string }" - with pytest.raises(Exception, match="This shouldn't be swallowed"): - schema.execute_sync(query) + sync_result = schema.execute_sync(query) + assert len(sync_result.errors) == 1 + assert sync_result.errors[0].message == "This shouldn't be swallowed" - with pytest.raises(Exception, match="This shouldn't be swallowed"): - await schema.execute(query) + async_result = await schema.execute(query) + assert len(async_result.errors) == 1 + assert async_result.errors[0].message == "This shouldn't be swallowed" def test_on_parsing_end_is_called_with_parsing_errors(): @@ -721,8 +847,9 @@ class Query: schema = strawberry.Schema(query=Query, extensions=[ExtensionA]) - with pytest.raises(RuntimeError, match="failed to complete synchronously"): - schema.execute_sync("query { food }") + result = schema.execute_sync("query { food }") + assert len(result.errors) == 1 + assert result.errors[0].message.endswith("failed to complete synchronously.") def test_extension_override_execution(): @@ -1020,7 +1147,8 @@ def hi(self) -> str: # Query not set on input query = "{ hi }" - with pytest.raises( - ValueError, match="Hook on_operation on <(.*)> must be callable, received 'ABC'" - ): - schema.execute_sync(query) + result = schema.execute_sync(query) + assert len(result.errors) == 1 + assert isinstance(result.errors[0].original_error, ValueError) + assert result.errors[0].message.startswith("Hook on_operation on <") + assert result.errors[0].message.endswith("> must be callable, received 'ABC'")