Skip to content

Commit 5d6790e

Browse files
author
Nicolas ESTRADA
committed
fix: various fixes related to pyarrow backends
- refactor: changing closing semantics for db conns (using contextlib.closing) - fix: it is necessary sometimes to use the same reflection level as with the initial snapshot for arrow schemas - fix: timezone flag is now an acceptable seemless schema migration - fix: aligned precision for fixed integer types to match the ones inferred from the sql_database source (I guess to account for signed values) - chore: removed test case with changing the precision of a byte array with pyarrow (absurd one to begin with and no longer possible with the new rows_to_arrow implementation)
1 parent a591618 commit 5d6790e

File tree

4 files changed

+47
-46
lines changed

4 files changed

+47
-46
lines changed

sources/pg_legacy_replication/helpers.py

+32-35
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import hashlib
22
from collections import defaultdict
3+
from contextlib import closing
34
from dataclasses import dataclass, field
45
from functools import partial
56
from typing import (
@@ -164,8 +165,7 @@ def get_max_lsn(
164165
Returns None if the replication slot is empty.
165166
Does not consume the slot, i.e. messages are not flushed.
166167
"""
167-
conn = _get_conn(credentials)
168-
try:
168+
with closing(_get_conn(credentials)) as conn:
169169
with conn.cursor() as cur:
170170
pg_version = get_pg_version(cur)
171171
lsn_field = "lsn" if pg_version >= 100000 else "location"
@@ -181,8 +181,6 @@ def get_max_lsn(
181181
)
182182
row = cur.fetchone()
183183
return row[0] if row else None # type: ignore[no-any-return]
184-
finally:
185-
conn.close()
186184

187185

188186
def lsn_int_to_hex(lsn: int) -> str:
@@ -204,16 +202,13 @@ def advance_slot(
204202
the behavior of that method seems odd when used outside of `consume_stream`.
205203
"""
206204
assert upto_lsn > 0
207-
conn = _get_conn(credentials)
208-
try:
205+
with closing(_get_conn(credentials)) as conn:
209206
with conn.cursor() as cur:
210207
# There is unfortunately no way in pg9.6 to manually advance the replication slot
211208
if get_pg_version(cur) > 100000:
212209
cur.execute(
213210
f"SELECT * FROM pg_replication_slot_advance('{slot_name}', '{lsn_int_to_hex(upto_lsn)}');"
214211
)
215-
finally:
216-
conn.close()
217212

218213

219214
def _get_conn(
@@ -371,7 +366,7 @@ def get_table_schema(self, msg: RowMessage) -> TTableSchema:
371366
retained_schema = compare_schemas(last_schema, new_schema)
372367
self.last_table_schema[table_name] = retained_schema
373368
except AssertionError as e:
374-
logger.debug(str(e))
369+
logger.info(str(e))
375370
raise StopReplication
376371

377372
return new_schema
@@ -381,13 +376,14 @@ def _fetch_table_schema_with_sqla(
381376
) -> TTableSchema:
382377
"""Last resort function used to fetch the table schema from the database"""
383378
engine = engine_from_credentials(self.credentials)
379+
options = self.repl_options[table_name]
384380
to_col_schema = partial(
385-
sqla_col_to_column_schema, reflection_level="full_with_precision"
381+
sqla_col_to_column_schema,
382+
reflection_level=options.get("reflection_level", "full"),
386383
)
387384
try:
388385
metadata = MetaData(schema=schema)
389386
table = Table(table_name, metadata, autoload_with=engine)
390-
options = self.repl_options[table_name]
391387
included_columns = options.get("included_columns")
392388
columns = {
393389
col["name"]: col
@@ -427,6 +423,7 @@ class ItemGenerator:
427423
start_lsn: int
428424
repl_options: DefaultDict[str, ReplicationOptions]
429425
target_batch_size: int = 1000
426+
keepalive_interval: Optional[int] = None
430427
last_commit_lsn: Optional[int] = field(default=None, init=False)
431428
generated_all: bool = False
432429

@@ -438,30 +435,27 @@ def __iter__(self) -> Iterator[TableItems]:
438435
Maintains LSN of last consumed commit message in object state.
439436
Advances the slot only when all messages have been consumed.
440437
"""
441-
conn = get_rep_conn(self.credentials)
442-
consumer = MessageConsumer(
443-
credentials=self.credentials,
444-
upto_lsn=self.upto_lsn,
445-
table_qnames=self.table_qnames,
446-
repl_options=self.repl_options,
447-
target_batch_size=self.target_batch_size,
448-
)
449-
450-
cur = conn.cursor()
451-
try:
452-
cur.start_replication(slot_name=self.slot_name, start_lsn=self.start_lsn)
453-
cur.consume_stream(consumer)
454-
except StopReplication: # completed batch or reached `upto_lsn`
455-
yield from self.flush_batch(cur, consumer)
456-
finally:
457-
logger.debug(
458-
"Closing connection... last_commit_lsn: %s, generated_all: %s, feedback_ts: %s",
459-
self.last_commit_lsn,
460-
self.generated_all,
461-
cur.feedback_timestamp,
462-
)
463-
cur.close()
464-
conn.close()
438+
with closing(get_rep_conn(self.credentials)) as rep_conn:
439+
with rep_conn.cursor() as rep_cur:
440+
try:
441+
consumer = MessageConsumer(
442+
credentials=self.credentials,
443+
upto_lsn=self.upto_lsn,
444+
table_qnames=self.table_qnames,
445+
repl_options=self.repl_options,
446+
target_batch_size=self.target_batch_size,
447+
)
448+
rep_cur.start_replication(self.slot_name, start_lsn=self.start_lsn)
449+
rep_cur.consume_stream(consumer, self.keepalive_interval)
450+
except StopReplication: # completed batch or reached `upto_lsn`
451+
yield from self.flush_batch(rep_cur, consumer)
452+
finally:
453+
logger.debug(
454+
"Closing connection... last_commit_lsn: %s, generated_all: %s, feedback_ts: %s",
455+
self.last_commit_lsn,
456+
self.generated_all,
457+
rep_cur.feedback_timestamp,
458+
)
465459

466460
def flush_batch(
467461
self, cur: ReplicationCursor, consumer: MessageConsumer
@@ -662,6 +656,7 @@ def _actual_column_name(column: DatumMessage) -> str:
662656
"nullable",
663657
"precision",
664658
"scale",
659+
"timezone",
665660
}
666661

667662

@@ -703,6 +698,8 @@ def compare_schemas(last: TTableSchema, new: TTableSchema) -> TTableSchema:
703698
col_schema["precision"] = s1.get("precision", s2.get("precision"))
704699
if "scale" in s1 or "scale" in s2:
705700
col_schema["scale"] = s1.get("scale", s2.get("scale"))
701+
if "timezone" in s1 or "timezone" in s2:
702+
col_schema["timezone"] = s1.get("timezone", s2.get("timezone"))
706703

707704
# Update with the more detailed schema per column
708705
table_schema["columns"][name] = col_schema

sources/pg_legacy_replication/schema_types.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@
6767
"""Maps decoderbuf's datum msg type to dlt type."""
6868

6969
_FIXED_PRECISION_TYPES: Dict[int, Tuple[int, Optional[int]]] = {
70-
21: (16, None), # smallint
71-
23: (32, None), # integer
70+
21: (32, None), # smallint
71+
23: (64, None), # integer
7272
20: (64, None), # bigint
73-
700: (32, None), # real
73+
700: (64, None), # real
7474
}
7575
"""Dict for fixed precision types"""
7676

tests/pg_legacy_replication/cases.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
"col4_precision": "2022-05-23T13:26:46.167231+00:00",
4141
"col5_precision": "string data 2 \n \r \x8e 🦆",
4242
"col6_precision": Decimal("2323.34"),
43-
"col7_precision": b"binary data 2 \n \r \x8e",
43+
# "col7_precision": b"binary data 2 \n \r \x8e", # FIXME This is no longer possible in pyarrow and it's absurd to begin with
4444
"col11_precision": "13:26:45.176451",
4545
}
4646
TABLE_UPDATE: List[TColumnSchema] = [
@@ -86,12 +86,12 @@
8686
"scale": 2,
8787
"nullable": False,
8888
},
89-
{
90-
"name": "col7_precision",
91-
"data_type": "binary",
92-
"precision": 19,
93-
"nullable": False,
94-
},
89+
# {
90+
# "name": "col7_precision",
91+
# "data_type": "binary",
92+
# "precision": 19,
93+
# "nullable": False,
94+
# }, # FIXME See comment above
9595
{"name": "col11_precision", "data_type": "time", "precision": 6, "nullable": False},
9696
]
9797

@@ -987,6 +987,7 @@ class SchemaChoice(IntEnum):
987987
"name": "filled_at",
988988
"nullable": True,
989989
"data_type": "timestamp",
990+
"timezone": True,
990991
"precision": 6,
991992
},
992993
"_pg_lsn": {"name": "_pg_lsn", "nullable": True, "data_type": "bigint"},

tests/pg_legacy_replication/test_pg_replication.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -298,11 +298,14 @@ def items(data):
298298
if init_load and give_hints:
299299
snapshot.items.apply_hints(columns=column_schema)
300300

301+
repl_options = {"items": {"backend": backend}}
302+
if give_hints:
303+
repl_options["items"]["column_hints"] = column_schema
301304
changes = replication_source(
302305
slot_name=slot_name,
303306
schema=src_pl.dataset_name,
304307
table_names="items",
305-
repl_options={"items": {"backend": backend}},
308+
repl_options=repl_options,
306309
)
307310
changes.items.apply_hints(
308311
write_disposition="merge", primary_key="col1", columns=merge_hints

0 commit comments

Comments
 (0)