From b6ba03c033a2b23e9a881f20f0fd37917012804a Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Thu, 14 Dec 2023 18:11:07 +0000 Subject: [PATCH 01/23] Support native tools (initial commit) --- .../application/assistant_application.py | 130 ++++++++++++-- aidial_assistant/chain/command_chain.py | 8 +- aidial_assistant/model/model_client.py | 154 ++++++++++++----- aidial_assistant/tools_chain/__init__.py | 0 aidial_assistant/tools_chain/addon_runner.py | 96 +++++++++++ aidial_assistant/tools_chain/http_runner.py | 23 +++ aidial_assistant/tools_chain/tool_runner.py | 11 ++ aidial_assistant/tools_chain/tools_chain.py | 99 +++++++++++ aidial_assistant/utils/exceptions.py | 22 +-- poetry.lock | 163 +++++++++++++++--- pyproject.toml | 6 +- .../chain/test_command_chain_best_effort.py | 4 +- tests/unit_tests/model/test_model_client.py | 47 +++-- tests/unit_tests/tools_chain/__init__.py | 0 .../utils/test_exception_handler.py | 29 ++-- 15 files changed, 644 insertions(+), 148 deletions(-) create mode 100644 aidial_assistant/tools_chain/__init__.py create mode 100644 aidial_assistant/tools_chain/addon_runner.py create mode 100644 aidial_assistant/tools_chain/http_runner.py create mode 100644 aidial_assistant/tools_chain/tool_runner.py create mode 100644 aidial_assistant/tools_chain/tools_chain.py create mode 100644 tests/unit_tests/tools_chain/__init__.py diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index d4a5598..5ef2a47 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -1,11 +1,21 @@ import logging +import os from pathlib import Path +from typing_extensions import override from aidial_sdk.chat_completion import FinishReason from aidial_sdk.chat_completion.base import ChatCompletion -from aidial_sdk.chat_completion.request import Addon, Message, Request, Role +from aidial_sdk.chat_completion.request import ( + Addon, + Message as SdkMessage, + Request, + Role, +) from aidial_sdk.chat_completion.response import Response from aiohttp import hdrs +from openai import AsyncOpenAI +from openai._types import Omit +from openai.lib.azure import AsyncAzureOpenAI from aidial_assistant.application.addons_dialogue_limiter import ( AddonsDialogueLimiter, @@ -25,7 +35,11 @@ from aidial_assistant.model.model_client import ( ModelClient, ReasonLengthException, + Tool, + Message, ) +from aidial_assistant.tools_chain.addon_runner import AddonRunner +from aidial_assistant.tools_chain.tools_chain import ToolsChain from aidial_assistant.utils.exceptions import ( RequestParameterValidationError, unhandled_exception_handler, @@ -44,8 +58,6 @@ def _get_request_args(request: Request) -> dict[str, str]: args = { "model": request.model, "temperature": request.temperature, - "api_version": request.api_version, - "api_key": request.api_key, "user": request.user, "headers": None if request.jwt is None @@ -65,7 +77,7 @@ def _validate_addons(addons: list[Addon] | None): ) -def _validate_messages(messages: list[Message]) -> None: +def _validate_messages(messages: list[SdkMessage]) -> None: if not messages: raise RequestParameterValidationError( "Message list cannot be empty.", param="messages" @@ -82,6 +94,36 @@ def _validate_request(request: Request) -> None: _validate_addons(request.addons) +def _construct_function(name: str, description: str) -> Tool: + return { + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "A task written in natural language", + } + }, + "required": ["query"], + }, + }, + } + + +class MyClient(AsyncAzureOpenAI): + @property + @override + def default_headers(self) -> dict[str, str | Omit]: + headers = super().default_headers + del headers["Authorization"] + + return headers + + class AssistantApplication(ChatCompletion): def __init__(self, config_dir: Path): self.args = parse_args(config_dir) @@ -91,16 +133,15 @@ async def chat_completion( self, request: Request, response: Response ) -> None: _validate_request(request) - chat_args = self.args.openai_conf.dict() | _get_request_args(request) + chat_args = _get_request_args(request) model = ModelClient( - model_args=chat_args - | { - "deployment_id": chat_args["model"], - "api_type": "azure", - "stream": True, - }, - buffer_size=self.args.chat_conf.buffer_size, + client=MyClient( + azure_endpoint=self.args.openai_conf.api_base, + api_key=request.api_key, + api_version="2023-12-01-preview", + ), + model_args=chat_args, ) addons: list[str] = ( @@ -109,7 +150,6 @@ async def chat_completion( token_source = AddonTokenSource(request.headers, addons) tools: dict[str, PluginInfo] = {} - tool_descriptions: dict[str, str] = {} for addon in addons: info = await get_open_ai_plugin_info(addon) tools[info.ai_plugin.name_for_model] = PluginInfo( @@ -122,10 +162,29 @@ async def chat_completion( ), ) - tool_descriptions[info.ai_plugin.name_for_model] = ( - info.open_api.info.description # type: ignore - or info.ai_plugin.description_for_human + if request.model in {"gpt-4-turbo-1106", "gpt-4-1106-preview"}: + await AssistantApplication._run_native_tools_chat( + model, tools, request, response + ) + else: + await AssistantApplication._run_emulated_tools_chat( + model, tools, request, response + ) + + @staticmethod + async def _run_emulated_tools_chat( + model: ModelClient, + tools: dict[str, PluginInfo], + request: Request, + response: Response, + ): + tool_descriptions = { + k: ( + v.info.open_api.info.description + or v.info.ai_plugin.description_for_human ) + for k, v in tools.items() + } # TODO: Add max_addons_dialogue_tokens as a request parameter max_addons_dialogue_tokens = 1000 @@ -179,3 +238,42 @@ async def chat_completion( if discarded_messages is not None: response.set_discarded_messages(discarded_messages) + + @staticmethod + async def _run_native_tools_chat( + model: ModelClient, + tools: dict[str, PluginInfo], + request: Request, + response: Response, + ): + chain = ToolsChain( + model, + [ + _construct_function(k, v.info.ai_plugin.description_for_human) + for k, v in tools.items() + ], + AddonRunner(model, tools), + ) + + choice = response.create_single_choice() + choice.open() + + callback = AssistantChainCallback(choice) + finish_reason = FinishReason.STOP + messages = [ + Message( + role=message.role, + content=message.content or "", + ) + for message in request.messages + ] + try: + await chain.run_chat(messages, callback) + except ReasonLengthException: + finish_reason = FinishReason.LENGTH + + choice.close(finish_reason) + + response.set_usage( + model.total_prompt_tokens, model.total_completion_tokens + ) diff --git a/aidial_assistant/chain/command_chain.py b/aidial_assistant/chain/command_chain.py index 4ae0e8b..acde231 100644 --- a/aidial_assistant/chain/command_chain.py +++ b/aidial_assistant/chain/command_chain.py @@ -4,7 +4,7 @@ from typing import Any, AsyncIterator, Callable, Tuple, cast from aidial_sdk.chat_completion.request import Role -from openai import InvalidRequestError +from openai import BadRequestError from aidial_assistant.application.prompts import ENFORCE_JSON_FORMAT_TEMPLATE from aidial_assistant.chain.callbacks.chain_callback import ChainCallback @@ -112,9 +112,9 @@ async def run_chat( else history.to_user_messages() ) await self._generate_result(messages, callback) - except (InvalidRequestError, LimitExceededException) as e: + except (BadRequestError, LimitExceededException) as e: if dialogue.is_empty() or ( - isinstance(e, InvalidRequestError) and e.code == "429" + isinstance(e, BadRequestError) and e.code == "429" ): raise @@ -187,7 +187,7 @@ async def _run_with_protocol_failure_retries( ) finally: self._log_message(Role.ASSISTANT, chunk_stream.buffer) - except (InvalidRequestError, LimitExceededException) as e: + except (BadRequestError, LimitExceededException) as e: if last_error: # Retries can increase the prompt size, which may lead to token overflow. # Thus, if the original error was a protocol error, it should be thrown instead. diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index cb5499b..c14a558 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -1,9 +1,11 @@ from abc import ABC +from collections import defaultdict from typing import Any, AsyncIterator, List, TypedDict -import openai from aidial_sdk.chat_completion import Role -from aiohttp import ClientSession +from aidial_sdk.utils.merge_chunks import merge +from openai import AsyncOpenAI +from openai.lib.azure import AsyncAzureOpenAI from pydantic import BaseModel @@ -13,10 +15,34 @@ class ReasonLengthException(Exception): class Message(BaseModel): role: Role - content: str + content: str = "" + name: str = "" + tool_call_id: str = "" + tool_calls: list[dict[str, Any]] = [] def to_openai_message(self) -> dict[str, str]: - return {"role": self.role.value, "content": self.content} + return ( + { + "role": self.role.value, + "content": self.content, + } + | ( + { + "role": "tool", + "name": self.name, + "tool_call_id": self.tool_call_id, + } + if self.name + else {} + ) + | ( + { + "tool_calls": self.tool_calls, + } + if self.tool_calls + else {} + ) + ) @classmethod def system(cls, content): @@ -36,6 +62,35 @@ class Usage(TypedDict): completion_tokens: int +class Parameters(TypedDict): + type: str + properties: dict[str, Any] + required: list[str] + + +class Function(TypedDict): + name: str + description: str + parameters: Parameters + + +class Tool(TypedDict): + type: str + function: Function + + +class FunctionCall(TypedDict): + name: str + arguments: str + + +class ToolCall(TypedDict): + index: int + id: str + type: str + function: FunctionCall + + class ExtraResultsCallback: def on_discarded_messages(self, discarded_messages: int): pass @@ -43,6 +98,9 @@ def on_discarded_messages(self, discarded_messages: int): def on_prompt_tokens(self, prompt_tokens: int): pass + def on_tool_calls(self, tool_calls: list[ToolCall]): + pass + async def _flush_stream(stream: AsyncIterator[str]): try: @@ -53,13 +111,9 @@ async def _flush_stream(stream: AsyncIterator[str]): class ModelClient(ABC): - def __init__( - self, - model_args: dict[str, Any], - buffer_size: int, - ): + def __init__(self, client: AsyncOpenAI, model_args: dict[str, Any]): + self.client = client self.model_args = model_args - self.buffer_size = buffer_size self._total_prompt_tokens: int = 0 self._total_completion_tokens: int = 0 @@ -70,43 +124,53 @@ async def agenerate( extra_results_callback: ExtraResultsCallback | None = None, **kwargs, ) -> AsyncIterator[str]: - async with ClientSession(read_bufsize=self.buffer_size) as session: - openai.aiosession.set(session) - - model_result = await openai.ChatCompletion.acreate( - messages=[message.to_openai_message() for message in messages], - **self.model_args | kwargs, - ) - - finish_reason_length = False - async for chunk in model_result: # type: ignore - usage: Usage | None = chunk.get("usage") - if usage: - prompt_tokens = usage["prompt_tokens"] - self._total_prompt_tokens += prompt_tokens - self._total_completion_tokens += usage["completion_tokens"] - if extra_results_callback: - extra_results_callback.on_prompt_tokens(prompt_tokens) + model_result = await self.client.chat.completions.create( + **self.model_args, + extra_body=kwargs, + stream=True, + messages=[message.to_openai_message() for message in messages], + ) + finish_reason_length = False + tool_calls_chunks = [] + async for chunk in model_result: # type: ignore + chunk = chunk.dict() + usage: Usage | None = chunk.get("usage") + if usage: + prompt_tokens = usage["prompt_tokens"] + self._total_prompt_tokens += prompt_tokens + self._total_completion_tokens += usage["completion_tokens"] if extra_results_callback: - discarded_messages: int | None = chunk.get( - "statistics", {} - ).get("discarded_messages") - if discarded_messages is not None: - extra_results_callback.on_discarded_messages( - discarded_messages - ) - - choice = chunk["choices"][0] - text = choice["delta"].get("content") - if text: - yield text - - if choice.get("finish_reason") == "length": - finish_reason_length = True - - if finish_reason_length: - raise ReasonLengthException() + extra_results_callback.on_prompt_tokens(prompt_tokens) + + if extra_results_callback: + discarded_messages: int | None = chunk.get( + "statistics", {} + ).get("discarded_messages") + if discarded_messages is not None: + extra_results_callback.on_discarded_messages( + discarded_messages + ) + + choice = chunk["choices"][0] + delta = choice["delta"] + text = delta.get("content") + if text: + yield text + + tool_calls_chunk = delta.get("tool_calls") + if tool_calls_chunk: + tool_calls_chunks.append(tool_calls_chunk) + + if choice.get("finish_reason") == "length": + finish_reason_length = True + + if finish_reason_length: + raise ReasonLengthException() + + if extra_results_callback and tool_calls_chunks: + tool_calls: list[ToolCall] = merge(*tool_calls_chunks) + extra_results_callback.on_tool_calls(tool_calls) # TODO: Use a dedicated endpoint for counting tokens. # This request may throw an error if the number of tokens is too large. diff --git a/aidial_assistant/tools_chain/__init__.py b/aidial_assistant/tools_chain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aidial_assistant/tools_chain/addon_runner.py b/aidial_assistant/tools_chain/addon_runner.py new file mode 100644 index 0000000..c8debc2 --- /dev/null +++ b/aidial_assistant/tools_chain/addon_runner.py @@ -0,0 +1,96 @@ +from typing import Any + +from langchain_community.tools.openapi.utils.api_models import ( + APIOperation, + APIPropertyBase, +) + +from aidial_assistant.commands.base import ( + ExecutionCallback, + ResultObject, + TextResult, +) +from aidial_assistant.commands.plugin_callback import PluginChainCallback +from aidial_assistant.commands.run_plugin import PluginInfo +from aidial_assistant.model.model_client import ( + ModelClient, + Message, + ReasonLengthException, + Tool, +) +from aidial_assistant.open_api.operation_selector import collect_operations +from aidial_assistant.tools_chain.http_runner import HttpRunner +from aidial_assistant.tools_chain.tool_runner import ToolRunner +from aidial_assistant.tools_chain.tools_chain import ToolsChain + + +def build_property(p: APIPropertyBase) -> dict[str, Any]: + parameter = { + "type": p.type, + "description": p.description, + "default": p.default, + } + return {k: v for k, v in parameter.items() if v is not None} + + +def construct_function(op: APIOperation) -> Tool: + properties = {} + required = [] + for p in op.properties: + properties[p.name] = build_property(p) + + if p.required: + required.append(p.name) + + if op.request_body is not None: + for p in op.request_body.properties: + properties[p.name] = build_property(p) + + if p.required: + required.append(p.name) + + return { + "type": "function", + "function": { + "name": op.operation_id, + "description": op.description or "", + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } + + +class AddonRunner(ToolRunner): + def __init__(self, model: ModelClient, addons: dict[str, PluginInfo]): + self.model = model + self.addons = addons + + async def run( + self, + name: str, + arg: dict[str, Any], + execution_callback: ExecutionCallback, + ) -> ResultObject: + query: str = arg["query"] + + addon = self.addons[name] + ops = collect_operations( + addon.info.open_api, addon.info.ai_plugin.api.url + ) + tools = [construct_function(op) for op in ops.values()] + + chain = ToolsChain(self.model, tools, HttpRunner(ops, addon.auth)) + + messages = [ + Message.system(addon.info.ai_plugin.description_for_model), + Message.user(query), + ] + chain_callback = PluginChainCallback(execution_callback) + try: + await chain.run_chat(messages, chain_callback) + return TextResult(chain_callback.result) + except ReasonLengthException: + return TextResult(chain_callback.result) diff --git a/aidial_assistant/tools_chain/http_runner.py b/aidial_assistant/tools_chain/http_runner.py new file mode 100644 index 0000000..7070848 --- /dev/null +++ b/aidial_assistant/tools_chain/http_runner.py @@ -0,0 +1,23 @@ +from typing import Any + +from langchain_community.tools.openapi.utils.api_models import APIOperation + +from aidial_assistant.commands.base import ExecutionCallback, ResultObject +from aidial_assistant.open_api.requester import OpenAPIEndpointRequester +from aidial_assistant.tools_chain.tool_runner import ToolRunner + + +class HttpRunner(ToolRunner): + def __init__(self, ops: dict[str, APIOperation], auth: str): + self.ops = ops + self.auth = auth + + async def run( + self, + name: str, + arg: dict[str, Any], + execution_callback: ExecutionCallback, + ) -> ResultObject: + return await OpenAPIEndpointRequester( + self.ops[name], self.auth + ).execute(arg) diff --git a/aidial_assistant/tools_chain/tool_runner.py b/aidial_assistant/tools_chain/tool_runner.py new file mode 100644 index 0000000..253c260 --- /dev/null +++ b/aidial_assistant/tools_chain/tool_runner.py @@ -0,0 +1,11 @@ +from abc import ABC +from typing import Any + +from aidial_assistant.commands.base import ExecutionCallback + + +class ToolRunner(ABC): + async def run( + self, name: str, arg: Any, execution_callback: ExecutionCallback + ): + pass diff --git a/aidial_assistant/tools_chain/tools_chain.py b/aidial_assistant/tools_chain/tools_chain.py new file mode 100644 index 0000000..f6cea15 --- /dev/null +++ b/aidial_assistant/tools_chain/tools_chain.py @@ -0,0 +1,99 @@ +import json +from typing import Any + +from aidial_sdk.chat_completion import Role + +from aidial_assistant.chain.callbacks.chain_callback import ChainCallback +from aidial_assistant.chain.callbacks.command_callback import CommandCallback +from aidial_assistant.chain.history import History +from aidial_assistant.model.model_client import ( + ModelClient, + Message, + ExtraResultsCallback, + ToolCall, + Tool, +) +from aidial_assistant.tools_chain.tool_runner import ToolRunner + + +def _publish_command( + command_callback: CommandCallback, name: str, arguments: str +): + command_callback.on_command(name) + args_callback = command_callback.args_callback() + args_callback.on_args_start() + arg_callback = args_callback.arg_callback() + arg_callback.on_arg(arguments) + arg_callback.on_arg_end() + args_callback.on_args_end() + + +class ToolCallsCallback(ExtraResultsCallback): + def __init__(self): + self.tool_calls: list[ToolCall] = [] + + def on_tool_calls(self, tool_calls: list[ToolCall]): + self.tool_calls = tool_calls + + +class ToolsChain: + def __init__( + self, + model: ModelClient, + tools: list[Tool], + tool_runner: ToolRunner, + ): + self.model = model + self.tools = tools + self.tool_runner = tool_runner + + async def run_chat(self, messages: list[Message], callback: ChainCallback): + result_callback = callback.result_callback() + while True: + tool_calls_callback = ToolCallsCallback() + async for chunk in self.model.agenerate( + messages, tool_calls_callback, tools=self.tools + ): + result_callback.on_result(chunk) + + if not tool_calls_callback.tool_calls: + break + + messages.append( + Message( + role=Role.ASSISTANT, + tool_calls=tool_calls_callback.tool_calls, + ) + ) + + for tool_call in tool_calls_callback.tool_calls: + function = tool_call["function"] + name = function["name"] + arguments = function["arguments"] + with callback.command_callback() as command_callback: + _publish_command(command_callback, name, arguments) + try: + result = await self.tool_runner.run( + name, + json.loads(arguments), + command_callback.execution_callback(), + ) + messages.append( + Message( + role=Role.USER, + name=name, + tool_call_id=tool_call["id"], + content=result.text, + ) + ) + command_callback.on_result(result) + except Exception as e: + messages.append( + Message( + role=Role.USER, + name=name, + tool_call_id=tool_call["id"], + content=str(e), + ) + ) + command_callback.on_error(e) diff --git a/aidial_assistant/utils/exceptions.py b/aidial_assistant/utils/exceptions.py index bb03218..44f7b50 100644 --- a/aidial_assistant/utils/exceptions.py +++ b/aidial_assistant/utils/exceptions.py @@ -2,7 +2,7 @@ from functools import wraps from aidial_sdk import HTTPException -from openai import OpenAIError +from openai import OpenAIError, APIError logger = logging.getLogger(__name__) @@ -26,18 +26,14 @@ def _to_http_exception(e: Exception) -> HTTPException: param=e.param, ) - if isinstance(e, OpenAIError): - http_status = e.http_status or 500 - if e.error: - return HTTPException( - message=e.error.message, - status_code=http_status, - type=e.error.type, - code=e.error.code, - param=e.error.param, - ) - - return HTTPException(message=str(e), status_code=http_status) + if isinstance(e, APIError): + raise HTTPException( + message=e.message, + status_code=getattr(e, "status_code") or 500, + type=e.type, + code=e.code, + param=e.param, + ) return HTTPException( message=str(e), status_code=500, type="internal_server_error" diff --git a/poetry.lock b/poetry.lock index 3819072..b4b3fbb 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aidial-sdk" -version = "0.2.0" +version = "0.3.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "aidial_sdk-0.2.0-py3-none-any.whl", hash = "sha256:ce3c2e2ea5ef133d2594bb64c7a70f54970e1f8339608ecfb47b0f955e1536e7"}, - {file = "aidial_sdk-0.2.0.tar.gz", hash = "sha256:fcb00ccfa6fbed7482d6d78828a95ba7e29f45269708cfc3691db9711b91f3fe"}, + {file = "aidial_sdk-0.3.0-py3-none-any.whl", hash = "sha256:67f4efd1e44a1442741d089295e560257a42ec28658ad18040ada93ce1cb08bb"}, + {file = "aidial_sdk-0.3.0.tar.gz", hash = "sha256:f5da5e57e839765d0440c5ddf3586748c20f6ee6233b85a70454bce16a6359cd"}, ] [package.dependencies] @@ -435,6 +435,17 @@ files = [ {file = "distlib-0.3.7.tar.gz", hash = "sha256:9dafe54b34a028eafd95039d5e5d4851a13734540f1331060d31c9916e7147a8"}, ] +[[package]] +name = "distro" +version = "1.8.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.8.0-py3-none-any.whl", hash = "sha256:99522ca3e365cac527b44bde033f64c6945d90eb9f769703caaec52b09bbd3ff"}, + {file = "distro-1.8.0.tar.gz", hash = "sha256:02e111d1dc6a50abb8eed6bf31c3e48ed8b0830d1ea2a1b78c61765c2513fdd8"}, +] + [[package]] name = "fastapi" version = "0.103.2" @@ -644,6 +655,51 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "httpcore" +version = "1.0.2" +description = "A minimal low-level HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpcore-1.0.2-py3-none-any.whl", hash = "sha256:096cc05bca73b8e459a1fc3dcf585148f63e534eae4339559c9b8a8d6399acc7"}, + {file = "httpcore-1.0.2.tar.gz", hash = "sha256:9fc092e4799b26174648e54b74ed5f683132a464e95643b226e00c2ed2fa6535"}, +] + +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.23.0)"] + +[[package]] +name = "httpx" +version = "0.25.2" +description = "The next generation HTTP client." +optional = false +python-versions = ">=3.8" +files = [ + {file = "httpx-0.25.2-py3-none-any.whl", hash = "sha256:a05d3d052d9b2dfce0e3896636467f8a5342fb2b902c819428e1ac65413ca118"}, + {file = "httpx-0.25.2.tar.gz", hash = "sha256:8b8fcaa0c8ea7b05edd69a094e63a2094c4efcb48129fb757361bc423c0ad9e8"}, +] + +[package.dependencies] +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" + +[package.extras] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + [[package]] name = "idna" version = "3.4" @@ -727,21 +783,22 @@ files = [ [[package]] name = "langchain" -version = "0.0.329" +version = "0.0.350" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langchain-0.0.329-py3-none-any.whl", hash = "sha256:5f3e884991271e8b55eda4c63a11105dcd7da119682ce0e3d5d1385b3a4103d2"}, - {file = "langchain-0.0.329.tar.gz", hash = "sha256:488f3cb68a587696f136d4f01f97df8d8270e295b3cc56158057dab0f61f4166"}, + {file = "langchain-0.0.350-py3-none-any.whl", hash = "sha256:11b605f325a4271a7815baaec05bc7622e3ad1f10f26b05c752cafa27663ed38"}, + {file = "langchain-0.0.350.tar.gz", hash = "sha256:f0e68a92d200bb722586688ab7b411b2430bd98ad265ca03b264e7e7acbb6c01"}, ] [package.dependencies] aiohttp = ">=3.8.3,<4.0.0" -anyio = "<4.0" dataclasses-json = ">=0.5.7,<0.7" jsonpatch = ">=1.33,<2.0" -langsmith = ">=0.0.52,<0.1.0" +langchain-community = ">=0.0.2,<0.1" +langchain-core = ">=0.1,<0.2" +langsmith = ">=0.0.63,<0.1.0" numpy = ">=1,<2" pydantic = ">=1,<3" PyYAML = ">=5.3" @@ -750,29 +807,78 @@ SQLAlchemy = ">=1.4,<3" tenacity = ">=8.1.0,<9.0.0" [package.extras] -all = ["O365 (>=2.0.26,<3.0.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "amadeus (>=8.1.0)", "arxiv (>=1.4,<2.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "awadb (>=0.3.9,<0.4.0)", "azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "beautifulsoup4 (>=4,<5)", "clarifai (>=9.1.0)", "clickhouse-connect (>=0.5.14,<0.6.0)", "cohere (>=4,<5)", "deeplake (>=3.8.3,<4.0.0)", "docarray[hnswlib] (>=0.32.0,<0.33.0)", "duckduckgo-search (>=3.8.3,<4.0.0)", "elasticsearch (>=8,<9)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "google-api-python-client (==2.70.0)", "google-auth (>=2.18.1,<3.0.0)", "google-search-results (>=2,<3)", "gptcache (>=0.1.7)", "html2text (>=2020.1.16,<2021.0.0)", "huggingface_hub (>=0,<1)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "lancedb (>=0.1,<0.2)", "langkit (>=0.0.6,<0.1.0)", "lark (>=1.1.5,<2.0.0)", "librosa (>=0.10.0.post2,<0.11.0)", "lxml (>=4.9.2,<5.0.0)", "manifest-ml (>=0.0.1,<0.0.2)", "marqo (>=1.2.4,<2.0.0)", "momento (>=1.10.1,<2.0.0)", "nebula3-python (>=3.4.0,<4.0.0)", "neo4j (>=5.8.1,<6.0.0)", "networkx (>=2.6.3,<4)", "nlpcloud (>=1,<2)", "nltk (>=3,<4)", "nomic (>=1.0.43,<2.0.0)", "openai (>=0,<1)", "openlm (>=0.0.5,<0.0.6)", "opensearch-py (>=2.0.0,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pexpect (>=4.8.0,<5.0.0)", "pgvector (>=0.1.6,<0.2.0)", "pinecone-client (>=2,<3)", "pinecone-text (>=0.4.2,<0.5.0)", "psycopg2-binary (>=2.9.5,<3.0.0)", "pymongo (>=4.3.3,<5.0.0)", "pyowm (>=3.3.0,<4.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pytesseract (>=0.3.10,<0.4.0)", "python-arango (>=7.5.9,<8.0.0)", "pyvespa (>=0.33.0,<0.34.0)", "qdrant-client (>=1.3.1,<2.0.0)", "rdflib (>=6.3.2,<7.0.0)", "redis (>=4,<5)", "requests-toolbelt (>=1.0.0,<2.0.0)", "sentence-transformers (>=2,<3)", "singlestoredb (>=0.7.1,<0.8.0)", "tensorflow-text (>=2.11.0,<3.0.0)", "tigrisdb (>=1.0.0b6,<2.0.0)", "tiktoken (>=0.3.2,<0.6.0)", "torch (>=1,<3)", "transformers (>=4,<5)", "weaviate-client (>=3,<4)", "wikipedia (>=1,<2)", "wolframalpha (==5.0.0)"] -azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (>=0,<1)"] +azure = ["azure-ai-formrecognizer (>=3.2.1,<4.0.0)", "azure-ai-textanalytics (>=5.3.0,<6.0.0)", "azure-ai-vision (>=0.11.1b1,<0.12.0)", "azure-cognitiveservices-speech (>=1.28.0,<2.0.0)", "azure-core (>=1.26.4,<2.0.0)", "azure-cosmos (>=4.4.0b1,<5.0.0)", "azure-identity (>=1.12.0,<2.0.0)", "azure-search-documents (==11.4.0b8)", "openai (<2)"] clarifai = ["clarifai (>=9.1.0)"] cli = ["typer (>=0.9.0,<0.10.0)"] cohere = ["cohere (>=4,<5)"] docarray = ["docarray[hnswlib] (>=0.32.0,<0.33.0)"] embeddings = ["sentence-transformers (>=2,<3)"] -extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "dashvector (>=1.0.1,<2.0.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.6.0,<0.7.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "html2text (>=2020.1.16,<2021.0.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (>=0,<1)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "couchbase (>=4.1.9,<5.0.0)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] javascript = ["esprima (>=4.0.1,<5.0.0)"] -llms = ["clarifai (>=9.1.0)", "cohere (>=4,<5)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (>=0,<1)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"] -openai = ["openai (>=0,<1)", "tiktoken (>=0.3.2,<0.6.0)"] +llms = ["clarifai (>=9.1.0)", "cohere (>=4,<5)", "huggingface_hub (>=0,<1)", "manifest-ml (>=0.0.1,<0.0.2)", "nlpcloud (>=1,<2)", "openai (<2)", "openlm (>=0.0.5,<0.0.6)", "torch (>=1,<3)", "transformers (>=4,<5)"] +openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"] qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"] text-helpers = ["chardet (>=5.1.0,<6.0.0)"] +[[package]] +name = "langchain-community" +version = "0.0.2" +description = "Community contributed LangChain integrations." +optional = false +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langchain_community-0.0.2-py3-none-any.whl", hash = "sha256:2af62f7917db26a83b4ca6a46f6fc9d65c6358256cbb516375d89817b9437493"}, + {file = "langchain_community-0.0.2.tar.gz", hash = "sha256:178fb0005c9438a945cabec7cadc89e6ca4eb99abd4cb310bc82618ac8215268"}, +] + +[package.dependencies] +aiohttp = ">=3.8.3,<4.0.0" +dataclasses-json = ">=0.5.7,<0.7" +langchain-core = ">=0.1,<0.2" +langsmith = ">=0.0.63,<0.1.0" +numpy = ">=1,<2" +PyYAML = ">=5.3" +requests = ">=2,<3" +SQLAlchemy = ">=1.4,<3" +tenacity = ">=8.1.0,<9.0.0" + +[package.extras] +cli = ["typer (>=0.9.0,<0.10.0)"] +extended-testing = ["aiosqlite (>=0.19.0,<0.20.0)", "aleph-alpha-client (>=2.15.0,<3.0.0)", "anthropic (>=0.3.11,<0.4.0)", "arxiv (>=1.4,<2.0)", "assemblyai (>=0.17.0,<0.18.0)", "atlassian-python-api (>=3.36.0,<4.0.0)", "beautifulsoup4 (>=4,<5)", "bibtexparser (>=1.4.0,<2.0.0)", "cassio (>=0.1.0,<0.2.0)", "chardet (>=5.1.0,<6.0.0)", "cohere (>=4,<5)", "dashvector (>=1.0.1,<2.0.0)", "databricks-vectorsearch (>=0.21,<0.22)", "datasets (>=2.15.0,<3.0.0)", "dgml-utils (>=0.3.0,<0.4.0)", "esprima (>=4.0.1,<5.0.0)", "faiss-cpu (>=1,<2)", "feedparser (>=6.0.10,<7.0.0)", "fireworks-ai (>=0.9.0,<0.10.0)", "geopandas (>=0.13.1,<0.14.0)", "gitpython (>=3.1.32,<4.0.0)", "google-cloud-documentai (>=2.20.1,<3.0.0)", "gql (>=3.4.1,<4.0.0)", "hologres-vector (>=0.0.6,<0.0.7)", "html2text (>=2020.1.16,<2021.0.0)", "javelin-sdk (>=0.1.8,<0.2.0)", "jinja2 (>=3,<4)", "jq (>=1.4.1,<2.0.0)", "jsonschema (>1)", "lxml (>=4.9.2,<5.0.0)", "markdownify (>=0.11.6,<0.12.0)", "motor (>=3.3.1,<4.0.0)", "msal (>=1.25.0,<2.0.0)", "mwparserfromhell (>=0.6.4,<0.7.0)", "mwxml (>=0.3.3,<0.4.0)", "newspaper3k (>=0.2.8,<0.3.0)", "numexpr (>=2.8.6,<3.0.0)", "openai (<2)", "openapi-pydantic (>=0.3.2,<0.4.0)", "pandas (>=2.0.1,<3.0.0)", "pdfminer-six (>=20221105,<20221106)", "pgvector (>=0.1.6,<0.2.0)", "praw (>=7.7.1,<8.0.0)", "psychicapi (>=0.8.0,<0.9.0)", "py-trello (>=0.19.0,<0.20.0)", "pymupdf (>=1.22.3,<2.0.0)", "pypdf (>=3.4.0,<4.0.0)", "pypdfium2 (>=4.10.0,<5.0.0)", "pyspark (>=3.4.0,<4.0.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "rapidfuzz (>=3.1.1,<4.0.0)", "rapidocr-onnxruntime (>=1.3.2,<2.0.0)", "requests-toolbelt (>=1.0.0,<2.0.0)", "rspace_client (>=2.5.0,<3.0.0)", "scikit-learn (>=1.2.2,<2.0.0)", "sqlite-vss (>=0.1.2,<0.2.0)", "streamlit (>=1.18.0,<2.0.0)", "sympy (>=1.12,<2.0)", "telethon (>=1.28.5,<2.0.0)", "timescale-vector (>=0.0.1,<0.0.2)", "tqdm (>=4.48.0)", "upstash-redis (>=0.15.0,<0.16.0)", "xata (>=1.0.0a7,<2.0.0)", "xmltodict (>=0.13.0,<0.14.0)"] + +[[package]] +name = "langchain-core" +version = "0.1.0" +description = "Building applications with LLMs through composability" +optional = false +python-versions = ">=3.8.1,<4.0" +files = [ + {file = "langchain_core-0.1.0-py3-none-any.whl", hash = "sha256:6b155a175e1f1555860b22333c14161c652b0013e229e7b8a083639c821312a8"}, + {file = "langchain_core-0.1.0.tar.gz", hash = "sha256:4c70aa62905896b65c47a966f87584f72026cbe402655749281df81c794e0d6e"}, +] + +[package.dependencies] +anyio = ">=3,<5" +jsonpatch = ">=1.33,<2.0" +langsmith = ">=0.0.63,<0.1.0" +packaging = ">=23.2,<24.0" +pydantic = ">=1,<3" +PyYAML = ">=5.3" +requests = ">=2,<3" +tenacity = ">=8.1.0,<9.0.0" + +[package.extras] +extended-testing = ["jinja2 (>=3,<4)"] + [[package]] name = "langsmith" -version = "0.0.54" +version = "0.0.69" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "langsmith-0.0.54-py3-none-any.whl", hash = "sha256:55eca5967cadb661a49ad32aecda48a824fadef202ca384575209a9d6f823b74"}, - {file = "langsmith-0.0.54.tar.gz", hash = "sha256:76c8e34b4d10ad93541107138089635829f9d60601a7f6bddf5ba582d178e521"}, + {file = "langsmith-0.0.69-py3-none-any.whl", hash = "sha256:49a2546bb83eedb0552673cf81a068bb08078d6d48471f4f1018e1d5c6aa46b1"}, + {file = "langsmith-0.0.69.tar.gz", hash = "sha256:8fb5297f274db0576ec650d9bab0319acfbb6622d62bc5bb9fe31c6235dc0358"}, ] [package.dependencies] @@ -1043,25 +1149,26 @@ files = [ [[package]] name = "openai" -version = "0.28.1" -description = "Python client library for the OpenAI API" +version = "1.3.9" +description = "The official Python library for the openai API" optional = false python-versions = ">=3.7.1" files = [ - {file = "openai-0.28.1-py3-none-any.whl", hash = "sha256:d18690f9e3d31eedb66b57b88c2165d760b24ea0a01f150dd3f068155088ce68"}, - {file = "openai-0.28.1.tar.gz", hash = "sha256:4be1dad329a65b4ce1a660fe6d5431b438f429b5855c883435f0f7fcb6d2dcc8"}, + {file = "openai-1.3.9-py3-none-any.whl", hash = "sha256:d30faeffe5995a2cf6b790c0260a5b59647e81c3a1f3b62f51b5e0a0e52681c9"}, + {file = "openai-1.3.9.tar.gz", hash = "sha256:6f638d96bc89b4394be1d7b37d312f70a055df1a471c92d4c4b2ae3a70c98cb3"}, ] [package.dependencies] -aiohttp = "*" -requests = ">=2.20" -tqdm = "*" +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.5,<5" [package.extras] -datalib = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] -dev = ["black (>=21.6b0,<22.0)", "pytest (==6.*)", "pytest-asyncio", "pytest-mock"] -embeddings = ["matplotlib", "numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "plotly", "scikit-learn (>=1.0.2)", "scipy", "tenacity (>=8.0.1)"] -wandb = ["numpy", "openpyxl (>=3.0.7)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)", "wandb"] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] [[package]] name = "openapi-pydantic" @@ -1778,4 +1885,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "a13c588da4a63c1ae51b331f9dacdb495243077706a5b68b9ccb3520ad0176c4" +content-hash = "73ae85c178411c402c34edf6adde3638a7719ae4f25a6e1aee63b2efa43a673a" diff --git a/pyproject.toml b/pyproject.toml index 6649e52..d28759e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,12 +19,12 @@ clean = "scripts.clean:main" python = "^3.11" aiocache = "^0.12.2" jinja2 = "^3.1.2" -langchain = "^0.0.329" -openai = "^0.28.0" +langchain = "^0.0.350" +openai = "^1.3.9" pydantic = "1.10.13" pyyaml = "^6.0.1" typing-extensions = "^4.8.0" -aidial-sdk = "^0.2.0" +aidial-sdk = "^0.3.0" aiohttp = "^3.9.0" openapi-schema-pydantic = "^1.2.4" openapi-pydantic = "^0.3.2" diff --git a/tests/unit_tests/chain/test_command_chain_best_effort.py b/tests/unit_tests/chain/test_command_chain_best_effort.py index 74244d9..8711861 100644 --- a/tests/unit_tests/chain/test_command_chain_best_effort.py +++ b/tests/unit_tests/chain/test_command_chain_best_effort.py @@ -4,7 +4,7 @@ import pytest from aidial_sdk.chat_completion import Role from jinja2 import Template -from openai import InvalidRequestError +from openai import BadRequestError from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback @@ -154,7 +154,7 @@ async def test_no_tokens_for_tools(): model_client = Mock(spec=ModelClient) model_client.agenerate.side_effect = [ to_async_string(TEST_COMMAND_REQUEST), - InvalidRequestError(NO_TOKENS_ERROR, ""), + BadRequestError(NO_TOKENS_ERROR), to_async_string(BEST_EFFORT_ANSWER), ] test_command = Mock(spec=Command) diff --git a/tests/unit_tests/model/test_model_client.py b/tests/unit_tests/model/test_model_client.py index 3457901..200ab71 100644 --- a/tests/unit_tests/model/test_model_client.py +++ b/tests/unit_tests/model/test_model_client.py @@ -1,7 +1,8 @@ -from unittest import mock from unittest.mock import Mock, call import pytest +from openai import AsyncOpenAI +from openai.types.chat import ChatCompletionChunk from aidial_assistant.model.model_client import ( ExtraResultsCallback, @@ -12,23 +13,21 @@ from aidial_assistant.utils.text import join_string from tests.utils.async_helper import to_async_iterator -API_METHOD = "openai.ChatCompletion.acreate" MODEL_ARGS = {"model": "args"} -BUFFER_SIZE = 321 -@mock.patch(API_METHOD) @pytest.mark.asyncio -async def test_discarded_messages(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( +async def test_discarded_messages(): + openai_client = Mock() + openai_client.chat.completions.create.return_value = to_async_iterator( [ - { - "choices": [{"delta": {"content": ""}}], - "statistics": {"discarded_messages": 2}, - } + ChatCompletionChunk( + choices=[{"delta": {"content": ""}}], + statistics={"discarded_messages": 2}, + ) ] ) + model_client = ModelClient(openai_client, MODEL_ARGS) extra_results_callback = Mock(spec=ExtraResultsCallback) await join_string(model_client.agenerate([], extra_results_callback)) @@ -38,26 +37,25 @@ async def test_discarded_messages(api): ] -@mock.patch(API_METHOD) @pytest.mark.asyncio -async def test_content(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( +async def test_content(): + openai_client = Mock(spec=AsyncOpenAI) + openai_client.chat.completions.create.return_value = to_async_iterator( [ {"choices": [{"delta": {"content": "one, "}}]}, {"choices": [{"delta": {"content": "two, "}}]}, {"choices": [{"delta": {"content": "three"}}]}, ] ) + model_client = ModelClient(openai_client, MODEL_ARGS) assert await join_string(model_client.agenerate([])) == "one, two, three" -@mock.patch(API_METHOD) @pytest.mark.asyncio -async def test_reason_length_with_usage(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( +async def test_reason_length_with_usage(): + openai_client = Mock(spec=AsyncOpenAI) + openai_client.chat.completions.create.return_value = to_async_iterator( [ {"choices": [{"delta": {"content": "text"}}]}, { @@ -71,6 +69,7 @@ async def test_reason_length_with_usage(api): }, ] ) + model_client = ModelClient(openai_client, MODEL_ARGS) with pytest.raises(ReasonLengthException): async for chunk in model_client.agenerate([]): @@ -80,11 +79,11 @@ async def test_reason_length_with_usage(api): assert model_client.total_completion_tokens == 2 -@mock.patch(API_METHOD) @pytest.mark.asyncio -async def test_api_args(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator([]) +async def test_api_args(): + openai_client = Mock(spec=AsyncOpenAI) + openai_client.chat.completions.create.return_value = to_async_iterator([]) + model_client = ModelClient(openai_client, MODEL_ARGS) messages = [ Message.system(content="a"), Message.user(content="b"), @@ -93,7 +92,7 @@ async def test_api_args(api): await join_string(model_client.agenerate(messages)) - assert api.call_args_list == [ + assert openai_client.chat.completions.create.call_args_list == [ call( messages=[ {"role": "system", "content": "a"}, diff --git a/tests/unit_tests/tools_chain/__init__.py b/tests/unit_tests/tools_chain/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/utils/test_exception_handler.py b/tests/unit_tests/utils/test_exception_handler.py index caf73b2..be7ef83 100644 --- a/tests/unit_tests/utils/test_exception_handler.py +++ b/tests/unit_tests/utils/test_exception_handler.py @@ -1,6 +1,7 @@ +import httpx import pytest from aidial_sdk import HTTPException -from openai import OpenAIError +from openai import OpenAIError, APIStatusError from aidial_assistant.utils.exceptions import ( RequestParameterValidationError, @@ -29,19 +30,17 @@ async def function(): @pytest.mark.asyncio async def test_openai_error(): - http_status = 123 - @unhandled_exception_handler async def function(): - raise OpenAIError(message=ERROR_MESSAGE, http_status=http_status) + raise OpenAIError(ERROR_MESSAGE) with pytest.raises(HTTPException) as exc_info: await function() assert ( repr(exc_info.value) - == f"HTTPException(message='{ERROR_MESSAGE}', status_code={http_status}," - f" type='runtime_error', param=None, code=None)" + == f"HTTPException(message='{ERROR_MESSAGE}', status_code=500," + f" type='internal_server_error', param=None, code=None)" ) @@ -51,17 +50,21 @@ async def test_openai_error_with_json_body(): error_type = "" error_code = "" json_body = { - "error": { - "message": ERROR_MESSAGE, - "type": error_type, - "code": error_code, - "param": PARAM, - } + "type": error_type, + "code": error_code, + "param": PARAM, } @unhandled_exception_handler async def function(): - raise OpenAIError(json_body=json_body, http_status=http_status) + raise APIStatusError( + ERROR_MESSAGE, + response=httpx.Response( + request=httpx.Request("GET", "http://localhost"), + status_code=http_status, + ), + body=json_body, + ) with pytest.raises(HTTPException) as exc_info: await function() From d55b24d00bd1fb26d74f63dc8b87b82d60d6f75b Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Thu, 4 Jan 2024 10:50:18 +0000 Subject: [PATCH 02/23] Some intermediate fixes. --- .../application/assistant_application.py | 106 +++++++++++++----- .../application/assistant_callback.py | 34 ++---- aidial_assistant/application/prompts.py | 14 +-- .../chain/callbacks/arg_callback.py | 21 ---- .../chain/callbacks/args_callback.py | 21 +--- aidial_assistant/chain/command_result.py | 16 ++- aidial_assistant/model/model_client.py | 43 +++---- aidial_assistant/tools_chain/addon_runner.py | 13 ++- aidial_assistant/tools_chain/history.py | 0 aidial_assistant/tools_chain/tools_chain.py | 47 +++++--- aidial_assistant/utils/exceptions.py | 2 +- poetry.lock | 8 +- pyproject.toml | 2 +- .../utils/test_exception_handler.py | 2 +- 14 files changed, 170 insertions(+), 159 deletions(-) delete mode 100644 aidial_assistant/chain/callbacks/arg_callback.py create mode 100644 aidial_assistant/tools_chain/history.py diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index 23c5691..2d4e937 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -1,20 +1,13 @@ +import json import logging -import os from pathlib import Path -from typing_extensions import override from aidial_sdk.chat_completion import FinishReason from aidial_sdk.chat_completion.base import ChatCompletion -from aidial_sdk.chat_completion.request import ( - Addon, - Message as SdkMessage, - Request, - Role, -) +from aidial_sdk.chat_completion.request import Addon +from aidial_sdk.chat_completion.request import Message as SdkMessage +from aidial_sdk.chat_completion.request import Request, Role from aidial_sdk.chat_completion.response import Response -from aiohttp import hdrs -from openai import AsyncOpenAI -from openai._types import Omit from openai.lib.azure import AsyncAzureOpenAI from aidial_assistant.application.addons_dialogue_limiter import ( @@ -29,14 +22,21 @@ MAIN_SYSTEM_DIALOG_MESSAGE, ) from aidial_assistant.chain.command_chain import CommandChain, CommandDict -from aidial_assistant.chain.history import History +from aidial_assistant.chain.command_result import ( + CommandInvocation, + Commands, + Responses, +) +from aidial_assistant.chain.history import History, ScopedMessage, MessageScope from aidial_assistant.commands.reply import Reply from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin from aidial_assistant.model.model_client import ( + Message, ModelClient, ReasonLengthException, Tool, - Message, + ToolCall, + FunctionCall, ) from aidial_assistant.tools_chain.addon_runner import AddonRunner from aidial_assistant.tools_chain.tools_chain import ToolsChain @@ -111,14 +111,64 @@ def _construct_function(name: str, description: str) -> Tool: } -class MyClient(AsyncAzureOpenAI): - @property - @override - def default_headers(self) -> dict[str, str | Omit]: - headers = super().default_headers - del headers["Authorization"] - - return headers +def _convert_commands_to_tools( + scoped_messages: list[ScopedMessage], +) -> list[Message]: + messages: list[Message] = [] + next_tool_id: int = 0 + last_call_count: int = 0 + for scoped_message in scoped_messages: + message = scoped_message.message + if scoped_message.scope == MessageScope.INTERNAL: + if message.role == Role.ASSISTANT: + commands: Commands = json.loads(message.content) + messages.append( + Message( + role=Role.ASSISTANT, + tool_calls=[ + ToolCall( + index=index, + id=str(next_tool_id + index), + function=FunctionCall( + name=command["args"][0], + arguments=json.dumps( + { + "query": command["args"][1], + } + ), + ), + type="function", + ) + for index, command in enumerate( + commands["commands"] + ) + ], + ) + ) + last_call_count = len(commands["commands"]) + next_tool_id += last_call_count + elif message.role == Role.USER: + responses: Responses = json.loads(message.content) + response_count = len(responses["responses"]) + if response_count != last_call_count: + raise RequestParameterValidationError( + f"Expected {last_call_count} responses, but got {response_count}.", + param="messages", + ) + first_tool_id = next_tool_id - last_call_count + messages.extend( + [ + Message( + role=Role.TOOL, + tool_call_id=str(first_tool_id + index), + content=response["response"], + ) + for index, response in enumerate(responses["responses"]) + ] + ) + else: + messages.append(scoped_message.message) + return messages class AssistantApplication(ChatCompletion): @@ -133,7 +183,7 @@ async def chat_completion( chat_args = _get_request_args(request) model = ModelClient( - client=MyClient( + client=AsyncAzureOpenAI( azure_endpoint=self.args.openai_conf.api_base, api_key=request.api_key, api_version="2023-12-01-preview", @@ -159,7 +209,7 @@ async def chat_completion( ), ) - if request.model in {"gpt-4-turbo-1106", "gpt-4-1106-preview"}: + if request.model in {"gpt-4-turbo-1106", "anthropic.claude-v2-1"}: await AssistantApplication._run_native_tools_chat( model, tools, request, response ) @@ -257,18 +307,14 @@ async def _run_native_tools_chat( callback = AssistantChainCallback(choice) finish_reason = FinishReason.STOP - messages = [ - Message( - role=message.role, - content=message.content or "", - ) - for message in request.messages - ] + messages = _convert_commands_to_tools(parse_history(request.messages)) try: await chain.run_chat(messages, callback) except ReasonLengthException: finish_reason = FinishReason.LENGTH + if callback.invocations: + choice.set_state(State(invocations=callback.invocations)) choice.close(finish_reason) response.set_usage( diff --git a/aidial_assistant/application/assistant_callback.py b/aidial_assistant/application/assistant_callback.py index 4fbcb8d..35c2ce8 100644 --- a/aidial_assistant/application/assistant_callback.py +++ b/aidial_assistant/application/assistant_callback.py @@ -1,12 +1,12 @@ +import json from types import TracebackType -from typing import Callable +from typing import Callable, Any from aidial_sdk.chat_completion import Status from aidial_sdk.chat_completion.choice import Choice from aidial_sdk.chat_completion.stage import Stage from typing_extensions import override -from aidial_assistant.chain.callbacks.arg_callback import ArgCallback from aidial_assistant.chain.callbacks.args_callback import ArgsCallback from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback @@ -16,36 +16,16 @@ from aidial_assistant.utils.state import Invocation -class PluginNameArgCallback(ArgCallback): - def __init__(self, callback: Callable[[str], None]): - super().__init__(0, callback) - - @override - def on_arg(self, chunk: str): - chunk = chunk.replace('"', "") - if len(chunk) > 0: - self.callback(chunk) - - @override - def on_arg_end(self): - self.callback("(") - - class RunPluginArgsCallback(ArgsCallback): def __init__(self, callback: Callable[[str], None]): super().__init__(callback) @override - def on_args_start(self): - pass - - @override - def arg_callback(self) -> ArgCallback: - self.arg_index += 1 - if self.arg_index == 0: - return PluginNameArgCallback(self.callback) - else: - return ArgCallback(self.arg_index - 1, self.callback) + def on_args(self, args: dict[str, Any]): + args = args.copy() + name = args["name"] + del args["name"] + self.callback(name + "(" + json.dumps(args) + ")") class AssistantCommandCallback(CommandCallback): diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index 194a13f..85a406c 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -34,9 +34,9 @@ def build(self, **kwargs) -> Template: "commands": [ { "command": "", - "args": [ - "", "", ... - ] + "arguments": { + "": "" + } } ] } @@ -47,12 +47,12 @@ def build(self, **kwargs) -> Template: * reply The command delivers final response to the user. Arguments: - - MESSAGE is a string containing the final and complete result for the user. + - is a string containing the final and complete result for the user. Your goal is to answer user questions. Use relevant commands when they help to achieve the goal. ## Example -{"commands": [{"command": "reply", "args": ["Hello, world!"]}]} +{"commands": [{"command": "reply", "arguments": {"message": "Hello, world!"}}]} """.strip() _SYSTEM_TEXT = """ @@ -72,11 +72,11 @@ def build(self, **kwargs) -> Template: This command executes a specified addon to address a one-time task described in natural language. Addons do not see current conversation and require all details to be provided in the query to solve the task. Arguments: - - NAME is one of the following addons: + - is one of the following addons: {%- for name, description in tools.items() %} * {{name}} - {{description | decap}} {%- endfor %} - - QUERY is the query string. + - is the query string. {%- endif %} {{protocol_footer}} """.strip() diff --git a/aidial_assistant/chain/callbacks/arg_callback.py b/aidial_assistant/chain/callbacks/arg_callback.py deleted file mode 100644 index 5102637..0000000 --- a/aidial_assistant/chain/callbacks/arg_callback.py +++ /dev/null @@ -1,21 +0,0 @@ -from typing import Callable - - -class ArgCallback: - """Callback for reporting arguments""" - - def __init__(self, arg_index: int, callback: Callable[[str], None]): - self.arg_index = arg_index - self.callback = callback - - def on_arg_start(self): - """Called when the arg starts""" - if self.arg_index > 0: - self.callback(", ") - - def on_arg(self, chunk: str): - """Called when an argument chunk is read""" - self.callback(chunk) - - def on_arg_end(self): - """Called when the arg ends""" diff --git a/aidial_assistant/chain/callbacks/args_callback.py b/aidial_assistant/chain/callbacks/args_callback.py index ca5b59f..fe4423b 100644 --- a/aidial_assistant/chain/callbacks/args_callback.py +++ b/aidial_assistant/chain/callbacks/args_callback.py @@ -1,6 +1,5 @@ -from typing import Callable - -from aidial_assistant.chain.callbacks.arg_callback import ArgCallback +import json +from typing import Callable, Any class ArgsCallback: @@ -8,17 +7,7 @@ class ArgsCallback: def __init__(self, callback: Callable[[str], None]): self.callback = callback - self.arg_index = -1 - - def on_args_start(self): - """Called when the arguments start""" - self.callback("(") - - def arg_callback(self) -> ArgCallback: - """Returns a callback for reporting an argument""" - self.arg_index += 1 - return ArgCallback(self.arg_index, self.callback) - def on_args_end(self): - """Called when the arguments end""" - self.callback(")") + def on_args(self, args: dict[str, Any]): + """Called when the argument dict is constructed""" + self.callback("(" + json.dumps(args) + ")") diff --git a/aidial_assistant/chain/command_result.py b/aidial_assistant/chain/command_result.py index 133685d..bf8247b 100644 --- a/aidial_assistant/chain/command_result.py +++ b/aidial_assistant/chain/command_result.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import List, TypedDict +from typing import List, TypedDict, Any class Status(str, Enum): @@ -18,12 +18,20 @@ class CommandResult(TypedDict): class CommandInvocation(TypedDict): command: str - args: list[str] + args: dict[str, Any] + + +class Commands(TypedDict): + commands: list[CommandInvocation] + + +class Responses(TypedDict): + responses: list[CommandResult] def responses_to_text(responses: List[CommandResult]) -> str: - return json.dumps({"responses": responses}) + return json.dumps(Responses(responses=responses)) def commands_to_text(commands: List[CommandInvocation]) -> str: - return json.dumps({"commands": commands}) + return json.dumps(Commands(commands=commands)) diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index c14a558..f657c89 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -1,11 +1,9 @@ from abc import ABC -from collections import defaultdict from typing import Any, AsyncIterator, List, TypedDict from aidial_sdk.chat_completion import Role from aidial_sdk.utils.merge_chunks import merge from openai import AsyncOpenAI -from openai.lib.azure import AsyncAzureOpenAI from pydantic import BaseModel @@ -15,34 +13,23 @@ class ReasonLengthException(Exception): class Message(BaseModel): role: Role - content: str = "" - name: str = "" - tool_call_id: str = "" - tool_calls: list[dict[str, Any]] = [] + content: str | None = None + tool_call_id: str | None = None + tool_calls: list[dict[str, Any]] | None = None def to_openai_message(self) -> dict[str, str]: - return ( - { - "role": self.role.value, - "content": self.content, - } - | ( - { - "role": "tool", - "name": self.name, - "tool_call_id": self.tool_call_id, - } - if self.name - else {} - ) - | ( - { - "tool_calls": self.tool_calls, - } - if self.tool_calls - else {} - ) - ) + result = {"role": self.role.value} + + if self.content is not None: + result["content"] = self.content + + if self.tool_call_id: + result["tool_call_id"] = self.tool_call_id + + if self.tool_calls: + result["tool_calls"] = self.tool_calls + + return result @classmethod def system(cls, content): diff --git a/aidial_assistant/tools_chain/addon_runner.py b/aidial_assistant/tools_chain/addon_runner.py index c8debc2..51178a5 100644 --- a/aidial_assistant/tools_chain/addon_runner.py +++ b/aidial_assistant/tools_chain/addon_runner.py @@ -5,6 +5,9 @@ APIPropertyBase, ) +from aidial_assistant.chain.model_response_reader import ( + AssistantProtocolException, +) from aidial_assistant.commands.base import ( ExecutionCallback, ResultObject, @@ -13,8 +16,8 @@ from aidial_assistant.commands.plugin_callback import PluginChainCallback from aidial_assistant.commands.run_plugin import PluginInfo from aidial_assistant.model.model_client import ( - ModelClient, Message, + ModelClient, ReasonLengthException, Tool, ) @@ -74,9 +77,13 @@ async def run( arg: dict[str, Any], execution_callback: ExecutionCallback, ) -> ResultObject: - query: str = arg["query"] + if name not in self.addons: + raise AssistantProtocolException( + f"Addon '{name}' not found. Available addons: {list(self.addons.keys())}" + ) addon = self.addons[name] + ops = collect_operations( addon.info.open_api, addon.info.ai_plugin.api.url ) @@ -86,7 +93,7 @@ async def run( messages = [ Message.system(addon.info.ai_plugin.description_for_model), - Message.user(query), + Message.user(arg["query"]), ] chain_callback = PluginChainCallback(execution_callback) try: diff --git a/aidial_assistant/tools_chain/history.py b/aidial_assistant/tools_chain/history.py new file mode 100644 index 0000000..e69de29 diff --git a/aidial_assistant/tools_chain/tools_chain.py b/aidial_assistant/tools_chain/tools_chain.py index f6cea15..1a901d3 100644 --- a/aidial_assistant/tools_chain/tools_chain.py +++ b/aidial_assistant/tools_chain/tools_chain.py @@ -5,27 +5,29 @@ from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback -from aidial_assistant.chain.history import History +from aidial_assistant.chain.command_result import ( + CommandInvocation, + CommandResult, + Status, + commands_to_text, + responses_to_text, +) from aidial_assistant.model.model_client import ( - ModelClient, - Message, ExtraResultsCallback, - ToolCall, + Message, + ModelClient, Tool, + ToolCall, ) from aidial_assistant.tools_chain.tool_runner import ToolRunner def _publish_command( - command_callback: CommandCallback, name: str, arguments: str + command_callback: CommandCallback, name: str, arguments: dict[str, Any] ): command_callback.on_command(name) args_callback = command_callback.args_callback() - args_callback.on_args_start() - arg_callback = args_callback.arg_callback() - arg_callback.on_arg(arguments) - arg_callback.on_arg_end() - args_callback.on_args_end() + args_callback.on_args(arguments) class ToolCallsCallback(ExtraResultsCallback): @@ -66,34 +68,47 @@ async def run_chat(self, messages: list[Message], callback: ChainCallback): ) ) + commands: list[CommandInvocation] = [] + results: list[CommandResult] = [] for tool_call in tool_calls_callback.tool_calls: function = tool_call["function"] name = function["name"] - arguments = function["arguments"] + arguments = json.loads(function["arguments"]) + commands.append(CommandInvocation(command=name, args=arguments)) with callback.command_callback() as command_callback: _publish_command(command_callback, name, arguments) try: result = await self.tool_runner.run( name, - json.loads(arguments), + arguments, command_callback.execution_callback(), ) messages.append( Message( - role=Role.USER, - name=name, + role=Role.TOOL, tool_call_id=tool_call["id"], content=result.text, ) ) command_callback.on_result(result) + results.append( + CommandResult( + status=Status.SUCCESS, response=result.text + ) + ) except Exception as e: messages.append( Message( - role=Role.USER, - name=name, + role=Role.TOOL, tool_call_id=tool_call["id"], content=str(e), ) ) command_callback.on_error(e) + results.append( + CommandResult(status=Status.ERROR, response=str(e)) + ) + + callback.on_state( + commands_to_text(commands), responses_to_text(results) + ) diff --git a/aidial_assistant/utils/exceptions.py b/aidial_assistant/utils/exceptions.py index 44f7b50..d4b774d 100644 --- a/aidial_assistant/utils/exceptions.py +++ b/aidial_assistant/utils/exceptions.py @@ -2,7 +2,7 @@ from functools import wraps from aidial_sdk import HTTPException -from openai import OpenAIError, APIError +from openai import APIError logger = logging.getLogger(__name__) diff --git a/poetry.lock b/poetry.lock index b4b3fbb..410b913 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aidial-sdk" -version = "0.3.0" +version = "0.5.0" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "aidial_sdk-0.3.0-py3-none-any.whl", hash = "sha256:67f4efd1e44a1442741d089295e560257a42ec28658ad18040ada93ce1cb08bb"}, - {file = "aidial_sdk-0.3.0.tar.gz", hash = "sha256:f5da5e57e839765d0440c5ddf3586748c20f6ee6233b85a70454bce16a6359cd"}, + {file = "aidial_sdk-0.5.0-py3-none-any.whl", hash = "sha256:db0cb45440d055a4361cdd35baf3b7db4d51c3c5b7c63a901ca920638937a26f"}, + {file = "aidial_sdk-0.5.0.tar.gz", hash = "sha256:29df146c44953ed90cecb07fb58c2087c800c511fa6a1a515392ed4de3b44621"}, ] [package.dependencies] @@ -1885,4 +1885,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "73ae85c178411c402c34edf6adde3638a7719ae4f25a6e1aee63b2efa43a673a" +content-hash = "86678261ef71da9d241126587fab83682f3f9c9ef905878db5c7b152fbde619a" diff --git a/pyproject.toml b/pyproject.toml index c572760..4fc3ca5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ openai = "^1.3.9" pydantic = "1.10.13" pyyaml = "^6.0.1" typing-extensions = "^4.8.0" -aidial-sdk = "^0.3.0" +aidial-sdk = "^0.5.0" aiohttp = "^3.9.0" openapi-schema-pydantic = "^1.2.4" openapi-pydantic = "^0.3.2" diff --git a/tests/unit_tests/utils/test_exception_handler.py b/tests/unit_tests/utils/test_exception_handler.py index be7ef83..a9bfc31 100644 --- a/tests/unit_tests/utils/test_exception_handler.py +++ b/tests/unit_tests/utils/test_exception_handler.py @@ -1,7 +1,7 @@ import httpx import pytest from aidial_sdk import HTTPException -from openai import OpenAIError, APIStatusError +from openai import APIStatusError, OpenAIError from aidial_assistant.utils.exceptions import ( RequestParameterValidationError, From eb0d90adba9aca96da9219168c0019ab3a8658f0 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Fri, 5 Jan 2024 14:02:52 +0000 Subject: [PATCH 03/23] Small fixes. --- aidial_assistant/application/assistant_application.py | 2 +- aidial_assistant/commands/run_plugin.py | 2 +- aidial_assistant/model/model_client.py | 5 +---- aidial_assistant/tools_chain/addon_runner.py | 2 +- 4 files changed, 4 insertions(+), 7 deletions(-) diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index 2d4e937..98f6971 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -186,7 +186,7 @@ async def chat_completion( client=AsyncAzureOpenAI( azure_endpoint=self.args.openai_conf.api_base, api_key=request.api_key, - api_version="2023-12-01-preview", + api_version=request.api_version, ), model_args=chat_args, ) diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index 15cc00f..44c4412 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -66,7 +66,7 @@ async def _run_plugin( ) -> ResultObject: if name not in self.plugins: raise ValueError( - f"Unknown addon: {name}. Available addons: {[*self.plugins.keys()]}" + f"Unknown addon: {name}. Available addons: {list(self.plugins.keys())}" ) plugin = self.plugins[name] diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index f657c89..a6a7055 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -18,10 +18,7 @@ class Message(BaseModel): tool_calls: list[dict[str, Any]] | None = None def to_openai_message(self) -> dict[str, str]: - result = {"role": self.role.value} - - if self.content is not None: - result["content"] = self.content + result = {"role": self.role.value, "content": self.content} if self.tool_call_id: result["tool_call_id"] = self.tool_call_id diff --git a/aidial_assistant/tools_chain/addon_runner.py b/aidial_assistant/tools_chain/addon_runner.py index 51178a5..52fdf7e 100644 --- a/aidial_assistant/tools_chain/addon_runner.py +++ b/aidial_assistant/tools_chain/addon_runner.py @@ -79,7 +79,7 @@ async def run( ) -> ResultObject: if name not in self.addons: raise AssistantProtocolException( - f"Addon '{name}' not found. Available addons: {list(self.addons.keys())}" + f"Unknown addon '{name}. Available addons: {list(self.addons.keys())}" ) addon = self.addons[name] From cb5cbc912bf371174b9838287296e265510c3378 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Mon, 8 Jan 2024 20:34:51 +0000 Subject: [PATCH 04/23] More fixes. --- .../application/assistant_application.py | 129 ++++++++------- .../application/assistant_callback.py | 29 +--- aidial_assistant/application/prompts.py | 13 +- .../chain/callbacks/args_callback.py | 14 +- aidial_assistant/chain/command_chain.py | 45 +++--- aidial_assistant/chain/command_result.py | 4 +- aidial_assistant/chain/history.py | 3 +- .../chain/model_response_reader.py | 12 +- aidial_assistant/commands/base.py | 4 +- aidial_assistant/commands/open_api.py | 8 +- aidial_assistant/commands/run_plugin.py | 31 ++-- aidial_assistant/commands/run_tool.py | 100 ++++++++++++ aidial_assistant/model/model_client.py | 44 +----- aidial_assistant/tools_chain/addon_runner.py | 103 ------------ aidial_assistant/tools_chain/history.py | 0 aidial_assistant/tools_chain/http_runner.py | 23 --- aidial_assistant/tools_chain/tool_runner.py | 11 -- aidial_assistant/tools_chain/tools_chain.py | 149 +++++++++++------- aidial_assistant/utils/exceptions.py | 2 +- aidial_assistant/utils/open_ai.py | 61 +++++++ aidial_assistant/utils/state.py | 26 ++- .../chain/test_command_chain_best_effort.py | 20 ++- tests/unit_tests/chain/test_history.py | 2 +- tests/unit_tests/chain/test_model_client.py | 105 ------------ tests/unit_tests/model/test_model_client.py | 56 ++++--- tests/unit_tests/utils/test_state.py | 7 +- tests/utils/async_helper.py | 4 + 27 files changed, 479 insertions(+), 526 deletions(-) create mode 100644 aidial_assistant/commands/run_tool.py delete mode 100644 aidial_assistant/tools_chain/addon_runner.py delete mode 100644 aidial_assistant/tools_chain/history.py delete mode 100644 aidial_assistant/tools_chain/http_runner.py delete mode 100644 aidial_assistant/tools_chain/tool_runner.py create mode 100644 aidial_assistant/utils/open_ai.py delete mode 100644 tests/unit_tests/chain/test_model_client.py diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index 57b6a69..a8eaca3 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -23,28 +23,27 @@ MAIN_SYSTEM_DIALOG_MESSAGE, ) from aidial_assistant.chain.command_chain import CommandChain, CommandDict -from aidial_assistant.chain.command_result import ( - CommandInvocation, - Commands, - Responses, -) -from aidial_assistant.chain.history import History, ScopedMessage, MessageScope +from aidial_assistant.chain.command_result import Commands, Responses +from aidial_assistant.chain.history import History, MessageScope, ScopedMessage from aidial_assistant.commands.reply import Reply from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin +from aidial_assistant.commands.run_tool import RunTool from aidial_assistant.model.model_client import ( Message, ModelClient, ReasonLengthException, - Tool, ToolCall, - FunctionCall, ) -from aidial_assistant.tools_chain.addon_runner import AddonRunner from aidial_assistant.tools_chain.tools_chain import ToolsChain from aidial_assistant.utils.exceptions import ( RequestParameterValidationError, unhandled_exception_handler, ) +from aidial_assistant.utils.open_ai import ( + FunctionCall, + Tool, + construct_function, +) from aidial_assistant.utils.open_ai_plugin import ( AddonTokenSource, get_open_ai_plugin_info, @@ -102,23 +101,17 @@ def _validate_request(request: Request) -> None: def _construct_function(name: str, description: str) -> Tool: - return { - "type": "function", - "function": { - "name": name, - "description": description, - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": "A task written in natural language", - } - }, - "required": ["query"], - }, + return construct_function( + name, + description, + { + "query": { + "type": "string", + "description": "A task written in natural language", + } }, - } + ["query"], + ) def _convert_commands_to_tools( @@ -130,6 +123,12 @@ def _convert_commands_to_tools( for scoped_message in scoped_messages: message = scoped_message.message if scoped_message.scope == MessageScope.INTERNAL: + if not message.content: + raise RequestParameterValidationError( + "State is broken. Content cannot be empty.", + param="messages", + ) + if message.role == Role.ASSISTANT: commands: Commands = json.loads(message.content) messages.append( @@ -140,12 +139,8 @@ def _convert_commands_to_tools( index=index, id=str(next_tool_id + index), function=FunctionCall( - name=command["args"][0], - arguments=json.dumps( - { - "query": command["args"][1], - } - ), + name=command["command"], + arguments=json.dumps(command["arguments"]), ), type="function", ) @@ -207,19 +202,21 @@ async def chat_completion( (addon_reference.url for addon_reference in addon_references), ) - addons: dict[str, PluginInfo] = {} + addons: list[PluginInfo] = [] # DIAL Core has own names for addons, so in stages we need to map them to the names used by the user addon_name_mapping: dict[str, str] = {} for addon_reference in addon_references: info = await get_open_ai_plugin_info(addon_reference.url) - addons[info.ai_plugin.name_for_model] = PluginInfo( - info=info, - auth=get_plugin_auth( - info.ai_plugin.auth.type, - info.ai_plugin.auth.authorization_type, - addon_reference.url, - token_source, - ), + addons.append( + PluginInfo( + info=info, + auth=get_plugin_auth( + info.ai_plugin.auth.type, + info.ai_plugin.auth.authorization_type, + addon_reference.url, + token_source, + ), + ) ) if addon_reference.name: @@ -229,35 +226,40 @@ async def chat_completion( if request.model in {"gpt-4-turbo-1106", "anthropic.claude-v2-1"}: await AssistantApplication._run_native_tools_chat( - model, addons, request, response + model, addons, addon_name_mapping, request, response ) else: await AssistantApplication._run_emulated_tools_chat( - model, addons, request, response + model, addons, addon_name_mapping, request, response ) @staticmethod async def _run_emulated_tools_chat( model: ModelClient, - addons: dict[str, PluginInfo], + addons: list[PluginInfo], + addon_name_mapping: dict[str, str], request: Request, response: Response, ): # TODO: Add max_addons_dialogue_tokens as a request parameter max_addons_dialogue_tokens = 1000 + + def create_command(addon: PluginInfo): + return lambda: RunPlugin(model, addon, max_addons_dialogue_tokens) + command_dict: CommandDict = { - RunPlugin.token(): lambda: RunPlugin( - model, addons, max_addons_dialogue_tokens - ), - Reply.token(): Reply, + addon.info.ai_plugin.name_for_model: create_command(addon) + for addon in addons } + command_dict[Reply.token()] = Reply + chain = CommandChain( model_client=model, name="ASSISTANT", command_dict=command_dict ) addon_descriptions = { - name: addon.info.open_api.info.description + addon.info.ai_plugin.name_for_model: addon.info.open_api.info.description or addon.info.ai_plugin.description_for_human - for name, addon in addons.items() + for addon in addons } history = History( assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build( @@ -304,23 +306,32 @@ async def _run_emulated_tools_chat( @staticmethod async def _run_native_tools_chat( model: ModelClient, - addons: dict[str, PluginInfo], + addons: list[PluginInfo], + addon_name_mapping: dict[str, str], request: Request, response: Response, ): - chain = ToolsChain( - model, - [ - _construct_function(k, v.info.ai_plugin.description_for_human) - for k, v in addons.items() - ], - AddonRunner(model, addons), - ) + tools: list[Tool] = [ + _construct_function( + addon.info.ai_plugin.name_for_model, + addon.info.ai_plugin.description_for_human, + ) + for addon in addons + ] + + def create_command(addon: PluginInfo): + return lambda: RunTool(model, addon) + + command_dict: CommandDict = { + addon.info.ai_plugin.name_for_model: create_command(addon) + for addon in addons + } + chain = ToolsChain(model, tools, command_dict) choice = response.create_single_choice() choice.open() - callback = AssistantChainCallback(choice) + callback = AssistantChainCallback(choice, addon_name_mapping) finish_reason = FinishReason.STOP messages = _convert_commands_to_tools(parse_history(request.messages)) try: diff --git a/aidial_assistant/application/assistant_callback.py b/aidial_assistant/application/assistant_callback.py index ff57b3f..e7f0bf0 100644 --- a/aidial_assistant/application/assistant_callback.py +++ b/aidial_assistant/application/assistant_callback.py @@ -1,6 +1,4 @@ -import json from types import TracebackType -from typing import Callable, Any from aidial_sdk.chat_completion import Status from aidial_sdk.chat_completion.choice import Choice @@ -12,27 +10,9 @@ from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback from aidial_assistant.commands.base import ExecutionCallback, ResultObject -from aidial_assistant.commands.run_plugin import RunPlugin from aidial_assistant.utils.state import Invocation -class RunPluginArgsCallback(ArgsCallback): - def __init__( - self, - callback: Callable[[str], None], - addon_name_mapping: dict[str, str], - ): - super().__init__(callback) - self.addon_name_mapping = addon_name_mapping - - @override - def on_args(self, args: dict[str, Any]): - args = args.copy() - name = args["name"] - del args["name"] - self.callback(self.addon_name_mapping.get(name, name) + "(" + json.dumps(args) + ")") - - class AssistantCommandCallback(CommandCallback): def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]): self.stage = stage @@ -42,12 +22,7 @@ def __init__(self, stage: Stage, addon_name_mapping: dict[str, str]): @override def on_command(self, command: str): - if command == RunPlugin.token(): - self._args_callback = RunPluginArgsCallback( - self._on_stage_name, self.addon_name_mapping - ) - else: - self._on_stage_name(command) + self._on_stage_name(self.addon_name_mapping.get(command, command)) @override def execution_callback(self) -> ExecutionCallback: @@ -55,7 +30,7 @@ def execution_callback(self) -> ExecutionCallback: @override def args_callback(self) -> ArgsCallback: - return self._args_callback + return ArgsCallback(self._on_stage_name) @override def on_result(self, result: ResultObject): diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index 935d6b0..30f0dad 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -67,17 +67,12 @@ def build(self, **kwargs) -> Template: {{request_format}} ## Commands -{%- if addons %} -* run-addon -This command executes a specified addon to address a one-time task described in natural language. -Addons do not see current conversation and require all details to be provided in the query to solve the task. -Arguments: - - is one of the following addons: {%- for name, description in addons.items() %} - * {{name}} - {{description | decap}} -{%- endfor %} +* {{name}} +{{description}} +Arguments: - is the query string. -{%- endif %} +{%- endfor %} {{protocol_footer}} """.strip() diff --git a/aidial_assistant/chain/callbacks/args_callback.py b/aidial_assistant/chain/callbacks/args_callback.py index fe4423b..5730ec9 100644 --- a/aidial_assistant/chain/callbacks/args_callback.py +++ b/aidial_assistant/chain/callbacks/args_callback.py @@ -1,5 +1,4 @@ -import json -from typing import Callable, Any +from typing import Callable class ArgsCallback: @@ -8,6 +7,11 @@ class ArgsCallback: def __init__(self, callback: Callable[[str], None]): self.callback = callback - def on_args(self, args: dict[str, Any]): - """Called when the argument dict is constructed""" - self.callback("(" + json.dumps(args) + ")") + def on_args_start(self): + self.callback("(") + + def on_args_chunk(self, chunk: str): + self.callback(chunk) + + def on_args_end(self): + self.callback(")") diff --git a/aidial_assistant/chain/command_chain.py b/aidial_assistant/chain/command_chain.py index e10bd7c..889e15c 100644 --- a/aidial_assistant/chain/command_chain.py +++ b/aidial_assistant/chain/command_chain.py @@ -27,8 +27,8 @@ from aidial_assistant.commands.base import Command, FinalCommand from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream from aidial_assistant.json_stream.exceptions import JsonParsingException -from aidial_assistant.json_stream.json_node import JsonNode -from aidial_assistant.json_stream.json_parser import JsonParser +from aidial_assistant.json_stream.json_object import JsonObject +from aidial_assistant.json_stream.json_parser import JsonParser, string_node from aidial_assistant.json_stream.json_string import JsonString from aidial_assistant.model.model_client import Message, ModelClient from aidial_assistant.utils.stream import CumulativeStream @@ -74,8 +74,8 @@ def __init__( ) self.max_retry_count = max_retry_count - def _log_message(self, role: Role, content: str): - logger.debug(f"[{self.name}] {role.value}: {content}") + def _log_message(self, role: Role, content: str | None): + logger.debug(f"[{self.name}] {role.value}: {content or ''}") def _log_messages(self, messages: list[Message]): if logger.isEnabledFor(logging.DEBUG): @@ -210,11 +210,11 @@ async def _run_commands( async for invocation in request_reader.parse_invocations(): command_name = await invocation.parse_name() command = self._create_command(command_name) - args = invocation.parse_args() + args = await invocation.parse_args() if isinstance(command, FinalCommand): if len(responses) > 0: continue - message = await anext(args) + message = string_node(await args.get("message")) await CommandChain._to_result( message if isinstance(message, JsonString) @@ -237,7 +237,7 @@ async def _run_commands( def _create_command(self, name: str) -> Command: if name not in self.command_dict: raise AssistantProtocolException( - f"The command '{name}' is expected to be one of {[*self.command_dict.keys()]}" + f"The command '{name}' is expected to be one of {list(self.command_dict.keys())}" ) return self.command_dict[name]() @@ -263,21 +263,19 @@ def _reinforce_json_format(messages: list[Message]) -> list[Message]: @staticmethod async def _to_args( - args: AsyncIterator[JsonNode], callback: CommandCallback - ) -> AsyncIterator[Any]: + args: JsonObject, callback: CommandCallback + ) -> dict[str, Any]: args_callback = callback.args_callback() args_callback.on_args_start() - async for arg in args: - arg_callback = args_callback.arg_callback() - arg_callback.on_arg_start() - result = "" - async for chunk in arg.to_chunks(): - arg_callback.on_arg(chunk) - result += chunk - arg_callback.on_arg_end() - yield json.loads(result) + result = "" + async for chunk in args.to_chunks(): + args_callback.on_args_chunk(chunk) + result += chunk + parsed_args = json.loads(result) args_callback.on_args_end() + return parsed_args + @staticmethod async def _to_result(stream: AsyncIterator[str], callback: ResultCallback): try: @@ -294,20 +292,15 @@ async def _to_result(stream: AsyncIterator[str], callback: ResultCallback): async def _execute_command( name: str, command: Command, - args: AsyncIterator[JsonNode], + args: JsonObject, chain_callback: ChainCallback, ) -> CommandResult: try: with chain_callback.command_callback() as command_callback: command_callback.on_command(name) - args_list = [ - arg - async for arg in CommandChain._to_args( - args, command_callback - ) - ] response = await command.execute( - args_list, command_callback.execution_callback() + await CommandChain._to_args(args, command_callback), + command_callback.execution_callback(), ) command_callback.on_result(response) diff --git a/aidial_assistant/chain/command_result.py b/aidial_assistant/chain/command_result.py index bf8247b..1c6ef38 100644 --- a/aidial_assistant/chain/command_result.py +++ b/aidial_assistant/chain/command_result.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import List, TypedDict, Any +from typing import Any, List, TypedDict class Status(str, Enum): @@ -18,7 +18,7 @@ class CommandResult(TypedDict): class CommandInvocation(TypedDict): command: str - args: dict[str, Any] + arguments: dict[str, Any] class Commands(TypedDict): diff --git a/aidial_assistant/chain/history.py b/aidial_assistant/chain/history.py index 27bfa07..2bbe1ad 100644 --- a/aidial_assistant/chain/history.py +++ b/aidial_assistant/chain/history.py @@ -85,7 +85,8 @@ def to_protocol_messages(self) -> list[Message]: content = commands_to_text( [ CommandInvocation( - command=Reply.token(), args=[message.content] + command=Reply.token(), + arguments={"message": message.content}, ) ] ) diff --git a/aidial_assistant/chain/model_response_reader.py b/aidial_assistant/chain/model_response_reader.py index dffe11f..8e63f1a 100644 --- a/aidial_assistant/chain/model_response_reader.py +++ b/aidial_assistant/chain/model_response_reader.py @@ -46,18 +46,12 @@ async def parse_name(self) -> str: except (TypeError, KeyError) as e: raise AssistantProtocolException(f"Cannot parse command name: {e}") - async def parse_args(self) -> AsyncIterator[JsonNode]: + async def parse_args(self) -> JsonObject: try: - args = await self.node.get("args") - # HACK: model not always passes args as an array - if isinstance(args, JsonArray): - async for arg in array_node(args): - yield arg - else: - yield args + return object_node(await self.node.get("arguments")) except (TypeError, KeyError) as e: raise AssistantProtocolException( - f"Cannot parse command args array: {e}" + f"Cannot parse command arguments array: {e}" ) diff --git a/aidial_assistant/commands/base.py b/aidial_assistant/commands/base.py index b4f2fca..cb02ea9 100644 --- a/aidial_assistant/commands/base.py +++ b/aidial_assistant/commands/base.py @@ -43,7 +43,7 @@ def token() -> str: pass async def execute( - self, args: List[Any], execution_callback: ExecutionCallback + self, args: dict[str, Any], execution_callback: ExecutionCallback ) -> ResultObject: raise Exception(f"Command {self} isn't implemented") @@ -60,7 +60,7 @@ def assert_arg_count(self, args: List[Any], count: int): class FinalCommand(Command, ABC): @override async def execute( - self, args: List[Any], execution_callback: ExecutionCallback + self, args: dict[str, Any], execution_callback: ExecutionCallback ) -> ResultObject: raise Exception( f"Internal error: command {self} is final and can't be executed" diff --git a/aidial_assistant/commands/open_api.py b/aidial_assistant/commands/open_api.py index 4f90303..8f2a679 100644 --- a/aidial_assistant/commands/open_api.py +++ b/aidial_assistant/commands/open_api.py @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any from langchain.tools.openapi.utils.api_models import APIOperation from typing_extensions import override @@ -22,10 +22,8 @@ def __init__(self, op: APIOperation, plugin_auth: str | None): @override async def execute( - self, args: List[Any], execution_callback: ExecutionCallback + self, args: dict[str, Any], execution_callback: ExecutionCallback ) -> ResultObject: - self.assert_arg_count(args, 1) - return await OpenAPIEndpointRequester( self.op, self.plugin_auth - ).execute(args[0]) + ).execute(args) diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index 44c4412..443809a 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -1,5 +1,3 @@ -from typing import List - from langchain.tools import APIOperation from pydantic.main import BaseModel from typing_extensions import override @@ -40,11 +38,11 @@ class RunPlugin(Command): def __init__( self, model_client: ModelClient, - plugins: dict[str, PluginInfo], + plugin: PluginInfo, max_completion_tokens: int, ): self.model_client = model_client - self.plugins = plugins + self.plugin = plugin self.max_completion_tokens = max_completion_tokens @staticmethod @@ -53,29 +51,24 @@ def token(): @override async def execute( - self, args: List[str], execution_callback: ExecutionCallback + self, args: dict[str, str], execution_callback: ExecutionCallback ) -> ResultObject: - self.assert_arg_count(args, 2) - name = args[0] - query = args[1] + if "query" not in args: + raise Exception("query is required") + + query = args["query"] - return await self._run_plugin(name, query, execution_callback) + return await self._run_plugin(query, execution_callback) async def _run_plugin( - self, name: str, query: str, execution_callback: ExecutionCallback + self, query: str, execution_callback: ExecutionCallback ) -> ResultObject: - if name not in self.plugins: - raise ValueError( - f"Unknown addon: {name}. Available addons: {list(self.plugins.keys())}" - ) - - plugin = self.plugins[name] - info = plugin.info + info = self.plugin.info ops = collect_operations(info.open_api, info.ai_plugin.api.url) api_schema = "\n\n".join([op.to_typescript() for op in ops.values()]) # type: ignore def create_command(op: APIOperation): - return lambda: OpenAPIChatCommand(op, plugin.auth) + return lambda: OpenAPIChatCommand(op, self.plugin.auth) command_dict: dict[str, CommandConstructor] = {} for name, op in ops.items(): @@ -99,7 +92,7 @@ def create_command(op: APIOperation): chat = CommandChain( model_client=self.model_client, - name="PLUGIN:" + name, + name="PLUGIN:" + self.plugin.info.ai_plugin.name_for_model, command_dict=command_dict, max_completion_tokens=self.max_completion_tokens, ) diff --git a/aidial_assistant/commands/run_tool.py b/aidial_assistant/commands/run_tool.py new file mode 100644 index 0000000..629bc18 --- /dev/null +++ b/aidial_assistant/commands/run_tool.py @@ -0,0 +1,100 @@ +from typing import Any + +from langchain_community.tools.openapi.utils.api_models import ( + APIOperation, + APIPropertyBase, +) +from typing_extensions import override + +from aidial_assistant.chain.command_chain import CommandDict +from aidial_assistant.commands.base import ( + Command, + ExecutionCallback, + ResultObject, + TextResult, +) +from aidial_assistant.commands.open_api import OpenAPIChatCommand +from aidial_assistant.commands.plugin_callback import PluginChainCallback +from aidial_assistant.commands.run_plugin import PluginInfo +from aidial_assistant.model.model_client import ( + Message, + ModelClient, + ReasonLengthException, +) +from aidial_assistant.open_api.operation_selector import collect_operations +from aidial_assistant.tools_chain.tools_chain import ToolsChain +from aidial_assistant.utils.open_ai import Tool, construct_function + + +def _construct_property(p: APIPropertyBase) -> dict[str, Any]: + parameter = { + "type": p.type, + "description": p.description, + "default": p.default, + } + return {k: v for k, v in parameter.items() if v is not None} + + +def _construct_function(op: APIOperation) -> Tool: + properties = {} + required = [] + for p in op.properties: + properties[p.name] = _construct_property(p) + + if p.required: + required.append(p.name) + + if op.request_body is not None: + for p in op.request_body.properties: + properties[p.name] = _construct_property(p) + + if p.required: + required.append(p.name) + + return construct_function( + op.operation_id, op.description or "", properties, required + ) + + +class RunTool(Command): + def __init__(self, model: ModelClient, addon: PluginInfo): + self.model = model + self.addon = addon + + @staticmethod + def token(): + return "run-tool" + + @override + async def execute( + self, args: dict[str, Any], execution_callback: ExecutionCallback + ) -> ResultObject: + if "query" not in args: + raise Exception("query is required") + + query = args["query"] + + ops = collect_operations( + self.addon.info.open_api, self.addon.info.ai_plugin.api.url + ) + tools: list[Tool] = [_construct_function(op) for op in ops.values()] + + def create_command(op: APIOperation): + return lambda: OpenAPIChatCommand(op, self.addon.auth) + + command_dict: CommandDict = { + name: create_command(op) for name, op in ops.items() + } + + chain = ToolsChain(self.model, tools, command_dict) + + messages = [ + Message.system(self.addon.info.ai_plugin.description_for_model), + Message.user(query), + ] + chain_callback = PluginChainCallback(execution_callback) + try: + await chain.run_chat(messages, chain_callback) + return TextResult(chain_callback.result) + except ReasonLengthException: + return TextResult(chain_callback.result) diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index a6a7055..aa7bcc4 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -1,11 +1,13 @@ from abc import ABC -from typing import Any, AsyncIterator, List, TypedDict +from typing import Any, AsyncIterator, List from aidial_sdk.chat_completion import Role from aidial_sdk.utils.merge_chunks import merge from openai import AsyncOpenAI from pydantic import BaseModel +from aidial_assistant.utils.open_ai import ToolCall, Usage + class ReasonLengthException(Exception): pass @@ -15,7 +17,7 @@ class Message(BaseModel): role: Role content: str | None = None tool_call_id: str | None = None - tool_calls: list[dict[str, Any]] | None = None + tool_calls: list[ToolCall] | None = None def to_openai_message(self) -> dict[str, str]: result = {"role": self.role.value, "content": self.content} @@ -41,40 +43,6 @@ def assistant(cls, content): return cls(role=Role.ASSISTANT, content=content) -class Usage(TypedDict): - prompt_tokens: int - completion_tokens: int - - -class Parameters(TypedDict): - type: str - properties: dict[str, Any] - required: list[str] - - -class Function(TypedDict): - name: str - description: str - parameters: Parameters - - -class Tool(TypedDict): - type: str - function: Function - - -class FunctionCall(TypedDict): - name: str - arguments: str - - -class ToolCall(TypedDict): - index: int - id: str - type: str - function: FunctionCall - - class ExtraResultsCallback: def on_discarded_messages(self, discarded_messages: int): pass @@ -112,8 +80,8 @@ async def agenerate( **self.model_args, extra_body=kwargs, stream=True, - messages=[message.to_openai_message() for message in messages], - ) + messages=[message.to_openai_message() for message in messages], # type: ignore + ) # type: ignore finish_reason_length = False tool_calls_chunks = [] diff --git a/aidial_assistant/tools_chain/addon_runner.py b/aidial_assistant/tools_chain/addon_runner.py deleted file mode 100644 index 52fdf7e..0000000 --- a/aidial_assistant/tools_chain/addon_runner.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import Any - -from langchain_community.tools.openapi.utils.api_models import ( - APIOperation, - APIPropertyBase, -) - -from aidial_assistant.chain.model_response_reader import ( - AssistantProtocolException, -) -from aidial_assistant.commands.base import ( - ExecutionCallback, - ResultObject, - TextResult, -) -from aidial_assistant.commands.plugin_callback import PluginChainCallback -from aidial_assistant.commands.run_plugin import PluginInfo -from aidial_assistant.model.model_client import ( - Message, - ModelClient, - ReasonLengthException, - Tool, -) -from aidial_assistant.open_api.operation_selector import collect_operations -from aidial_assistant.tools_chain.http_runner import HttpRunner -from aidial_assistant.tools_chain.tool_runner import ToolRunner -from aidial_assistant.tools_chain.tools_chain import ToolsChain - - -def build_property(p: APIPropertyBase) -> dict[str, Any]: - parameter = { - "type": p.type, - "description": p.description, - "default": p.default, - } - return {k: v for k, v in parameter.items() if v is not None} - - -def construct_function(op: APIOperation) -> Tool: - properties = {} - required = [] - for p in op.properties: - properties[p.name] = build_property(p) - - if p.required: - required.append(p.name) - - if op.request_body is not None: - for p in op.request_body.properties: - properties[p.name] = build_property(p) - - if p.required: - required.append(p.name) - - return { - "type": "function", - "function": { - "name": op.operation_id, - "description": op.description or "", - "parameters": { - "type": "object", - "properties": properties, - "required": required, - }, - }, - } - - -class AddonRunner(ToolRunner): - def __init__(self, model: ModelClient, addons: dict[str, PluginInfo]): - self.model = model - self.addons = addons - - async def run( - self, - name: str, - arg: dict[str, Any], - execution_callback: ExecutionCallback, - ) -> ResultObject: - if name not in self.addons: - raise AssistantProtocolException( - f"Unknown addon '{name}. Available addons: {list(self.addons.keys())}" - ) - - addon = self.addons[name] - - ops = collect_operations( - addon.info.open_api, addon.info.ai_plugin.api.url - ) - tools = [construct_function(op) for op in ops.values()] - - chain = ToolsChain(self.model, tools, HttpRunner(ops, addon.auth)) - - messages = [ - Message.system(addon.info.ai_plugin.description_for_model), - Message.user(arg["query"]), - ] - chain_callback = PluginChainCallback(execution_callback) - try: - await chain.run_chat(messages, chain_callback) - return TextResult(chain_callback.result) - except ReasonLengthException: - return TextResult(chain_callback.result) diff --git a/aidial_assistant/tools_chain/history.py b/aidial_assistant/tools_chain/history.py deleted file mode 100644 index e69de29..0000000 diff --git a/aidial_assistant/tools_chain/http_runner.py b/aidial_assistant/tools_chain/http_runner.py deleted file mode 100644 index 7070848..0000000 --- a/aidial_assistant/tools_chain/http_runner.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Any - -from langchain_community.tools.openapi.utils.api_models import APIOperation - -from aidial_assistant.commands.base import ExecutionCallback, ResultObject -from aidial_assistant.open_api.requester import OpenAPIEndpointRequester -from aidial_assistant.tools_chain.tool_runner import ToolRunner - - -class HttpRunner(ToolRunner): - def __init__(self, ops: dict[str, APIOperation], auth: str): - self.ops = ops - self.auth = auth - - async def run( - self, - name: str, - arg: dict[str, Any], - execution_callback: ExecutionCallback, - ) -> ResultObject: - return await OpenAPIEndpointRequester( - self.ops[name], self.auth - ).execute(arg) diff --git a/aidial_assistant/tools_chain/tool_runner.py b/aidial_assistant/tools_chain/tool_runner.py deleted file mode 100644 index 253c260..0000000 --- a/aidial_assistant/tools_chain/tool_runner.py +++ /dev/null @@ -1,11 +0,0 @@ -from abc import ABC -from typing import Any - -from aidial_assistant.commands.base import ExecutionCallback - - -class ToolRunner(ABC): - async def run( - self, name: str, arg: Any, execution_callback: ExecutionCallback - ): - pass diff --git a/aidial_assistant/tools_chain/tools_chain.py b/aidial_assistant/tools_chain/tools_chain.py index 1a901d3..1453f04 100644 --- a/aidial_assistant/tools_chain/tools_chain.py +++ b/aidial_assistant/tools_chain/tools_chain.py @@ -2,9 +2,11 @@ from typing import Any from aidial_sdk.chat_completion import Role +from openai import BadRequestError from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback +from aidial_assistant.chain.command_chain import CommandDict from aidial_assistant.chain.command_result import ( CommandInvocation, CommandResult, @@ -12,22 +14,27 @@ commands_to_text, responses_to_text, ) +from aidial_assistant.chain.model_response_reader import ( + AssistantProtocolException, +) +from aidial_assistant.commands.base import Command from aidial_assistant.model.model_client import ( ExtraResultsCallback, Message, ModelClient, - Tool, ToolCall, ) -from aidial_assistant.tools_chain.tool_runner import ToolRunner +from aidial_assistant.utils.open_ai import Tool def _publish_command( - command_callback: CommandCallback, name: str, arguments: dict[str, Any] + command_callback: CommandCallback, name: str, arguments: str ): command_callback.on_command(name) args_callback = command_callback.args_callback() - args_callback.on_args(arguments) + args_callback.on_args_start() + args_callback.on_args_chunk(arguments) + args_callback.on_args_end() class ToolCallsCallback(ExtraResultsCallback): @@ -43,72 +50,106 @@ def __init__( self, model: ModelClient, tools: list[Tool], - tool_runner: ToolRunner, + command_dict: CommandDict, ): self.model = model self.tools = tools - self.tool_runner = tool_runner + self.command_dict = command_dict async def run_chat(self, messages: list[Message], callback: ChainCallback): result_callback = callback.result_callback() + dialogue: list[Message] = [] + last_message_message_count = 0 while True: tool_calls_callback = ToolCallsCallback() - async for chunk in self.model.agenerate( - messages, tool_calls_callback, tools=self.tools - ): - result_callback.on_result(chunk) + try: + async for chunk in self.model.agenerate( + messages + dialogue, tool_calls_callback, tools=self.tools + ): + result_callback.on_result(chunk) + except BadRequestError as e: + if len(dialogue) == 0 or e.code == "429": + raise + + dialogue = dialogue[:-last_message_message_count] + async for chunk in self.model.agenerate( + messages + dialogue, tool_calls_callback + ): + result_callback.on_result(chunk) + break if not tool_calls_callback.tool_calls: break - messages.append( + result_messages = await self._process_tools( + tool_calls_callback.tool_calls, callback + ) + dialogue.append( Message( role=Role.ASSISTANT, tool_calls=tool_calls_callback.tool_calls, ) ) + dialogue.extend(result_messages) + last_message_message_count = len(result_messages) + 1 + + def _create_command(self, name: str) -> Command: + if name not in self.command_dict: + raise AssistantProtocolException( + f"The tool '{name}' is expected to be one of {list(self.command_dict.keys())}" + ) + + return self.command_dict[name]() + + async def _process_tools( + self, tool_calls: list[ToolCall], callback: ChainCallback + ): + commands: list[CommandInvocation] = [] + command_results: list[CommandResult] = [] + result_messages: list[Message] = [] + for tool_call in tool_calls: + function = tool_call["function"] + name = function["name"] + arguments: dict[str, Any] = json.loads(function["arguments"]) + with callback.command_callback() as command_callback: + _publish_command(command_callback, name, json.dumps(arguments)) + command = self._create_command(name) + result = await self._execute_command( + command, + arguments, + command_callback, + ) + result_messages.append( + Message( + role=Role.TOOL, + tool_call_id=tool_call["id"], + content=result["response"], + ) + ) + command_results.append(result) + + commands.append( + CommandInvocation(command=name, arguments=arguments) + ) + + callback.on_state( + commands_to_text(commands), responses_to_text(command_results) + ) + + return result_messages - commands: list[CommandInvocation] = [] - results: list[CommandResult] = [] - for tool_call in tool_calls_callback.tool_calls: - function = tool_call["function"] - name = function["name"] - arguments = json.loads(function["arguments"]) - commands.append(CommandInvocation(command=name, args=arguments)) - with callback.command_callback() as command_callback: - _publish_command(command_callback, name, arguments) - try: - result = await self.tool_runner.run( - name, - arguments, - command_callback.execution_callback(), - ) - messages.append( - Message( - role=Role.TOOL, - tool_call_id=tool_call["id"], - content=result.text, - ) - ) - command_callback.on_result(result) - results.append( - CommandResult( - status=Status.SUCCESS, response=result.text - ) - ) - except Exception as e: - messages.append( - Message( - role=Role.TOOL, - tool_call_id=tool_call["id"], - content=str(e), - ) - ) - command_callback.on_error(e) - results.append( - CommandResult(status=Status.ERROR, response=str(e)) - ) - - callback.on_state( - commands_to_text(commands), responses_to_text(results) + @staticmethod + async def _execute_command( + command: Command, + args: dict[str, Any], + command_callback: CommandCallback, + ) -> CommandResult: + try: + result = await command.execute( + args, command_callback.execution_callback() ) + command_callback.on_result(result) + return CommandResult(status=Status.SUCCESS, response=result.text) + except Exception as e: + command_callback.on_error(e) + return CommandResult(status=Status.ERROR, response=str(e)) diff --git a/aidial_assistant/utils/exceptions.py b/aidial_assistant/utils/exceptions.py index d4b774d..791fb9b 100644 --- a/aidial_assistant/utils/exceptions.py +++ b/aidial_assistant/utils/exceptions.py @@ -30,7 +30,7 @@ def _to_http_exception(e: Exception) -> HTTPException: raise HTTPException( message=e.message, status_code=getattr(e, "status_code") or 500, - type=e.type, + type=e.type or "runtime_error", code=e.code, param=e.param, ) diff --git a/aidial_assistant/utils/open_ai.py b/aidial_assistant/utils/open_ai.py new file mode 100644 index 0000000..0aefdf8 --- /dev/null +++ b/aidial_assistant/utils/open_ai.py @@ -0,0 +1,61 @@ +from typing import Any, TypedDict + + +class Usage(TypedDict): + prompt_tokens: int + completion_tokens: int + + +class Property(TypedDict, total=False): + type: str + description: str + default: Any + + +class Parameters(TypedDict): + type: str + properties: dict[str, Property] + required: list[str] + + +class Function(TypedDict): + name: str + description: str + parameters: Parameters + + +class Tool(TypedDict): + type: str + function: Function + + +class FunctionCall(TypedDict): + name: str + arguments: str + + +class ToolCall(TypedDict): + index: int + id: str + type: str + function: FunctionCall + + +def construct_function( + name: str, + description: str, + properties: dict[str, Property], + required: list[str], +) -> Tool: + return { + "type": "function", + "function": { + "name": name, + "description": description, + "parameters": { + "type": "object", + "properties": properties, + "required": required, + }, + }, + } diff --git a/aidial_assistant/utils/state.py b/aidial_assistant/utils/state.py index a1a6734..f6c4414 100644 --- a/aidial_assistant/utils/state.py +++ b/aidial_assistant/utils/state.py @@ -1,7 +1,12 @@ +import json from typing import TypedDict from aidial_sdk.chat_completion.request import CustomContent, Message, Role +from aidial_assistant.chain.command_result import ( + CommandInvocation, + commands_to_text, +) from aidial_assistant.chain.history import MessageScope, ScopedMessage from aidial_assistant.model.model_client import Message as ModelMessage @@ -32,6 +37,23 @@ def _get_invocations(custom_content: CustomContent | None) -> list[Invocation]: return invocations +def _normalize_commands(string: str) -> str: + commands = json.loads(string) + result: list[CommandInvocation] = [] + + for command in commands["commands"]: + command_name = command["command"] + if command_name in ("run-addon", "run-plugin"): + args = command["args"] + result.append( + CommandInvocation(command=args[0], arguments={"query": args[1]}) + ) + else: + result.append(command) + + return commands_to_text(result) + + def parse_history(history: list[Message]) -> list[ScopedMessage]: messages: list[ScopedMessage] = [] for message in history: @@ -41,7 +63,9 @@ def parse_history(history: list[Message]) -> list[ScopedMessage]: messages.append( ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.assistant(invocation["request"]), + message=ModelMessage.assistant( + _normalize_commands(invocation["request"]) + ), ) ) messages.append( diff --git a/tests/unit_tests/chain/test_command_chain_best_effort.py b/tests/unit_tests/chain/test_command_chain_best_effort.py index 8711861..66daf1f 100644 --- a/tests/unit_tests/chain/test_command_chain_best_effort.py +++ b/tests/unit_tests/chain/test_command_chain_best_effort.py @@ -1,6 +1,7 @@ import json from unittest.mock import MagicMock, Mock, call +import httpx import pytest from aidial_sdk.chat_completion import Role from jinja2 import Template @@ -28,7 +29,11 @@ TEST_COMMAND_NAME = "" TEST_COMMAND_OUTPUT = "" TEST_COMMAND_REQUEST = json.dumps( - {"commands": [{"command": TEST_COMMAND_NAME, "args": ["test_arg"]}]} + { + "commands": [ + {"command": TEST_COMMAND_NAME, "arguments": {"arg": "value"}} + ] + } ) TEST_COMMAND_RESPONSE = json.dumps( {"responses": [{"status": "SUCCESS", "response": TEST_COMMAND_OUTPUT}]} @@ -154,7 +159,18 @@ async def test_no_tokens_for_tools(): model_client = Mock(spec=ModelClient) model_client.agenerate.side_effect = [ to_async_string(TEST_COMMAND_REQUEST), - BadRequestError(NO_TOKENS_ERROR), + BadRequestError( + message=NO_TOKENS_ERROR, + response=httpx.Response( + request=httpx.Request("GET", "http://localhost"), + status_code=400, + ), + body={ + "type": "", + "code": "", + "param": "", + }, + ), to_async_string(BEST_EFFORT_ANSWER), ] test_command = Mock(spec=Command) diff --git a/tests/unit_tests/chain/test_history.py b/tests/unit_tests/chain/test_history.py index 0916ec1..2a71fae 100644 --- a/tests/unit_tests/chain/test_history.py +++ b/tests/unit_tests/chain/test_history.py @@ -125,6 +125,6 @@ def test_protocol_messages_with_system_message(): Message.system(f"system message={system_message}"), Message.user(user_message), Message.assistant( - f'{{"commands": [{{"command": "reply", "args": ["{assistant_message}"]}}]}}' + f'{{"commands": [{{"command": "reply", "arguments": {{"message": "{assistant_message}"}}}}]}}' ), ] diff --git a/tests/unit_tests/chain/test_model_client.py b/tests/unit_tests/chain/test_model_client.py deleted file mode 100644 index 3457901..0000000 --- a/tests/unit_tests/chain/test_model_client.py +++ /dev/null @@ -1,105 +0,0 @@ -from unittest import mock -from unittest.mock import Mock, call - -import pytest - -from aidial_assistant.model.model_client import ( - ExtraResultsCallback, - Message, - ModelClient, - ReasonLengthException, -) -from aidial_assistant.utils.text import join_string -from tests.utils.async_helper import to_async_iterator - -API_METHOD = "openai.ChatCompletion.acreate" -MODEL_ARGS = {"model": "args"} -BUFFER_SIZE = 321 - - -@mock.patch(API_METHOD) -@pytest.mark.asyncio -async def test_discarded_messages(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( - [ - { - "choices": [{"delta": {"content": ""}}], - "statistics": {"discarded_messages": 2}, - } - ] - ) - extra_results_callback = Mock(spec=ExtraResultsCallback) - - await join_string(model_client.agenerate([], extra_results_callback)) - - assert extra_results_callback.on_discarded_messages.call_args_list == [ - call(2) - ] - - -@mock.patch(API_METHOD) -@pytest.mark.asyncio -async def test_content(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( - [ - {"choices": [{"delta": {"content": "one, "}}]}, - {"choices": [{"delta": {"content": "two, "}}]}, - {"choices": [{"delta": {"content": "three"}}]}, - ] - ) - - assert await join_string(model_client.agenerate([])) == "one, two, three" - - -@mock.patch(API_METHOD) -@pytest.mark.asyncio -async def test_reason_length_with_usage(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator( - [ - {"choices": [{"delta": {"content": "text"}}]}, - { - "choices": [ - {"delta": {"content": ""}, "finish_reason": "length"} # type: ignore - ] - }, - { - "choices": [{"delta": {"content": ""}}], - "usage": {"prompt_tokens": 1, "completion_tokens": 2}, - }, - ] - ) - - with pytest.raises(ReasonLengthException): - async for chunk in model_client.agenerate([]): - assert chunk == "text" - - assert model_client.total_prompt_tokens == 1 - assert model_client.total_completion_tokens == 2 - - -@mock.patch(API_METHOD) -@pytest.mark.asyncio -async def test_api_args(api): - model_client = ModelClient(MODEL_ARGS, BUFFER_SIZE) - api.return_value = to_async_iterator([]) - messages = [ - Message.system(content="a"), - Message.user(content="b"), - Message.assistant(content="c"), - ] - - await join_string(model_client.agenerate(messages)) - - assert api.call_args_list == [ - call( - messages=[ - {"role": "system", "content": "a"}, - {"role": "user", "content": "b"}, - {"role": "assistant", "content": "c"}, - ], - **MODEL_ARGS, - ) - ] diff --git a/tests/unit_tests/model/test_model_client.py b/tests/unit_tests/model/test_model_client.py index 200ab71..a7735e8 100644 --- a/tests/unit_tests/model/test_model_client.py +++ b/tests/unit_tests/model/test_model_client.py @@ -1,8 +1,9 @@ +from typing import Any from unittest.mock import Mock, call import pytest from openai import AsyncOpenAI -from openai.types.chat import ChatCompletionChunk +from pydantic import BaseModel from aidial_assistant.model.model_client import ( ExtraResultsCallback, @@ -10,18 +11,26 @@ ModelClient, ReasonLengthException, ) +from aidial_assistant.utils.open_ai import Usage from aidial_assistant.utils.text import join_string -from tests.utils.async_helper import to_async_iterator +from tests.utils.async_helper import to_awaitable_iterator MODEL_ARGS = {"model": "args"} +class Chunk(BaseModel): + choices: list[dict[str, Any]] + statistics: dict[str, int] | None = None + usage: Usage | None = None + + @pytest.mark.asyncio async def test_discarded_messages(): - openai_client = Mock() - openai_client.chat.completions.create.return_value = to_async_iterator( + openai_client = Mock(spec=AsyncOpenAI) + openai_client.chat = Mock() + openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ - ChatCompletionChunk( + Chunk( choices=[{"delta": {"content": ""}}], statistics={"discarded_messages": 2}, ) @@ -40,11 +49,12 @@ async def test_discarded_messages(): @pytest.mark.asyncio async def test_content(): openai_client = Mock(spec=AsyncOpenAI) - openai_client.chat.completions.create.return_value = to_async_iterator( + openai_client.chat = Mock() + openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ - {"choices": [{"delta": {"content": "one, "}}]}, - {"choices": [{"delta": {"content": "two, "}}]}, - {"choices": [{"delta": {"content": "three"}}]}, + Chunk(choices=[{"delta": {"content": "one, "}}]), + Chunk(choices=[{"delta": {"content": "two, "}}]), + Chunk(choices=[{"delta": {"content": "three"}}]), ] ) model_client = ModelClient(openai_client, MODEL_ARGS) @@ -55,18 +65,19 @@ async def test_content(): @pytest.mark.asyncio async def test_reason_length_with_usage(): openai_client = Mock(spec=AsyncOpenAI) - openai_client.chat.completions.create.return_value = to_async_iterator( + openai_client.chat = Mock() + openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ - {"choices": [{"delta": {"content": "text"}}]}, - { - "choices": [ + Chunk(choices=[{"delta": {"content": "text"}}]), + Chunk( + choices=[ {"delta": {"content": ""}, "finish_reason": "length"} # type: ignore ] - }, - { - "choices": [{"delta": {"content": ""}}], - "usage": {"prompt_tokens": 1, "completion_tokens": 2}, - }, + ), + Chunk( + choices=[{"delta": {"content": ""}}], + usage={"prompt_tokens": 1, "completion_tokens": 2}, + ), ] ) model_client = ModelClient(openai_client, MODEL_ARGS) @@ -82,7 +93,10 @@ async def test_reason_length_with_usage(): @pytest.mark.asyncio async def test_api_args(): openai_client = Mock(spec=AsyncOpenAI) - openai_client.chat.completions.create.return_value = to_async_iterator([]) + openai_client.chat = Mock() + openai_client.chat.completions.create.return_value = to_awaitable_iterator( + [] + ) model_client = ModelClient(openai_client, MODEL_ARGS) messages = [ Message.system(content="a"), @@ -90,7 +104,7 @@ async def test_api_args(): Message.assistant(content="c"), ] - await join_string(model_client.agenerate(messages)) + await join_string(model_client.agenerate(messages, extra="args")) assert openai_client.chat.completions.create.call_args_list == [ call( @@ -100,5 +114,7 @@ async def test_api_args(): {"role": "assistant", "content": "c"}, ], **MODEL_ARGS, + stream=True, + extra_body={"extra": "args"}, ) ] diff --git a/tests/unit_tests/utils/test_state.py b/tests/unit_tests/utils/test_state.py index d7b29a2..ba93972 100644 --- a/tests/unit_tests/utils/test_state.py +++ b/tests/unit_tests/utils/test_state.py @@ -8,8 +8,9 @@ SECOND_USER_MESSAGE = "" FIRST_ASSISTANT_MESSAGE = "" SECOND_ASSISTANT_MESSAGE = "" -FIRST_REQUEST = "" -SECOND_REQUEST = "" +FIRST_REQUEST = '{"commands": [{"command": "run-addon", "args": ["", ""]}]}' +FIRST_REQUEST_FIXED = '{"commands": [{"command": "", "arguments": {"query": ""}}]}' +SECOND_REQUEST = '{"commands": [{"command": "", "arguments": {"query": ""}}]}' FIRST_RESPONSE = "" SECOND_RESPONSE = "" @@ -48,7 +49,7 @@ def test_parse_history(): ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.assistant(FIRST_REQUEST), + message=ModelMessage.assistant(FIRST_REQUEST_FIXED), ), ScopedMessage( scope=MessageScope.INTERNAL, diff --git a/tests/utils/async_helper.py b/tests/utils/async_helper.py index 00e3bbb..5021323 100644 --- a/tests/utils/async_helper.py +++ b/tests/utils/async_helper.py @@ -20,3 +20,7 @@ def to_async_repeated_string( async def to_async_iterator(sequence: Iterable[T]) -> AsyncIterator[T]: for item in sequence: yield item + + +async def to_awaitable_iterator(sequence: Iterable[T]) -> AsyncIterator[T]: + return to_async_iterator(sequence) From cad2dc10e069e65deca8b0255b841301d29b5e38 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Tue, 9 Jan 2024 15:58:46 +0000 Subject: [PATCH 05/23] Check for a reserved command name. --- aidial_assistant/application/assistant_application.py | 6 ++++++ aidial_assistant/commands/run_plugin.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index a8eaca3..db51ddf 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -251,6 +251,12 @@ def create_command(addon: PluginInfo): addon.info.ai_plugin.name_for_model: create_command(addon) for addon in addons } + if Reply.token() in command_dict: + RequestParameterValidationError( + f"Addon with name '{Reply.token()}' is not allowed for model {request.model}.", + param="addons", + ) + command_dict[Reply.token()] = Reply chain = CommandChain( diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index 443809a..ff7bf7d 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -75,6 +75,8 @@ def create_command(op: APIOperation): # The function is necessary to capture the current value of op. # Otherwise, only first op will be used for all commands command_dict[name] = create_command(op) + if Reply.token() in command_dict: + Exception(f"Operation with name '{Reply.token()}' is not allowed.") command_dict[Reply.token()] = Reply From d012fee5477567fbafde39ed0b41a44454af1c21 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Wed, 10 Jan 2024 09:34:37 +0000 Subject: [PATCH 06/23] Add extra line between commands. --- aidial_assistant/application/prompts.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index 30f0dad..38517b2 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -72,6 +72,7 @@ def build(self, **kwargs) -> Template: {{description}} Arguments: - is the query string. + {%- endfor %} {{protocol_footer}} """.strip() @@ -97,6 +98,7 @@ def build(self, **kwargs) -> Template: * {{command_name}} Arguments: - + {%- endfor %} {{protocol_footer}} """.strip() From 41464dc92567b6fe173eb37085c25a16b476bf56 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Wed, 10 Jan 2024 14:57:46 +0000 Subject: [PATCH 07/23] Update dial sdk to support httpx for opentelemetry. --- aidial_assistant/app.py | 16 ++++++------ poetry.lock | 56 +++++++++++++++++++++++------------------ pyproject.toml | 2 +- 3 files changed, 40 insertions(+), 34 deletions(-) diff --git a/aidial_assistant/app.py b/aidial_assistant/app.py index db7ecb3..ff5d302 100644 --- a/aidial_assistant/app.py +++ b/aidial_assistant/app.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 import logging.config import os from pathlib import Path @@ -7,9 +6,6 @@ from aidial_sdk.telemetry.types import TelemetryConfig, TracingConfig from starlette.responses import Response -from aidial_assistant.application.assistant_application import ( - AssistantApplication, -) from aidial_assistant.utils.log_config import get_log_config log_level = os.getenv("LOG_LEVEL", "INFO") @@ -21,9 +17,13 @@ service_name="aidial-assistant", tracing=TracingConfig() ) app = DIALApp(telemetry_config=telemetry_config) -app.add_chat_completion("assistant", AssistantApplication(config_dir)) +# A delayed import is necessary to set up the httpx hook before the openai client inherits from AsyncClient. +from aidial_assistant.application.assistant_application import ( # noqa: E402 + AssistantApplication, +) -@app.get("/healthcheck/status200") -def status200() -> Response: - return Response("Service is running...", status_code=200) +app.add_chat_completion( + "assistant", + AssistantApplication(config_dir), +) diff --git a/poetry.lock b/poetry.lock index bbf6051..55e2b0a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,13 +2,13 @@ [[package]] name = "aidial-sdk" -version = "0.5.0" +version = "0.5.1" description = "Framework to create applications and model adapters for AI DIAL" optional = false python-versions = ">=3.8.1,<4.0" files = [ - {file = "aidial_sdk-0.5.0-py3-none-any.whl", hash = "sha256:db0cb45440d055a4361cdd35baf3b7db4d51c3c5b7c63a901ca920638937a26f"}, - {file = "aidial_sdk-0.5.0.tar.gz", hash = "sha256:29df146c44953ed90cecb07fb58c2087c800c511fa6a1a515392ed4de3b44621"}, + {file = "aidial_sdk-0.5.1-py3-none-any.whl", hash = "sha256:345e8f59593adf616be9b9bad6f46b98b0a6e7fcd7cc17932fabe8c266b3cfe4"}, + {file = "aidial_sdk-0.5.1.tar.gz", hash = "sha256:5bb327882c90719b3054b52f1e211c00fb9667b2c2010aeb6bbd60f6f40ea1d4"}, ] [package.dependencies] @@ -21,19 +21,20 @@ opentelemetry-exporter-prometheus = {version = "1.12.0rc1", optional = true, mar opentelemetry-instrumentation = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-aiohttp-client = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-fastapi = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} +opentelemetry-instrumentation-httpx = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-logging = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-requests = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-system-metrics = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-instrumentation-urllib = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} opentelemetry-sdk = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} +prometheus-client = {version = "0.17.1", optional = true, markers = "extra == \"telemetry\""} pydantic = ">=1.10,<3" requests = ">=2.19,<3.0" -starlette-exporter = {version = "0.16.0", optional = true, markers = "extra == \"telemetry\""} uvicorn = ">=0.19,<1.0" wrapt = ">=1.14,<2.0" [package.extras] -telemetry = ["opentelemetry-api (==1.20.0)", "opentelemetry-distro (==0.41b0)", "opentelemetry-exporter-otlp-proto-grpc (==1.20.0)", "opentelemetry-exporter-prometheus (==1.12.0rc1)", "opentelemetry-instrumentation (==0.41b0)", "opentelemetry-instrumentation-aiohttp-client (==0.41b0)", "opentelemetry-instrumentation-fastapi (==0.41b0)", "opentelemetry-instrumentation-logging (==0.41b0)", "opentelemetry-instrumentation-requests (==0.41b0)", "opentelemetry-instrumentation-system-metrics (==0.41b0)", "opentelemetry-instrumentation-urllib (==0.41b0)", "opentelemetry-sdk (==1.20.0)", "starlette-exporter (==0.16.0)"] +telemetry = ["opentelemetry-api (==1.20.0)", "opentelemetry-distro (==0.41b0)", "opentelemetry-exporter-otlp-proto-grpc (==1.20.0)", "opentelemetry-exporter-prometheus (==1.12.0rc1)", "opentelemetry-instrumentation (==0.41b0)", "opentelemetry-instrumentation-aiohttp-client (==0.41b0)", "opentelemetry-instrumentation-fastapi (==0.41b0)", "opentelemetry-instrumentation-httpx (==0.41b0)", "opentelemetry-instrumentation-logging (==0.41b0)", "opentelemetry-instrumentation-requests (==0.41b0)", "opentelemetry-instrumentation-system-metrics (==0.41b0)", "opentelemetry-instrumentation-urllib (==0.41b0)", "opentelemetry-sdk (==1.20.0)", "prometheus-client (==0.17.1)"] [[package]] name = "aiocache" @@ -1526,6 +1527,26 @@ opentelemetry-util-http = "0.41b0" instruments = ["fastapi (>=0.58,<1.0)"] test = ["httpx (>=0.22,<1.0)", "opentelemetry-instrumentation-fastapi[instruments]", "opentelemetry-test-utils (==0.41b0)", "requests (>=2.23,<3.0)"] +[[package]] +name = "opentelemetry-instrumentation-httpx" +version = "0.41b0" +description = "OpenTelemetry HTTPX Instrumentation" +optional = false +python-versions = ">=3.7" +files = [ + {file = "opentelemetry_instrumentation_httpx-0.41b0-py3-none-any.whl", hash = "sha256:6ada84b7caa95a2889b2d883c089a977546b0102c815658b88f1c2dae713e9b2"}, + {file = "opentelemetry_instrumentation_httpx-0.41b0.tar.gz", hash = "sha256:96ebc54f3f41bfcd2fc043349c8cee4b11737602512383d437e24c39a1e4adff"}, +] + +[package.dependencies] +opentelemetry-api = ">=1.12,<2.0" +opentelemetry-instrumentation = "0.41b0" +opentelemetry-semantic-conventions = "0.41b0" + +[package.extras] +instruments = ["httpx (>=0.18.0)"] +test = ["opentelemetry-instrumentation-httpx[instruments]", "opentelemetry-sdk (>=1.12,<2.0)", "opentelemetry-test-utils (==0.41b0)"] + [[package]] name = "opentelemetry-instrumentation-logging" version = "0.41b0" @@ -1711,13 +1732,13 @@ testing = ["pytest", "pytest-benchmark"] [[package]] name = "prometheus-client" -version = "0.19.0" +version = "0.17.1" description = "Python client for the Prometheus monitoring system." optional = false -python-versions = ">=3.8" +python-versions = ">=3.6" files = [ - {file = "prometheus_client-0.19.0-py3-none-any.whl", hash = "sha256:c88b1e6ecf6b41cd8fb5731c7ae919bf66df6ec6fafa555cd6c0e16ca169ae92"}, - {file = "prometheus_client-0.19.0.tar.gz", hash = "sha256:4585b0d1223148c27a225b10dbec5ae9bc4c81a99a3fa80774fa6209935324e1"}, + {file = "prometheus_client-0.17.1-py3-none-any.whl", hash = "sha256:e537f37160f6807b8202a6fc4764cdd19bac5480ddd3e0d463c3002b34462101"}, + {file = "prometheus_client-0.17.1.tar.gz", hash = "sha256:21e674f39831ae3f8acde238afd9a27a37d0d2fb5a28ea094f0ce25d2cbf2091"}, ] [package.extras] @@ -2117,21 +2138,6 @@ anyio = ">=3.4.0,<5" [package.extras] full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] -[[package]] -name = "starlette-exporter" -version = "0.16.0" -description = "Prometheus metrics exporter for Starlette applications." -optional = false -python-versions = "*" -files = [ - {file = "starlette_exporter-0.16.0-py3-none-any.whl", hash = "sha256:9dbe8dc647acbeb8680d53cedbbb8042ca75ca1b6987f609c5601ea96ddb7422"}, - {file = "starlette_exporter-0.16.0.tar.gz", hash = "sha256:728cccf975c85d3cf2844b0110b51e1fa2dce628ef68bc38da58ad691f9b5d68"}, -] - -[package.dependencies] -prometheus-client = ">=0.12" -starlette = "*" - [[package]] name = "tenacity" version = "8.2.3" @@ -2436,4 +2442,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "ed5461ea43711f379d54e7e11a0f8bb7e425d7fb6b0be92049547d7445320eed" +content-hash = "147e4e4a549e48fafe8133a5aeb00839fb1ee7dea31f2720c8630e9b10dc7f4c" diff --git a/pyproject.toml b/pyproject.toml index 24f0249..6271286 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ openai = "^1.3.9" pydantic = "1.10.13" pyyaml = "^6.0.1" typing-extensions = "^4.8.0" -aidial-sdk = { version = "^0.5.0", extras = ["telemetry"] } +aidial-sdk = { version = "^0.5.1", extras = ["telemetry"] } aiohttp = "^3.9.0" openapi-schema-pydantic = "^1.2.4" openapi-pydantic = "^0.3.2" From 5db6bb0c8c0dda93232d33e02dec7502ae112d3c Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Wed, 10 Jan 2024 15:02:14 +0000 Subject: [PATCH 08/23] Remove unused import. --- aidial_assistant/app.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aidial_assistant/app.py b/aidial_assistant/app.py index ff5d302..7a4c466 100644 --- a/aidial_assistant/app.py +++ b/aidial_assistant/app.py @@ -4,7 +4,6 @@ from aidial_sdk import DIALApp from aidial_sdk.telemetry.types import TelemetryConfig, TracingConfig -from starlette.responses import Response from aidial_assistant.utils.log_config import get_log_config From 076c2ea5074e9d46a12abbe6aa31479ebc213296 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Wed, 10 Jan 2024 16:42:49 +0000 Subject: [PATCH 09/23] Clarify prompts. --- aidial_assistant/application/prompts.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index 38517b2..c1af1db 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -69,11 +69,10 @@ def build(self, **kwargs) -> Template: ## Commands {%- for name, description in addons.items() %} * {{name}} -{{description}} +{{description.strip()}} Arguments: - - is the query string. - -{%- endfor %} + - is the query string written in natural language. +{% endfor %} {{protocol_footer}} """.strip() @@ -98,8 +97,7 @@ def build(self, **kwargs) -> Template: * {{command_name}} Arguments: - - -{%- endfor %} +{% endfor %} {{protocol_footer}} """.strip() From eb88263253be23d5f906c4f24ff20773f7a0acdd Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Wed, 10 Jan 2024 17:14:49 +0000 Subject: [PATCH 10/23] Minor prompt adjustments. --- aidial_assistant/application/prompts.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index c1af1db..7766e86 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -30,6 +30,7 @@ def build(self, **kwargs) -> Template: _REQUEST_FORMAT_TEXT = """ You should ALWAYS reply with a JSON containing an array of commands: +```json { "commands": [ { @@ -40,19 +41,22 @@ def build(self, **kwargs) -> Template: } ] } -The commands are invoked by system on user's behalf. +``` +The commands are invoked by the system on the user's behalf. """.strip() _PROTOCOL_FOOTER = """ * reply -The command delivers final response to the user. +The command delivers the final response to the user. Arguments: - - is a string containing the final and complete result for the user. +- `message` is a string containing the final and complete result for the user. Your goal is to answer user questions. Use relevant commands when they help to achieve the goal. ## Example +```json {"commands": [{"command": "reply", "arguments": {"message": "Hello, world!"}}]} +``` """.strip() _SYSTEM_TEXT = """ @@ -71,7 +75,7 @@ def build(self, **kwargs) -> Template: * {{name}} {{description.strip()}} Arguments: - - is the query string written in natural language. +- `query` is a query string written in natural language. {% endfor %} {{protocol_footer}} """.strip() From 762cbc218123fbb43001dcf7882423b22a4833ad Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Thu, 11 Jan 2024 15:26:52 +0000 Subject: [PATCH 11/23] Improve prompt formatting for gpt-4-0314. --- aidial_assistant/application/prompts.py | 9 ++++++--- aidial_assistant/commands/base.py | 18 +++++++++++------- aidial_assistant/commands/run_plugin.py | 6 ++---- aidial_assistant/commands/run_tool.py | 6 ++---- 4 files changed, 21 insertions(+), 18 deletions(-) diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index 7766e86..f9cf757 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -48,8 +48,9 @@ def build(self, **kwargs) -> Template: _PROTOCOL_FOOTER = """ * reply The command delivers the final response to the user. + Arguments: -- `message` is a string containing the final and complete result for the user. + - 'message' is a string containing the final and complete result for the user. Your goal is to answer user questions. Use relevant commands when they help to achieve the goal. @@ -74,8 +75,9 @@ def build(self, **kwargs) -> Template: {%- for name, description in addons.items() %} * {{name}} {{description.strip()}} + Arguments: -- `query` is a query string written in natural language. + - 'query' is a query written in natural language. {% endfor %} {{protocol_footer}} """.strip() @@ -99,8 +101,9 @@ def build(self, **kwargs) -> Template: ## Commands {%- for command_name in command_names %} * {{command_name}} + Arguments: - - + - {% endfor %} {{protocol_footer}} """.strip() diff --git a/aidial_assistant/commands/base.py b/aidial_assistant/commands/base.py index cb02ea9..cd5c301 100644 --- a/aidial_assistant/commands/base.py +++ b/aidial_assistant/commands/base.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Callable, List, TypedDict +from typing import Any, Callable, List, TypedDict, TypeVar from typing_extensions import override @@ -50,12 +50,6 @@ async def execute( def __str__(self) -> str: return self.token() - def assert_arg_count(self, args: List[Any], count: int): - if len(args) != count: - raise ValueError( - f"Command {self} expects {count} args, but got {len(args)}" - ) - class FinalCommand(Command, ABC): @override @@ -70,3 +64,13 @@ async def execute( class CommandObject(TypedDict): command: str args: List[str] + + +T = TypeVar("T") + + +def get_required_field(args: dict[str, T], field: str) -> T: + value = args.get(field) + if value is None: + raise Exception(f"Parameter '{field}' is required") + return value diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index ff7bf7d..4a26a47 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -16,6 +16,7 @@ ExecutionCallback, ResultObject, TextResult, + get_required_field, ) from aidial_assistant.commands.open_api import OpenAPIChatCommand from aidial_assistant.commands.plugin_callback import PluginChainCallback @@ -53,10 +54,7 @@ def token(): async def execute( self, args: dict[str, str], execution_callback: ExecutionCallback ) -> ResultObject: - if "query" not in args: - raise Exception("query is required") - - query = args["query"] + query = get_required_field(args, "query") return await self._run_plugin(query, execution_callback) diff --git a/aidial_assistant/commands/run_tool.py b/aidial_assistant/commands/run_tool.py index 629bc18..546ed6f 100644 --- a/aidial_assistant/commands/run_tool.py +++ b/aidial_assistant/commands/run_tool.py @@ -12,6 +12,7 @@ ExecutionCallback, ResultObject, TextResult, + get_required_field, ) from aidial_assistant.commands.open_api import OpenAPIChatCommand from aidial_assistant.commands.plugin_callback import PluginChainCallback @@ -69,10 +70,7 @@ def token(): async def execute( self, args: dict[str, Any], execution_callback: ExecutionCallback ) -> ResultObject: - if "query" not in args: - raise Exception("query is required") - - query = args["query"] + query = get_required_field(args, "query") ops = collect_operations( self.addon.info.open_api, self.addon.info.ai_plugin.api.url From 55886a10018dc399dfd05c169965634ffec79250 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Thu, 11 Jan 2024 16:16:58 +0000 Subject: [PATCH 12/23] Use latest openai api version to support tools. --- aidial_assistant/application/assistant_application.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index db51ddf..34d9325 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -192,7 +192,8 @@ async def chat_completion( client=AsyncAzureOpenAI( azure_endpoint=self.args.openai_conf.api_base, api_key=request.api_key, - api_version=request.api_version, + # 2023-12-01-preview is needed to support tools + api_version="2023-12-01-preview", ), model_args=chat_args, ) From ca0a1db463ecd54165d4c587ab94a47712a26c11 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Fri, 12 Jan 2024 15:51:15 +0000 Subject: [PATCH 13/23] Address review comments. --- aidial_assistant/app.py | 9 +- .../application/addons_dialogue_limiter.py | 7 +- .../application/assistant_application.py | 145 +++++------------ aidial_assistant/chain/command_chain.py | 60 +++---- aidial_assistant/chain/dialogue.py | 9 +- aidial_assistant/chain/history.py | 60 +++---- aidial_assistant/commands/base.py | 3 + aidial_assistant/commands/run_plugin.py | 9 +- aidial_assistant/commands/run_tool.py | 48 +++--- aidial_assistant/model/model_client.py | 88 +++++------ aidial_assistant/tools_chain/tools_chain.py | 148 +++++++++++++----- aidial_assistant/utils/open_ai.py | 90 ++++++----- aidial_assistant/utils/state.py | 36 +++-- .../test_addons_dialogue_limiter.py | 21 ++- .../chain/test_command_chain_best_effort.py | 78 ++++----- tests/unit_tests/chain/test_history.py | 51 +++--- tests/unit_tests/model/test_model_client.py | 16 +- tests/unit_tests/utils/test_state.py | 18 +-- 18 files changed, 474 insertions(+), 422 deletions(-) diff --git a/aidial_assistant/app.py b/aidial_assistant/app.py index 7a4c466..1f1e385 100644 --- a/aidial_assistant/app.py +++ b/aidial_assistant/app.py @@ -8,7 +8,6 @@ from aidial_assistant.utils.log_config import get_log_config log_level = os.getenv("LOG_LEVEL", "INFO") -config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs")) logging.config.dictConfig(get_log_config(log_level)) @@ -22,7 +21,13 @@ AssistantApplication, ) +config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs")) +tools_supporting_deployments: set[str] = set( + os.getenv( + "TOOLS_SUPPORTING_DEPLOYMENTS", "gpt-4-turbo-1106,anthropic.claude-v2-1" + ).split(",") +) app.add_chat_completion( "assistant", - AssistantApplication(config_dir), + AssistantApplication(config_dir, tools_supporting_deployments), ) diff --git a/aidial_assistant/application/addons_dialogue_limiter.py b/aidial_assistant/application/addons_dialogue_limiter.py index fc3b130..e850931 100644 --- a/aidial_assistant/application/addons_dialogue_limiter.py +++ b/aidial_assistant/application/addons_dialogue_limiter.py @@ -4,7 +4,10 @@ LimitExceededException, ModelRequestLimiter, ) -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ( + ChatCompletionMessageParam, + ModelClient, +) class AddonsDialogueLimiter(ModelRequestLimiter): @@ -16,7 +19,7 @@ def __init__(self, max_dialogue_tokens: int, model_client: ModelClient): self._initial_tokens: int | None = None @override - async def verify_limit(self, messages: list[Message]): + async def verify_limit(self, messages: list[ChatCompletionMessageParam]): if self._initial_tokens is None: self._initial_tokens = await self.model_client.count_tokens( messages diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index 34d9325..3ba163a 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -1,14 +1,13 @@ -import json import logging from pathlib import Path +from typing import Tuple from aidial_sdk.chat_completion import FinishReason from aidial_sdk.chat_completion.base import ChatCompletion -from aidial_sdk.chat_completion.request import Addon -from aidial_sdk.chat_completion.request import Message as SdkMessage -from aidial_sdk.chat_completion.request import Request, Role +from aidial_sdk.chat_completion.request import Addon, Message, Request, Role from aidial_sdk.chat_completion.response import Response from openai.lib.azure import AsyncAzureOpenAI +from openai.types.chat import ChatCompletionToolParam from pydantic import BaseModel from aidial_assistant.application.addons_dialogue_limiter import ( @@ -22,28 +21,29 @@ MAIN_BEST_EFFORT_TEMPLATE, MAIN_SYSTEM_DIALOG_MESSAGE, ) -from aidial_assistant.chain.command_chain import CommandChain, CommandDict -from aidial_assistant.chain.command_result import Commands, Responses -from aidial_assistant.chain.history import History, MessageScope, ScopedMessage +from aidial_assistant.chain.command_chain import ( + CommandChain, + CommandConstructor, + CommandDict, +) +from aidial_assistant.chain.history import History from aidial_assistant.commands.reply import Reply from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin from aidial_assistant.commands.run_tool import RunTool from aidial_assistant.model.model_client import ( - Message, ModelClient, ReasonLengthException, - ToolCall, ) -from aidial_assistant.tools_chain.tools_chain import ToolsChain +from aidial_assistant.tools_chain.tools_chain import ( + CommandToolDict, + ToolsChain, + convert_commands_to_tools, +) from aidial_assistant.utils.exceptions import ( RequestParameterValidationError, unhandled_exception_handler, ) -from aidial_assistant.utils.open_ai import ( - FunctionCall, - Tool, - construct_function, -) +from aidial_assistant.utils.open_ai import construct_tool from aidial_assistant.utils.open_ai_plugin import ( AddonTokenSource, get_open_ai_plugin_info, @@ -83,7 +83,7 @@ def _validate_addons(addons: list[Addon] | None) -> list[AddonReference]: return addon_references -def _validate_messages(messages: list[SdkMessage]) -> None: +def _validate_messages(messages: list[Message]) -> None: if not messages: raise RequestParameterValidationError( "Message list cannot be empty.", param="messages" @@ -95,13 +95,8 @@ def _validate_messages(messages: list[SdkMessage]) -> None: ) -def _validate_request(request: Request) -> None: - _validate_messages(request.messages) - _validate_addons(request.addons) - - -def _construct_function(name: str, description: str) -> Tool: - return construct_function( +def _construct_tool(name: str, description: str) -> ChatCompletionToolParam: + return construct_tool( name, description, { @@ -114,71 +109,12 @@ def _construct_function(name: str, description: str) -> Tool: ) -def _convert_commands_to_tools( - scoped_messages: list[ScopedMessage], -) -> list[Message]: - messages: list[Message] = [] - next_tool_id: int = 0 - last_call_count: int = 0 - for scoped_message in scoped_messages: - message = scoped_message.message - if scoped_message.scope == MessageScope.INTERNAL: - if not message.content: - raise RequestParameterValidationError( - "State is broken. Content cannot be empty.", - param="messages", - ) - - if message.role == Role.ASSISTANT: - commands: Commands = json.loads(message.content) - messages.append( - Message( - role=Role.ASSISTANT, - tool_calls=[ - ToolCall( - index=index, - id=str(next_tool_id + index), - function=FunctionCall( - name=command["command"], - arguments=json.dumps(command["arguments"]), - ), - type="function", - ) - for index, command in enumerate( - commands["commands"] - ) - ], - ) - ) - last_call_count = len(commands["commands"]) - next_tool_id += last_call_count - elif message.role == Role.USER: - responses: Responses = json.loads(message.content) - response_count = len(responses["responses"]) - if response_count != last_call_count: - raise RequestParameterValidationError( - f"Expected {last_call_count} responses, but got {response_count}.", - param="messages", - ) - first_tool_id = next_tool_id - last_call_count - messages.extend( - [ - Message( - role=Role.TOOL, - tool_call_id=str(first_tool_id + index), - content=response["response"], - ) - for index, response in enumerate(responses["responses"]) - ] - ) - else: - messages.append(scoped_message.message) - return messages - - class AssistantApplication(ChatCompletion): - def __init__(self, config_dir: Path): + def __init__( + self, config_dir: Path, tools_supporting_deployments: set[str] + ): self.args = parse_args(config_dir) + self.tools_supporting_deployments = tools_supporting_deployments @unhandled_exception_handler async def chat_completion( @@ -203,12 +139,12 @@ async def chat_completion( (addon_reference.url for addon_reference in addon_references), ) - addons: list[PluginInfo] = [] + plugins: list[PluginInfo] = [] # DIAL Core has own names for addons, so in stages we need to map them to the names used by the user addon_name_mapping: dict[str, str] = {} for addon_reference in addon_references: info = await get_open_ai_plugin_info(addon_reference.url) - addons.append( + plugins.append( PluginInfo( info=info, auth=get_plugin_auth( @@ -225,13 +161,13 @@ async def chat_completion( info.ai_plugin.name_for_model ] = addon_reference.name - if request.model in {"gpt-4-turbo-1106", "anthropic.claude-v2-1"}: + if request.model in self.tools_supporting_deployments: await AssistantApplication._run_native_tools_chat( - model, addons, addon_name_mapping, request, response + model, plugins, addon_name_mapping, request, response ) else: await AssistantApplication._run_emulated_tools_chat( - model, addons, addon_name_mapping, request, response + model, plugins, addon_name_mapping, request, response ) @staticmethod @@ -313,34 +249,31 @@ def create_command(addon: PluginInfo): @staticmethod async def _run_native_tools_chat( model: ModelClient, - addons: list[PluginInfo], + plugins: list[PluginInfo], addon_name_mapping: dict[str, str], request: Request, response: Response, ): - tools: list[Tool] = [ - _construct_function( - addon.info.ai_plugin.name_for_model, - addon.info.ai_plugin.description_for_human, + def create_command_tool( + plugin: PluginInfo, + ) -> Tuple[CommandConstructor, ChatCompletionToolParam]: + return lambda: RunTool(model, plugin), _construct_tool( + plugin.info.ai_plugin.name_for_model, + plugin.info.ai_plugin.description_for_human, ) - for addon in addons - ] - def create_command(addon: PluginInfo): - return lambda: RunTool(model, addon) - - command_dict: CommandDict = { - addon.info.ai_plugin.name_for_model: create_command(addon) - for addon in addons + command_tool_dict: CommandToolDict = { + plugin.info.ai_plugin.name_for_model: create_command_tool(plugin) + for plugin in plugins } - chain = ToolsChain(model, tools, command_dict) + chain = ToolsChain(model, command_tool_dict) choice = response.create_single_choice() choice.open() callback = AssistantChainCallback(choice, addon_name_mapping) finish_reason = FinishReason.STOP - messages = _convert_commands_to_tools(parse_history(request.messages)) + messages = convert_commands_to_tools(parse_history(request.messages)) try: await chain.run_chat(messages, callback) except ReasonLengthException: diff --git a/aidial_assistant/chain/command_chain.py b/aidial_assistant/chain/command_chain.py index 889e15c..ea76ab4 100644 --- a/aidial_assistant/chain/command_chain.py +++ b/aidial_assistant/chain/command_chain.py @@ -1,14 +1,13 @@ import json import logging from abc import ABC, abstractmethod -from typing import Any, AsyncIterator, Callable, Tuple, cast +from typing import Any, AsyncIterator, Tuple, cast -from aidial_sdk.chat_completion.request import Role from openai import BadRequestError from aidial_assistant.application.prompts import ENFORCE_JSON_FORMAT_TEMPLATE +from aidial_assistant.chain.callbacks.args_callback import ArgsCallback from aidial_assistant.chain.callbacks.chain_callback import ChainCallback -from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback from aidial_assistant.chain.command_result import ( CommandInvocation, @@ -24,13 +23,20 @@ CommandsReader, skip_to_json_start, ) -from aidial_assistant.commands.base import Command, FinalCommand +from aidial_assistant.commands.base import ( + Command, + CommandConstructor, + FinalCommand, +) from aidial_assistant.json_stream.chunked_char_stream import ChunkedCharStream from aidial_assistant.json_stream.exceptions import JsonParsingException from aidial_assistant.json_stream.json_object import JsonObject from aidial_assistant.json_stream.json_parser import JsonParser, string_node from aidial_assistant.json_stream.json_string import JsonString -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ( + ChatCompletionMessageParam, + ModelClient, +) from aidial_assistant.utils.stream import CumulativeStream logger = logging.getLogger(__name__) @@ -41,7 +47,6 @@ # Later, the upper limit will be provided by the DIAL Core (proxy). MAX_MODEL_COMPLETION_CHUNKS = 32000 -CommandConstructor = Callable[[], Command] CommandDict = dict[str, CommandConstructor] @@ -51,7 +56,7 @@ class LimitExceededException(Exception): class ModelRequestLimiter(ABC): @abstractmethod - async def verify_limit(self, messages: list[Message]): + async def verify_limit(self, messages: list[ChatCompletionMessageParam]): pass @@ -74,13 +79,13 @@ def __init__( ) self.max_retry_count = max_retry_count - def _log_message(self, role: Role, content: str | None): - logger.debug(f"[{self.name}] {role.value}: {content or ''}") + def _log_message(self, role: str, content: str | None): + logger.debug(f"[{self.name}] {role}: {content or ''}") - def _log_messages(self, messages: list[Message]): + def _log_messages(self, messages: list[ChatCompletionMessageParam]): if logger.isEnabledFor(logging.DEBUG): for message in messages: - self._log_message(message.role, message.content) + self._log_message(message["role"], message.get("content")) async def run_chat( self, @@ -128,7 +133,7 @@ async def run_chat( async def _run_with_protocol_failure_retries( self, callback: ChainCallback, - messages: list[Message], + messages: list[ChatCompletionMessageParam], model_request_limiter: ModelRequestLimiter | None = None, ) -> DialogueTurn | None: last_error: Exception | None = None @@ -186,7 +191,7 @@ async def _run_with_protocol_failure_retries( ) ) finally: - self._log_message(Role.ASSISTANT, chunk_stream.buffer) + self._log_message("assistant", chunk_stream.buffer) except (BadRequestError, LimitExceededException) as e: if last_error: # Retries can increase the prompt size, which may lead to token overflow. @@ -243,29 +248,28 @@ def _create_command(self, name: str) -> Command: return self.command_dict[name]() async def _generate_result( - self, messages: list[Message], callback: ChainCallback + self, + messages: list[ChatCompletionMessageParam], + callback: ChainCallback, ): stream = self.model_client.agenerate(messages) await CommandChain._to_result(stream, callback.result_callback()) @staticmethod - def _reinforce_json_format(messages: list[Message]) -> list[Message]: - last_message = messages[-1] - return messages[:-1] + [ - Message( - role=last_message.role, - content=ENFORCE_JSON_FORMAT_TEMPLATE.render( - response=last_message.content - ), - ), - ] + def _reinforce_json_format( + messages: list[ChatCompletionMessageParam], + ) -> list[ChatCompletionMessageParam]: + last_message = messages[-1].copy() + last_message["content"] = ENFORCE_JSON_FORMAT_TEMPLATE.render( + response=last_message.get("content", "") + ) + return messages[:-1] + [last_message] @staticmethod async def _to_args( - args: JsonObject, callback: CommandCallback + args: JsonObject, args_callback: ArgsCallback ) -> dict[str, Any]: - args_callback = callback.args_callback() args_callback.on_args_start() result = "" async for chunk in args.to_chunks(): @@ -299,7 +303,9 @@ async def _execute_command( with chain_callback.command_callback() as command_callback: command_callback.on_command(name) response = await command.execute( - await CommandChain._to_args(args, command_callback), + await CommandChain._to_args( + args, command_callback.args_callback() + ), command_callback.execution_callback(), ) command_callback.on_result(response) diff --git a/aidial_assistant/chain/dialogue.py b/aidial_assistant/chain/dialogue.py index b8b3077..f1abda5 100644 --- a/aidial_assistant/chain/dialogue.py +++ b/aidial_assistant/chain/dialogue.py @@ -1,6 +1,7 @@ +from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel -from aidial_assistant.model.model_client import Message +from aidial_assistant.utils.open_ai import assistant_message, user_message class DialogueTurn(BaseModel): @@ -10,11 +11,11 @@ class DialogueTurn(BaseModel): class Dialogue: def __init__(self): - self.messages: list[Message] = [] + self.messages: list[ChatCompletionMessageParam] = [] def append(self, dialogue_turn: DialogueTurn): - self.messages.append(Message.assistant(dialogue_turn.assistant_message)) - self.messages.append(Message.user(dialogue_turn.user_message)) + self.messages.append(assistant_message(dialogue_turn.assistant_message)) + self.messages.append(user_message(dialogue_turn.user_message)) def pop(self): self.messages.pop() diff --git a/aidial_assistant/chain/history.py b/aidial_assistant/chain/history.py index 2bbe1ad..2324bda 100644 --- a/aidial_assistant/chain/history.py +++ b/aidial_assistant/chain/history.py @@ -1,17 +1,17 @@ from enum import Enum -from aidial_sdk.chat_completion import Role from jinja2 import Template +from openai.types.chat import ChatCompletionMessageParam from pydantic import BaseModel -from aidial_assistant.application.prompts import ENFORCE_JSON_FORMAT_TEMPLATE from aidial_assistant.chain.command_result import ( CommandInvocation, commands_to_text, ) from aidial_assistant.chain.dialogue import Dialogue from aidial_assistant.commands.reply import Reply -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ModelClient +from aidial_assistant.utils.open_ai import assistant_message, system_message class ContextLengthExceeded(Exception): @@ -25,19 +25,7 @@ class MessageScope(str, Enum): class ScopedMessage(BaseModel): scope: MessageScope = MessageScope.USER - message: Message - - -def enforce_json_format(messages: list[Message]) -> list[Message]: - last_message = messages[-1] - return messages[:-1] + [ - Message( - role=last_message.role, - content=ENFORCE_JSON_FORMAT_TEMPLATE.render( - response=last_message.content - ), - ), - ] + message: ChatCompletionMessageParam class History: @@ -58,45 +46,45 @@ def __init__( if message.scope == MessageScope.USER ) - def to_protocol_messages(self) -> list[Message]: - messages: list[Message] = [] + def to_protocol_messages(self) -> list[ChatCompletionMessageParam]: + messages: list[ChatCompletionMessageParam] = [] for index, scoped_message in enumerate(self.scoped_messages): message = scoped_message.message scope = scoped_message.scope if index == 0: - if message.role == Role.SYSTEM: + if message["role"] == "role": messages.append( - Message.system( + system_message( self.assistant_system_message_template.render( - system_prefix=message.content + system_prefix=message["content"] ) ) ) else: messages.append( - Message.system( + system_message( self.assistant_system_message_template.render() ) ) messages.append(message) - elif scope == MessageScope.USER and message.role == Role.ASSISTANT: + elif scope == MessageScope.USER and message["role"] == "assistant": # Clients see replies in plain text, but the model should understand how to reply appropriately. content = commands_to_text( [ CommandInvocation( command=Reply.token(), - arguments={"message": message.content}, + arguments={"message": message.get("content", "")}, ) ] ) - messages.append(Message.assistant(content=content)) + messages.append(assistant_message(content)) else: messages.append(message) return messages - def to_user_messages(self) -> list[Message]: + def to_user_messages(self) -> list[ChatCompletionMessageParam]: return [ scoped_message.message for scoped_message in self.scoped_messages @@ -105,18 +93,16 @@ def to_user_messages(self) -> list[Message]: def to_best_effort_messages( self, error: str, dialogue: Dialogue - ) -> list[Message]: + ) -> list[ChatCompletionMessageParam]: messages = self.to_user_messages() - last_message = messages[-1] - messages[-1] = Message( - role=last_message.role, - content=self.best_effort_template.render( - message=last_message.content, - error=error, - dialogue=dialogue.messages, - ), + last_message = messages[-1].copy() + last_message["content"] = self.best_effort_template.render( + message=last_message.get("content", ""), + error=error, + dialogue=dialogue.messages, ) + messages[-1] = last_message return messages @@ -147,7 +133,7 @@ def _skip_messages(self, discarded_messages: int) -> list[ScopedMessage]: message_iterator = iter(self.scoped_messages) for _ in range(discarded_messages): current_message = next(message_iterator) - while current_message.message.role == Role.SYSTEM: + while current_message.message["role"] == "system": # System messages should be kept in the history messages.append(current_message) current_message = next(message_iterator) @@ -158,7 +144,7 @@ def _skip_messages(self, discarded_messages: int) -> list[ScopedMessage]: # Internal messages (i.e. addon requests/responses) are always followed by an assistant reply assert ( - current_message.message.role == Role.ASSISTANT + current_message.message["role"] == "assistant" ), "Internal messages must be followed by an assistant reply." remaining_messages = list(message_iterator) diff --git a/aidial_assistant/commands/base.py b/aidial_assistant/commands/base.py index cd5c301..1c12d8b 100644 --- a/aidial_assistant/commands/base.py +++ b/aidial_assistant/commands/base.py @@ -66,6 +66,9 @@ class CommandObject(TypedDict): args: List[str] +CommandConstructor = Callable[[], Command] + + T = TypeVar("T") diff --git a/aidial_assistant/commands/run_plugin.py b/aidial_assistant/commands/run_plugin.py index 4a26a47..96f2913 100644 --- a/aidial_assistant/commands/run_plugin.py +++ b/aidial_assistant/commands/run_plugin.py @@ -22,11 +22,11 @@ from aidial_assistant.commands.plugin_callback import PluginChainCallback from aidial_assistant.commands.reply import Reply from aidial_assistant.model.model_client import ( - Message, ModelClient, ReasonLengthException, ) from aidial_assistant.open_api.operation_selector import collect_operations +from aidial_assistant.utils.open_ai import user_message from aidial_assistant.utils.open_ai_plugin import OpenAIPluginInfo @@ -87,7 +87,7 @@ def create_command(op: APIOperation): best_effort_template=ADDON_BEST_EFFORT_TEMPLATE.build( api_schema=api_schema ), - scoped_messages=[ScopedMessage(message=Message.user(query))], + scoped_messages=[ScopedMessage(message=user_message(query))], ) chat = CommandChain( @@ -100,6 +100,7 @@ def create_command(op: APIOperation): callback = PluginChainCallback(execution_callback) try: await chat.run_chat(history, callback) - return TextResult(callback.result) except ReasonLengthException: - return TextResult(callback.result) + pass + + return TextResult(callback.result) diff --git a/aidial_assistant/commands/run_tool.py b/aidial_assistant/commands/run_tool.py index 546ed6f..dcf527e 100644 --- a/aidial_assistant/commands/run_tool.py +++ b/aidial_assistant/commands/run_tool.py @@ -4,9 +4,9 @@ APIOperation, APIPropertyBase, ) +from openai.types.chat import ChatCompletionToolParam from typing_extensions import override -from aidial_assistant.chain.command_chain import CommandDict from aidial_assistant.commands.base import ( Command, ExecutionCallback, @@ -18,25 +18,31 @@ from aidial_assistant.commands.plugin_callback import PluginChainCallback from aidial_assistant.commands.run_plugin import PluginInfo from aidial_assistant.model.model_client import ( - Message, ModelClient, ReasonLengthException, ) from aidial_assistant.open_api.operation_selector import collect_operations -from aidial_assistant.tools_chain.tools_chain import ToolsChain -from aidial_assistant.utils.open_ai import Tool, construct_function +from aidial_assistant.tools_chain.tools_chain import ( + CommandTool, + CommandToolDict, + ToolsChain, +) +from aidial_assistant.utils.open_ai import ( + construct_tool, + system_message, + user_message, +) def _construct_property(p: APIPropertyBase) -> dict[str, Any]: parameter = { "type": p.type, "description": p.description, - "default": p.default, } return {k: v for k, v in parameter.items() if v is not None} -def _construct_function(op: APIOperation) -> Tool: +def _construct_function(op: APIOperation) -> ChatCompletionToolParam: properties = {} required = [] for p in op.properties: @@ -52,15 +58,15 @@ def _construct_function(op: APIOperation) -> Tool: if p.required: required.append(p.name) - return construct_function( + return construct_tool( op.operation_id, op.description or "", properties, required ) class RunTool(Command): - def __init__(self, model: ModelClient, addon: PluginInfo): + def __init__(self, model: ModelClient, plugin: PluginInfo): self.model = model - self.addon = addon + self.plugin = plugin @staticmethod def token(): @@ -73,26 +79,28 @@ async def execute( query = get_required_field(args, "query") ops = collect_operations( - self.addon.info.open_api, self.addon.info.ai_plugin.api.url + self.plugin.info.open_api, self.plugin.info.ai_plugin.api.url ) - tools: list[Tool] = [_construct_function(op) for op in ops.values()] - def create_command(op: APIOperation): - return lambda: OpenAPIChatCommand(op, self.addon.auth) + def create_command_tool(op: APIOperation) -> CommandTool: + return lambda: OpenAPIChatCommand( + op, self.plugin.auth + ), _construct_function(op) - command_dict: CommandDict = { - name: create_command(op) for name, op in ops.items() + command_tool_dict: CommandToolDict = { + name: create_command_tool(op) for name, op in ops.items() } - chain = ToolsChain(self.model, tools, command_dict) + chain = ToolsChain(self.model, command_tool_dict) messages = [ - Message.system(self.addon.info.ai_plugin.description_for_model), - Message.user(query), + system_message(self.plugin.info.ai_plugin.description_for_model), + user_message(query), ] chain_callback = PluginChainCallback(execution_callback) try: await chain.run_chat(messages, chain_callback) - return TextResult(chain_callback.result) except ReasonLengthException: - return TextResult(chain_callback.result) + pass + + return TextResult(chain_callback.result) diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index aa7bcc4..926ea94 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -1,48 +1,20 @@ from abc import ABC from typing import Any, AsyncIterator, List -from aidial_sdk.chat_completion import Role from aidial_sdk.utils.merge_chunks import merge from openai import AsyncOpenAI -from pydantic import BaseModel +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, +) -from aidial_assistant.utils.open_ai import ToolCall, Usage +from aidial_assistant.utils.open_ai import Usage class ReasonLengthException(Exception): pass -class Message(BaseModel): - role: Role - content: str | None = None - tool_call_id: str | None = None - tool_calls: list[ToolCall] | None = None - - def to_openai_message(self) -> dict[str, str]: - result = {"role": self.role.value, "content": self.content} - - if self.tool_call_id: - result["tool_call_id"] = self.tool_call_id - - if self.tool_calls: - result["tool_calls"] = self.tool_calls - - return result - - @classmethod - def system(cls, content): - return cls(role=Role.SYSTEM, content=content) - - @classmethod - def user(cls, content): - return cls(role=Role.USER, content=content) - - @classmethod - def assistant(cls, content): - return cls(role=Role.ASSISTANT, content=content) - - class ExtraResultsCallback: def on_discarded_messages(self, discarded_messages: int): pass @@ -50,7 +22,9 @@ def on_discarded_messages(self, discarded_messages: int): def on_prompt_tokens(self, prompt_tokens: int): pass - def on_tool_calls(self, tool_calls: list[ToolCall]): + def on_tool_calls( + self, tool_calls: list[ChatCompletionMessageToolCallParam] + ): pass @@ -72,7 +46,7 @@ def __init__(self, client: AsyncOpenAI, model_args: dict[str, Any]): async def agenerate( self, - messages: List[Message], + messages: List[ChatCompletionMessageParam], extra_results_callback: ExtraResultsCallback | None = None, **kwargs, ) -> AsyncIterator[str]: @@ -80,14 +54,14 @@ async def agenerate( **self.model_args, extra_body=kwargs, stream=True, - messages=[message.to_openai_message() for message in messages], # type: ignore - ) # type: ignore + messages=messages, + ) finish_reason_length = False - tool_calls_chunks = [] + tool_calls_chunks: list[list[dict[str, Any]]] = [] async for chunk in model_result: # type: ignore - chunk = chunk.dict() - usage: Usage | None = chunk.get("usage") + all_values = chunk.dict() + usage: Usage | None = all_values.get("usage") if usage: prompt_tokens = usage["prompt_tokens"] self._total_prompt_tokens += prompt_tokens @@ -96,7 +70,7 @@ async def agenerate( extra_results_callback.on_prompt_tokens(prompt_tokens) if extra_results_callback: - discarded_messages: int | None = chunk.get( + discarded_messages: int | None = all_values.get( "statistics", {} ).get("discarded_messages") if discarded_messages is not None: @@ -104,29 +78,37 @@ async def agenerate( discarded_messages ) - choice = chunk["choices"][0] - delta = choice["delta"] - text = delta.get("content") - if text: - yield text + choice = chunk.choices[0] + delta = choice.delta + if delta.content: + yield delta.content - tool_calls_chunk = delta.get("tool_calls") - if tool_calls_chunk: - tool_calls_chunks.append(tool_calls_chunk) + if delta.tool_calls: + tool_calls_chunks.append( + [ + tool_call_chunk.dict() + for tool_call_chunk in delta.tool_calls + ] + ) - if choice.get("finish_reason") == "length": + if choice.finish_reason == "length": finish_reason_length = True if finish_reason_length: raise ReasonLengthException() if extra_results_callback and tool_calls_chunks: - tool_calls: list[ToolCall] = merge(*tool_calls_chunks) + tool_calls: list[ChatCompletionMessageToolCallParam] = [ + ChatCompletionMessageToolCallParam(**tool_call) + for tool_call in merge(*tool_calls_chunks) + ] extra_results_callback.on_tool_calls(tool_calls) # TODO: Use a dedicated endpoint for counting tokens. # This request may throw an error if the number of tokens is too large. - async def count_tokens(self, messages: list[Message]) -> int: + async def count_tokens( + self, messages: list[ChatCompletionMessageParam] + ) -> int: class PromptTokensCallback(ExtraResultsCallback): def __init__(self): self.token_count: int | None = None @@ -147,7 +129,7 @@ def on_prompt_tokens(self, prompt_tokens: int): # TODO: Use a dedicated endpoint for discarded_messages. async def get_discarded_messages( - self, messages: list[Message], max_prompt_tokens: int + self, messages: list[ChatCompletionMessageParam], max_prompt_tokens: int ) -> int: class DiscardedMessagesCallback(ExtraResultsCallback): def __init__(self): diff --git a/aidial_assistant/tools_chain/tools_chain.py b/aidial_assistant/tools_chain/tools_chain.py index 1453f04..97b170c 100644 --- a/aidial_assistant/tools_chain/tools_chain.py +++ b/aidial_assistant/tools_chain/tools_chain.py @@ -1,30 +1,97 @@ import json -from typing import Any +from typing import Any, Tuple, cast -from aidial_sdk.chat_completion import Role from openai import BadRequestError +from openai.types.chat import ( + ChatCompletionMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionToolParam, +) +from openai.types.chat.chat_completion_message_tool_call_param import Function from aidial_assistant.chain.callbacks.chain_callback import ChainCallback from aidial_assistant.chain.callbacks.command_callback import CommandCallback -from aidial_assistant.chain.command_chain import CommandDict +from aidial_assistant.chain.command_chain import CommandConstructor from aidial_assistant.chain.command_result import ( CommandInvocation, CommandResult, + Commands, + Responses, Status, commands_to_text, responses_to_text, ) +from aidial_assistant.chain.history import MessageScope, ScopedMessage from aidial_assistant.chain.model_response_reader import ( AssistantProtocolException, ) from aidial_assistant.commands.base import Command from aidial_assistant.model.model_client import ( ExtraResultsCallback, - Message, ModelClient, - ToolCall, ) -from aidial_assistant.utils.open_ai import Tool +from aidial_assistant.utils.exceptions import RequestParameterValidationError +from aidial_assistant.utils.open_ai import tool_calls_message, tool_message + + +def convert_commands_to_tools( + scoped_messages: list[ScopedMessage], +) -> list[ChatCompletionMessageParam]: + messages: list[ChatCompletionMessageParam] = [] + next_tool_id: int = 0 + last_call_count: int = 0 + for scoped_message in scoped_messages: + message = scoped_message.message + if scoped_message.scope == MessageScope.INTERNAL: + content = cast(str, message.get("content")) + if not content: + raise RequestParameterValidationError( + "State is broken. Content cannot be empty.", + param="messages", + ) + + if message["role"] == "assistant": + commands: Commands = json.loads(content) + messages.append( + tool_calls_message( + [ + ChatCompletionMessageToolCallParam( + id=str(next_tool_id + index), + function=Function( + name=command["command"], + arguments=json.dumps(command["arguments"]), + ), + type="function", + ) + for index, command in enumerate( + commands["commands"] + ) + ], + ) + ) + last_call_count = len(commands["commands"]) + next_tool_id += last_call_count + elif message["role"] == "user": + responses: Responses = json.loads(content) + response_count = len(responses["responses"]) + if response_count != last_call_count: + raise RequestParameterValidationError( + f"Expected {last_call_count} responses, but got {response_count}.", + param="messages", + ) + first_tool_id = next_tool_id - last_call_count + messages.extend( + [ + tool_message( + content=response["response"], + tool_call_id=str(first_tool_id + index), + ) + for index, response in enumerate(responses["responses"]) + ] + ) + else: + messages.append(scoped_message.message) + return messages def _publish_command( @@ -37,41 +104,46 @@ def _publish_command( args_callback.on_args_end() +CommandTool = Tuple[CommandConstructor, ChatCompletionToolParam] +CommandToolDict = dict[str, CommandTool] + + class ToolCallsCallback(ExtraResultsCallback): def __init__(self): - self.tool_calls: list[ToolCall] = [] + self.tool_calls: list[ChatCompletionMessageToolCallParam] = [] - def on_tool_calls(self, tool_calls: list[ToolCall]): + def on_tool_calls( + self, tool_calls: list[ChatCompletionMessageToolCallParam] + ): self.tool_calls = tool_calls class ToolsChain: - def __init__( - self, - model: ModelClient, - tools: list[Tool], - command_dict: CommandDict, - ): + def __init__(self, model: ModelClient, command_tool_dict: CommandToolDict): self.model = model - self.tools = tools - self.command_dict = command_dict + self.command_tool_dict = command_tool_dict - async def run_chat(self, messages: list[Message], callback: ChainCallback): + async def run_chat( + self, + messages: list[ChatCompletionMessageParam], + callback: ChainCallback, + ): result_callback = callback.result_callback() - dialogue: list[Message] = [] - last_message_message_count = 0 + dialogue: list[ChatCompletionMessageParam] = [] + last_message_block_length = 0 + tools = [tool for _, tool in self.command_tool_dict.values()] while True: tool_calls_callback = ToolCallsCallback() try: async for chunk in self.model.agenerate( - messages + dialogue, tool_calls_callback, tools=self.tools + messages + dialogue, tool_calls_callback, tools=tools ): result_callback.on_result(chunk) except BadRequestError as e: if len(dialogue) == 0 or e.code == "429": raise - dialogue = dialogue[:-last_message_message_count] + dialogue = dialogue[:-last_message_block_length] async for chunk in self.model.agenerate( messages + dialogue, tool_calls_callback ): @@ -81,32 +153,35 @@ async def run_chat(self, messages: list[Message], callback: ChainCallback): if not tool_calls_callback.tool_calls: break - result_messages = await self._process_tools( - tool_calls_callback.tool_calls, callback - ) dialogue.append( - Message( - role=Role.ASSISTANT, - tool_calls=tool_calls_callback.tool_calls, + tool_calls_message( + tool_calls_callback.tool_calls, ) ) + result_messages = await self._run_tools( + tool_calls_callback.tool_calls, callback + ) dialogue.extend(result_messages) - last_message_message_count = len(result_messages) + 1 + last_message_block_length = len(result_messages) + 1 def _create_command(self, name: str) -> Command: - if name not in self.command_dict: + if name not in self.command_tool_dict: raise AssistantProtocolException( - f"The tool '{name}' is expected to be one of {list(self.command_dict.keys())}" + f"The tool '{name}' is expected to be one of {list(self.command_tool_dict.keys())}" ) - return self.command_dict[name]() + command, _ = self.command_tool_dict[name] + + return command() - async def _process_tools( - self, tool_calls: list[ToolCall], callback: ChainCallback + async def _run_tools( + self, + tool_calls: list[ChatCompletionMessageToolCallParam], + callback: ChainCallback, ): commands: list[CommandInvocation] = [] command_results: list[CommandResult] = [] - result_messages: list[Message] = [] + result_messages: list[ChatCompletionMessageParam] = [] for tool_call in tool_calls: function = tool_call["function"] name = function["name"] @@ -120,10 +195,9 @@ async def _process_tools( command_callback, ) result_messages.append( - Message( - role=Role.TOOL, - tool_call_id=tool_call["id"], + tool_message( content=result["response"], + tool_call_id=tool_call["id"], ) ) command_results.append(result) diff --git a/aidial_assistant/utils/open_ai.py b/aidial_assistant/utils/open_ai.py index 0aefdf8..b72acfc 100644 --- a/aidial_assistant/utils/open_ai.py +++ b/aidial_assistant/utils/open_ai.py @@ -1,4 +1,14 @@ -from typing import Any, TypedDict +from typing import TypedDict + +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionMessageToolCallParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolMessageParam, + ChatCompletionToolParam, + ChatCompletionUserMessageParam, +) +from openai.types.shared_params import FunctionDefinition class Usage(TypedDict): @@ -9,53 +19,55 @@ class Usage(TypedDict): class Property(TypedDict, total=False): type: str description: str - default: Any -class Parameters(TypedDict): - type: str - properties: dict[str, Property] - required: list[str] +def construct_tool( + name: str, + description: str, + properties: dict[str, Property], + required: list[str], +) -> ChatCompletionToolParam: + return ChatCompletionToolParam( + type="function", + function=FunctionDefinition( + name=name, + description=description, + parameters={ + "type": "object", + "properties": properties, + "required": required, + }, + ), + ) -class Function(TypedDict): - name: str - description: str - parameters: Parameters +def system_message(content: str) -> ChatCompletionSystemMessageParam: + return ChatCompletionSystemMessageParam(role="system", content=content) -class Tool(TypedDict): - type: str - function: Function +def user_message(content: str) -> ChatCompletionUserMessageParam: + return ChatCompletionUserMessageParam(role="user", content=content) -class FunctionCall(TypedDict): - name: str - arguments: str +def assistant_message(content: str) -> ChatCompletionAssistantMessageParam: + return ChatCompletionAssistantMessageParam( + role="assistant", content=content + ) -class ToolCall(TypedDict): - index: int - id: str - type: str - function: FunctionCall +def tool_calls_message( + tool_calls: list[ChatCompletionMessageToolCallParam], +) -> ChatCompletionAssistantMessageParam: + return ChatCompletionAssistantMessageParam( + role="assistant", tool_calls=tool_calls + ) -def construct_function( - name: str, - description: str, - properties: dict[str, Property], - required: list[str], -) -> Tool: - return { - "type": "function", - "function": { - "name": name, - "description": description, - "parameters": { - "type": "object", - "properties": properties, - "required": required, - }, - }, - } +def tool_message( + content: str, tool_call_id: str +) -> ChatCompletionToolMessageParam: + return ChatCompletionToolMessageParam( + role="tool", + content=content, + tool_call_id=tool_call_id, + ) diff --git a/aidial_assistant/utils/state.py b/aidial_assistant/utils/state.py index f6c4414..3d12e54 100644 --- a/aidial_assistant/utils/state.py +++ b/aidial_assistant/utils/state.py @@ -8,7 +8,12 @@ commands_to_text, ) from aidial_assistant.chain.history import MessageScope, ScopedMessage -from aidial_assistant.model.model_client import Message as ModelMessage +from aidial_assistant.utils.exceptions import RequestParameterValidationError +from aidial_assistant.utils.open_ai import ( + assistant_message, + system_message, + user_message, +) class Invocation(TypedDict): @@ -37,7 +42,8 @@ def _get_invocations(custom_content: CustomContent | None) -> list[Invocation]: return invocations -def _normalize_commands(string: str) -> str: +def _convert_old_commands(string: str) -> str: + """Converts old commands to new format.""" commands = json.loads(string) result: list[CommandInvocation] = [] @@ -63,24 +69,32 @@ def parse_history(history: list[Message]) -> list[ScopedMessage]: messages.append( ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.assistant( - _normalize_commands(invocation["request"]) + message=assistant_message( + _convert_old_commands(invocation["request"]) ), ) ) messages.append( ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.user(invocation["response"]), + message=user_message(invocation["response"]), ) ) - messages.append( - ScopedMessage( - message=ModelMessage( - role=message.role, content=message.content or "" - ) + messages.append( + ScopedMessage(message=assistant_message(message.content or "")) + ) + elif message.role == Role.USER: + messages.append( + ScopedMessage(message=user_message(message.content or "")) + ) + elif message.role == Role.SYSTEM: + messages.append( + ScopedMessage(message=system_message(message.content or "")) + ) + else: + raise RequestParameterValidationError( + f"Role {message.role} is not supported.", param="messages" ) - ) return messages diff --git a/tests/unit_tests/application/test_addons_dialogue_limiter.py b/tests/unit_tests/application/test_addons_dialogue_limiter.py index b323b0d..871f4f6 100644 --- a/tests/unit_tests/application/test_addons_dialogue_limiter.py +++ b/tests/unit_tests/application/test_addons_dialogue_limiter.py @@ -6,7 +6,12 @@ AddonsDialogueLimiter, ) from aidial_assistant.chain.command_chain import LimitExceededException -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ModelClient +from aidial_assistant.utils.open_ai import ( + assistant_message, + system_message, + user_message, +) MAX_TOKENS = 1 @@ -17,8 +22,11 @@ async def test_dialogue_size_is_ok(): model.count_tokens.side_effect = [1, 2] limiter = AddonsDialogueLimiter(MAX_TOKENS, model) - initial_messages = [Message.system("a"), Message.user("b")] - dialogue_messages = [Message.assistant("c"), Message.user("d")] + initial_messages = [system_message("a"), user_message("b")] + dialogue_messages = [ + assistant_message("c"), + user_message("d"), + ] await limiter.verify_limit(initial_messages) await limiter.verify_limit(initial_messages + dialogue_messages) @@ -35,8 +43,11 @@ async def test_dialogue_overflow(): model.count_tokens.side_effect = [1, 3] limiter = AddonsDialogueLimiter(MAX_TOKENS, model) - initial_messages = [Message.system("a"), Message.user("b")] - dialogue_messages = [Message.assistant("c"), Message.user("d")] + initial_messages = [system_message("a"), user_message("b")] + dialogue_messages = [ + assistant_message("c"), + user_message("d"), + ] await limiter.verify_limit(initial_messages) with pytest.raises(LimitExceededException) as exc_info: diff --git a/tests/unit_tests/chain/test_command_chain_best_effort.py b/tests/unit_tests/chain/test_command_chain_best_effort.py index 66daf1f..f17fd4e 100644 --- a/tests/unit_tests/chain/test_command_chain_best_effort.py +++ b/tests/unit_tests/chain/test_command_chain_best_effort.py @@ -3,7 +3,6 @@ import httpx import pytest -from aidial_sdk.chat_completion import Role from jinja2 import Template from openai import BadRequestError @@ -16,7 +15,12 @@ ) from aidial_assistant.chain.history import History, ScopedMessage from aidial_assistant.commands.base import Command, TextResult -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ModelClient +from aidial_assistant.utils.open_ai import ( + assistant_message, + system_message, + user_message, +) from tests.utils.async_helper import to_async_string, to_async_strings SYSTEM_MESSAGE = "" @@ -46,10 +50,8 @@ "user_message={{message}}, error={{error}}, dialogue={{dialogue}}" ), scoped_messages=[ - ScopedMessage( - message=Message(role=Role.SYSTEM, content=SYSTEM_MESSAGE) - ), - ScopedMessage(message=Message(role=Role.USER, content=USER_MESSAGE)), + ScopedMessage(message=system_message(SYSTEM_MESSAGE)), + ScopedMessage(message=user_message(USER_MESSAGE)), ], ) @@ -81,14 +83,14 @@ async def test_model_doesnt_support_protocol(): assert model_client.agenerate.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(SYSTEM_MESSAGE), - Message.user(USER_MESSAGE), + system_message(SYSTEM_MESSAGE), + user_message(USER_MESSAGE), ] ), ] @@ -116,8 +118,8 @@ async def test_model_partially_supports_protocol(): result_callback = Mock(spec=ResultCallback) chain_callback.result_callback.return_value = result_callback succeeded_dialogue = [ - Message.assistant(TEST_COMMAND_REQUEST), - Message.user(TEST_COMMAND_RESPONSE), + assistant_message(TEST_COMMAND_REQUEST), + user_message(TEST_COMMAND_RESPONSE), ] await command_chain.run_chat(history=TEST_HISTORY, callback=chain_callback) @@ -131,22 +133,22 @@ async def test_model_partially_supports_protocol(): assert model_client.agenerate.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(USER_MESSAGE), - Message.assistant(TEST_COMMAND_REQUEST), - Message.user(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(USER_MESSAGE), + assistant_message(TEST_COMMAND_REQUEST), + user_message(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(SYSTEM_MESSAGE), - Message.user( + system_message(SYSTEM_MESSAGE), + user_message( f"user_message={USER_MESSAGE}, error={FAILED_PROTOCOL_ERROR}, dialogue={succeeded_dialogue}" ), ] @@ -196,22 +198,22 @@ async def test_no_tokens_for_tools(): assert model_client.agenerate.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(USER_MESSAGE), - Message.assistant(TEST_COMMAND_REQUEST), - Message.user(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(USER_MESSAGE), + assistant_message(TEST_COMMAND_REQUEST), + user_message(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(SYSTEM_MESSAGE), - Message.user( + system_message(SYSTEM_MESSAGE), + user_message( f"user_message={USER_MESSAGE}, error={NO_TOKENS_ERROR}, dialogue=[]" ), ] @@ -254,14 +256,14 @@ async def test_model_request_limit_exceeded(): assert model_client.agenerate.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(SYSTEM_MESSAGE), - Message.user( + system_message(SYSTEM_MESSAGE), + user_message( f"user_message={USER_MESSAGE}, error={LIMIT_EXCEEDED_ERROR}, dialogue=[]" ), ] @@ -270,16 +272,16 @@ async def test_model_request_limit_exceeded(): assert model_request_limiter.verify_limit.call_args_list == [ call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(f"{USER_MESSAGE}{ENFORCE_JSON_FORMAT}"), ] ), call( [ - Message.system(f"system_prefix={SYSTEM_MESSAGE}"), - Message.user(USER_MESSAGE), - Message.assistant(TEST_COMMAND_REQUEST), - Message.user(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), + system_message(f"system_prefix={SYSTEM_MESSAGE}"), + user_message(USER_MESSAGE), + assistant_message(TEST_COMMAND_REQUEST), + user_message(f"{TEST_COMMAND_RESPONSE}{ENFORCE_JSON_FORMAT}"), ] ), ] diff --git a/tests/unit_tests/chain/test_history.py b/tests/unit_tests/chain/test_history.py index 2a71fae..c3e6317 100644 --- a/tests/unit_tests/chain/test_history.py +++ b/tests/unit_tests/chain/test_history.py @@ -4,7 +4,12 @@ from jinja2 import Template from aidial_assistant.chain.history import History, MessageScope, ScopedMessage -from aidial_assistant.model.model_client import Message, ModelClient +from aidial_assistant.model.model_client import ModelClient +from aidial_assistant.utils.open_ai import ( + assistant_message, + system_message, + user_message, +) TRUNCATION_TEST_DATA = [ (0, [0, 1, 2, 3, 4, 5, 6]), @@ -28,18 +33,19 @@ async def test_history_truncation( assistant_system_message_template=Template(""), best_effort_template=Template(""), scoped_messages=[ - ScopedMessage(message=Message.system(content="a")), - ScopedMessage(message=Message.user(content="b")), - ScopedMessage(message=Message.system(content="c")), + ScopedMessage(message=system_message("a")), + ScopedMessage(message=user_message("b")), + ScopedMessage(message=system_message("c")), ScopedMessage( - message=Message.assistant(content="d"), + message=assistant_message("d"), scope=MessageScope.INTERNAL, ), ScopedMessage( - message=Message.user(content="e"), scope=MessageScope.INTERNAL + message=user_message(content="e"), + scope=MessageScope.INTERNAL, ), - ScopedMessage(message=Message.assistant(content="f")), - ScopedMessage(message=Message.user(content="g")), + ScopedMessage(message=assistant_message("f")), + ScopedMessage(message=user_message("g")), ], ) @@ -64,8 +70,8 @@ async def test_truncation_overflow(): assistant_system_message_template=Template(""), best_effort_template=Template(""), scoped_messages=[ - ScopedMessage(message=Message.system(content="a")), - ScopedMessage(message=Message.user(content="b")), + ScopedMessage(message=system_message("a")), + ScopedMessage(message=user_message("b")), ], ) @@ -87,9 +93,10 @@ async def test_truncation_with_incorrect_message_sequence(): best_effort_template=Template(""), scoped_messages=[ ScopedMessage( - message=Message.user(content="a"), scope=MessageScope.INTERNAL + message=user_message("a"), + scope=MessageScope.INTERNAL, ), - ScopedMessage(message=Message.user(content="b")), + ScopedMessage(message=user_message("b")), ], ) @@ -106,25 +113,25 @@ async def test_truncation_with_incorrect_message_sequence(): def test_protocol_messages_with_system_message(): - system_message = "" - user_message = "" - assistant_message = "" + system_content = "" + user_content = "" + assistant_content = "" history = History( assistant_system_message_template=Template( "system message={{system_prefix}}" ), best_effort_template=Template(""), scoped_messages=[ - ScopedMessage(message=Message.system(system_message)), - ScopedMessage(message=Message.user(user_message)), - ScopedMessage(message=Message.assistant(assistant_message)), + ScopedMessage(message=system_message(system_content)), + ScopedMessage(message=user_message(user_content)), + ScopedMessage(message=assistant_message(assistant_content)), ], ) assert history.to_protocol_messages() == [ - Message.system(f"system message={system_message}"), - Message.user(user_message), - Message.assistant( - f'{{"commands": [{{"command": "reply", "arguments": {{"message": "{assistant_message}"}}}}]}}' + system_message(f"system message={system_content}"), + user_message(user_content), + assistant_message( + f'{{"commands": [{{"command": "reply", "arguments": {{"message": "{assistant_content}"}}}}]}}' ), ] diff --git a/tests/unit_tests/model/test_model_client.py b/tests/unit_tests/model/test_model_client.py index a7735e8..6ffb2ba 100644 --- a/tests/unit_tests/model/test_model_client.py +++ b/tests/unit_tests/model/test_model_client.py @@ -7,11 +7,15 @@ from aidial_assistant.model.model_client import ( ExtraResultsCallback, - Message, ModelClient, ReasonLengthException, ) -from aidial_assistant.utils.open_ai import Usage +from aidial_assistant.utils.open_ai import ( + Usage, + assistant_message, + system_message, + user_message, +) from aidial_assistant.utils.text import join_string from tests.utils.async_helper import to_awaitable_iterator @@ -76,7 +80,7 @@ async def test_reason_length_with_usage(): ), Chunk( choices=[{"delta": {"content": ""}}], - usage={"prompt_tokens": 1, "completion_tokens": 2}, + usage=Usage(prompt_tokens=1, completion_tokens=2), ), ] ) @@ -99,9 +103,9 @@ async def test_api_args(): ) model_client = ModelClient(openai_client, MODEL_ARGS) messages = [ - Message.system(content="a"), - Message.user(content="b"), - Message.assistant(content="c"), + system_message("a"), + user_message("b"), + assistant_message("c"), ] await join_string(model_client.agenerate(messages, extra="args")) diff --git a/tests/unit_tests/utils/test_state.py b/tests/unit_tests/utils/test_state.py index ba93972..58569e8 100644 --- a/tests/unit_tests/utils/test_state.py +++ b/tests/unit_tests/utils/test_state.py @@ -1,7 +1,7 @@ from aidial_sdk.chat_completion import CustomContent, Message, Role from aidial_assistant.chain.history import MessageScope, ScopedMessage -from aidial_assistant.model.model_client import Message as ModelMessage +from aidial_assistant.utils.open_ai import assistant_message, user_message from aidial_assistant.utils.state import parse_history FIRST_USER_MESSAGE = "" @@ -45,34 +45,34 @@ def test_parse_history(): assert parse_history(messages) == [ ScopedMessage( scope=MessageScope.USER, - message=ModelMessage.user(FIRST_USER_MESSAGE), + message=user_message(FIRST_USER_MESSAGE), ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.assistant(FIRST_REQUEST_FIXED), + message=assistant_message(FIRST_REQUEST_FIXED), ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.user(FIRST_RESPONSE), + message=user_message(FIRST_RESPONSE), ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.assistant(SECOND_REQUEST), + message=assistant_message(SECOND_REQUEST), ), ScopedMessage( scope=MessageScope.INTERNAL, - message=ModelMessage.user(content=SECOND_RESPONSE), + message=user_message(content=SECOND_RESPONSE), ), ScopedMessage( scope=MessageScope.USER, - message=ModelMessage.assistant(FIRST_ASSISTANT_MESSAGE), + message=assistant_message(FIRST_ASSISTANT_MESSAGE), ), ScopedMessage( scope=MessageScope.USER, - message=ModelMessage.user(SECOND_USER_MESSAGE), + message=user_message(SECOND_USER_MESSAGE), ), ScopedMessage( scope=MessageScope.USER, - message=ModelMessage.assistant(SECOND_ASSISTANT_MESSAGE), + message=assistant_message(SECOND_ASSISTANT_MESSAGE), ), ] From 4b33f4d1b5a1e15402ae9a3705ccdd8f00cd4f0d Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Fri, 12 Jan 2024 15:53:03 +0000 Subject: [PATCH 14/23] Rename method. --- aidial_assistant/commands/run_tool.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aidial_assistant/commands/run_tool.py b/aidial_assistant/commands/run_tool.py index dcf527e..2d7cbe9 100644 --- a/aidial_assistant/commands/run_tool.py +++ b/aidial_assistant/commands/run_tool.py @@ -42,7 +42,7 @@ def _construct_property(p: APIPropertyBase) -> dict[str, Any]: return {k: v for k, v in parameter.items() if v is not None} -def _construct_function(op: APIOperation) -> ChatCompletionToolParam: +def _construct_tool(op: APIOperation) -> ChatCompletionToolParam: properties = {} required = [] for p in op.properties: @@ -85,7 +85,7 @@ async def execute( def create_command_tool(op: APIOperation) -> CommandTool: return lambda: OpenAPIChatCommand( op, self.plugin.auth - ), _construct_function(op) + ), _construct_tool(op) command_tool_dict: CommandToolDict = { name: create_command_tool(op) for name, op in ops.items() From 6581c9c94e32af9265d2792d06f117be9a34b1f8 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Fri, 12 Jan 2024 15:54:25 +0000 Subject: [PATCH 15/23] Remove redundant comment. --- aidial_assistant/model/model_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index 926ea94..81ddab5 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -59,7 +59,7 @@ async def agenerate( finish_reason_length = False tool_calls_chunks: list[list[dict[str, Any]]] = [] - async for chunk in model_result: # type: ignore + async for chunk in model_result: all_values = chunk.dict() usage: Usage | None = all_values.get("usage") if usage: From f7ef289416a2531fbc3abff57d34ce4fbee8fcbb Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Fri, 12 Jan 2024 16:09:01 +0000 Subject: [PATCH 16/23] Fix tests. --- tests/unit_tests/model/test_model_client.py | 27 +++++++++++++++------ 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/tests/unit_tests/model/test_model_client.py b/tests/unit_tests/model/test_model_client.py index 6ffb2ba..6e8d533 100644 --- a/tests/unit_tests/model/test_model_client.py +++ b/tests/unit_tests/model/test_model_client.py @@ -3,6 +3,7 @@ import pytest from openai import AsyncOpenAI +from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from pydantic import BaseModel from aidial_assistant.model.model_client import ( @@ -22,8 +23,18 @@ MODEL_ARGS = {"model": "args"} +class Delta(BaseModel): + content: str + tool_calls: list[ChoiceDeltaToolCall] | None = None + + +class Choice(BaseModel): + delta: Delta + finish_reason: str | None = None + + class Chunk(BaseModel): - choices: list[dict[str, Any]] + choices: list[Choice] statistics: dict[str, int] | None = None usage: Usage | None = None @@ -35,7 +46,7 @@ async def test_discarded_messages(): openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ Chunk( - choices=[{"delta": {"content": ""}}], + choices=[Choice(delta=Delta(content=""))], statistics={"discarded_messages": 2}, ) ] @@ -56,9 +67,9 @@ async def test_content(): openai_client.chat = Mock() openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ - Chunk(choices=[{"delta": {"content": "one, "}}]), - Chunk(choices=[{"delta": {"content": "two, "}}]), - Chunk(choices=[{"delta": {"content": "three"}}]), + Chunk(choices=[Choice(delta=Delta(content="one, "))]), + Chunk(choices=[Choice(delta=Delta(content="two, "))]), + Chunk(choices=[Choice(delta=Delta(content="three"))]), ] ) model_client = ModelClient(openai_client, MODEL_ARGS) @@ -72,14 +83,14 @@ async def test_reason_length_with_usage(): openai_client.chat = Mock() openai_client.chat.completions.create.return_value = to_awaitable_iterator( [ - Chunk(choices=[{"delta": {"content": "text"}}]), + Chunk(choices=[Choice(delta=Delta(content="text"))]), Chunk( choices=[ - {"delta": {"content": ""}, "finish_reason": "length"} # type: ignore + Choice(delta=Delta(content=""), finish_reason="length") ] ), Chunk( - choices=[{"delta": {"content": ""}}], + choices=[Choice(delta=Delta(content=""))], usage=Usage(prompt_tokens=1, completion_tokens=2), ), ] From c24d81eaa6721a02d9eebe46e146f1aea77b53de Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Fri, 12 Jan 2024 16:15:08 +0000 Subject: [PATCH 17/23] Fix typo. --- aidial_assistant/chain/history.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aidial_assistant/chain/history.py b/aidial_assistant/chain/history.py index 2324bda..6e8db05 100644 --- a/aidial_assistant/chain/history.py +++ b/aidial_assistant/chain/history.py @@ -53,7 +53,7 @@ def to_protocol_messages(self) -> list[ChatCompletionMessageParam]: scope = scoped_message.scope if index == 0: - if message["role"] == "role": + if message["role"] == "system": messages.append( system_message( self.assistant_system_message_template.render( From 5663d6be5e069026f2dada61ab5371d5af335bf8 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Fri, 12 Jan 2024 16:16:56 +0000 Subject: [PATCH 18/23] Remove redundant import. --- tests/unit_tests/model/test_model_client.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit_tests/model/test_model_client.py b/tests/unit_tests/model/test_model_client.py index 6e8d533..a5ed1cf 100644 --- a/tests/unit_tests/model/test_model_client.py +++ b/tests/unit_tests/model/test_model_client.py @@ -1,4 +1,3 @@ -from typing import Any from unittest.mock import Mock, call import pytest From 10b613f42e4546b5ff492167ced5e57d619a319e Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Fri, 12 Jan 2024 17:17:01 +0000 Subject: [PATCH 19/23] Add comment to clarify logic. --- aidial_assistant/tools_chain/tools_chain.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aidial_assistant/tools_chain/tools_chain.py b/aidial_assistant/tools_chain/tools_chain.py index 97b170c..9c47677 100644 --- a/aidial_assistant/tools_chain/tools_chain.py +++ b/aidial_assistant/tools_chain/tools_chain.py @@ -143,6 +143,8 @@ async def run_chat( if len(dialogue) == 0 or e.code == "429": raise + # If the dialog size exceeds model context size then remove last message block + # and try again without tools. dialogue = dialogue[:-last_message_block_length] async for chunk in self.model.agenerate( messages + dialogue, tool_calls_callback From 73d95ebcae6b27a6dfa504de376e69b31b236bad Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Mon, 15 Jan 2024 11:59:52 +0000 Subject: [PATCH 20/23] Prompt clarifications. --- aidial_assistant/application/prompts.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/aidial_assistant/application/prompts.py b/aidial_assistant/application/prompts.py index f9cf757..a54d169 100644 --- a/aidial_assistant/application/prompts.py +++ b/aidial_assistant/application/prompts.py @@ -47,7 +47,7 @@ def build(self, **kwargs) -> Template: _PROTOCOL_FOOTER = """ * reply -The command delivers the final response to the user. +The last command that delivers the final response to the user. Arguments: - 'message' is a string containing the final and complete result for the user. @@ -58,6 +58,8 @@ def build(self, **kwargs) -> Template: ```json {"commands": [{"command": "reply", "arguments": {"message": "Hello, world!"}}]} ``` + +End of the protocol. """.strip() _SYSTEM_TEXT = """ From 2dbf2fd0134c5b396249a10b3ee3a18a04260d5f Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Mon, 15 Jan 2024 12:29:38 +0000 Subject: [PATCH 21/23] Address review comments. --- README.md | 13 +++++++------ aidial_assistant/app.py | 4 +--- .../application/assistant_application.py | 4 ++-- aidial_assistant/commands/run_tool.py | 4 ++-- aidial_assistant/model/model_client.py | 6 +++--- aidial_assistant/tools_chain/tools_chain.py | 12 ++++++------ aidial_assistant/utils/state.py | 10 +++++++++- 7 files changed, 30 insertions(+), 23 deletions(-) diff --git a/README.md b/README.md index a581f6d..939717d 100644 --- a/README.md +++ b/README.md @@ -81,12 +81,13 @@ make serve Copy .env.example to .env and customize it for your environment: -| Variable | Default | Description | -|-----------------|--------------------------|--------------------------------------------------------| -| CONFIG_DIR | aidial_assistant/configs | Configuration directory | -| LOG_LEVEL | INFO | Log level. Use DEBUG for dev purposes and INFO in prod | -| OPENAI_API_BASE | N/A | OpenAI API Base | -| WEB_CONCURRENCY | 1 | Number of workers for the server | +| Variable | Default | Description | +|------------------------------|--------------------------|--------------------------------------------------------------------------------| +| CONFIG_DIR | aidial_assistant/configs | Configuration directory | +| LOG_LEVEL | INFO | Log level. Use DEBUG for dev purposes and INFO in prod | +| OPENAI_API_BASE | | OpenAI API Base | +| WEB_CONCURRENCY | 1 | Number of workers for the server | +| TOOLS_SUPPORTING_DEPLOYMENTS | | Comma-separated deployment names that support tools in chat completion request | ### Docker diff --git a/aidial_assistant/app.py b/aidial_assistant/app.py index 1f1e385..d14541d 100644 --- a/aidial_assistant/app.py +++ b/aidial_assistant/app.py @@ -23,9 +23,7 @@ config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs")) tools_supporting_deployments: set[str] = set( - os.getenv( - "TOOLS_SUPPORTING_DEPLOYMENTS", "gpt-4-turbo-1106,anthropic.claude-v2-1" - ).split(",") + os.getenv("TOOLS_SUPPORTING_DEPLOYMENTS", "").split(",") ) app.add_chat_completion( "assistant", diff --git a/aidial_assistant/application/assistant_application.py b/aidial_assistant/application/assistant_application.py index 3ba163a..0001617 100644 --- a/aidial_assistant/application/assistant_application.py +++ b/aidial_assistant/application/assistant_application.py @@ -262,11 +262,11 @@ def create_command_tool( plugin.info.ai_plugin.description_for_human, ) - command_tool_dict: CommandToolDict = { + commands: CommandToolDict = { plugin.info.ai_plugin.name_for_model: create_command_tool(plugin) for plugin in plugins } - chain = ToolsChain(model, command_tool_dict) + chain = ToolsChain(model, commands) choice = response.create_single_choice() choice.open() diff --git a/aidial_assistant/commands/run_tool.py b/aidial_assistant/commands/run_tool.py index 2d7cbe9..b1a147e 100644 --- a/aidial_assistant/commands/run_tool.py +++ b/aidial_assistant/commands/run_tool.py @@ -87,11 +87,11 @@ def create_command_tool(op: APIOperation) -> CommandTool: op, self.plugin.auth ), _construct_tool(op) - command_tool_dict: CommandToolDict = { + commands: CommandToolDict = { name: create_command_tool(op) for name, op in ops.items() } - chain = ToolsChain(self.model, command_tool_dict) + chain = ToolsChain(self.model, commands) messages = [ system_message(self.plugin.info.ai_plugin.description_for_model), diff --git a/aidial_assistant/model/model_client.py b/aidial_assistant/model/model_client.py index 81ddab5..83dfa8e 100644 --- a/aidial_assistant/model/model_client.py +++ b/aidial_assistant/model/model_client.py @@ -60,8 +60,8 @@ async def agenerate( finish_reason_length = False tool_calls_chunks: list[list[dict[str, Any]]] = [] async for chunk in model_result: - all_values = chunk.dict() - usage: Usage | None = all_values.get("usage") + chunk_dict = chunk.dict() + usage: Usage | None = chunk_dict.get("usage") if usage: prompt_tokens = usage["prompt_tokens"] self._total_prompt_tokens += prompt_tokens @@ -70,7 +70,7 @@ async def agenerate( extra_results_callback.on_prompt_tokens(prompt_tokens) if extra_results_callback: - discarded_messages: int | None = all_values.get( + discarded_messages: int | None = chunk_dict.get( "statistics", {} ).get("discarded_messages") if discarded_messages is not None: diff --git a/aidial_assistant/tools_chain/tools_chain.py b/aidial_assistant/tools_chain/tools_chain.py index 9c47677..87fbadb 100644 --- a/aidial_assistant/tools_chain/tools_chain.py +++ b/aidial_assistant/tools_chain/tools_chain.py @@ -119,9 +119,9 @@ def on_tool_calls( class ToolsChain: - def __init__(self, model: ModelClient, command_tool_dict: CommandToolDict): + def __init__(self, model: ModelClient, commands: CommandToolDict): self.model = model - self.command_tool_dict = command_tool_dict + self.commands = commands async def run_chat( self, @@ -131,7 +131,7 @@ async def run_chat( result_callback = callback.result_callback() dialogue: list[ChatCompletionMessageParam] = [] last_message_block_length = 0 - tools = [tool for _, tool in self.command_tool_dict.values()] + tools = [tool for _, tool in self.commands.values()] while True: tool_calls_callback = ToolCallsCallback() try: @@ -167,12 +167,12 @@ async def run_chat( last_message_block_length = len(result_messages) + 1 def _create_command(self, name: str) -> Command: - if name not in self.command_tool_dict: + if name not in self.commands: raise AssistantProtocolException( - f"The tool '{name}' is expected to be one of {list(self.command_tool_dict.keys())}" + f"The tool '{name}' is expected to be one of {list(self.commands.keys())}" ) - command, _ = self.command_tool_dict[name] + command, _ = self.commands[name] return command() diff --git a/aidial_assistant/utils/state.py b/aidial_assistant/utils/state.py index 3d12e54..f8c9dd5 100644 --- a/aidial_assistant/utils/state.py +++ b/aidial_assistant/utils/state.py @@ -43,12 +43,20 @@ def _get_invocations(custom_content: CustomContent | None) -> list[Invocation]: def _convert_old_commands(string: str) -> str: - """Converts old commands to new format.""" + """Converts old commands to new format. + Previously saved conversations with assistant will stop working if state is not updated. + + Old format: + {"commands": [{"command": "run-addon", "args": ["", ""]}]} + New format: + {"commands": [{"command": "", "arguments": {"query": ""}}]} + """ commands = json.loads(string) result: list[CommandInvocation] = [] for command in commands["commands"]: command_name = command["command"] + # run-addon was previously called run-plugin if command_name in ("run-addon", "run-plugin"): args = command["args"] result.append( From bb11c200c452c9e5a0570e18fdd4f79b5b479131 Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Mon, 15 Jan 2024 13:22:03 +0000 Subject: [PATCH 22/23] Update .env.example. --- .env.example | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.env.example b/.env.example index a8584bc..9af28af 100644 --- a/.env.example +++ b/.env.example @@ -1,4 +1,5 @@ CONFIG_DIR=aidial_assistant/configs LOG_LEVEL=DEBUG OPENAI_API_BASE=http://localhost:5001 -WEB_CONCURRENCY=1 \ No newline at end of file +WEB_CONCURRENCY=1 +TOOLS_SUPPORTING_DEPLOYMENTS=gpt-4-turbo-1106,anthropic.claude-v2-1 \ No newline at end of file From 94a93a0489ab5a88bdaddda1bcbb177ad14a0bca Mon Sep 17 00:00:00 2001 From: Alexey Klimov Date: Mon, 15 Jan 2024 13:23:53 +0000 Subject: [PATCH 23/23] Use official model name. --- .env.example | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.env.example b/.env.example index 9af28af..1c225bc 100644 --- a/.env.example +++ b/.env.example @@ -2,4 +2,4 @@ CONFIG_DIR=aidial_assistant/configs LOG_LEVEL=DEBUG OPENAI_API_BASE=http://localhost:5001 WEB_CONCURRENCY=1 -TOOLS_SUPPORTING_DEPLOYMENTS=gpt-4-turbo-1106,anthropic.claude-v2-1 \ No newline at end of file +TOOLS_SUPPORTING_DEPLOYMENTS=gpt-4-1106-preview \ No newline at end of file