diff --git a/kafkaesk/ext/logging/handler.py b/kafkaesk/ext/logging/handler.py index 8cb3933..c0f14dc 100644 --- a/kafkaesk/ext/logging/handler.py +++ b/kafkaesk/ext/logging/handler.py @@ -3,7 +3,6 @@ from pydantic import BaseModel from typing import IO from typing import Optional -from typing import Tuple import asyncio import kafkaesk @@ -32,7 +31,9 @@ class KafkaeskQueue: def __init__( self, app: kafkaesk.app.Application, max_queue: int = 10000, ): - self._queue: asyncio.Queue[Tuple[str, BaseModel]] = asyncio.Queue(maxsize=max_queue) + self._queue: Optional[asyncio.Queue] = None + self._queue_size = max_queue + self._app = app self._app.on("finalize", self.flush) @@ -40,6 +41,9 @@ def __init__( self._task: Optional[asyncio.Task] = None def start(self) -> None: + if self._queue is None: + self._queue = asyncio.Queue(maxsize=self._queue_size) + if self._task is None or self._task.done(): self._task = asyncio.get_event_loop().create_task(self._run()) @@ -59,6 +63,9 @@ def running(self) -> bool: return True async def _run(self) -> None: + if self._queue is None: + raise RuntimeError("Queue must be started before workers") + while True: try: stream, message = await asyncio.wait_for(self._queue.get(), 1) @@ -72,9 +79,10 @@ async def _run(self) -> None: return async def flush(self) -> None: - while not self._queue.empty(): - stream, message = await self._queue.get() - await self._publish(stream, message) + if self._queue is not None: + while not self._queue.empty(): + stream, message = await self._queue.get() + await self._publish(stream, message) async def _publish(self, stream: str, message: BaseModel) -> None: if self._app._intialized: @@ -90,7 +98,8 @@ def _print_to_stderr(self, message: BaseModel, error: str) -> None: sys.stderr.write(f"Error sending log to Kafak: \n{error}\nMessage: {message.json()}") def put_nowait(self, stream: str, message: PydanticLogModel) -> None: - self._queue.put_nowait((stream, message)) + if self._queue is not None: + self._queue.put_nowait((stream, message)) class PydanticKafkaeskHandler(logging.Handler): diff --git a/tests/ext/logging/test_handler.py b/tests/ext/logging/test_handler.py index 8a30085..4416a92 100644 --- a/tests/ext/logging/test_handler.py +++ b/tests/ext/logging/test_handler.py @@ -154,6 +154,7 @@ async def consume(data: PydanticLogModel): async def test_queue_flush(self, app, queue, log_consumer): async with app: + queue.start() for i in range(10): queue.put_nowait("log.test", PydanticLogModel(count=i)) @@ -198,6 +199,7 @@ async def test_queue_publish(self, app, queue, log_consumer, capsys): @pytest.mark.with_max_queue(1) async def test_queue_max_size(self, app, queue): + queue.start() queue.put_nowait("log.test", PydanticLogModel()) with pytest.raises(asyncio.QueueFull):