@@ -116,30 +116,34 @@ async def next_published(self) -> Event:
116
116
class RedisPydanticStreamBackend (RedisStreamBackend ):
117
117
"""Redis Stream backend for broadcasting messages using Pydantic models."""
118
118
119
- def __init__ (self : typing . Self , url : str ) -> None :
119
+ def __init__ (self , url : str ) -> None :
120
120
"""Create a new Redis Stream backend."""
121
121
url = url .replace ("redis-pydantic-stream" , "redis" , 1 )
122
122
self .streams : dict [bytes | str | memoryview , int | bytes | str | memoryview ] = {}
123
123
self ._ready = asyncio .Event ()
124
124
self ._producer = redis .Redis .from_url (url )
125
125
self ._consumer = redis .Redis .from_url (url )
126
- self ._module_cache : dict [str , type ( BaseModel ) ] = {}
126
+ self ._module_cache : dict [str , type [ BaseModel ] ] = {}
127
127
128
- def _build_module_cache (self : typing . Self ) -> None :
128
+ def _build_module_cache (self ) -> None :
129
129
"""Build a cache of Pydantic models."""
130
130
modules = list (sys .modules .keys ())
131
131
for module_name in modules :
132
132
for _ , obj in inspect .getmembers (sys .modules [module_name ]):
133
133
if inspect .isclass (obj ) and issubclass (obj , BaseModel ):
134
134
self ._module_cache [obj .__name__ ] = obj
135
135
136
- async def publish (self : typing . Self , channel : str , message : BaseModel ) -> None :
136
+ async def publish (self , channel : str , message : BaseModel ) -> None :
137
137
"""Publish a message to a channel."""
138
138
msg_type : str = message .__class__ .__name__
139
+
140
+ if msg_type not in self ._module_cache :
141
+ self ._module_cache [msg_type ] = message .__class__
142
+
139
143
message_json : str = message .model_dump_json ()
140
144
await self ._producer .xadd (channel , {"msg_type" : msg_type , "message" : message_json })
141
145
142
- async def wait_for_messages (self : typing . Self ) -> list [StreamMessageType ]:
146
+ async def wait_for_messages (self ) -> list [StreamMessageType ]:
143
147
"""Wait for messages to be published."""
144
148
await self ._ready .wait ()
145
149
self ._build_module_cache ()
@@ -148,7 +152,7 @@ async def wait_for_messages(self: typing.Self) -> list[StreamMessageType]:
148
152
messages = await self ._consumer .xread (self .streams , count = 1 , block = 100 )
149
153
return messages
150
154
151
- async def next_published (self : typing . Self ) -> Event | None :
155
+ async def next_published (self ) -> Event :
152
156
"""Get the next published message."""
153
157
messages = await self .wait_for_messages ()
154
158
stream , events = messages [0 ]
@@ -160,7 +164,7 @@ async def next_published(self: typing.Self) -> Event | None:
160
164
if msg_type in self ._module_cache :
161
165
message_obj = self ._module_cache [msg_type ].model_validate_json (message_data )
162
166
if not message_obj :
163
- return None
167
+ return Event ( stream . decode ( "utf-8" ), message_data )
164
168
return Event (
165
169
channel = stream .decode ("utf-8" ),
166
170
message = message_obj ,
0 commit comments