From 5b97a73acf8e7e3926980068487f807c2cd4fed0 Mon Sep 17 00:00:00 2001 From: "Mathias V. Nielsen" <1547127+math280h@users.noreply.github.com> Date: Sun, 11 Dec 2022 17:21:52 -0500 Subject: [PATCH] Improved SQL Generation. (#17) * Improved SQL Generation. * Bugfix for redaction, linting + typing * Linting --- README.md | 74 ++++++++----------------- redactdump/app.py | 15 ++--- redactdump/core/config.py | 13 ++++- redactdump/core/database.py | 107 ++++++++++++++++++------------------ redactdump/core/file.py | 46 ++++++++++++---- redactdump/core/models.py | 21 +++++++ redactdump/core/redactor.py | 34 +++++++++--- tests/config.yaml | 2 + tests/test_redactor.py | 18 ++++-- 9 files changed, 191 insertions(+), 139 deletions(-) create mode 100644 redactdump/core/models.py diff --git a/README.md b/README.md index 04f5090..ee14581 100644 --- a/README.md +++ b/README.md @@ -76,58 +76,8 @@ output: ```` ### Configuration Schema -
-Configuration schema - -```python -Schema({ - "connection": { - "type": str, - "host": str, - "port": int, - "database": str, - Optional("username"): str, - Optional("password"): str, - }, - "redact": { - Optional("columns"): { - str: [ - { - "name": str, - "replacement": lambda r: True - if r is None or type(r) is str - else False, - } - ] - }, - Optional("patterns"): { - Optional("column"): [ - { - "pattern": str, - "replacement": lambda r: True - if r is None or type(r) is str - else False, - } - ], - Optional("data"): [ - { - "pattern": str, - "replacement": lambda r: True - if r is None or type(r) is str - else False, - } - ], - }, - }, - "output": { - "type": lambda t: True if t in ["file", "multi_file"] else False, - "location": str, - Optional("naming"): str, - }, -}) -``` -
+The configuration schema can be found [here](redactdump/core/config.py) ## Example @@ -191,3 +141,25 @@ INSERT INTO table_name VALUES (99, 'Robin Jefferson'); ``` + +## Known limitations + +### Data types not supported + +* box +* bytea +* inet +* interval +* circle +* cidr +* line +* lseg +* macaddr +* macaddr8 +* pg_lsn +* pg_snapshot +* point +* polygon +* tsquery +* tsvector +* txid_snapshot diff --git a/redactdump/app.py b/redactdump/app.py index 5c98a79..fbe480e 100644 --- a/redactdump/app.py +++ b/redactdump/app.py @@ -1,5 +1,5 @@ from concurrent.futures import ThreadPoolExecutor -from typing import Tuple, Union +from typing import Optional import configargparse from rich.console import Console @@ -90,14 +90,16 @@ def __init__(self) -> None: self.database = Database(self.config, self.console) self.file = File(self.config, self.console) - def dump(self, table: str) -> Tuple[str, int, Union[str, None]]: + def dump(self, table: Table) -> tuple[Table, int, Optional[str]]: """ Dump a table to a file. Args: - table (str): Table name. + table (Table): Table name. """ - self.console.print(f":construction: [blue]Working on table:[/blue] {table}") + self.console.print( + f":construction: [blue]Working on table:[/blue] {table.name}" + ) row_count = ( self.database.count_rows(table) @@ -105,7 +107,6 @@ def dump(self, table: str) -> Tuple[str, int, Union[str, None]]: or "max_rows_per_table" not in self.config.config["limits"] else int(self.config.config["limits"]["max_rows_per_table"]) ) - rows = self.database.get_row_names(table) last_num = 0 step = ( @@ -122,7 +123,7 @@ def dump(self, table: str) -> Tuple[str, int, Union[str, None]]: limit = step if x + step < row_count else step + row_count - x location = self.file.write_to_file( - table, self.database.get_data(table, rows, last_num, limit) + table, self.database.get_data(table, last_num, limit) ) last_num = x @@ -162,7 +163,7 @@ async def run(self) -> None: for res in sorted_output: table.add_row( - res[0], + res[0].name, f"{str(res[1])}{row_count_limited}", res[2] if res[2] is not None else "No data", ) diff --git a/redactdump/core/config.py b/redactdump/core/config.py index b2fddd4..b6f7beb 100644 --- a/redactdump/core/config.py +++ b/redactdump/core/config.py @@ -37,8 +37,11 @@ def load_config(self) -> dict: Optional("username"): str, Optional("password"): str, }, - Optional("limits"): {"max_rows_per_table": int, "select_columns": list}, - Optional("performance"): {"rows_per_request": int}, + Optional("limits"): { + Optional("max_rows_per_table"): int, + Optional("select_columns"): list, + }, + Optional("performance"): {Optional("rows_per_request"): int}, Optional("debug"): {"enabled": bool}, "redact": { Optional("columns"): { @@ -90,4 +93,10 @@ def load_config(self) -> dict: config["debug"] = {} config["debug"]["enabled"] = False + if "limits" not in config: + config["limits"] = {} + + if "select_columns" not in config["limits"]: + config["limits"]["select_columns"] = [] + return config diff --git a/redactdump/core/database.py b/redactdump/core/database.py index c66e4fc..928736b 100644 --- a/redactdump/core/database.py +++ b/redactdump/core/database.py @@ -4,6 +4,7 @@ from sqlalchemy import create_engine, text from redactdump.core.config import Config +from redactdump.core.models import Table, TableColumn from redactdump.core.redactor import Redactor @@ -43,14 +44,14 @@ def __init__(self, config: Config, console: Console) -> None: future=True, ) - def get_tables(self) -> List[str]: + def get_tables(self) -> List[Table]: """ Get a list of tables. Returns: List[str]: A list of tables. """ - tables = [] + tables: List[Table] = [] with self.engine.connect() as conn: conn = conn.execution_options( postgresql_readonly=True, postgresql_deferrable=True @@ -58,20 +59,43 @@ def get_tables(self) -> List[str]: with conn.begin(): result = conn.execute( text( - "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema='public'" + "SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND " + "table_schema='public' " ) ) - for item in result: - tables.append(item[0]) + for table in result: + table_columns = [] + columns = conn.execute( + text( + f"SELECT column_name, column_default, is_nullable, data_type FROM " + f"information_schema.columns WHERE table_name = '{table[0]}'" + ) + ) + for column in columns: + if ( + not self.config.config["limits"]["select_columns"] + or column["column_name"] + in self.config.config["limits"]["select_columns"] + ): + table_columns.append( + TableColumn( + column["column_name"], + column["data_type"], + column["is_nullable"], + column["column_default"], + ) + ) + + tables.append(Table(table[0], table_columns)) return tables - def count_rows(self, table: str) -> int: + def count_rows(self, table: Table) -> int: """ Get the number of rows in a table. Args: - table (str): The table name. + table (Table): The table name. Returns: int: The number of rows in the table. @@ -81,19 +105,20 @@ def count_rows(self, table: str) -> int: postgresql_readonly=True, postgresql_deferrable=True ) with conn.begin(): - result = conn.execute(text(f"SELECT COUNT(*) FROM {table}")) + result = conn.execute(text(f"SELECT COUNT(*) FROM {table.name}")) for item in result: return item[0] return 0 - def get_data(self, table: str, rows: list, offset: int, limit: int) -> list: + def get_data( + self, table: Table, offset: int, limit: int + ) -> list[list[TableColumn]]: """ Get data from a table. Args: - table (str): The table name. - rows (list): The list of row names. + table (Table): The table name. offset (int): The offset. limit (int): The limit. @@ -106,63 +131,39 @@ def get_data(self, table: str, rows: list, offset: int, limit: int) -> list: postgresql_readonly=True, postgresql_deferrable=True ) - if not set(self.config.config["limits"]["select_columns"]).issubset(rows): + if not set(self.config.config["limits"]["select_columns"]).issubset( + [column.name for column in table.columns] + ): return [] with conn.begin(): select = ( "*" - if "limits" not in self.config.config - or "select_columns" not in self.config.config["limits"] + if not self.config.config["limits"]["select_columns"] else ",".join(self.config.config["limits"]["select_columns"]) ) if self.config.config["debug"]["enabled"]: self.console.print( - f"[cyan]DEBUG: Running 'SELECT {select} FROM {table} OFFSET {offset} LIMIT {limit}'[/cyan]" + f"[cyan]DEBUG: Running 'SELECT {select} FROM {table.name} OFFSET {offset} LIMIT {limit}'[/cyan]" ) result = conn.execute( - text(f"SELECT {select} FROM {table} OFFSET {offset} LIMIT {limit}") + text( + f"SELECT {select} FROM {table.name} OFFSET {offset} LIMIT {limit}" + ) ) records = [dict(zip(row.keys(), row)) for row in result] for item in records: if self.redactor.data_rules or self.redactor.column_rules: - item = self.redactor.redact(item, rows) - - data.append(item) + modified_column = self.redactor.redact(item, table.columns) + else: + for key, value in item.items(): + column = next( + (x for x in table.columns if x.name == key), None + ) + if column is not None: + column.value = value + modified_column = table.columns + data.append(modified_column) return data - - def get_row_names(self, table: str) -> list: - """ - Get the row names from a table. - - Args: - table (str): The table name. - - Returns: - list: The row names. - """ - names = [] - with self.engine.connect() as conn: - conn = conn.execution_options( - postgresql_readonly=True, postgresql_deferrable=True - ) - with conn.begin(): - result = conn.execute( - text( - f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}'" - ) - ) - - select_columns = ( - [] - if "limits" not in self.config.config - or "select_columns" not in self.config.config["limits"] - else self.config.config["limits"]["select_columns"] - ) - - for item in result: - if not select_columns or item[0] in select_columns: - names.append(item[0]) - return names diff --git a/redactdump/core/file.py b/redactdump/core/file.py index 348ae66..2a01c6a 100644 --- a/redactdump/core/file.py +++ b/redactdump/core/file.py @@ -1,10 +1,11 @@ from datetime import datetime, timezone import os -from typing import Union +from typing import List, Union from rich.console import Console from redactdump.core.config import Config +from redactdump.core.models import Table, TableColumn class File: @@ -54,13 +55,13 @@ def create_output_locations(self) -> None: self.console.print() @staticmethod - def get_name(output: dict, table: str) -> str: + def get_name(output: dict, table: Table) -> str: """ Get the formatted name of the file. Args: output (dict): Output configuration. - table (str): Table name. + table (Table): Table. Returns: str: Name of the file. @@ -70,20 +71,22 @@ def get_name(output: dict, table: str) -> str: naming = ( output["naming"] .replace("[timestamp]", time.strftime("%Y-%m-%d-%H-%M-%S")) - .replace("[table_name]", table) + .replace("[table_name]", table.name) ) name = f"{naming}.sql" else: - name = f"{table}-{time.strftime('%Y-%m-%d-%H-%M-%S')}.sql" + name = f"{table.name}-{time.strftime('%Y-%m-%d-%H-%M-%S')}.sql" return name - def write_to_file(self, table: str, data: list) -> Union[str, None]: + def write_to_file( + self, table: Table, rows: List[List[TableColumn]] + ) -> Union[str, None]: """ Write data to file. Args: - table (str): Table name. - data (list): Data to write. + table (Table): Table name. + rows (List[List[TableColumn]]): Data to write. Returns: Union[str, None]: Name of the file. @@ -92,10 +95,29 @@ def write_to_file(self, table: str, data: list) -> Union[str, None]: if output["type"] == "multi_file": name = self.get_name(output, table) with open(f"{output['location']}/{name}", "a") as file: - for entry in data: + for row in rows: + values = [] - for value in entry.values(): - values.append(str(value)) - file.write(f"INSERT INTO {table} VALUES ({', '.join(values)});\n") + for column in row: + if ( + column.data_type == "bigint" + or column.data_type == "integer" + or column.data_type == "smallint" + or column.data_type == "double precision" + or column.data_type == "numeric" + ): + values.append(str(column.value)) + elif ( + column.data_type == "bit" + or column.data_type == "bit varying" + ): + values.append(str(f"b'{column.value}'")) + else: + values.append(str(f"'{column.value}'")) + + columns = '"' + '", "'.join([column.name for column in row]) + '"' + file.write( + f"INSERT INTO {table.name} ({columns}) VALUES ({', '.join(values)});\n" + ) return name return "" diff --git a/redactdump/core/models.py b/redactdump/core/models.py new file mode 100644 index 0000000..023991a --- /dev/null +++ b/redactdump/core/models.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass +from typing import List, Union + + +@dataclass +class TableColumn: + """TableColumn.""" + + name: str + data_type: str + is_nullable: bool + default: str + value: Union[str, None] = None + + +@dataclass +class Table: + """Table.""" + + name: str + columns: List[TableColumn] diff --git a/redactdump/core/redactor.py b/redactdump/core/redactor.py index 0f44da5..9aab235 100644 --- a/redactdump/core/redactor.py +++ b/redactdump/core/redactor.py @@ -5,6 +5,7 @@ from faker import Faker from redactdump.core.config import Config +from redactdump.core.models import TableColumn @dataclass @@ -71,29 +72,44 @@ def get_replacement(self, replacement: str) -> Union[str, Any]: if replacement is not None: func = getattr(self.fake, replacement) value = func() - if type(value) is not str: - return value - return f"'{value}'" + return value return "NULL" - def redact(self, data: dict, rows: list) -> dict: + def redact(self, data: dict, columns: List[TableColumn]) -> list[TableColumn]: """ Redact data. Args: data (dict): Data to redact. - rows (list): Rows to redact. + columns (list): Rows to redact. Returns: dict: Redacted data. """ + columns_redacted = [] for rule in self.column_rules: - for row in [row for row in rows if rule.pattern.search(row)]: - data[row] = self.get_replacement(rule.replacement) + for column in [ + column + for column in columns + if rule.pattern.search(column.name) + and column.name not in columns_redacted + ]: + column.value = self.get_replacement(rule.replacement) + columns_redacted.append(column.name) for rule in self.data_rules: for key, value in data.items(): + discovered_column = next((x for x in columns if x.name == key), None) + + if discovered_column is None: + raise LookupError + if discovered_column.name in columns_redacted: + continue + if rule.pattern.search(str(value)): - data[key] = self.get_replacement(rule.replacement) + discovered_column.value = self.get_replacement(rule.replacement) + columns_redacted.append(discovered_column.name) + else: + discovered_column.value = value - return data + return columns diff --git a/tests/config.yaml b/tests/config.yaml index eab5090..97648cd 100644 --- a/tests/config.yaml +++ b/tests/config.yaml @@ -12,6 +12,8 @@ redact: data: - pattern: '192.168.0.1' replacement: ipv4 + - pattern: 'my@email.com' + replacement: email - pattern: 'John Doe' replacement: name diff --git a/tests/test_redactor.py b/tests/test_redactor.py index 4c8d48d..a149fde 100644 --- a/tests/test_redactor.py +++ b/tests/test_redactor.py @@ -1,6 +1,7 @@ from configargparse import Namespace from redactdump.core import Config +from redactdump.core.models import TableColumn from redactdump.core.redactor import Redactor @@ -29,8 +30,15 @@ def test_redaction() -> None: for idx, item in enumerate(data): if redactor.data_rules or redactor.column_rules: - redactor.redact(item, ["full_name", "email"]) - assert data[idx]["full_name"] != original[idx]["full_name"] - assert data[idx]["secondary_name"] != original[idx]["secondary_name"] - assert data[idx]["ip"] != original[idx]["ip"] - assert data[idx]["email"] == original[idx]["email"] + results = redactor.redact( + item, + [ + TableColumn("full_name", "character varying", True, "", None), + TableColumn("secondary_name", "character varying", True, "", None), + TableColumn("ip", "character varying", True, "", None), + TableColumn("email", "character varying", True, "", None), + ], + ) + + for result in results: + assert result.value != original[idx][result.name]