Skip to content

Commit

Permalink
telegram source to chat dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
psyb0t committed Jul 7, 2024
1 parent 58cd374 commit 5a6e799
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 115 deletions.
1 change: 1 addition & 0 deletions src/ezpyai/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
from ezpyai.constants._http_methods import *
from ezpyai.constants._names import *
from ezpyai.constants._chat_ids import *
from ezpyai.constants._chat_roles import *

LIB_NAME: str = "ezpyai"
3 changes: 3 additions & 0 deletions src/ezpyai/constants/_chat_roles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
CHAT_ROLE_SYSTEM: str = "system"
CHAT_ROLE_USER: str = "user"
CHAT_ROLE_ASSISTANT: str = "assistant"
1 change: 1 addition & 0 deletions src/ezpyai/constants/_dict_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DICT_KEY_TYPE: str = "type"
DICT_KEY_NAME: str = "name"
DICT_KEY_MESSAGES: str = "messages"
DICT_KEY_NUM_MESSAGES: str = "num_messages"
DICT_KEY_DATE: str = "date"
DICT_KEY_DATE_UNIXTIME: str = "date_unixtime"
DICT_KEY_TEXT: str = "text"
Expand Down
2 changes: 1 addition & 1 deletion src/ezpyai/dataset/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ezpyai.dataset.chat.chat import DatasetChat
from ezpyai.dataset.chat.dataset_chat import DatasetChat, DatasetChatEntry # type: ignore
49 changes: 0 additions & 49 deletions src/ezpyai/dataset/chat/chat.py

This file was deleted.

21 changes: 21 additions & 0 deletions src/ezpyai/dataset/chat/dataset_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List


class DatasetChatEntry:
def __init__(self, role: str, content: str) -> None:
self.role: str = role
self.content: str = content

def __str__(self) -> str:
return f"{self.role}: {self.content}"


class DatasetChat:
def __init__(self, entries: List[DatasetChatEntry] | None = None) -> None:
if entries is None:
entries = []

self.entries: List[DatasetChatEntry] = entries

def add_entry(self, role: str, content: str) -> None:
self.entries.append(DatasetChatEntry(role=role, content=content))
2 changes: 1 addition & 1 deletion src/ezpyai/dataset/chat/sources/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ezpyai.dataset.chat.sources.telegram import DatasetSourceTelegram
from ezpyai.dataset.chat.sources.telegram import DatasetSourceTelegram # type: ignore
10 changes: 10 additions & 0 deletions src/ezpyai/dataset/chat/sources/_dataset_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from abc import ABC, abstractmethod
from typing import List

from ezpyai.dataset.chat import DatasetChat


class DatasetSource(ABC):
@abstractmethod
def to_dataset_chats(self, system_message_tpl: str) -> List[DatasetChat]:
pass
68 changes: 53 additions & 15 deletions src/ezpyai/dataset/chat/sources/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,38 @@
import json
from typing import List, Dict, Any
from datetime import datetime
from jinja2 import Template

from ezpyai._logger import logger
from ezpyai.exceptions import FileNotFoundError, JSONParseError
from ezpyai.dataset.chat.sources._dataset_source import DatasetSource
from ezpyai.dataset.chat import DatasetChat, DatasetChatEntry

from ezpyai.constants import (
DICT_KEY_CHATS,
DICT_KEY_LIST,
DICT_KEY_TYPE,
DICT_KEY_ID,
DICT_KEY_NAME,
DICT_KEY_MESSAGES,
DICT_KEY_NUM_MESSAGES,
DICT_KEY_DATE,
DICT_KEY_DATE_UNIXTIME,
DICT_KEY_TEXT,
DICT_KEY_TEXT_ENTITIES,
DICT_KEY_FROM,
DICT_KEY_FROM_ID,
NAME_UNKNOWN,
CHAT_ID_TELEGRAM,
CHAT_ROLE_SYSTEM,
CHAT_ROLE_USER,
CHAT_ROLE_ASSISTANT,
)

_TELEGRAM_CHAT_TYPE_PERSONAL: str = "personal_chat"
_TELEGRAM_MESSAGE_TYPE_MESSAGE: str = "message"


# dict_keys(['id', 'type', 'date', 'date_unixtime', 'from', 'from_id', 'text', 'text_entities'])
class _TelegramChatMessage:
def __init__(
self,
Expand All @@ -49,13 +57,13 @@ def __str__(self) -> str:

def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"type": self.type,
"date": self.date,
"date_unixtime": self.date_unixtime,
"from": self.from_name,
"from_id": self.from_id,
"text": self.text,
DICT_KEY_ID: self.id,
DICT_KEY_TYPE: self.type,
DICT_KEY_DATE: self.date,
DICT_KEY_DATE_UNIXTIME: self.date_unixtime,
DICT_KEY_FROM: self.from_name,
DICT_KEY_FROM_ID: self.from_id,
DICT_KEY_TEXT: self.text,
}


Expand All @@ -72,20 +80,20 @@ def __init__(

def __str__(self) -> str:
dict_repr = self.to_dict()
dict_repr["num_messages"] = len(dict_repr["messages"])
dict_repr.pop("messages")
dict_repr[DICT_KEY_NUM_MESSAGES] = len(dict_repr[DICT_KEY_MESSAGES])
dict_repr.pop(DICT_KEY_MESSAGES)

return f"{self.__class__.__name__}: {dict_repr}"

def to_dict(self) -> Dict[str, Any]:
return {
"id": self.id,
"name": self.name,
"messages": [message.to_dict() for message in self.messages],
DICT_KEY_ID: self.id,
DICT_KEY_NAME: self.name,
DICT_KEY_MESSAGES: [message.to_dict() for message in self.messages],
}


class DatasetSourceTelegram:
class DatasetSourceTelegram(DatasetSource):
def __init__(
self,
json_export_file_path: str,
Expand All @@ -105,7 +113,7 @@ def __init__(
def __str__(self) -> str:
return f"{self.__class__.__name__}(json_export_file_path={self._json_export_file_path}, assistant_from_id={self._assistant_from_id}, entries={len(self._chats)})"

def get_chats(
def _get_chats(
self,
with_zero_messages: bool = True,
) -> List[_TelegramChat]:
Expand Down Expand Up @@ -260,3 +268,33 @@ def _get_processed_message(self, message: Dict[str, Any]) -> _TelegramChatMessag
message_from_id,
message_text,
)

def to_dataset_chats(self, system_message_tpl: str = "") -> List[DatasetChat]:
dataset_chats: List[DatasetChat] = []
chats = self._get_chats(with_zero_messages=False)

for chat in chats:
dataset_chat_entries: List[DatasetChatEntry] = []
system_message_is_set = False
for message in chat.messages:
if not system_message_is_set and system_message_tpl:
template = Template(system_message_tpl)
content = template.render(chat=chat, message=message)

dataset_chat_entries.append(
DatasetChatEntry(role=CHAT_ROLE_SYSTEM, content=content)
)

system_message_is_set = True

role: str = CHAT_ROLE_USER
if message.from_id == self._assistant_from_id:
role = CHAT_ROLE_ASSISTANT

dataset_chat_entries.append(
DatasetChatEntry(role=role, content=message.text)
)

dataset_chats.append(DatasetChat(dataset_chat_entries))

return dataset_chats
49 changes: 0 additions & 49 deletions src/ezpyai/llm/providers/_llm_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,13 @@ def get_structured_response(


class BaseLLMProvider(LLMProvider):
"""
Base class for LLM providers.
"""

@abstractmethod
def get_response(self, prompt: Prompt) -> str:
"""
Get the response for the given prompt.
Args:
prompt (Prompt): The input prompt.
Returns:
str: The response.
"""

return ""

def _validate_response_format(
self, data: Any, response_format: Dict[Any, Any] | List[Any]
) -> bool:
"""
Validate the response format.
Args:
data (Any): The data to validate.
response_format (Union[Dict, List]): The response format.
Returns:
bool: True if the response format is valid, False otherwise.
"""

if not response_format:
return True

Expand All @@ -75,16 +50,6 @@ def _validate_response_format(
)

def remove_artifacts(self, response: str) -> str:
"""
Remove artifacts from the response.
Args:
response (str): The response to remove artifacts from.
Returns:
str: The response without artifacts.
"""

artifacts = ["```json", "```"]
for artifact in artifacts:
response = response.replace(artifact, "")
Expand All @@ -94,20 +59,6 @@ def remove_artifacts(self, response: str) -> str:
def get_structured_response(
self, prompt: Prompt, response_format: Dict[Any, Any] | List[Any]
) -> Dict[Any, Any] | List[Any] | None:
"""
Get the structured response for the given prompt and response format.
Args:
prompt (Prompt): The input prompt.
response_format Dict[Any, Any] | List[Any]: The response format.
Returns:
Dict[Any, Any] | List[Any]: The structured response.
Raises:
JSONParseError: If the response cannot be parsed as JSON.
"""

prompt = Prompt(
user_message=prompt.get_user_message(),
context=prompt.get_context(),
Expand Down

0 comments on commit 5a6e799

Please sign in to comment.