Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
psyb0t committed Jun 18, 2024
1 parent 365d614 commit bd87846
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 11 deletions.
22 changes: 14 additions & 8 deletions src/ezpyai/llm/knowledge/_knowledge_gatherer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ezpyai._logger import logger
from ezpyai._constants import _DICT_KEY_SUMMARY
from ezpyai.llm._llm import LLM
from ezpyai.llm.prompt import Prompt, SUMMARIZER_SYSTEM_MESSAGE
from ezpyai.llm.prompt import Prompt, get_summarizer_prompt
from ezpyai.llm.knowledge.knowledge_item import KnowledgeItem

_MIMETYPE_TEXT = "text/plain"
Expand All @@ -42,6 +42,7 @@ class KnowledgeGatherer:
Attributes:
_items (Dict[str, KnowledgeItem]): A dictionary containing file paths
and their processed content indexed by SHA256 hashes of the content.
_summarizer (LLM): The LLM summarizer to use for knowledge collection.
"""

def __init__(self, summarizer: LLM = None) -> None:
Expand Down Expand Up @@ -81,7 +82,7 @@ def _get_knowledge_item_from_file_paragraph(
file_name = os.path.splitext(os.path.basename(file_path))[0]
file_ext = os.path.splitext(file_path)[1]

return KnowledgeItem(
knowledge_item = KnowledgeItem(
id=id,
content=paragraph,
metadata={
Expand All @@ -92,7 +93,17 @@ def _get_knowledge_item_from_file_paragraph(
},
)

self._summarize(knowledge_item)

return knowledge_item

def _summarize(self, knowledge_item: KnowledgeItem) -> None:
"""
Summarize the given knowledge item.
Args:
knowledge_item (KnowledgeItem): The knowledge item to summarize.
"""
if self._summarizer is None:
return ""

Expand All @@ -101,10 +112,7 @@ def _summarize(self, knowledge_item: KnowledgeItem) -> None:

logger.debug(f"Summarizing knowledge item: {knowledge_item}")

prompt: Prompt = Prompt(
system_message=SUMMARIZER_SYSTEM_MESSAGE,
user_message=f"Summarize the following text: {knowledge_item.content}",
)
prompt: Prompt = get_summarizer_prompt(knowledge_item.content)

knowledge_item.summary = self._summarizer.get_structured_response(
prompt, response_format={_DICT_KEY_SUMMARY: ""}
Expand Down Expand Up @@ -193,8 +201,6 @@ def _process_file(self, file_path: str):
paragraph_number=paragraph_counter,
)

self._summarize(knowledge_item)

self._items[knowledge_item.id] = knowledge_item

logger.debug(f"Added {knowledge_item.id} to _items dictionary")
Expand Down
4 changes: 1 addition & 3 deletions src/ezpyai/llm/knowledge/chroma_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def store(self, collection: str, data_path: str, summarizer: LLM = None) -> None
knowledge_gatherer.gather(data_path)
knowledge_items = knowledge_gatherer.get_items()

logger.debug(
f"Collected the following knowledge items: \n{json.dumps(knowledge_items, indent=2)}"
)
logger.debug(f"Collected knowlege items: {knowledge_items}")

collection: chromadb.Collection = self._client.get_or_create_collection(
name=collection,
Expand Down
26 changes: 26 additions & 0 deletions src/ezpyai/llm/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@


class Prompt:
"""
A class to store prompt data.
Attributes:
system_message (str): The system message of the prompt.
user_message (str): The user message of the prompt.
context (List[str]): The context of the prompt.
"""

def __init__(
self,
user_message: str,
Expand Down Expand Up @@ -60,3 +69,20 @@ def get_user_message(self) -> str:

def set_user_message(self, user_message: str) -> None:
self._user_message = user_message


def get_summarizer_prompt(to_summarize: str) -> Prompt:
"""
Get a prompt for summarizing the given text.
Args:
to_summarize (str): The text to summarize.
Returns:
Prompt: The prompt for summarizing the text.
"""

return Prompt(
system_message=SUMMARIZER_SYSTEM_MESSAGE,
user_message=f"Summarize the following text: {to_summarize}",
)

0 comments on commit bd87846

Please sign in to comment.