|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 | 3 | import asyncio
|
| 4 | +import inspect |
| 5 | +import sys |
4 | 6 | import typing
|
5 | 7 |
|
6 | 8 | from redis import asyncio as redis
|
| 9 | +from pydantic import BaseModel |
7 | 10 |
|
8 | 11 | from .._base import Event
|
9 | 12 | from .base import BroadcastBackend
|
@@ -108,3 +111,57 @@ async def next_published(self) -> Event:
|
108 | 111 | channel=stream.decode("utf-8"),
|
109 | 112 | message=message.get(b"message", b"").decode("utf-8"),
|
110 | 113 | )
|
| 114 | + |
| 115 | + |
| 116 | +class RedisPydanticStreamBackend(RedisStreamBackend): |
| 117 | + """Redis Stream backend for broadcasting messages using Pydantic models.""" |
| 118 | + |
| 119 | + def __init__(self: typing.Self, url: str) -> None: |
| 120 | + """Create a new Redis Stream backend.""" |
| 121 | + url = url.replace("redis-pydantic-stream", "redis", 1) |
| 122 | + self.streams: dict[bytes | str | memoryview, int | bytes | str | memoryview] = {} |
| 123 | + self._ready = asyncio.Event() |
| 124 | + self._producer = redis.Redis.from_url(url) |
| 125 | + self._consumer = redis.Redis.from_url(url) |
| 126 | + self._module_cache: dict[str, type(BaseModel)] = {} |
| 127 | + |
| 128 | + def _build_module_cache(self: typing.Self) -> None: |
| 129 | + """Build a cache of Pydantic models.""" |
| 130 | + modules = list(sys.modules.keys()) |
| 131 | + for module_name in modules: |
| 132 | + for _, obj in inspect.getmembers(sys.modules[module_name]): |
| 133 | + if inspect.isclass(obj) and issubclass(obj, BaseModel): |
| 134 | + self._module_cache[obj.__name__] = obj |
| 135 | + |
| 136 | + async def publish(self: typing.Self, channel: str, message: BaseModel) -> None: |
| 137 | + """Publish a message to a channel.""" |
| 138 | + msg_type: str = message.__class__.__name__ |
| 139 | + message_json: str = message.model_dump_json() |
| 140 | + await self._producer.xadd(channel, {"msg_type": msg_type, "message": message_json}) |
| 141 | + |
| 142 | + async def wait_for_messages(self: typing.Self) -> list[StreamMessageType]: |
| 143 | + """Wait for messages to be published.""" |
| 144 | + await self._ready.wait() |
| 145 | + self._build_module_cache() |
| 146 | + messages = None |
| 147 | + while not messages: |
| 148 | + messages = await self._consumer.xread(self.streams, count=1, block=100) |
| 149 | + return messages |
| 150 | + |
| 151 | + async def next_published(self: typing.Self) -> Event | None: |
| 152 | + """Get the next published message.""" |
| 153 | + messages = await self.wait_for_messages() |
| 154 | + stream, events = messages[0] |
| 155 | + _msg_id, message = events[0] |
| 156 | + self.streams[stream.decode("utf-8")] = _msg_id.decode("utf-8") |
| 157 | + msg_type = message.get(b"msg_type", b"").decode("utf-8") |
| 158 | + message_data = message.get(b"message", b"").decode("utf-8") |
| 159 | + message_obj: BaseModel | None = None |
| 160 | + if msg_type in self._module_cache: |
| 161 | + message_obj = self._module_cache[msg_type].model_validate_json(message_data) |
| 162 | + if not message_obj: |
| 163 | + return None |
| 164 | + return Event( |
| 165 | + channel=stream.decode("utf-8"), |
| 166 | + message=message_obj, |
| 167 | + ) |
0 commit comments