Skip to content

Commit

Permalink
feat: Add logging module
Browse files Browse the repository at this point in the history
fix: Update pytgpt import
fix: Modify history file path handling
fix: Correct SQL query syntax
fix: Improve AI-generated command processing
fix: Add redo functionality
fix: Enhance error handling for system commands
fix: Improve table display logic
fix: Fix regex pattern for SELECT statements
fix: Update requirements.txt
  • Loading branch information
Simatwa committed Nov 9, 2024
1 parent 0ce2987 commit 866f919
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 22 deletions.
53 changes: 36 additions & 17 deletions manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import rich
import click
import logging
import sqlite3
import getpass
import datetime
Expand Down Expand Up @@ -39,6 +40,12 @@

table_headers = ("_", "name", "type", "_", "_", "_")

logging.basicConfig(
format="%(asctime)s - %(levelname)s : %(message)s",
datefmt="%d-%b-%Y %H:%M:%S",
level=logging.INFO,
)


def cli_error_handler(func):
"""Decorator for handling exceptions accordingly"""
Expand Down Expand Up @@ -127,7 +134,7 @@ class TextToSql:
def __init__(self, db_manager: Sqlite3Manager):
"""Initializes `TextToSql`"""
try:
import pytgpt.auto as auto
from pytgpt.auto import AUTO
except ImportError:
raise Exception(
"Looks like pytgpt isn't installed. Install it before using TextToSql - "
Expand All @@ -136,7 +143,7 @@ def __init__(self, db_manager: Sqlite3Manager):
history_file = Path.home() / ".sqlite-cli-manager-ai-chat-history.txt"
if history_file.exists():
os.remove(history_file)
self.ai = auto.AUTO(filepath=history_file)
self.ai = AUTO(filepath=str(history_file))
assert isinstance(
db_manager, Sqlite3Manager
), f"db_manager must be an instance of {Sqlite3Manager} not {type(db_manager)}"
Expand Down Expand Up @@ -170,7 +177,7 @@ def context_prompt(self) -> str:
"""
\n
For example:
User: List top 10 entries in the Linux table where distro contains letter 'a'
User: List first 10 entries in the Linux table where distro contains letter 'a'
LLM : {SELECT * FROM Linux WHERE distro LIKE '%a%';}
User : Remove entries from table Linux whose id is greater than 10.
Expand All @@ -179,7 +186,7 @@ def context_prompt(self) -> str:
If the user's request IS UNDOUBTEDBLY INCOMPLETE, seek clarification.
For example:
User: Add column to Linux table.
LLM: What kind of data will be stored in the column and suggest column name if possible?
LLM: Describe the data to be stored in the column and suggest column name if possible?
User: The column will be storing maintainance status of the linux distros.
LLM: {ALTER TABLE Linux ADD COLUMN is_maintained BOOLEAN;}
Expand Down Expand Up @@ -211,7 +218,8 @@ def generate(self, prompt: str):
"""Main method"""
self.ai.intro = self.context_prompt
assert prompt, f"Prompt cannot be null!"
return self.process_response(self.ai.chat(prompt))
ai_response = self.ai.chat(prompt)
return self.process_response(ai_response)


class HistoryCompletions(Completer):
Expand Down Expand Up @@ -475,12 +483,16 @@ def do_columns(self, line):
else:
click.secho("Table name is required.", fg="yellow")

def do_redo(self, line):
"""Re-run previous sql command"""
history = self.completer_session.history.get_strings()
return self.default(history[-2], prompt_confirmation=True)

@cli_error_handler
def default(self, line: str, prompt_confirmation: bool = False, ai_generated=False):
"""Run sql statemnt against database"""

if line.startswith("./"):
self.do_sys(line[2:])
self.do_sys(line[2:].strip())
return

elif line.startswith("!"):
Expand All @@ -495,20 +507,22 @@ def default(self, line: str, prompt_confirmation: bool = False, ai_generated=Fal
return

elif line.startswith("/sql"):
line = [line[4:]]
line = [line[4:].strip()]
elif line.startswith("/ai"):
line = TextToSql(self.db_manager).generate(line[3:])
ai_generated = prompt_confirmation = True
line = TextToSql(self.db_manager).generate(line[3:].strip())
prompt_confirmation = True
ai_generated = True
elif self.ai:
line = self.text_to_sql.generate(line)
ai_generated = prompt_confirmation = True
ai_generated = True
prompt_confirmation = True
else:
line = [line]
self.__start_time = time.time()
for sql_statement in line:
if (
not self.yes
and prompt_confirmation
prompt_confirmation
and not self.yes
and not click.confirm("[Exc] - " + sql_statement)
):
continue
Expand Down Expand Up @@ -568,8 +582,13 @@ def stdout_data(
table.add_column("Index", justify="center")

def add_headers(header_values: list[str]):
for header in header_values:
table.add_column(header)
if data and len(header_values) == len(data[0]):
for header in header_values:
table.add_column(header)
else:
logging.debug(
f"No data to be displayed or length of data and headers don't match."
)

if headers:
add_headers(headers)
Expand All @@ -586,7 +605,7 @@ def add_headers(header_values: list[str]):
specific_column_names_string = re.findall(
r"^select\s+([\w_,\s]+)\s+from.+", *re_args
)
if re.match(r"^select\s+\.*", *re_args):
if re.match(r"^select\s+\*.*", *re_args):
table_name = re.findall(r".+from\s+([\w_]+).*", *re_args)
if table_name:
tbl_name = table_name[0]
Expand Down Expand Up @@ -615,7 +634,7 @@ def add_headers(header_values: list[str]):
else:
for index, entry in enumerate(data):
table.add_row(*[str(index)] + [str(token) for token in entry])
rich.print(table)
rich.print(table)

@staticmethod
@click.command()
Expand Down
5 changes: 0 additions & 5 deletions requirements.txt

This file was deleted.

0 comments on commit 866f919

Please sign in to comment.