1
+ from __future__ import annotations
2
+
1
3
import asyncio
2
4
from contextlib import asynccontextmanager
3
- from typing import (
4
- TYPE_CHECKING ,
5
- Any ,
6
- AsyncGenerator ,
7
- AsyncIterator ,
8
- Dict ,
9
- Optional ,
10
- cast ,
11
- )
5
+ from typing import TYPE_CHECKING , Any , AsyncGenerator , AsyncIterator , cast
12
6
from urllib .parse import urlparse
13
7
14
8
if TYPE_CHECKING : # pragma: no cover
@@ -21,11 +15,7 @@ def __init__(self, channel: str, message: str) -> None:
21
15
self .message = message
22
16
23
17
def __eq__ (self , other : object ) -> bool :
24
- return (
25
- isinstance (other , Event )
26
- and self .channel == other .channel
27
- and self .message == other .message
28
- )
18
+ return isinstance (other , Event ) and self .channel == other .channel and self .message == other .message
29
19
30
20
def __repr__ (self ) -> str :
31
21
return f"Event(channel={ self .channel !r} , message={ self .message !r} )"
@@ -36,14 +26,12 @@ class Unsubscribed(Exception):
36
26
37
27
38
28
class Broadcast :
39
- def __init__ (
40
- self , url : Optional [str ] = None , * , backend : Optional ["BroadcastBackend" ] = None
41
- ) -> None :
29
+ def __init__ (self , url : str | None = None , * , backend : BroadcastBackend | None = None ) -> None :
42
30
assert url or backend , "Either `url` or `backend` must be provided."
43
31
self ._backend = backend or self ._create_backend (cast (str , url ))
44
- self ._subscribers : Dict [str , Any ] = {}
32
+ self ._subscribers : dict [str , set [ asyncio . Queue [ Event | None ]] ] = {}
45
33
46
- def _create_backend (self , url : str ) -> " BroadcastBackend" :
34
+ def _create_backend (self , url : str ) -> BroadcastBackend :
47
35
parsed_url = urlparse (url )
48
36
if parsed_url .scheme in ("redis" , "rediss" ):
49
37
from broadcaster ._backends .redis import RedisBackend
@@ -66,7 +54,7 @@ def _create_backend(self, url: str) -> "BroadcastBackend":
66
54
return MemoryBackend (url )
67
55
raise ValueError (f"Unsupported backend: { parsed_url .scheme } " )
68
56
69
- async def __aenter__ (self ) -> " Broadcast" :
57
+ async def __aenter__ (self ) -> Broadcast :
70
58
await self .connect ()
71
59
return self
72
60
@@ -94,8 +82,8 @@ async def publish(self, channel: str, message: Any) -> None:
94
82
await self ._backend .publish (channel , message )
95
83
96
84
@asynccontextmanager
97
- async def subscribe (self , channel : str ) -> AsyncIterator [" Subscriber" ]:
98
- queue : asyncio .Queue = asyncio .Queue ()
85
+ async def subscribe (self , channel : str ) -> AsyncIterator [Subscriber ]:
86
+ queue : asyncio .Queue [ Event | None ] = asyncio .Queue ()
99
87
100
88
try :
101
89
if not self ._subscribers .get (channel ):
@@ -114,10 +102,10 @@ async def subscribe(self, channel: str) -> AsyncIterator["Subscriber"]:
114
102
115
103
116
104
class Subscriber :
117
- def __init__ (self , queue : asyncio .Queue ) -> None :
105
+ def __init__ (self , queue : asyncio .Queue [ Event | None ] ) -> None :
118
106
self ._queue = queue
119
107
120
- async def __aiter__ (self ) -> Optional [ AsyncGenerator ] :
108
+ async def __aiter__ (self ) -> AsyncGenerator [ Event | None , None ] | None :
121
109
try :
122
110
while True :
123
111
yield await self .get ()
0 commit comments