Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix snowflake uploader array bind variable #382

Merged
merged 3 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
## 0.5.2-dev1
## 0.5.2

### Enchancements
### Enhancements

* **Only embed elements with text** - Only embed elements with text to avoid errors from embedders and optimize calls to APIs.
* **Improved google drive precheck mechanism**
* **Added integration tests for google drive precheck and connector**

## 0.5.2-dev0

### Enhancements
### Fixes

* **Only embed elements with text** - Only embed elements with text to avoid errors from embedders and optimize calls to APIs.
* **Fix Snowflake Uploader error with array variable binding**

## 0.5.1

Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.5.2-dev1" # pragma: no cover
__version__ = "0.5.2" # pragma: no cover
56 changes: 53 additions & 3 deletions unstructured_ingest/v2/processes/connectors/sql/snowflake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generator, Optional
from typing import TYPE_CHECKING, Any, Generator, Optional

import numpy as np
import pandas as pd
Expand All @@ -15,6 +16,7 @@
SourceRegistryEntry,
)
from unstructured_ingest.v2.processes.connectors.sql.sql import (
_DATE_COLUMNS,
SQLAccessConfig,
SqlBatchFileData,
SQLConnectionConfig,
Expand All @@ -26,6 +28,7 @@
SQLUploaderConfig,
SQLUploadStager,
SQLUploadStagerConfig,
parse_date_string,
)

if TYPE_CHECKING:
Expand All @@ -34,6 +37,17 @@

CONNECTOR_TYPE = "snowflake"

_ARRAY_COLUMNS = (
"embeddings",
"languages",
"link_urls",
"link_texts",
"sent_from",
"sent_to",
"emphasized_text_contents",
"emphasized_text_tags",
)


class SnowflakeAccessConfig(SQLAccessConfig):
password: Optional[str] = Field(default=None, description="DB password")
Expand Down Expand Up @@ -160,6 +174,42 @@ class SnowflakeUploader(SQLUploader):
connector_type: str = CONNECTOR_TYPE
values_delimiter: str = "?"

def prepare_data(
self, columns: list[str], data: tuple[tuple[Any, ...], ...]
) -> list[tuple[Any, ...]]:
output = []
for row in data:
parsed = []
for column_name, value in zip(columns, row):
if column_name in _DATE_COLUMNS:
if value is None or pd.isna(value): # pandas is nan
parsed.append(None)
else:
parsed.append(parse_date_string(value))
elif column_name in _ARRAY_COLUMNS:
if not isinstance(value, list) and (
value is None or pd.isna(value)
): # pandas is nan
parsed.append(None)
else:
parsed.append(json.dumps(value))
else:
parsed.append(value)
output.append(tuple(parsed))
return output

def _parse_values(self, columns: list[str]) -> str:
return ",".join(
[
(
f"PARSE_JSON({self.values_delimiter})"
if col in _ARRAY_COLUMNS
else self.values_delimiter
)
for col in columns
]
)

def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
if self.can_delete():
self.delete_by_record_id(file_data=file_data)
Expand All @@ -173,10 +223,10 @@ def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
self._fit_to_schema(df=df)

columns = list(df.columns)
stmt = "INSERT INTO {table_name} ({columns}) VALUES({values})".format(
stmt = "INSERT INTO {table_name} ({columns}) SELECT {values}".format(
table_name=self.upload_config.table_name,
columns=",".join(columns),
values=",".join([self.values_delimiter for _ in columns]),
values=self._parse_values(columns),
)
logger.info(
f"writing a total of {len(df)} elements via"
Expand Down