diff --git a/vllm/utils.py b/vllm/utils.py index 0147d595fec70..695764dadc123 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -13,10 +13,11 @@ import sys import tempfile import threading +import time import uuid import warnings import weakref -from asyncio import FIRST_COMPLETED, ensure_future +from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task from collections.abc import Mapping from functools import lru_cache, partial, wraps from platform import uname @@ -437,6 +438,12 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: return _async_wrapper +def _next_task(iterator: AsyncGenerator[T, None], + loop: AbstractEventLoop) -> Task: + # Can use anext() in python >= 3.10 + return loop.create_task(iterator.__anext__()) # type: ignore[arg-type] + + async def iterate_with_cancellation( iterator: AsyncGenerator[T, None], is_cancelled: Callable[[], Awaitable[bool]], @@ -445,19 +452,27 @@ async def iterate_with_cancellation( at least once per second to check for client cancellation. """ - # Can use anext() in python >= 3.10 - awaits = [ensure_future(iterator.__anext__())] + loop = asyncio.get_running_loop() + + awaits: List[Future[T]] = [_next_task(iterator, loop)] + next_cancel_check: float = 0 while True: - done, pending = await asyncio.wait(awaits, timeout=1) - if await is_cancelled(): - with contextlib.suppress(BaseException): - awaits[0].cancel() - await iterator.aclose() - raise asyncio.CancelledError("client cancelled") + done, pending = await asyncio.wait(awaits, timeout=1.5) + + # Check for cancellation at most once per second + time_now = time.time() + if time_now >= next_cancel_check: + if await is_cancelled(): + with contextlib.suppress(BaseException): + awaits[0].cancel() + await iterator.aclose() + raise asyncio.CancelledError("client cancelled") + next_cancel_check = time_now + 1 + if done: try: item = await awaits[0] - awaits[0] = ensure_future(iterator.__anext__()) + awaits[0] = _next_task(iterator, loop) yield item except StopAsyncIteration: # we are done @@ -478,25 +493,29 @@ async def merge_async_iterators( to check for client cancellation. """ - # Can use anext() in python >= 3.10 - awaits = { - ensure_future(pair[1].__anext__()): pair - for pair in enumerate(iterators) - } - timeout = None if is_cancelled is None else 1 + loop = asyncio.get_running_loop() + + awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)} + timeout = None if is_cancelled is None else 1.5 + next_cancel_check: float = 0 try: while awaits: done, pending = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED, timeout=timeout) - if is_cancelled is not None and await is_cancelled(): - raise asyncio.CancelledError("client cancelled") + if is_cancelled is not None: + # Check for cancellation at most once per second + time_now = time.time() + if time_now >= next_cancel_check: + if await is_cancelled(): + raise asyncio.CancelledError("client cancelled") + next_cancel_check = time_now + 1 for d in done: pair = awaits.pop(d) try: item = await d i, it = pair - awaits[ensure_future(it.__anext__())] = pair + awaits[_next_task(it, loop)] = pair yield i, item except StopAsyncIteration: pass