Skip to content

Commit 79aabf6

Browse files
add TrinoBatchInsert (#15)
Signed-off-by: Erik Erlandson <eerlands@redhat.com>
1 parent 8001770 commit 79aabf6

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

osc_ingest_trino/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@
2323
"upload_directory_to_s3",
2424
"load_credentials_dotenv",
2525
"attach_trino_engine",
26+
"TrinoBatchInsert",
2627
]
2728

osc_ingest_trino/trino_utils.py

+48
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
__all__ = [
77
"attach_trino_engine",
8+
"TrinoBatchInsert",
89
]
910

1011
def attach_trino_engine(env_var_prefix = 'TRINO', catalog = None, schema = None, verbose = False):
@@ -31,3 +32,50 @@ def attach_trino_engine(env_var_prefix = 'TRINO', catalog = None, schema = None,
3132
connection = engine.connect()
3233
return engine
3334

35+
class TrinoBatchInsert(object):
36+
def __init__(self,
37+
catalog = None,
38+
schema = None,
39+
batch_size = 1000,
40+
verbose = False):
41+
self.catalog = catalog
42+
self.schema = schema
43+
self.batch_size = batch_size
44+
self.verbose = verbose
45+
46+
# conforms to signature expected by pandas 'callable' value for method kw arg
47+
# https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.to_sql.html
48+
# https://pandas.pydata.org/docs/user_guide/io.html#io-sql-method
49+
def __call__(self, sqltbl, dbcxn, columns, data_iter):
50+
batch = []
51+
for r in data_iter:
52+
# each row of data_iter is a python tuple
53+
batch.append(str(r))
54+
# possible alternative: dispatch batches by total batch size in bytes
55+
if len(batch) >= self.batch_size:
56+
self._do_insert(dbcxn, sqltbl, batch)
57+
batch = []
58+
if len(batch) > 0:
59+
self._do_insert(dbcxn, sqltbl, batch)
60+
61+
def _do_insert(self, dbcxn, sqltbl, batch_rows):
62+
if self.verbose: print(f'inserting {len(batch_rows)} records')
63+
valclause = ',\n'.join(batch_rows)
64+
sql = f'insert into {self._full_table_name(sqltbl)} values\n{valclause}'
65+
# could add something that prints only summary here, but
66+
# generally too much data to print reasonably
67+
#if self.verbose: print(f'{sql}')
68+
qres = dbcxn.execute(sql)
69+
x = qres.fetchall()
70+
if self.verbose: print(f'batch insert result: {x}')
71+
72+
def _full_table_name(self, sqltbl):
73+
# start with table name
74+
name = f'{sqltbl.name}'
75+
# prepend schema - allow override from this class
76+
name = f'{self.schema or sqltbl.schema}.{name}'
77+
# prepend catalog, if provided
78+
if self.catalog is not None:
79+
name = f'{self.catalog}.{name}'
80+
if self.verbose: print(f'constructed fully qualified table name as: "{name}"')
81+
return name

0 commit comments

Comments
 (0)