Skip to content

Commit 399442d

Browse files
Improvements to Kafka backend (#125)
1 parent 6daa0d2 commit 399442d

File tree

2 files changed

+25
-3
lines changed

2 files changed

+25
-3
lines changed

broadcaster/_backends/kafka.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import asyncio
34
import typing
45
from urllib.parse import urlparse
56

@@ -10,9 +11,11 @@
1011

1112

1213
class KafkaBackend(BroadcastBackend):
13-
def __init__(self, url: str):
14-
self._servers = [urlparse(url).netloc]
14+
def __init__(self, urls: str | list[str]) -> None:
15+
urls = [urls] if isinstance(urls, str) else urls
16+
self._servers = [urlparse(url).netloc for url in urls]
1517
self._consumer_channels: set[str] = set()
18+
self._ready = asyncio.Event()
1619

1720
async def connect(self) -> None:
1821
self._producer = AIOKafkaProducer(bootstrap_servers=self._servers)
@@ -27,6 +30,7 @@ async def disconnect(self) -> None:
2730
async def subscribe(self, channel: str) -> None:
2831
self._consumer_channels.add(channel)
2932
self._consumer.subscribe(topics=self._consumer_channels)
33+
await self._wait_for_assignment()
3034

3135
async def unsubscribe(self, channel: str) -> None:
3236
self._consumer.unsubscribe()
@@ -35,5 +39,13 @@ async def publish(self, channel: str, message: typing.Any) -> None:
3539
await self._producer.send_and_wait(channel, message.encode("utf8"))
3640

3741
async def next_published(self) -> Event:
42+
await self._ready.wait()
3843
message = await self._consumer.getone()
3944
return Event(channel=message.topic, message=message.value.decode("utf8"))
45+
46+
async def _wait_for_assignment(self) -> None:
47+
"""Wait for the consumer to be assigned to the partition."""
48+
while not self._consumer.assignment():
49+
await asyncio.sleep(0.001)
50+
51+
self._ready.set()

tests/test_broadcast.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from broadcaster import Broadcast, BroadcastBackend, Event
9+
from broadcaster._backends.kafka import KafkaBackend
910

1011

1112
class CustomBackend(BroadcastBackend):
@@ -80,7 +81,6 @@ async def test_postgres():
8081
assert event.message == "hello"
8182

8283

83-
@pytest.mark.skip("Deadlock on `next_published`")
8484
@pytest.mark.asyncio
8585
async def test_kafka():
8686
async with Broadcast("kafka://localhost:9092") as broadcast:
@@ -91,6 +91,16 @@ async def test_kafka():
9191
assert event.message == "hello"
9292

9393

94+
@pytest.mark.asyncio
95+
async def test_kafka_multiple_urls():
96+
async with Broadcast(backend=KafkaBackend(urls=["kafka://localhost:9092", "kafka://localhost:9092"])) as broadcast:
97+
async with broadcast.subscribe("chatroom") as subscriber:
98+
await broadcast.publish("chatroom", "hello")
99+
event = await subscriber.get()
100+
assert event.channel == "chatroom"
101+
assert event.message == "hello"
102+
103+
94104
@pytest.mark.asyncio
95105
async def test_custom():
96106
backend = CustomBackend("")

0 commit comments

Comments
 (0)