1
1
import hashlib
2
2
from collections import defaultdict
3
+ from contextlib import closing
3
4
from dataclasses import dataclass , field
4
5
from functools import partial
5
6
from typing import (
@@ -164,8 +165,7 @@ def get_max_lsn(
164
165
Returns None if the replication slot is empty.
165
166
Does not consume the slot, i.e. messages are not flushed.
166
167
"""
167
- conn = _get_conn (credentials )
168
- try :
168
+ with closing (_get_conn (credentials )) as conn :
169
169
with conn .cursor () as cur :
170
170
pg_version = get_pg_version (cur )
171
171
lsn_field = "lsn" if pg_version >= 100000 else "location"
@@ -181,8 +181,6 @@ def get_max_lsn(
181
181
)
182
182
row = cur .fetchone ()
183
183
return row [0 ] if row else None # type: ignore[no-any-return]
184
- finally :
185
- conn .close ()
186
184
187
185
188
186
def lsn_int_to_hex (lsn : int ) -> str :
@@ -204,16 +202,13 @@ def advance_slot(
204
202
the behavior of that method seems odd when used outside of `consume_stream`.
205
203
"""
206
204
assert upto_lsn > 0
207
- conn = _get_conn (credentials )
208
- try :
205
+ with closing (_get_conn (credentials )) as conn :
209
206
with conn .cursor () as cur :
210
207
# There is unfortunately no way in pg9.6 to manually advance the replication slot
211
208
if get_pg_version (cur ) > 100000 :
212
209
cur .execute (
213
210
f"SELECT * FROM pg_replication_slot_advance('{ slot_name } ', '{ lsn_int_to_hex (upto_lsn )} ');"
214
211
)
215
- finally :
216
- conn .close ()
217
212
218
213
219
214
def _get_conn (
@@ -371,7 +366,7 @@ def get_table_schema(self, msg: RowMessage) -> TTableSchema:
371
366
retained_schema = compare_schemas (last_schema , new_schema )
372
367
self .last_table_schema [table_name ] = retained_schema
373
368
except AssertionError as e :
374
- logger .debug (str (e ))
369
+ logger .info (str (e ))
375
370
raise StopReplication
376
371
377
372
return new_schema
@@ -381,13 +376,14 @@ def _fetch_table_schema_with_sqla(
381
376
) -> TTableSchema :
382
377
"""Last resort function used to fetch the table schema from the database"""
383
378
engine = engine_from_credentials (self .credentials )
379
+ options = self .repl_options [table_name ]
384
380
to_col_schema = partial (
385
- sqla_col_to_column_schema , reflection_level = "full_with_precision"
381
+ sqla_col_to_column_schema ,
382
+ reflection_level = options .get ("reflection_level" , "full" ),
386
383
)
387
384
try :
388
385
metadata = MetaData (schema = schema )
389
386
table = Table (table_name , metadata , autoload_with = engine )
390
- options = self .repl_options [table_name ]
391
387
included_columns = options .get ("included_columns" )
392
388
columns = {
393
389
col ["name" ]: col
@@ -427,6 +423,7 @@ class ItemGenerator:
427
423
start_lsn : int
428
424
repl_options : DefaultDict [str , ReplicationOptions ]
429
425
target_batch_size : int = 1000
426
+ keepalive_interval : Optional [int ] = None
430
427
last_commit_lsn : Optional [int ] = field (default = None , init = False )
431
428
generated_all : bool = False
432
429
@@ -438,30 +435,27 @@ def __iter__(self) -> Iterator[TableItems]:
438
435
Maintains LSN of last consumed commit message in object state.
439
436
Advances the slot only when all messages have been consumed.
440
437
"""
441
- conn = get_rep_conn (self .credentials )
442
- consumer = MessageConsumer (
443
- credentials = self .credentials ,
444
- upto_lsn = self .upto_lsn ,
445
- table_qnames = self .table_qnames ,
446
- repl_options = self .repl_options ,
447
- target_batch_size = self .target_batch_size ,
448
- )
449
-
450
- cur = conn .cursor ()
451
- try :
452
- cur .start_replication (slot_name = self .slot_name , start_lsn = self .start_lsn )
453
- cur .consume_stream (consumer )
454
- except StopReplication : # completed batch or reached `upto_lsn`
455
- yield from self .flush_batch (cur , consumer )
456
- finally :
457
- logger .debug (
458
- "Closing connection... last_commit_lsn: %s, generated_all: %s, feedback_ts: %s" ,
459
- self .last_commit_lsn ,
460
- self .generated_all ,
461
- cur .feedback_timestamp ,
462
- )
463
- cur .close ()
464
- conn .close ()
438
+ with closing (get_rep_conn (self .credentials )) as rep_conn :
439
+ with rep_conn .cursor () as rep_cur :
440
+ try :
441
+ consumer = MessageConsumer (
442
+ credentials = self .credentials ,
443
+ upto_lsn = self .upto_lsn ,
444
+ table_qnames = self .table_qnames ,
445
+ repl_options = self .repl_options ,
446
+ target_batch_size = self .target_batch_size ,
447
+ )
448
+ rep_cur .start_replication (self .slot_name , start_lsn = self .start_lsn )
449
+ rep_cur .consume_stream (consumer , self .keepalive_interval )
450
+ except StopReplication : # completed batch or reached `upto_lsn`
451
+ yield from self .flush_batch (rep_cur , consumer )
452
+ finally :
453
+ logger .debug (
454
+ "Closing connection... last_commit_lsn: %s, generated_all: %s, feedback_ts: %s" ,
455
+ self .last_commit_lsn ,
456
+ self .generated_all ,
457
+ rep_cur .feedback_timestamp ,
458
+ )
465
459
466
460
def flush_batch (
467
461
self , cur : ReplicationCursor , consumer : MessageConsumer
@@ -662,6 +656,7 @@ def _actual_column_name(column: DatumMessage) -> str:
662
656
"nullable" ,
663
657
"precision" ,
664
658
"scale" ,
659
+ "timezone" ,
665
660
}
666
661
667
662
@@ -703,6 +698,8 @@ def compare_schemas(last: TTableSchema, new: TTableSchema) -> TTableSchema:
703
698
col_schema ["precision" ] = s1 .get ("precision" , s2 .get ("precision" ))
704
699
if "scale" in s1 or "scale" in s2 :
705
700
col_schema ["scale" ] = s1 .get ("scale" , s2 .get ("scale" ))
701
+ if "timezone" in s1 or "timezone" in s2 :
702
+ col_schema ["timezone" ] = s1 .get ("timezone" , s2 .get ("timezone" ))
706
703
707
704
# Update with the more detailed schema per column
708
705
table_schema ["columns" ][name ] = col_schema
0 commit comments