Skip to content

Commit 512560f

Browse files
kafka improvements
1 parent 6dc07d6 commit 512560f

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

broadcaster/_backends/kafka.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import typing
45
from urllib.parse import urlparse
56

@@ -10,9 +11,11 @@
1011

1112

1213
class KafkaBackend(BroadcastBackend):
13-
def __init__(self, url: str):
14-
self._servers = [urlparse(url).netloc]
14+
def __init__(self, urls: str | list[str]) -> None:
15+
urls = [urls] if isinstance(urls, str) else urls
16+
self._servers = [urlparse(url).netloc for url in urls]
1517
self._consumer_channels: set[str] = set()
18+
self._ready = asyncio.Event()
1619

1720
async def connect(self) -> None:
1821
self._producer = AIOKafkaProducer(bootstrap_servers=self._servers)
@@ -27,6 +30,7 @@ async def disconnect(self) -> None:
2730
async def subscribe(self, channel: str) -> None:
2831
self._consumer_channels.add(channel)
2932
self._consumer.subscribe(topics=self._consumer_channels)
33+
await self._wait_for_assignment()
3034

3135
async def unsubscribe(self, channel: str) -> None:
3236
self._consumer.unsubscribe()
@@ -35,5 +39,13 @@ async def publish(self, channel: str, message: typing.Any) -> None:
3539
await self._producer.send_and_wait(channel, message.encode("utf8"))
3640

3741
async def next_published(self) -> Event:
42+
await self._ready.wait()
3843
message = await self._consumer.getone()
3944
return Event(channel=message.topic, message=message.value.decode("utf8"))
45+
46+
async def _wait_for_assignment(self) -> None:
47+
"""Wait for the consumer to be assigned to the partition."""
48+
while not self._consumer.assignment():
49+
await asyncio.sleep(0.001)
50+
51+
self._ready.set()

tests/test_broadcast.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from broadcaster import Broadcast, BroadcastBackend, Event
9+
from broadcaster._backends.kafka import KafkaBackend
910

1011

1112
class CustomBackend(BroadcastBackend):
@@ -65,7 +66,6 @@ async def test_postgres():
6566
assert event.message == "hello"
6667

6768

68-
@pytest.mark.skip("Deadlock on `next_published`")
6969
@pytest.mark.asyncio
7070
async def test_kafka():
7171
async with Broadcast("kafka://localhost:9092") as broadcast:
@@ -76,6 +76,16 @@ async def test_kafka():
7676
assert event.message == "hello"
7777

7878

79+
@pytest.mark.asyncio
80+
async def test_kafka_multiple_urls():
81+
async with Broadcast(backend=KafkaBackend(urls=["kafka://localhost:9092", "kafka://localhost:9092"])) as broadcast:
82+
async with broadcast.subscribe("chatroom") as subscriber:
83+
await broadcast.publish("chatroom", "hello")
84+
event = await subscriber.get()
85+
assert event.channel == "chatroom"
86+
assert event.message == "hello"
87+
88+
7989
@pytest.mark.asyncio
8090
async def test_custom():
8191
backend = CustomBackend("")

0 commit comments

Comments
 (0)