Skip to content

Commit

Permalink
chore: create LlamaClient class
Browse files Browse the repository at this point in the history
  • Loading branch information
cbrzn committed Jun 17, 2024
1 parent 4ee9fa9 commit 1c994c8
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 5 deletions.
11 changes: 9 additions & 2 deletions autotx/AutoTx.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def __init__(
config: Config,
on_notify_user: Callable[[str], None] | None = None
):
if len(agents) == 0:
raise Exception("Agents attribute can not be an empty list")

self.web3 = web3
self.wallet = wallet
self.network = network
Expand Down Expand Up @@ -195,9 +198,13 @@ async def try_run(self, prompt: str, non_interactive: bool, summary_method: str
helper_agents.append(clarifier_agent)

autogen_agents = [agent.build_autogen_agent(self, user_proxy_agent, self.get_llm_config(), self.custom_model) for agent in self.agents]
manager_agent = manager.build(autogen_agents + helper_agents, self.max_rounds, not non_interactive, self.get_llm_config, self.custom_model)

recipient_agent = manager_agent if len(autogen_agents) > 1 else autogen_agents[0]
recipient_agent = None
if len(autogen_agents) > 1:
recipient_agent = manager.build(autogen_agents + helper_agents, self.max_rounds, not non_interactive, self.get_llm_config, self.custom_model)
else:
recipient_agent = autogen_agents[0]

chat = await user_proxy_agent.a_initiate_chat(
recipient_agent,
message=dedent(
Expand Down
3 changes: 2 additions & 1 deletion autotx/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from autotx.AutoTx import AutoTx
from autotx.autotx_agent import AutoTxAgent
from autotx.autotx_tool import AutoTxTool
from autotx.utils.LlamaClient import LlamaClient

__all__ = ['AutoTx', 'AutoTxAgent', 'AutoTxTool']
__all__ = ['AutoTx', 'AutoTxAgent', 'AutoTxTool', 'LlamaClient']
96 changes: 96 additions & 0 deletions autotx/utils/LlamaClient.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from types import SimpleNamespace
from typing import Any, Dict, Union, cast
from autogen import ModelClient
from llama_cpp import (
ChatCompletion,
ChatCompletionRequestAssistantMessage,
ChatCompletionRequestFunctionMessage,
ChatCompletionRequestMessage,
ChatCompletionRequestToolMessage,
ChatCompletionResponseMessage,
Completion,
CreateChatCompletionResponse,
Llama,
)


class LlamaClient(ModelClient): # type: ignore
def __init__(self, _: dict[str, Any], **args: Any):
self.llm: Llama = args["llm"]

def create(self, params: Dict[str, Any]) -> SimpleNamespace:
sanitized_messages = self._sanitize_chat_completion_messages(
cast(list[ChatCompletionRequestMessage], params.get("messages"))
)
response = self.llm.create_chat_completion(
messages=sanitized_messages,
tools=params.get("tools"),
model=params.get("model"),
)

return SimpleNamespace(**{**response, "cost": "0"}) # type: ignore

def message_retrieval(
self, response: CreateChatCompletionResponse
) -> list[ChatCompletionResponseMessage]:
choices = response["choices"]
return [choice["message"] for choice in choices]

def cost(self, _: Union[ChatCompletion, Completion]) -> float:
return 0.0

def get_usage(self, _: Union[ChatCompletion, Completion]) -> dict[str, Any]:
return {
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"cost": 0,
"model": "meetkai/functionary-small-v2.4-GGUF",
}

def _sanitize_chat_completion_messages(
self, messages: list[ChatCompletionRequestMessage]
) -> list[ChatCompletionRequestMessage]:
sanitized_messages: list[ChatCompletionRequestMessage] = []

for message in messages:
if "tool_calls" in message:
function_to_call = message["tool_calls"][0] # type: ignore
sanitized_messages.append(
ChatCompletionRequestAssistantMessage(
role="assistant",
function_call=function_to_call["function"],
content=None,
)
)
elif "tool_call_id" in message:
id: str = cast(ChatCompletionRequestToolMessage, message)[
"tool_call_id"
]

def get_tool_name(messages, id: str) -> Union[str, None]: # type: ignore
return next(
(
message["tool_calls"][0]["function"]["name"]
for message in messages
if "tool_calls" in message
and message["tool_calls"][0]["id"] == id
),
None,
)

function_name = get_tool_name(messages, id)
if function_name is None:
raise Exception(f"No tool response for this tool call with id {id}")

sanitized_messages.append(
ChatCompletionRequestFunctionMessage(
role="function",
name=function_name,
content=cast(Union[str | None], message["content"]),
)
)
else:
sanitized_messages.append(message)

return sanitized_messages
25 changes: 24 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "autotx"
version = "0.1.1-beta.1"
version = "0.1.1-beta.2"
description = ""
authors = ["Nestor Amesty <nestor09amesty@gmail.com>"]
readme = "README.md"
Expand All @@ -19,6 +19,7 @@ web3 = "^6.19.0"
safe-eth-py = "^5.8.0"
uvicorn = "^0.29.0"
supabase = "^2.5.0"
llama-cpp-python = "^0.2.78"

[tool.poetry.group.dev.dependencies]
mypy = "^1.8.0"
Expand Down

0 comments on commit 1c994c8

Please sign in to comment.