Skip to content

Commit

Permalink
Handle exceptions raised within extension hooks
Browse files Browse the repository at this point in the history
  • Loading branch information
DoctorJohn committed Nov 14, 2023
1 parent 5a88374 commit 0c3bb2a
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 128 deletions.
258 changes: 142 additions & 116 deletions strawberry/schema/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,65 +88,78 @@ async def execute(
extensions=list(extensions),
)

async with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
if not execution_context.query:
raise MissingQueryError()

async with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
try:
async with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
if not execution_context.query:
raise MissingQueryError()

async with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

async with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)

async with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

if isawaitable(result):
result = await cast(Awaitable["GraphQLExecutionResult"], result)

result = cast("GraphQLExecutionResult", result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)
if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

async with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)

async with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

if isawaitable(result):
result = await cast(Awaitable["GraphQLExecutionResult"], result)

result = cast("GraphQLExecutionResult", result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)

except (MissingQueryError, InvalidOperationTypeError) as e:
raise e
except Exception as exc:
error = GraphQLError(str(exc), original_error=exc)
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=await extensions_runner.get_extensions_results(),
)

return ExecutionResult(
data=execution_context.result.data,
Expand All @@ -169,69 +182,82 @@ def execute_sync(
extensions=list(extensions),
)

with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
if not execution_context.query:
raise MissingQueryError()

with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
try:
with extensions_runner.operation():
# Note: In graphql-core the schema would be validated here but in
# Strawberry we are validating it at initialisation time instead
if not execution_context.query:
raise MissingQueryError()

with extensions_runner.parsing():
try:
if not execution_context.graphql_document:
execution_context.graphql_document = parse_document(
execution_context.query, **execution_context.parse_options
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

except GraphQLError as error:
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)

with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

if isawaitable(result):
result = cast(Awaitable["GraphQLExecutionResult"], result)
ensure_future(result).cancel()
raise RuntimeError(
"GraphQL execution failed to complete synchronously."
if execution_context.operation_type not in allowed_operation_types:
raise InvalidOperationTypeError(execution_context.operation_type)

with extensions_runner.validation():
_run_validation(execution_context)
if execution_context.errors:
process_errors(execution_context.errors, execution_context)
return ExecutionResult(data=None, errors=execution_context.errors)

with extensions_runner.executing():
if not execution_context.result:
result = original_execute(
schema,
execution_context.graphql_document,
root_value=execution_context.root_value,
middleware=extensions_runner.as_middleware_manager(),
variable_values=execution_context.variables,
operation_name=execution_context.operation_name,
context_value=execution_context.context,
execution_context_class=execution_context_class,
)

result = cast("GraphQLExecutionResult", result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)
if isawaitable(result):
result = cast(Awaitable["GraphQLExecutionResult"], result)
ensure_future(result).cancel()
raise RuntimeError(
"GraphQL execution failed to complete synchronously."
)

result = cast("GraphQLExecutionResult", result)
execution_context.result = result
# Also set errors on the execution_context so that it's easier
# to access in extensions
if result.errors:
execution_context.errors = result.errors

# Run the `Schema.process_errors` function here before
# extensions have a chance to modify them (see the MaskErrors
# extension). That way we can log the original errors but
# only return a sanitised version to the client.
process_errors(result.errors, execution_context)

except (MissingQueryError, InvalidOperationTypeError) as e:
raise e
except Exception as exc:
error = GraphQLError(str(exc), original_error=exc)
execution_context.errors = [error]
process_errors([error], execution_context)
return ExecutionResult(
data=None,
errors=[error],
extensions=extensions_runner.get_extensions_results_sync(),
)

return ExecutionResult(
data=execution_context.result.data,
Expand Down
Loading

0 comments on commit 0c3bb2a

Please sign in to comment.