diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4470445eab..146534ef46 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -68,11 +68,9 @@ jobs: .nox key: ${{ runner.os }}-nox-${{ matrix.session.session }}-${{ - hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }}-2 + hashFiles('**/poetry.lock') }}-${{ hashFiles('**/noxfile.py') }}-3 - run: pip install poetry nox nox-poetry - - run: poetry --help - - run: which poetry - run: nox -r -t tests -s "${{ matrix.session.session }}" - uses: actions/upload-artifact@v4 if: ${{ always() }} diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..1b3d63d994 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,3 @@ +Release type: patch + +This releases fixes a bug where schema extensions where not running a LIFO order. diff --git a/strawberry/extensions/context.py b/strawberry/extensions/context.py index f2ef681562..216b0aad2f 100644 --- a/strawberry/extensions/context.py +++ b/strawberry/extensions/context.py @@ -2,14 +2,15 @@ import contextlib import inspect +import types import warnings from asyncio import iscoroutinefunction from typing import ( TYPE_CHECKING, Any, - AsyncIterator, + AsyncContextManager, Callable, - Iterator, + ContextManager, List, NamedTuple, Optional, @@ -28,12 +29,18 @@ class WrappedHook(NamedTuple): extension: SchemaExtension - initialized_hook: Union[AsyncIterator[None], Iterator[None]] + hook: Callable[..., Union[AsyncContextManager[None], ContextManager[None]]] is_async: bool class ExtensionContextManagerBase: - __slots__ = ("hooks", "deprecation_message", "default_hook") + __slots__ = ( + "hooks", + "deprecation_message", + "default_hook", + "async_exit_stack", + "exit_stack", + ) def __init_subclass__(cls): cls.DEPRECATION_MESSAGE = ( @@ -73,10 +80,20 @@ def get_hook(self, extension: SchemaExtension) -> Optional[WrappedHook]: if hook_fn: if inspect.isgeneratorfunction(hook_fn): - return WrappedHook(extension, hook_fn(extension), False) + context_manager = contextlib.contextmanager( + types.MethodType(hook_fn, extension) + ) + return WrappedHook( + extension=extension, hook=context_manager, is_async=False + ) if inspect.isasyncgenfunction(hook_fn): - return WrappedHook(extension, hook_fn(extension), True) + context_manager_async = contextlib.asynccontextmanager( + types.MethodType(hook_fn, extension) + ) + return WrappedHook( + extension=extension, hook=context_manager_async, is_async=True + ) if callable(hook_fn): return self.from_callable(extension, hook_fn) @@ -96,27 +113,31 @@ def from_legacy( ) -> WrappedHook: if iscoroutinefunction(on_start) or iscoroutinefunction(on_end): + @contextlib.asynccontextmanager async def iterator(): if on_start: await await_maybe(on_start()) + yield + if on_end: await await_maybe(on_end()) - hook = iterator() - return WrappedHook(extension, hook, True) + return WrappedHook(extension=extension, hook=iterator, is_async=True) else: - def iterator(): + @contextlib.contextmanager + def iterator_async(): if on_start: on_start() + yield + if on_end: on_end() - hook = iterator() - return WrappedHook(extension, hook, False) + return WrappedHook(extension=extension, hook=iterator_async, is_async=False) @staticmethod def from_callable( @@ -125,59 +146,34 @@ def from_callable( ) -> WrappedHook: if iscoroutinefunction(func): - async def async_iterator(): + @contextlib.asynccontextmanager + async def iterator(): await func(extension) yield - hook = async_iterator() - return WrappedHook(extension, hook, True) + return WrappedHook(extension=extension, hook=iterator, is_async=True) else: + @contextlib.contextmanager def iterator(): func(extension) yield - hook = iterator() - return WrappedHook(extension, hook, False) + return WrappedHook(extension=extension, hook=iterator, is_async=False) - def run_hooks_sync(self, is_exit: bool = False) -> None: - """Run extensions synchronously.""" - ctx = ( - contextlib.suppress(StopIteration, StopAsyncIteration) - if is_exit - else contextlib.nullcontext() - ) - for hook in self.hooks: - with ctx: - if hook.is_async: - raise RuntimeError( - f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} " - "failed to complete synchronously." - ) - else: - hook.initialized_hook.__next__() # type: ignore[union-attr] - - async def run_hooks_async(self, is_exit: bool = False) -> None: - """Run extensions asynchronously with support for sync lifecycle hooks. - - The ``is_exit`` flag is required as a `StopIteration` cannot be raised from - within a coroutine. - """ - ctx = ( - contextlib.suppress(StopIteration, StopAsyncIteration) - if is_exit - else contextlib.nullcontext() - ) + def __enter__(self) -> None: + self.exit_stack = contextlib.ExitStack() - for hook in self.hooks: - with ctx: - if hook.is_async: - await hook.initialized_hook.__anext__() # type: ignore[union-attr] - else: - hook.initialized_hook.__next__() # type: ignore[union-attr] + self.exit_stack.__enter__() - def __enter__(self): - self.run_hooks_sync() + for hook in self.hooks: + if hook.is_async: + raise RuntimeError( + f"SchemaExtension hook {hook.extension}.{self.HOOK_NAME} " + "failed to complete synchronously." + ) + else: + self.exit_stack.enter_context(hook.hook()) # type: ignore def __exit__( self, @@ -185,10 +181,18 @@ def __exit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ): - self.run_hooks_sync(is_exit=True) + self.exit_stack.__exit__(exc_type, exc_val, exc_tb) - async def __aenter__(self): - await self.run_hooks_async() + async def __aenter__(self) -> None: + self.async_exit_stack = contextlib.AsyncExitStack() + + await self.async_exit_stack.__aenter__() + + for hook in self.hooks: + if hook.is_async: + await self.async_exit_stack.enter_async_context(hook.hook()) # type: ignore + else: + self.async_exit_stack.enter_context(hook.hook()) # type: ignore async def __aexit__( self, @@ -196,7 +200,7 @@ async def __aexit__( exc_val: Optional[BaseException], exc_tb: Optional[TracebackType], ): - await self.run_hooks_async(is_exit=True) + await self.async_exit_stack.__aexit__(exc_type, exc_val, exc_tb) class OperationContextManager(ExtensionContextManagerBase): diff --git a/tests/schema/extensions/test_extensions.py b/tests/schema/extensions/test_extensions.py index 3deb9be140..555d83063c 100644 --- a/tests/schema/extensions/test_extensions.py +++ b/tests/schema/extensions/test_extensions.py @@ -407,18 +407,18 @@ async def on_execute(self): "ExtensionB, on_operation Entered", "ExtensionA, on_parse Entered", "ExtensionB, on_parse Entered", - "ExtensionA, on_parse Exited", "ExtensionB, on_parse Exited", + "ExtensionA, on_parse Exited", "ExtensionA, on_validate Entered", "ExtensionB, on_validate Entered", - "ExtensionA, on_validate Exited", "ExtensionB, on_validate Exited", + "ExtensionA, on_validate Exited", "ExtensionA, on_execute Entered", "ExtensionB, on_execute Entered", - "ExtensionA, on_execute Exited", "ExtensionB, on_execute Exited", - "ExtensionA, on_operation Exited", + "ExtensionA, on_execute Exited", "ExtensionB, on_operation Exited", + "ExtensionA, on_operation Exited", ] @@ -684,18 +684,19 @@ class ExtensionB(SchemaExtension): def on_execute(self): execution_order.append(type(self)) yield + execution_order.append(type(self)) class ExtensionC(SchemaExtension): def on_execute(self): execution_order.append(type(self)) yield + execution_order.append(type(self)) @strawberry.type class Query: food: str = "strawberry" - extensions = [ExtensionB, ExtensionC] - schema = strawberry.Schema(query=Query, extensions=extensions) + schema = strawberry.Schema(query=Query, extensions=[ExtensionB, ExtensionC]) query = """ query TestQuery { @@ -707,7 +708,7 @@ class Query: assert not result.errors assert result.data == {"food": "strawberry"} - assert execution_order == extensions + assert execution_order == [ExtensionB, ExtensionC, ExtensionC, ExtensionB] def test_async_extension_in_sync_context(): @@ -1024,3 +1025,20 @@ def hi(self) -> str: ValueError, match="Hook on_operation on <(.*)> must be callable, received 'ABC'" ): schema.execute_sync(query) + + +@pytest.mark.asyncio +async def test_calls_hooks_when_there_are_errors(async_extension): + @strawberry.type + class Query: + @strawberry.field + def hi(self) -> str: + raise Exception("This is an error") + + schema = strawberry.Schema(query=Query, extensions=[async_extension]) + + query = "{ hi }" + + result = await schema.execute(query) + assert result.errors + async_extension.perform_test()