From 866f9196b50748c9809bb2f5981585d40a4bfd56 Mon Sep 17 00:00:00 2001 From: Simatwa Date: Sat, 9 Nov 2024 18:49:09 +0300 Subject: [PATCH] feat: Add logging module fix: Update pytgpt import fix: Modify history file path handling fix: Correct SQL query syntax fix: Improve AI-generated command processing fix: Add redo functionality fix: Enhance error handling for system commands fix: Improve table display logic fix: Fix regex pattern for SELECT statements fix: Update requirements.txt --- manager.py | 53 ++++++++++++++++++++++++++++++++---------------- requirements.txt | 5 ----- 2 files changed, 36 insertions(+), 22 deletions(-) delete mode 100644 requirements.txt diff --git a/manager.py b/manager.py index 4869995..4074a80 100755 --- a/manager.py +++ b/manager.py @@ -6,6 +6,7 @@ import time import rich import click +import logging import sqlite3 import getpass import datetime @@ -39,6 +40,12 @@ table_headers = ("_", "name", "type", "_", "_", "_") +logging.basicConfig( + format="%(asctime)s - %(levelname)s : %(message)s", + datefmt="%d-%b-%Y %H:%M:%S", + level=logging.INFO, +) + def cli_error_handler(func): """Decorator for handling exceptions accordingly""" @@ -127,7 +134,7 @@ class TextToSql: def __init__(self, db_manager: Sqlite3Manager): """Initializes `TextToSql`""" try: - import pytgpt.auto as auto + from pytgpt.auto import AUTO except ImportError: raise Exception( "Looks like pytgpt isn't installed. Install it before using TextToSql - " @@ -136,7 +143,7 @@ def __init__(self, db_manager: Sqlite3Manager): history_file = Path.home() / ".sqlite-cli-manager-ai-chat-history.txt" if history_file.exists(): os.remove(history_file) - self.ai = auto.AUTO(filepath=history_file) + self.ai = AUTO(filepath=str(history_file)) assert isinstance( db_manager, Sqlite3Manager ), f"db_manager must be an instance of {Sqlite3Manager} not {type(db_manager)}" @@ -170,7 +177,7 @@ def context_prompt(self) -> str: """ \n For example: - User: List top 10 entries in the Linux table where distro contains letter 'a' + User: List first 10 entries in the Linux table where distro contains letter 'a' LLM : {SELECT * FROM Linux WHERE distro LIKE '%a%';} User : Remove entries from table Linux whose id is greater than 10. @@ -179,7 +186,7 @@ def context_prompt(self) -> str: If the user's request IS UNDOUBTEDBLY INCOMPLETE, seek clarification. For example: User: Add column to Linux table. - LLM: What kind of data will be stored in the column and suggest column name if possible? + LLM: Describe the data to be stored in the column and suggest column name if possible? User: The column will be storing maintainance status of the linux distros. LLM: {ALTER TABLE Linux ADD COLUMN is_maintained BOOLEAN;} @@ -211,7 +218,8 @@ def generate(self, prompt: str): """Main method""" self.ai.intro = self.context_prompt assert prompt, f"Prompt cannot be null!" - return self.process_response(self.ai.chat(prompt)) + ai_response = self.ai.chat(prompt) + return self.process_response(ai_response) class HistoryCompletions(Completer): @@ -475,12 +483,16 @@ def do_columns(self, line): else: click.secho("Table name is required.", fg="yellow") + def do_redo(self, line): + """Re-run previous sql command""" + history = self.completer_session.history.get_strings() + return self.default(history[-2], prompt_confirmation=True) + @cli_error_handler def default(self, line: str, prompt_confirmation: bool = False, ai_generated=False): """Run sql statemnt against database""" - if line.startswith("./"): - self.do_sys(line[2:]) + self.do_sys(line[2:].strip()) return elif line.startswith("!"): @@ -495,20 +507,22 @@ def default(self, line: str, prompt_confirmation: bool = False, ai_generated=Fal return elif line.startswith("/sql"): - line = [line[4:]] + line = [line[4:].strip()] elif line.startswith("/ai"): - line = TextToSql(self.db_manager).generate(line[3:]) - ai_generated = prompt_confirmation = True + line = TextToSql(self.db_manager).generate(line[3:].strip()) + prompt_confirmation = True + ai_generated = True elif self.ai: line = self.text_to_sql.generate(line) - ai_generated = prompt_confirmation = True + ai_generated = True + prompt_confirmation = True else: line = [line] self.__start_time = time.time() for sql_statement in line: if ( - not self.yes - and prompt_confirmation + prompt_confirmation + and not self.yes and not click.confirm("[Exc] - " + sql_statement) ): continue @@ -568,8 +582,13 @@ def stdout_data( table.add_column("Index", justify="center") def add_headers(header_values: list[str]): - for header in header_values: - table.add_column(header) + if data and len(header_values) == len(data[0]): + for header in header_values: + table.add_column(header) + else: + logging.debug( + f"No data to be displayed or length of data and headers don't match." + ) if headers: add_headers(headers) @@ -586,7 +605,7 @@ def add_headers(header_values: list[str]): specific_column_names_string = re.findall( r"^select\s+([\w_,\s]+)\s+from.+", *re_args ) - if re.match(r"^select\s+\.*", *re_args): + if re.match(r"^select\s+\*.*", *re_args): table_name = re.findall(r".+from\s+([\w_]+).*", *re_args) if table_name: tbl_name = table_name[0] @@ -615,7 +634,7 @@ def add_headers(header_values: list[str]): else: for index, entry in enumerate(data): table.add_row(*[str(index)] + [str(token) for token in entry]) - rich.print(table) + rich.print(table) @staticmethod @click.command() diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index 3ef7a14..0000000 --- a/requirements.txt +++ /dev/null @@ -1,5 +0,0 @@ -click==8.1.3 -rich==13.3.4 -colorama==0.4.6 -prompt-toolkit==3.0.48 -python-tgpt==0.7.7 \ No newline at end of file