Skip to content

Commit

Permalink
Schema extensions execution order should be LIFO (#3416)
Browse files Browse the repository at this point in the history
* fix #3413

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add comments

* run pre commit

* fix release notes

* add comments

* WIP

* Bound

* Add missing await

* Apply suggestions from code review

Co-authored-by: Doctor <50728601+ThirVondukr@users.noreply.github.com>

* Use kwargs, rename field

* Add some tests (one failing)

* Remove broken test

* Fully working errors

* Revert

* Update release file

* Fix types

* Fix cache?

* remove unused code

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Patrick Arminio <patrick.arminio@gmail.com>
Co-authored-by: Doctor <50728601+ThirVondukr@users.noreply.github.com>
  • Loading branch information
4 people authored Apr 13, 2024
1 parent 0c5bc4b commit 75e15b1
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 66 deletions.
4 changes: 1 addition & 3 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,9 @@ jobs:
.nox
key:
${{ runner.os }}-nox-${{ matrix.session.session }}-${{
hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }}-2
hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }}-3

- run: pip install poetry nox nox-poetry
- run: poetry --help
- run: which poetry
- run: nox -r -t tests -s "${{ matrix.session.session }}"
- uses: actions/upload-artifact@v4
if: ${{ always() }}
Expand Down
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Release type: patch

This releases fixes a bug where schema extensions where not running a LIFO order.
116 changes: 60 additions & 56 deletions strawberry/extensions/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@

import contextlib
import inspect
import types
import warnings
from asyncio import iscoroutinefunction
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
AsyncContextManager,
Callable,
Iterator,
ContextManager,
List,
NamedTuple,
Optional,
Expand All @@ -28,12 +29,18 @@

class WrappedHook(NamedTuple):
extension: SchemaExtension
initialized_hook: Union[AsyncIterator[None], Iterator[None]]
hook: Callable[..., Union[AsyncContextManager[None], ContextManager[None]]]
is_async: bool


class ExtensionContextManagerBase:
__slots__ = ("hooks", "deprecation_message", "default_hook")
__slots__ = (
"hooks",
"deprecation_message",
"default_hook",
"async_exit_stack",
"exit_stack",
)

def __init_subclass__(cls):
cls.DEPRECATION_MESSAGE = (
Expand Down Expand Up @@ -73,10 +80,20 @@ def get_hook(self, extension: SchemaExtension) -> Optional[WrappedHook]:

if hook_fn:
if inspect.isgeneratorfunction(hook_fn):
return WrappedHook(extension, hook_fn(extension), False)
context_manager = contextlib.contextmanager(
types.MethodType(hook_fn, extension)
)
return WrappedHook(
extension=extension, hook=context_manager, is_async=False
)

if inspect.isasyncgenfunction(hook_fn):
return WrappedHook(extension, hook_fn(extension), True)
context_manager_async = contextlib.asynccontextmanager(
types.MethodType(hook_fn, extension)
)
return WrappedHook(
extension=extension, hook=context_manager_async, is_async=True
)

if callable(hook_fn):
return self.from_callable(extension, hook_fn)
Expand All @@ -96,27 +113,31 @@ def from_legacy(
) -> WrappedHook:
if iscoroutinefunction(on_start) or iscoroutinefunction(on_end):

@contextlib.asynccontextmanager
async def iterator():
if on_start:
await await_maybe(on_start())

yield

if on_end:
await await_maybe(on_end())

hook = iterator()
return WrappedHook(extension, hook, True)
return WrappedHook(extension=extension, hook=iterator, is_async=True)

else:

def iterator():
@contextlib.contextmanager
def iterator_async():
if on_start:
on_start()

yield

if on_end:
on_end()

hook = iterator()
return WrappedHook(extension, hook, False)
return WrappedHook(extension=extension, hook=iterator_async, is_async=False)

@staticmethod
def from_callable(
Expand All @@ -125,78 +146,61 @@ def from_callable(
) -> WrappedHook:
if iscoroutinefunction(func):

async def async_iterator():
@contextlib.asynccontextmanager
async def iterator():
await func(extension)
yield

hook = async_iterator()
return WrappedHook(extension, hook, True)
return WrappedHook(extension=extension, hook=iterator, is_async=True)
else:

@contextlib.contextmanager
def iterator():
func(extension)
yield

hook = iterator()
return WrappedHook(extension, hook, False)
return WrappedHook(extension=extension, hook=iterator, is_async=False)

def run_hooks_sync(self, is_exit: bool = False) -> None:
"""Run extensions synchronously."""
ctx = (
contextlib.suppress(StopIteration, StopAsyncIteration)
if is_exit
else contextlib.nullcontext()
)
for hook in self.hooks:
with ctx:
if hook.is_async:
raise RuntimeError(
f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} "
"failed to complete synchronously."
)
else:
hook.initialized_hook.__next__() # type: ignore[union-attr]

async def run_hooks_async(self, is_exit: bool = False) -> None:
"""Run extensions asynchronously with support for sync lifecycle hooks.
The ``is_exit`` flag is required as a `StopIteration` cannot be raised from
within a coroutine.
"""
ctx = (
contextlib.suppress(StopIteration, StopAsyncIteration)
if is_exit
else contextlib.nullcontext()
)
def __enter__(self) -> None:
self.exit_stack = contextlib.ExitStack()

for hook in self.hooks:
with ctx:
if hook.is_async:
await hook.initialized_hook.__anext__() # type: ignore[union-attr]
else:
hook.initialized_hook.__next__() # type: ignore[union-attr]
self.exit_stack.__enter__()

def __enter__(self):
self.run_hooks_sync()
for hook in self.hooks:
if hook.is_async:
raise RuntimeError(
f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} "
"failed to complete synchronously."
)
else:
self.exit_stack.enter_context(hook.hook()) # type: ignore

def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
):
self.run_hooks_sync(is_exit=True)
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)

async def __aenter__(self):
await self.run_hooks_async()
async def __aenter__(self) -> None:
self.async_exit_stack = contextlib.AsyncExitStack()

await self.async_exit_stack.__aenter__()

for hook in self.hooks:
if hook.is_async:
await self.async_exit_stack.enter_async_context(hook.hook()) # type: ignore
else:
self.async_exit_stack.enter_context(hook.hook()) # type: ignore

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
):
await self.run_hooks_async(is_exit=True)
await self.async_exit_stack.__aexit__(exc_type, exc_val, exc_tb)


class OperationContextManager(ExtensionContextManagerBase):
Expand Down
32 changes: 25 additions & 7 deletions tests/schema/extensions/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,18 +407,18 @@ async def on_execute(self):
"ExtensionB, on_operation Entered",
"ExtensionA, on_parse Entered",
"ExtensionB, on_parse Entered",
"ExtensionA, on_parse Exited",
"ExtensionB, on_parse Exited",
"ExtensionA, on_parse Exited",
"ExtensionA, on_validate Entered",
"ExtensionB, on_validate Entered",
"ExtensionA, on_validate Exited",
"ExtensionB, on_validate Exited",
"ExtensionA, on_validate Exited",
"ExtensionA, on_execute Entered",
"ExtensionB, on_execute Entered",
"ExtensionA, on_execute Exited",
"ExtensionB, on_execute Exited",
"ExtensionA, on_operation Exited",
"ExtensionA, on_execute Exited",
"ExtensionB, on_operation Exited",
"ExtensionA, on_operation Exited",
]


Expand Down Expand Up @@ -684,18 +684,19 @@ class ExtensionB(SchemaExtension):
def on_execute(self):
execution_order.append(type(self))
yield
execution_order.append(type(self))

class ExtensionC(SchemaExtension):
def on_execute(self):
execution_order.append(type(self))
yield
execution_order.append(type(self))

@strawberry.type
class Query:
food: str = "strawberry"

extensions = [ExtensionB, ExtensionC]
schema = strawberry.Schema(query=Query, extensions=extensions)
schema = strawberry.Schema(query=Query, extensions=[ExtensionB, ExtensionC])

query = """
query TestQuery {
Expand All @@ -707,7 +708,7 @@ class Query:

assert not result.errors
assert result.data == {"food": "strawberry"}
assert execution_order == extensions
assert execution_order == [ExtensionB, ExtensionC, ExtensionC, ExtensionB]


def test_async_extension_in_sync_context():
Expand Down Expand Up @@ -1024,3 +1025,20 @@ def hi(self) -> str:
ValueError, match="Hook on_operation on <(.*)> must be callable, received 'ABC'"
):
schema.execute_sync(query)


@pytest.mark.asyncio
async def test_calls_hooks_when_there_are_errors(async_extension):
@strawberry.type
class Query:
@strawberry.field
def hi(self) -> str:
raise Exception("This is an error")

schema = strawberry.Schema(query=Query, extensions=[async_extension])

query = "{ hi }"

result = await schema.execute(query)
assert result.errors
async_extension.perform_test()

0 comments on commit 75e15b1

Please sign in to comment.