diff --git a/pytest_asyncio/plugin.py b/pytest_asyncio/plugin.py index 0f29be12..946de267 100644 --- a/pytest_asyncio/plugin.py +++ b/pytest_asyncio/plugin.py @@ -4,6 +4,8 @@ import functools import inspect import socket +from contextvars import Context, copy_context +from asyncio import coroutines import pytest try: @@ -124,6 +126,7 @@ async def async_finalizer(): return loop.run_until_complete(setup()) fixturedef.func = wrapper + elif inspect.iscoroutinefunction(fixturedef.func): coro = fixturedef.func @@ -186,10 +189,12 @@ def inner(**kwargs): def pytest_runtest_setup(item): if 'asyncio' in item.keywords: - # inject an event loop fixture for all async tests if 'event_loop' in item.fixturenames: item.fixturenames.remove('event_loop') item.fixturenames.insert(0, 'event_loop') + if 'context' in item.fixturenames: + item.fixturenames.remove('context') + item.fixturenames.insert(0, 'context') if item.get_closest_marker("asyncio") is not None \ and not getattr(item.obj, 'hypothesis', False) \ and getattr(item.obj, 'is_hypothesis_test', False): @@ -198,6 +203,40 @@ def pytest_runtest_setup(item): 'only works with Hypothesis 3.64.0 or later.' % item ) +class Task(asyncio.tasks._PyTask): + def __init__(self, coro, *, loop=None, name=None, context=None): + asyncio.futures._PyFuture.__init__(self, loop=loop) + if self._source_traceback: + del self._source_traceback[-1] + if not coroutines.iscoroutine(coro): + # raise after Future.__init__(), attrs are required for __del__ + # prevent logging for pending task in __del__ + self._log_destroy_pending = False + raise TypeError(f"a coroutine was expected, got {coro!r}") + + if name is None: + self._name = f'Task-{asyncio.tasks._task_name_counter()}' + else: + self._name = str(name) + + self._must_cancel = False + self._fut_waiter = None + self._coro = coro + self._context = context if context is not None else copy_context() + + self._loop.call_soon(self.__step, context=self._context) + asyncio._register_task(self) + + +@pytest.fixture +def context(event_loop, request): + """Create an empty context for the async test case and it's async fixtures.""" + context = Context() + def taskfactory(loop, coro): + return Task(coro, loop=loop, context=context) + event_loop.set_task_factory(taskfactory) + return context + @pytest.fixture def event_loop(request): diff --git a/tests/test_contextvars.py b/tests/test_contextvars.py new file mode 100644 index 00000000..5c165e57 --- /dev/null +++ b/tests/test_contextvars.py @@ -0,0 +1,21 @@ +"""Quick'n'dirty unit tests for provided fixtures and markers.""" +import asyncio +import pytest + +import pytest_asyncio.plugin + +from contextvars import ContextVar + + +ctxvar = ContextVar('ctxvar') + + +@pytest.fixture +async def set_some_context(context): + ctxvar.set('quarantine is fun') + + +@pytest.mark.asyncio +async def test_test(set_some_context): + # print ("Context in test:", list(context.items())) + assert ctxvar.get() == 'quarantine is fun'