Skip to content

Commit

Permalink
Improved SQL Generation. (#17)
Browse files Browse the repository at this point in the history
* Improved SQL Generation.

* Bugfix for redaction, linting + typing

* Linting
  • Loading branch information
math280h authored Dec 11, 2022
1 parent da7de52 commit 5b97a73
Show file tree
Hide file tree
Showing 9 changed files with 191 additions and 139 deletions.
74 changes: 23 additions & 51 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,58 +76,8 @@ output:
````

### Configuration Schema
<details>
<summary>Configuration schema</summary>

```python
Schema({
"connection": {
"type": str,
"host": str,
"port": int,
"database": str,
Optional("username"): str,
Optional("password"): str,
},
"redact": {
Optional("columns"): {
str: [
{
"name": str,
"replacement": lambda r: True
if r is None or type(r) is str
else False,
}
]
},
Optional("patterns"): {
Optional("column"): [
{
"pattern": str,
"replacement": lambda r: True
if r is None or type(r) is str
else False,
}
],
Optional("data"): [
{
"pattern": str,
"replacement": lambda r: True
if r is None or type(r) is str
else False,
}
],
},
},
"output": {
"type": lambda t: True if t in ["file", "multi_file"] else False,
"location": str,
Optional("naming"): str,
},
})
```

</details>
The configuration schema can be found [here](redactdump/core/config.py)

## Example

Expand Down Expand Up @@ -191,3 +141,25 @@ INSERT INTO table_name VALUES (99, 'Robin Jefferson');
```

</details>

## Known limitations

### Data types not supported

* box
* bytea
* inet
* interval
* circle
* cidr
* line
* lseg
* macaddr
* macaddr8
* pg_lsn
* pg_snapshot
* point
* polygon
* tsquery
* tsvector
* txid_snapshot
15 changes: 8 additions & 7 deletions redactdump/app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Union
from typing import Optional

import configargparse
from rich.console import Console
Expand Down Expand Up @@ -90,22 +90,23 @@ def __init__(self) -> None:
self.database = Database(self.config, self.console)
self.file = File(self.config, self.console)

def dump(self, table: str) -> Tuple[str, int, Union[str, None]]:
def dump(self, table: Table) -> tuple[Table, int, Optional[str]]:
"""
Dump a table to a file.
Args:
table (str): Table name.
table (Table): Table name.
"""
self.console.print(f":construction: [blue]Working on table:[/blue] {table}")
self.console.print(
f":construction: [blue]Working on table:[/blue] {table.name}"
)

row_count = (
self.database.count_rows(table)
if "limits" not in self.config.config
or "max_rows_per_table" not in self.config.config["limits"]
else int(self.config.config["limits"]["max_rows_per_table"])
)
rows = self.database.get_row_names(table)

last_num = 0
step = (
Expand All @@ -122,7 +123,7 @@ def dump(self, table: str) -> Tuple[str, int, Union[str, None]]:

limit = step if x + step < row_count else step + row_count - x
location = self.file.write_to_file(
table, self.database.get_data(table, rows, last_num, limit)
table, self.database.get_data(table, last_num, limit)
)
last_num = x

Expand Down Expand Up @@ -162,7 +163,7 @@ async def run(self) -> None:

for res in sorted_output:
table.add_row(
res[0],
res[0].name,
f"{str(res[1])}{row_count_limited}",
res[2] if res[2] is not None else "No data",
)
Expand Down
13 changes: 11 additions & 2 deletions redactdump/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ def load_config(self) -> dict:
Optional("username"): str,
Optional("password"): str,
},
Optional("limits"): {"max_rows_per_table": int, "select_columns": list},
Optional("performance"): {"rows_per_request": int},
Optional("limits"): {
Optional("max_rows_per_table"): int,
Optional("select_columns"): list,
},
Optional("performance"): {Optional("rows_per_request"): int},
Optional("debug"): {"enabled": bool},
"redact": {
Optional("columns"): {
Expand Down Expand Up @@ -90,4 +93,10 @@ def load_config(self) -> dict:
config["debug"] = {}
config["debug"]["enabled"] = False

if "limits" not in config:
config["limits"] = {}

if "select_columns" not in config["limits"]:
config["limits"]["select_columns"] = []

return config
107 changes: 54 additions & 53 deletions redactdump/core/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sqlalchemy import create_engine, text

from redactdump.core.config import Config
from redactdump.core.models import Table, TableColumn
from redactdump.core.redactor import Redactor


Expand Down Expand Up @@ -43,35 +44,58 @@ def __init__(self, config: Config, console: Console) -> None:
future=True,
)

def get_tables(self) -> List[str]:
def get_tables(self) -> List[Table]:
"""
Get a list of tables.
Returns:
List[str]: A list of tables.
"""
tables = []
tables: List[Table] = []
with self.engine.connect() as conn:
conn = conn.execution_options(
postgresql_readonly=True, postgresql_deferrable=True
)
with conn.begin():
result = conn.execute(
text(
"SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND table_schema='public'"
"SELECT table_name FROM information_schema.tables WHERE table_type='BASE TABLE' AND "
"table_schema='public' "
)
)

for item in result:
tables.append(item[0])
for table in result:
table_columns = []
columns = conn.execute(
text(
f"SELECT column_name, column_default, is_nullable, data_type FROM "
f"information_schema.columns WHERE table_name = '{table[0]}'"
)
)
for column in columns:
if (
not self.config.config["limits"]["select_columns"]
or column["column_name"]
in self.config.config["limits"]["select_columns"]
):
table_columns.append(
TableColumn(
column["column_name"],
column["data_type"],
column["is_nullable"],
column["column_default"],
)
)

tables.append(Table(table[0], table_columns))
return tables

def count_rows(self, table: str) -> int:
def count_rows(self, table: Table) -> int:
"""
Get the number of rows in a table.
Args:
table (str): The table name.
table (Table): The table name.
Returns:
int: The number of rows in the table.
Expand All @@ -81,19 +105,20 @@ def count_rows(self, table: str) -> int:
postgresql_readonly=True, postgresql_deferrable=True
)
with conn.begin():
result = conn.execute(text(f"SELECT COUNT(*) FROM {table}"))
result = conn.execute(text(f"SELECT COUNT(*) FROM {table.name}"))

for item in result:
return item[0]
return 0

def get_data(self, table: str, rows: list, offset: int, limit: int) -> list:
def get_data(
self, table: Table, offset: int, limit: int
) -> list[list[TableColumn]]:
"""
Get data from a table.
Args:
table (str): The table name.
rows (list): The list of row names.
table (Table): The table name.
offset (int): The offset.
limit (int): The limit.
Expand All @@ -106,63 +131,39 @@ def get_data(self, table: str, rows: list, offset: int, limit: int) -> list:
postgresql_readonly=True, postgresql_deferrable=True
)

if not set(self.config.config["limits"]["select_columns"]).issubset(rows):
if not set(self.config.config["limits"]["select_columns"]).issubset(
[column.name for column in table.columns]
):
return []

with conn.begin():
select = (
"*"
if "limits" not in self.config.config
or "select_columns" not in self.config.config["limits"]
if not self.config.config["limits"]["select_columns"]
else ",".join(self.config.config["limits"]["select_columns"])
)

if self.config.config["debug"]["enabled"]:
self.console.print(
f"[cyan]DEBUG: Running 'SELECT {select} FROM {table} OFFSET {offset} LIMIT {limit}'[/cyan]"
f"[cyan]DEBUG: Running 'SELECT {select} FROM {table.name} OFFSET {offset} LIMIT {limit}'[/cyan]"
)

result = conn.execute(
text(f"SELECT {select} FROM {table} OFFSET {offset} LIMIT {limit}")
text(
f"SELECT {select} FROM {table.name} OFFSET {offset} LIMIT {limit}"
)
)
records = [dict(zip(row.keys(), row)) for row in result]
for item in records:
if self.redactor.data_rules or self.redactor.column_rules:
item = self.redactor.redact(item, rows)

data.append(item)
modified_column = self.redactor.redact(item, table.columns)
else:
for key, value in item.items():
column = next(
(x for x in table.columns if x.name == key), None
)
if column is not None:
column.value = value
modified_column = table.columns
data.append(modified_column)
return data

def get_row_names(self, table: str) -> list:
"""
Get the row names from a table.
Args:
table (str): The table name.
Returns:
list: The row names.
"""
names = []
with self.engine.connect() as conn:
conn = conn.execution_options(
postgresql_readonly=True, postgresql_deferrable=True
)
with conn.begin():
result = conn.execute(
text(
f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}'"
)
)

select_columns = (
[]
if "limits" not in self.config.config
or "select_columns" not in self.config.config["limits"]
else self.config.config["limits"]["select_columns"]
)

for item in result:
if not select_columns or item[0] in select_columns:
names.append(item[0])
return names
Loading

0 comments on commit 5b97a73

Please sign in to comment.