Skip to content

Commit 32c7d70

Browse files
committed
Add 'setup' parameter to pool.create_pool(); issue #4.
1 parent d316004 commit 32c7d70

File tree

2 files changed

+50
-17
lines changed

2 files changed

+50
-17
lines changed

asyncpg/pool.py

+37-17
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,13 @@ class Pool:
2626
'_connect_args', '_connect_kwargs',
2727
'_working_addr', '_working_opts',
2828
'_con_count', '_max_queries', '_connections',
29-
'_initialized', '_closed')
29+
'_initialized', '_closed', '_setup')
3030

3131
def __init__(self, *connect_args,
3232
min_size,
3333
max_size,
3434
max_queries,
35+
setup,
3536
loop,
3637
**connect_kwargs):
3738

@@ -55,6 +56,8 @@ def __init__(self, *connect_args,
5556
self._maxsize = max_size
5657
self._max_queries = max_queries
5758

59+
self._setup = setup
60+
5861
self._connect_args = connect_args
5962
self._connect_kwargs = connect_kwargs
6063

@@ -65,10 +68,9 @@ def __init__(self, *connect_args,
6568

6669
self._closed = False
6770

68-
async def _new_connection(self, timeout=None):
71+
async def _new_connection(self):
6972
if self._working_addr is None:
7073
con = await connection.connect(*self._connect_args,
71-
timeout=timeout,
7274
loop=self._loop,
7375
**self._connect_kwargs)
7476

@@ -83,7 +85,6 @@ async def _new_connection(self, timeout=None):
8385
host, port = self._working_addr
8486

8587
con = await connection.connect(host=host, port=port,
86-
timeout=timeout,
8788
loop=self._loop,
8889
**self._working_opts)
8990

@@ -134,27 +135,40 @@ def acquire(self, *, timeout=None):
134135
return PoolAcquireContext(self, timeout)
135136

136137
async def _acquire(self, timeout):
138+
if timeout is None:
139+
return await self._acquire_impl()
140+
else:
141+
return await asyncio.wait_for(self._acquire_impl(),
142+
timeout=timeout,
143+
loop=self._loop)
144+
145+
async def _acquire_impl(self):
137146
self._check_init()
138147

139148
try:
140-
return self._queue.get_nowait()
149+
con = self._queue.get_nowait()
141150
except asyncio.QueueEmpty:
142-
pass
151+
con = None
152+
153+
if con is None:
154+
if self._con_count < self._maxsize:
155+
self._con_count += 1
156+
try:
157+
con = await self._new_connection()
158+
except:
159+
self._con_count -= 1
160+
raise
161+
else:
162+
con = await self._queue.get()
143163

144-
if self._con_count < self._maxsize:
145-
self._con_count += 1
164+
if self._setup is not None:
146165
try:
147-
con = await self._new_connection(timeout=timeout)
166+
await self._setup(con)
148167
except:
149-
self._con_count -= 1
168+
await self.release(con)
150169
raise
151-
return con
152170

153-
if timeout is None:
154-
return await self._queue.get()
155-
else:
156-
return await asyncio.wait_for(self._queue.get(), timeout=timeout,
157-
loop=self._loop)
171+
return con
158172

159173
async def release(self, connection):
160174
"""Release a database connection back to the pool."""
@@ -246,6 +260,7 @@ def create_pool(dsn=None, *,
246260
min_size=10,
247261
max_size=10,
248262
max_queries=50000,
263+
setup=None,
249264
loop=None,
250265
**connect_kwargs):
251266
r"""Create a connection pool.
@@ -281,11 +296,16 @@ def create_pool(dsn=None, *,
281296
:param int max_size: Max number of connections in the pool.
282297
:param int max_queries: Number of queries after a connection is closed
283298
and replaced with a new connection.
299+
:param coroutine setup: A coroutine to initialize a connection right before
300+
it is returned from :meth:`~pool.Pool.acquire`.
301+
An example use case would be to automatically
302+
set up notifications listeners for all connections
303+
of a pool.
284304
:param loop: An asyncio event loop instance. If ``None``, the default
285305
event loop will be used.
286306
:return: An instance of :class:`~asyncpg.pool.Pool`.
287307
"""
288308
return Pool(dsn,
289309
min_size=min_size, max_size=max_size,
290-
max_queries=max_queries, loop=loop,
310+
max_queries=max_queries, loop=loop, setup=setup,
291311
**connect_kwargs)

tests/test_pool.py

+13
Original file line numberDiff line numberDiff line change
@@ -81,3 +81,16 @@ async def worker():
8181
tasks = [worker() for _ in range(n)]
8282
await asyncio.gather(*tasks, loop=self.loop)
8383
await pool.close()
84+
85+
async def test_pool_06(self):
86+
fut = asyncio.Future(loop=self.loop)
87+
88+
async def setup(con):
89+
fut.set_result(con)
90+
91+
async with self.create_pool(database='postgres',
92+
min_size=5, max_size=5,
93+
setup=setup) as pool:
94+
con = await pool.acquire()
95+
96+
self.assertIs(con, await fut)

0 commit comments

Comments
 (0)