Skip to content

Commit

Permalink
Fix issues with tags formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco-Montanez committed Nov 18, 2024
1 parent 48b1b30 commit 8713623
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 23 deletions.
2 changes: 0 additions & 2 deletions docs/examples/01_file_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@
.build()
)

print(csv_format_with_escape.to_sql())

# 1.3 Create JSON File Format
json_format = (
FileFormat.builder("json_format")
Expand Down
9 changes: 7 additions & 2 deletions src/snowforge/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from enum import Enum
from typing import Dict, Optional, Union

from snowforge.utilities import sql_format_dict, sql_quote_comment, sql_quote_string
from snowforge.utilities import (
sql_format_dict,
sql_format_tags,
sql_quote_comment,
sql_quote_string,
)

from .file_format import FileFormatSpecification

Expand Down Expand Up @@ -359,7 +364,7 @@ def to_sql(self) -> str:
parts.append(f"COMMENT = {sql_quote_comment(self.comment)}")

if self.tags:
parts.append(f"TAGS = {sql_format_dict(self.tags)}")
parts.append(f"TAG {sql_format_tags(self.tags)}")

return " ".join(parts)

Expand Down
4 changes: 3 additions & 1 deletion src/snowforge/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@ def to_sql(self) -> str:
parts.append("COPY GRANTS")

if self.cluster_by:
parts.append(f"CLUSTER BY {sql_format_list(self.cluster_by)}")
parts.append(
f"CLUSTER BY {sql_format_list(self.cluster_by, quote_values=False)}"
)

if self.row_access_policy:
policy = self.row_access_policy
Expand Down
63 changes: 51 additions & 12 deletions src/snowforge/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,21 @@ def sql_format_boolean(value: bool) -> str:
return str(value).upper()


def sql_format_list(values: List[str]) -> str:
"""Formats a list of strings for SQL, with proper escaping.
def sql_format_list(values: List[str], quote_values: bool = True) -> str:
"""Formats a list of strings for SQL, with optional quoting.
Args:
values: List of strings to format
quote_values: Whether to quote the values (default: True)
Returns:
SQL-formatted string representation of the list
"""
quoted_values = [sql_quote_string(val) for val in values]
return f"({', '.join(quoted_values)})"
if quote_values:
formatted_values = [sql_quote_string(val) for val in values]
else:
formatted_values = values
return f"({', '.join(formatted_values)})"


def sql_format_value(value: Union[bool, str, int, float, None]) -> str:
Expand All @@ -69,19 +73,23 @@ def sql_format_value(value: Union[bool, str, int, float, None]) -> str:
return sql_quote_string(str(value))


def sql_format_dict(values: Dict[str, str]) -> str:
"""Formats a dictionary for SQL, with proper escaping.
def sql_format_dict(d: Dict[str, str]) -> str:
"""Formats a dictionary for SQL.
For tags in Snowflake, the format should be:
tag (key1 = 'value1', key2 = 'value2')
Args:
values: Dictionary to format
d: Dictionary to format
Returns:
SQL-formatted string representation of the dictionary
Formatted string for SQL
"""
parts = []
for key, value in values.items():
parts.append(f"{sql_quote_string(key)} = {sql_format_value(value)}")
return f"({', '.join(parts)})"
if not d:
return ""

pairs = [f"{k} = {sql_quote_string(v)}" for k, v in d.items()]
return f"({', '.join(pairs)})"


def sql_escape_comment(value: str) -> str:
Expand Down Expand Up @@ -111,3 +119,34 @@ def sql_quote_comment(value: str) -> str:
Quoted and escaped comment string safe for SQL
"""
return f"'{sql_escape_comment(value)}'"


def sql_format_tag(key: str, value: str) -> str:
"""Formats a single tag for SQL.
For table tags in Snowflake, the format should be:
tag (key = 'value')
Args:
key: Tag key
value: Tag value
Returns:
Formatted string for SQL
"""
return f"TAG ({key} = {sql_quote_string(value)})"


def sql_format_tags(tags: Dict[str, str]) -> str:
"""Formats multiple tags for SQL.
Args:
tags: Dictionary of tags
Returns:
Formatted string for SQL
"""
if not tags:
return ""

return f"{' '.join(sql_format_dict(tags).split(', '))}"
5 changes: 3 additions & 2 deletions tests/test_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ def test_internal_stage_full_config(internal_stage_params, file_format):
.with_tag("env", "test")
.build()
)

expected = (
'CREATE STAGE TEST_INTERNAL '
"ENCRYPTION = (TYPE = 'SNOWFLAKE_SSE') "
"DIRECTORY = (ENABLE = TRUE REFRESH_ON_CREATE = TRUE) "
"FILE_FORMAT = (FORMAT_NAME = 'TEST_FORMAT') "
"COMMENT = 'Test internal stage' "
"TAGS = ('env' = 'test')"
"TAG (env = 'test')"
)
assert stage.to_sql() == expected

Expand All @@ -80,7 +81,7 @@ def test_azure_external_stage():
'CREATE STAGE TEST_AZURE '
"URL = 'azure://container/path' "
"STORAGE_INTEGRATION = AZURE_INT "
"ENCRYPTION = ('TYPE' = 'AZURE_CSE')"
"ENCRYPTION = (TYPE = 'AZURE_CSE')"
)
assert stage.to_sql() == expected

Expand Down
4 changes: 2 additions & 2 deletions tests/test_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_stream_with_tags():
)
expected = (
"CREATE STREAM TEST_STREAM "
"WITH TAG ('env' = 'test', 'owner' = 'data_team') "
"WITH TAG (env = 'test', owner = 'data_team') "
"ON TABLE TEST_TABLE"
)
assert stream.to_sql() == expected
Expand All @@ -127,7 +127,7 @@ def test_complex_stream_configuration(complex_stream):
"""Test creation of a stream with all available options."""
expected = (
"CREATE OR REPLACE STREAM COMPLEX_STREAM "
"WITH TAG ('env' = 'test', 'owner' = 'data_team') "
"WITH TAG (env = 'test', owner = 'data_team') "
"ON TABLE TEST_DB.TEST_SCHEMA.TEST_TABLE "
"APPEND_ONLY = TRUE "
"SHOW_INITIAL_ROWS = TRUE "
Expand Down
4 changes: 2 additions & 2 deletions tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,8 @@ def test_table_creation_complex(complex_table):
"created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP()"
") "
"COMMENT = 'User accounts table' "
"CLUSTER BY ('email') "
"WITH TAG ('department' = 'hr', 'security_level' = 'high')"
"CLUSTER BY (email) "
"WITH TAG (department = 'hr', security_level = 'high')"
)
assert complex_table.to_sql() == expected

Expand Down

0 comments on commit 8713623

Please sign in to comment.