-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbatch_sync.py
372 lines (313 loc) · 16.2 KB
/
batch_sync.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
import os
import sys
import json
import time
import logging
import argparse
from datetime import date, datetime
import pandas as pd
import requests
from snowflake.snowpark import Session
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
# Custom JSON encoder to handle date objects
class DateTimeEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (date, datetime)):
# Convert to RFC 3339 format
if isinstance(obj, date) and not isinstance(obj, datetime):
obj = datetime.combine(obj, datetime.min.time())
return obj.strftime("%Y-%m-%dT%H:%M:%SZ")
return super(DateTimeEncoder, self).default(obj)
def sync_stream_to_batch(conn, project_key, source_stream, id_column, date_columns='', url_columns='', api_credentials_table='BATCH_API_CREDENTIALS'):
"""
Sync a Snowflake stream to Batch.com
Args:
conn: Snowflake connection
project_key (str): Batch project key
source_stream (str): Name of the source stream
id_column (str): Column to use as the customer ID
date_columns (str): Comma-separated list of date columns
url_columns (str): Comma-separated list of URL columns
api_credentials_table (str): Name of the table containing API credentials
Returns:
str: Result message
"""
try:
# Convert comma-separated lists to sets for easy lookup
date_columns_set = {col.strip().upper() for col in date_columns.split(',')} if date_columns else set()
url_columns_set = {col.strip().upper() for col in url_columns.split(',')} if url_columns else set()
cursor = conn.cursor()
# Retrieve API credentials
logger.info(f"Retrieving API credentials for project key: {project_key}")
cursor.execute(f"SELECT * FROM {api_credentials_table} WHERE project_key = '{project_key}'")
creds = cursor.fetchone()
if not creds:
error_msg = f"No API credentials found for project key: {project_key}"
logger.error(error_msg)
return error_msg
# Get the column names from cursor description
col_names = [desc[0] for desc in cursor.description]
# Find the index of the REST_API_KEY column
rest_api_key_idx = next((i for i, col in enumerate(col_names) if col.upper() == 'REST_API_KEY'), None)
if rest_api_key_idx is None:
error_msg = "REST_API_KEY column not found in API_CREDENTIALS table"
logger.error(error_msg)
return error_msg
rest_api_key = creds[rest_api_key_idx]
api_url = "https://api.batch.com/2.4/profiles/update"
# Begin a transaction for stream consumption
cursor.execute("BEGIN TRANSACTION")
try:
# Get column information by directly querying the stream
logger.info(f"Extracting columns directly from stream: {source_stream}")
# Try to get a sample row to extract column names
cursor.execute(f"SELECT * FROM {source_stream} LIMIT 1")
sample_row = cursor.fetchone()
if not sample_row:
# No data in stream, commit to mark the position and exit
cursor.execute("COMMIT")
message = f"No data found in stream {source_stream}."
logger.info(message)
return message
# Extract column names from cursor description
column_names = [desc[0] for desc in cursor.description if not desc[0].startswith('METADATA$')]
if not column_names:
error_msg = f"No non-metadata columns found in stream {source_stream}"
logger.error(error_msg)
cursor.execute("ROLLBACK")
return error_msg
logger.info(f"Found {len(column_names)} columns in stream: {column_names}")
# Verify the ID column exists (case-insensitive search)
id_column_found = False
for col in column_names:
if col.upper() == id_column.upper():
id_column = col # Use the exact case from the stream
id_column_found = True
break
if not id_column_found:
error_msg = f"Error: ID column '{id_column}' not found in stream columns: {column_names}"
logger.error(error_msg)
cursor.execute("ROLLBACK")
return error_msg
# Build the dynamic SQL query with all columns plus stream metadata
# Quote the column names to handle case sensitivity and special characters
columns_str = ', '.join([f'"{col}"' for col in column_names])
query = f"""
SELECT {columns_str},
METADATA$ACTION,
METADATA$ISUPDATE,
METADATA$ROW_ID
FROM {source_stream}
"""
# Fetch data from the stream
logger.info(f"Fetching changes from stream {source_stream}")
cursor.execute(query)
rows = cursor.fetchall()
# Get all column names including metadata
all_columns = column_names + ["METADATA$ACTION", "METADATA$ISUPDATE", "METADATA$ROW_ID"]
# Convert to pandas DataFrame for easier processing
changes_df = pd.DataFrame(rows, columns=all_columns)
if changes_df.empty:
# Commit the transaction to mark the current stream position
cursor.execute("COMMIT")
message = f"No rows found in stream {source_stream} despite having schema."
logger.info(message)
return message
except Exception as e:
error_msg = f"Error accessing stream {source_stream}: {str(e)}"
logger.error(error_msg)
cursor.execute("ROLLBACK")
return error_msg
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + rest_api_key,
"X-Batch-Project": project_key
}
success_count = 0
fail_count = 0
error_logs = []
user_data_batch = []
# Process each row in the dataframe
logger.info(f"Processing {len(changes_df)} change records from {source_stream}")
for index, row in changes_df.iterrows():
try:
# Get the action type (INSERT, UPDATE, or DELETE)
action = row["METADATA$ACTION"]
# Skip deleted records as Batch doesn't support deletion
if action == "DELETE":
logger.debug(f"Skipping DELETE action for row {index}")
continue
custom_id = str(row[id_column])
# Process attributes with proper data typing
attributes = {}
for col_name, value in row.items():
# Skip metadata columns and ID column
if col_name in ["METADATA$ACTION", "METADATA$ISUPDATE", "METADATA$ROW_ID"] or col_name == id_column:
continue
if pd.isna(value):
continue # Skip None/NaN values
# Convert column name to lowercase for consistency in Batch
attr_name = col_name.lower()
# Process based on field type, with appropriate attribute name wrapping
# Use case-insensitive matching for date and URL columns
if col_name.upper() in date_columns_set:
# Use date() wrapper for date field names
attributes[f"date({attr_name})"] = value
elif col_name.upper() in url_columns_set:
# Use url() wrapper for URL field names
attributes[f"url({attr_name})"] = value
else:
# Keep all other values with their native types
attributes[attr_name] = value
user_data_batch.append({
"identifiers": {
"custom_id": custom_id,
},
"attributes": attributes
})
if len(user_data_batch) == 1000 or index == len(changes_df) - 1:
try:
# Use the custom encoder to handle date objects
json_data = json.dumps(user_data_batch, cls=DateTimeEncoder)
logger.debug(f"Sending batch of {len(user_data_batch)} records to Batch API")
response = requests.post(api_url, headers=headers, data=json_data)
if response.status_code == 202:
success_count += len(user_data_batch)
logger.debug(f"Successfully sent {len(user_data_batch)} records")
else:
fail_count += len(user_data_batch)
error_msg = f"Failed for batch starting with custom_id {user_data_batch[0]['identifiers']['custom_id']}: {response.text[:500]}"
logger.error(error_msg)
error_logs.append(error_msg)
except Exception as e:
fail_count += len(user_data_batch)
error_msg = f"Exception for batch starting with custom_id {user_data_batch[0]['identifiers']['custom_id']}: {str(e)}"
logger.error(error_msg)
error_logs.append(error_msg)
user_data_batch = []
time.sleep(1) # Rate limiting
except Exception as e:
fail_count += 1
error_msg = f"Error processing row {index}: {str(e)}"
logger.error(error_msg)
error_logs.append(error_msg)
# If everything was successful, commit the transaction to mark the stream as consumed
if fail_count == 0:
logger.info("All records processed successfully, committing transaction to consume stream data")
cursor.execute("COMMIT")
else:
# If there were any failures, roll back so we can retry
logger.warning(f"{fail_count} records failed processing, rolling back transaction")
cursor.execute("ROLLBACK")
# Save results to a log table if desired
result_message = f"Stream sync complete for {source_stream}: {success_count} records succeeded, {fail_count} failed."
logger.info(result_message)
if error_logs:
error_detail = "\n".join(error_logs)
logger.warning(f"Errors encountered during sync:\n{error_detail}")
result_message += "\n" + error_detail
return result_message
except Exception as e:
error_msg = f"Error in sync_stream_to_batch: {str(e)}"
logger.error(error_msg)
try:
cursor.execute("ROLLBACK")
logger.info("Transaction rolled back due to error")
except:
pass
return error_msg
def parse_arguments():
"""Parse command line arguments"""
parser = argparse.ArgumentParser(description='Sync Snowflake stream to Batch.com')
parser.add_argument('--project-key', required=True, help='Batch project key')
parser.add_argument('--source-stream', required=True, help='Source stream name (format: DATABASE.SCHEMA.STREAM)')
parser.add_argument('--id-column', required=True, help='Column to use as the custom ID')
parser.add_argument('--date-columns', default='', help='Comma-separated list of date columns')
parser.add_argument('--url-columns', default='', help='Comma-separated list of URL columns')
parser.add_argument('--api-credentials-table', default='BATCH_API_CREDENTIALS', help='Table containing API credentials')
parser.add_argument('--connection-parameters', default='', help='JSON string of connection parameters')
return parser.parse_args()
def create_session(connection_params=None):
"""
Create a Snowpark session using either connection parameters or environment variables
Args:
connection_params (dict, optional): Connection parameters as a dictionary
Returns:
snowflake.snowpark.Session: Snowpark session
"""
try:
# If connection parameters are provided, use them
if connection_params:
logger.info("Creating session using provided connection parameters")
return Session.builder.configs(connection_params).create()
# For Snowpark Container Services, first check if we're running in a container environment
# with default connection parameters
logger.info("Checking for Snowflake connection method")
# In a Container Service, we might be able to use getConnection() which
# automatically uses the container's credentials
try:
from snowflake.connector import connect
try:
logger.info("Attempting to use default container service connection")
conn = connect(application='BATCH_CONNECTOR')
return Session._from_connection(conn)
except Exception as container_err:
logger.info(f"Container service connection not available: {str(container_err)}")
except ImportError:
logger.info("snowflake.connector module not available for direct connection")
# Fall back to environment variables
logger.info("Creating session using environment variables")
connection_parameters = {
"account": os.environ.get("SNOWFLAKE_ACCOUNT"),
"user": os.environ.get("SNOWFLAKE_USER"),
"password": os.environ.get("SNOWFLAKE_PASSWORD"),
"private_key_path": os.environ.get("SNOWFLAKE_PRIVATE_KEY_PATH"),
"private_key_passphrase": os.environ.get("SNOWFLAKE_PRIVATE_KEY_PASSPHRASE"),
"role": os.environ.get("SNOWFLAKE_ROLE", "ACCOUNTADMIN"),
"warehouse": os.environ.get("SNOWFLAKE_WAREHOUSE"),
"database": os.environ.get("SNOWFLAKE_DATABASE"),
"schema": os.environ.get("SNOWFLAKE_SCHEMA")
}
# Filter out None values
connection_parameters = {k: v for k, v in connection_parameters.items() if v is not None}
if not connection_parameters.get("account") and not (connection_parameters.get("user") or connection_parameters.get("private_key_path")):
logger.warning("No account or authentication method specified in environment variables")
return Session.builder.configs(connection_parameters).create()
except Exception as e:
logger.error(f"Error creating Snowpark session: {str(e)}")
raise
def main():
"""Main function to execute the script"""
try:
# Parse command line arguments
args = parse_arguments()
# Create Snowflake session
connection_params = json.loads(args.connection_parameters) if args.connection_parameters else None
session = create_session(connection_params)
# Get the underlying connection
conn = session._conn._conn
# Call the sync function
logger.info("Starting stream sync process")
result = sync_stream_to_batch(
conn,
args.project_key,
args.source_stream,
args.id_column,
args.date_columns,
args.url_columns,
args.api_credentials_table
)
logger.info(f"Final result: {result}")
# Close the session
session.close()
return result
except Exception as e:
logger.error(f"Error in main function: {str(e)}")
sys.exit(1)
if __name__ == "__main__":
main()