@@ -26,12 +26,13 @@ class Pool:
26
26
'_connect_args' , '_connect_kwargs' ,
27
27
'_working_addr' , '_working_opts' ,
28
28
'_con_count' , '_max_queries' , '_connections' ,
29
- '_initialized' , '_closed' )
29
+ '_initialized' , '_closed' , '_setup' )
30
30
31
31
def __init__ (self , * connect_args ,
32
32
min_size ,
33
33
max_size ,
34
34
max_queries ,
35
+ setup ,
35
36
loop ,
36
37
** connect_kwargs ):
37
38
@@ -55,6 +56,8 @@ def __init__(self, *connect_args,
55
56
self ._maxsize = max_size
56
57
self ._max_queries = max_queries
57
58
59
+ self ._setup = setup
60
+
58
61
self ._connect_args = connect_args
59
62
self ._connect_kwargs = connect_kwargs
60
63
@@ -65,10 +68,9 @@ def __init__(self, *connect_args,
65
68
66
69
self ._closed = False
67
70
68
- async def _new_connection (self , timeout = None ):
71
+ async def _new_connection (self ):
69
72
if self ._working_addr is None :
70
73
con = await connection .connect (* self ._connect_args ,
71
- timeout = timeout ,
72
74
loop = self ._loop ,
73
75
** self ._connect_kwargs )
74
76
@@ -83,7 +85,6 @@ async def _new_connection(self, timeout=None):
83
85
host , port = self ._working_addr
84
86
85
87
con = await connection .connect (host = host , port = port ,
86
- timeout = timeout ,
87
88
loop = self ._loop ,
88
89
** self ._working_opts )
89
90
@@ -134,27 +135,40 @@ def acquire(self, *, timeout=None):
134
135
return PoolAcquireContext (self , timeout )
135
136
136
137
async def _acquire (self , timeout ):
138
+ if timeout is None :
139
+ return await self ._acquire_impl ()
140
+ else :
141
+ return await asyncio .wait_for (self ._acquire_impl (),
142
+ timeout = timeout ,
143
+ loop = self ._loop )
144
+
145
+ async def _acquire_impl (self ):
137
146
self ._check_init ()
138
147
139
148
try :
140
- return self ._queue .get_nowait ()
149
+ con = self ._queue .get_nowait ()
141
150
except asyncio .QueueEmpty :
142
- pass
151
+ con = None
152
+
153
+ if con is None :
154
+ if self ._con_count < self ._maxsize :
155
+ self ._con_count += 1
156
+ try :
157
+ con = await self ._new_connection ()
158
+ except :
159
+ self ._con_count -= 1
160
+ raise
161
+ else :
162
+ con = await self ._queue .get ()
143
163
144
- if self ._con_count < self ._maxsize :
145
- self ._con_count += 1
164
+ if self ._setup is not None :
146
165
try :
147
- con = await self ._new_connection ( timeout = timeout )
166
+ await self ._setup ( con )
148
167
except :
149
- self ._con_count -= 1
168
+ await self .release ( con )
150
169
raise
151
- return con
152
170
153
- if timeout is None :
154
- return await self ._queue .get ()
155
- else :
156
- return await asyncio .wait_for (self ._queue .get (), timeout = timeout ,
157
- loop = self ._loop )
171
+ return con
158
172
159
173
async def release (self , connection ):
160
174
"""Release a database connection back to the pool."""
@@ -246,6 +260,7 @@ def create_pool(dsn=None, *,
246
260
min_size = 10 ,
247
261
max_size = 10 ,
248
262
max_queries = 50000 ,
263
+ setup = None ,
249
264
loop = None ,
250
265
** connect_kwargs ):
251
266
r"""Create a connection pool.
@@ -281,11 +296,16 @@ def create_pool(dsn=None, *,
281
296
:param int max_size: Max number of connections in the pool.
282
297
:param int max_queries: Number of queries after a connection is closed
283
298
and replaced with a new connection.
299
+ :param coroutine setup: A coroutine to initialize a connection right before
300
+ it is returned from :meth:`~pool.Pool.acquire`.
301
+ An example use case would be to automatically
302
+ set up notifications listeners for all connections
303
+ of a pool.
284
304
:param loop: An asyncio event loop instance. If ``None``, the default
285
305
event loop will be used.
286
306
:return: An instance of :class:`~asyncpg.pool.Pool`.
287
307
"""
288
308
return Pool (dsn ,
289
309
min_size = min_size , max_size = max_size ,
290
- max_queries = max_queries , loop = loop ,
310
+ max_queries = max_queries , loop = loop , setup = setup ,
291
311
** connect_kwargs )
0 commit comments