Skip to content

Commit bbb1c12

Browse files
Merge pull request #1 from random-things/feat/redis-pydantic-stream-backend
✨ Add `RedisPydanticStream` backend
2 parents a422d8a + 6fd3b75 commit bbb1c12

File tree

5 files changed

+91
-1
lines changed

5 files changed

+91
-1
lines changed

broadcaster/_base.py

+5
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,11 @@ def _create_backend(self, url: str) -> BroadcastBackend:
4343

4444
return RedisStreamBackend(url)
4545

46+
elif parsed_url.scheme == "redis-pydantic-stream":
47+
from broadcaster.backends.redis import RedisPydanticStreamBackend
48+
49+
return RedisPydanticStreamBackend(url)
50+
4651
elif parsed_url.scheme in ("postgres", "postgresql"):
4752
from broadcaster.backends.postgres import PostgresBackend
4853

broadcaster/backends/redis.py

+57
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from __future__ import annotations
22

33
import asyncio
4+
import inspect
5+
import sys
46
import typing
57

68
from redis import asyncio as redis
9+
from pydantic import BaseModel
710

811
from .._base import Event
912
from .base import BroadcastBackend
@@ -108,3 +111,57 @@ async def next_published(self) -> Event:
108111
channel=stream.decode("utf-8"),
109112
message=message.get(b"message", b"").decode("utf-8"),
110113
)
114+
115+
116+
class RedisPydanticStreamBackend(RedisStreamBackend):
117+
"""Redis Stream backend for broadcasting messages using Pydantic models."""
118+
119+
def __init__(self: typing.Self, url: str) -> None:
120+
"""Create a new Redis Stream backend."""
121+
url = url.replace("redis-pydantic-stream", "redis", 1)
122+
self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {}
123+
self._ready = asyncio.Event()
124+
self._producer = redis.Redis.from_url(url)
125+
self._consumer = redis.Redis.from_url(url)
126+
self._module_cache: dict[str, type(BaseModel)] = {}
127+
128+
def _build_module_cache(self: typing.Self) -> None:
129+
"""Build a cache of Pydantic models."""
130+
modules = list(sys.modules.keys())
131+
for module_name in modules:
132+
for _, obj in inspect.getmembers(sys.modules[module_name]):
133+
if inspect.isclass(obj) and issubclass(obj, BaseModel):
134+
self._module_cache[obj.__name__] = obj
135+
136+
async def publish(self: typing.Self, channel: str, message: BaseModel) -> None:
137+
"""Publish a message to a channel."""
138+
msg_type: str = message.__class__.__name__
139+
message_json: str = message.model_dump_json()
140+
await self._producer.xadd(channel, {"msg_type": msg_type, "message": message_json})
141+
142+
async def wait_for_messages(self: typing.Self) -> list[StreamMessageType]:
143+
"""Wait for messages to be published."""
144+
await self._ready.wait()
145+
self._build_module_cache()
146+
messages = None
147+
while not messages:
148+
messages = await self._consumer.xread(self.streams, count=1, block=100)
149+
return messages
150+
151+
async def next_published(self: typing.Self) -> Event | None:
152+
"""Get the next published message."""
153+
messages = await self.wait_for_messages()
154+
stream, events = messages[0]
155+
_msg_id, message = events[0]
156+
self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8")
157+
msg_type = message.get(b"msg_type", b"").decode("utf-8")
158+
message_data = message.get(b"message", b"").decode("utf-8")
159+
message_obj: BaseModel | None = None
160+
if msg_type in self._module_cache:
161+
message_obj = self._module_cache[msg_type].model_validate_json(message_data)
162+
if not message_obj:
163+
return None
164+
return Event(
165+
channel=stream.decode("utf-8"),
166+
message=message_obj,
167+
)

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ dependencies = [
3535
redis = ["redis"]
3636
postgres = ["asyncpg"]
3737
kafka = ["aiokafka"]
38+
pydantic = ["pydantic", "redis"]
3839
test = ["pytest", "pytest-asyncio"]
3940

4041
[project.urls]

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
-e .[redis,postgres,kafka]
1+
-e .[redis,postgres,kafka,pydantic]
22

33
# Documentation
44
mkdocs==1.5.3

tests/test_broadcast.py

+27
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@
55

66
import pytest
77

8+
from pydantic import BaseModel
9+
810
from broadcaster import Broadcast, BroadcastBackend, Event
911
from broadcaster.backends.kafka import KafkaBackend
1012

1113

14+
class PydanticEvent(BaseModel):
15+
event: str
16+
data: str
17+
18+
1219
class CustomBackend(BroadcastBackend):
1320
def __init__(self, url: str):
1421
self._subscribed: set[str] = set()
@@ -71,6 +78,26 @@ async def test_redis_stream():
7178
assert event.message == "hello"
7279

7380

81+
@pytest.mark.asyncio
82+
async def test_redis_pydantic_stream():
83+
async with Broadcast("redis-pydantic-stream://localhost:6379") as broadcast:
84+
async with broadcast.subscribe("chatroom") as subscriber:
85+
message = PydanticEvent(event="on_message", data="hello")
86+
await broadcast.publish("chatroom", message)
87+
event = await subscriber.get()
88+
assert event.channel == "chatroom"
89+
assert isinstance(event.message, PydanticEvent)
90+
assert event.message.event == message.event
91+
assert event.message.data == message.data
92+
async with broadcast.subscribe("chatroom1") as subscriber:
93+
await broadcast.publish("chatroom1", message)
94+
event = await subscriber.get()
95+
assert event.channel == "chatroom1"
96+
assert isinstance(event.message, PydanticEvent)
97+
assert event.message.event == message.event
98+
assert event.message.data == message.data
99+
100+
74101
@pytest.mark.asyncio
75102
async def test_postgres():
76103
async with Broadcast("postgres://postgres:postgres@localhost:5432/broadcaster") as broadcast:

0 commit comments

Comments
 (0)