Skip to content

Commit

Permalink
refactor exceptions some more
Browse files Browse the repository at this point in the history
  • Loading branch information
psyb0t committed Jul 3, 2024
1 parent 0e5a862 commit adf135e
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 40 deletions.
31 changes: 31 additions & 0 deletions src/ezpyai/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# LLMProvider Exceptions
class UnsupportedModelError(Exception):
"""Exception raised when an unsupported model is used."""

Expand All @@ -12,6 +13,29 @@ def __init__(self, message="Unsupported lora", *args):
super().__init__(message, *args)


# LLM Exceptions
class PromptUserMessageMissingError(Exception):
"""Exception raised when there is no user message in the prompt."""

def __init__(self, message="Prompt user message missing", *args):
super().__init__(message, *args)


class LLMResponseEmptyError(Exception):
"""Exception raised when there is no LLM response message."""

def __init__(self, message="LLM response empty", *args):
super().__init__(message, *args)


class LLMInferenceError(Exception):
"""Exception raised when there is an error during the LLM inference."""

def __init__(self, message="LLM inference error", *args):
super().__init__(message, *args)


# General Exceptions
class JSONParseError(Exception):
"""Exception raised when a JSON parse error occurs."""

Expand All @@ -38,3 +62,10 @@ class FileProcessingError(Exception):

def __init__(self, message="Error during file processing", *args):
super().__init__(message, *args)


class FileNotFoundError(Exception):
"""Exception raised when a file is not found."""

def __init__(self, message="File not found", *args):
super().__init__(message, *args)
10 changes: 9 additions & 1 deletion src/ezpyai/llm/dataset/sources/telegram.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
import os

from ezpyai.exceptions import FileNotFoundError


class DatasetSourceTelegram:
def __init__(self, json_export_file_path: str) -> None:
pass
if not os.path.exists(json_export_file_path):
raise FileNotFoundError(f"File not found: {json_export_file_path}")

self._json_export_file_path = json_export_file_path
19 changes: 0 additions & 19 deletions src/ezpyai/llm/exceptions.py

This file was deleted.

17 changes: 9 additions & 8 deletions src/ezpyai/llm/knowledge/_knowledge_gatherer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import hashlib
import pandas as pd
import xml.etree.ElementTree as ET
import ezpyai.exceptions as exceptions

from ezpyai.exceptions import (
UnsupportedFileTypeError,
FileReadError,
FileProcessingError,
)

from bs4 import BeautifulSoup
from typing import Dict
Expand Down Expand Up @@ -184,13 +189,9 @@ def _process_file(self, file_path: str):
self._process_zip(file_path)
return
else:
raise exceptions.UnsupportedFileTypeError(
f"Unsupported file type for {file_path}"
)
raise UnsupportedFileTypeError(f"Unsupported file type for {file_path}")
except Exception as e:
raise exceptions.FileReadError(
f"Error reading {file_path}: {str(e)}"
) from e
raise FileReadError(f"Error reading {file_path}: {str(e)}") from e

paragraphs = content.split("\n")
paragraph_counter = 1
Expand Down Expand Up @@ -235,7 +236,7 @@ def _process_zip(self, zip_path: str):

self._process_directory(temp_dir)
except Exception as e:
raise exceptions.FileProcessingError(
raise FileProcessingError(
f"Error processing ZIP file {zip_path}: {str(e)}"
) from e
finally:
Expand Down
24 changes: 12 additions & 12 deletions src/ezpyai/llm/providers/openai.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import os
import ezpyai.llm.exceptions as exceptions

from typing import Annotated
from typing import List, Dict
from openai import OpenAI as _OpenAI
from ezpyai._logger import logger
from ezpyai.llm.providers._llm_provider import BaseLLMProvider
from ezpyai.llm.prompt import Prompt

from ezpyai.exceptions import (
PromptUserMessageMissingError,
LLMInferenceError,
LLMResponseEmptyError,
)

from ezpyai._constants import (
ENV_VAR_NAME_OPENAI_API_KEY,
ENV_VAR_NAME_OPENAI_ORGANIZATION,
Expand Down Expand Up @@ -91,9 +96,7 @@ def _get_system_message(self, message: str) -> dict:
def _get_user_message(self, message: str) -> dict:
return {"role": "user", "content": message}

def _prompt_to_messages(
self, prompt: Prompt
) -> Annotated[list[dict], "Raises exceptions.NoUserMessage"]:
def _prompt_to_messages(self, prompt: Prompt) -> List[Dict]:
messages = []
if prompt.has_system_message():
messages.append(self._get_system_message(prompt.get_system_message()))
Expand All @@ -102,16 +105,13 @@ def _prompt_to_messages(
messages.append(self._get_user_message(prompt.get_context_as_string()))

if not prompt.has_user_message():
raise exceptions.NoUserMessage()
raise PromptUserMessageMissingError()

messages.append(self._get_user_message(prompt.get_user_message()))

return messages

def get_response(self, prompt: Prompt) -> Annotated[
str,
"Raises exceptions.NoUserMessage, exceptions.NoLLMResponseMessage, exceptions.InvokeError",
]:
def get_response(self, prompt: Prompt) -> str:
messages = self._prompt_to_messages(prompt)

try:
Expand All @@ -124,9 +124,9 @@ def get_response(self, prompt: Prompt) -> Annotated[
messages=messages,
)
except Exception as e:
raise exceptions.InvokeError() from e
raise LLMInferenceError() from e

if not response.choices:
raise exceptions.NoLLMResponseMessage()
raise LLMResponseEmptyError()

return response.choices[0].message.content

0 comments on commit adf135e

Please sign in to comment.