Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support native model ability to invoke tools #50

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b6ba03c
Support native tools (initial commit)
Oleksii-Klimov Dec 14, 2023
99e8322
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 2, 2024
d55b24d
Some intermediate fixes.
Oleksii-Klimov Jan 4, 2024
eb0d90a
Small fixes.
Oleksii-Klimov Jan 5, 2024
d6da3a9
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 5, 2024
cb5cbc9
More fixes.
Oleksii-Klimov Jan 8, 2024
f513d56
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 8, 2024
cad2dc1
Check for a reserved command name.
Oleksii-Klimov Jan 9, 2024
d012fee
Add extra line between commands.
Oleksii-Klimov Jan 10, 2024
41464dc
Update dial sdk to support httpx for opentelemetry.
Oleksii-Klimov Jan 10, 2024
5db6bb0
Remove unused import.
Oleksii-Klimov Jan 10, 2024
076c2ea
Clarify prompts.
Oleksii-Klimov Jan 10, 2024
eb88263
Minor prompt adjustments.
Oleksii-Klimov Jan 10, 2024
762cbc2
Improve prompt formatting for gpt-4-0314.
Oleksii-Klimov Jan 11, 2024
55886a1
Use latest openai api version to support tools.
Oleksii-Klimov Jan 11, 2024
ca0a1db
Address review comments.
Oleksii-Klimov Jan 12, 2024
4b33f4d
Rename method.
Oleksii-Klimov Jan 12, 2024
6581c9c
Remove redundant comment.
Oleksii-Klimov Jan 12, 2024
f7ef289
Fix tests.
Oleksii-Klimov Jan 12, 2024
c24d81e
Fix typo.
Oleksii-Klimov Jan 12, 2024
5663d6b
Remove redundant import.
Oleksii-Klimov Jan 12, 2024
10b613f
Add comment to clarify logic.
Oleksii-Klimov Jan 12, 2024
73d95eb
Prompt clarifications.
Oleksii-Klimov Jan 15, 2024
2dbf2fd
Address review comments.
Oleksii-Klimov Jan 15, 2024
bb11c20
Update .env.example.
Oleksii-Klimov Jan 15, 2024
94a93a0
Use official model name.
Oleksii-Klimov Jan 15, 2024
02844bd
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
CONFIG_DIR=aidial_assistant/configs
LOG_LEVEL=DEBUG
OPENAI_API_BASE=http://localhost:5001
WEB_CONCURRENCY=1
WEB_CONCURRENCY=1
TOOLS_SUPPORTING_DEPLOYMENTS=gpt-4-1106-preview
13 changes: 7 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,13 @@ make serve

Copy .env.example to .env and customize it for your environment:

| Variable | Default | Description |
|-----------------|--------------------------|--------------------------------------------------------|
| CONFIG_DIR | aidial_assistant/configs | Configuration directory |
| LOG_LEVEL | INFO | Log level. Use DEBUG for dev purposes and INFO in prod |
| OPENAI_API_BASE | N/A | OpenAI API Base |
| WEB_CONCURRENCY | 1 | Number of workers for the server |
| Variable | Default | Description |
|------------------------------|--------------------------|--------------------------------------------------------------------------------|
| CONFIG_DIR | aidial_assistant/configs | Configuration directory |
| LOG_LEVEL | INFO | Log level. Use DEBUG for dev purposes and INFO in prod |
| OPENAI_API_BASE | | OpenAI API Base |
| WEB_CONCURRENCY | 1 | Number of workers for the server |
| TOOLS_SUPPORTING_DEPLOYMENTS | | Comma-separated deployment names that support tools in chat completion request |

### Docker

Expand Down
22 changes: 12 additions & 10 deletions aidial_assistant/app.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
#!/usr/bin/env python3
import logging.config
import os
from pathlib import Path

from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig, TracingConfig
from starlette.responses import Response

from aidial_assistant.application.assistant_application import (
AssistantApplication,
)
from aidial_assistant.utils.log_config import get_log_config

log_level = os.getenv("LOG_LEVEL", "INFO")
config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs"))

logging.config.dictConfig(get_log_config(log_level))

telemetry_config = TelemetryConfig(
service_name="aidial-assistant", tracing=TracingConfig()
)
app = DIALApp(telemetry_config=telemetry_config)
app.add_chat_completion("assistant", AssistantApplication(config_dir))

# A delayed import is necessary to set up the httpx hook before the openai client inherits from AsyncClient.
from aidial_assistant.application.assistant_application import ( # noqa: E402
AssistantApplication,
)

@app.get("/healthcheck/status200")
def status200() -> Response:
return Response("Service is running...", status_code=200)
config_dir = Path(os.getenv("CONFIG_DIR", "aidial_assistant/configs"))
tools_supporting_deployments: set[str] = set(
os.getenv("TOOLS_SUPPORTING_DEPLOYMENTS", "").split(",")
)
app.add_chat_completion(
"assistant",
AssistantApplication(config_dir, tools_supporting_deployments),
)
7 changes: 5 additions & 2 deletions aidial_assistant/application/addons_dialogue_limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
LimitExceededException,
ModelRequestLimiter,
)
from aidial_assistant.model.model_client import Message, ModelClient
from aidial_assistant.model.model_client import (
ChatCompletionMessageParam,
ModelClient,
)


class AddonsDialogueLimiter(ModelRequestLimiter):
Expand All @@ -16,7 +19,7 @@ def __init__(self, max_dialogue_tokens: int, model_client: ModelClient):
self._initial_tokens: int | None = None

@override
async def verify_limit(self, messages: list[Message]):
async def verify_limit(self, messages: list[ChatCompletionMessageParam]):
if self._initial_tokens is None:
self._initial_tokens = await self.model_client.count_tokens(
messages
Expand Down
153 changes: 126 additions & 27 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
from pathlib import Path
from typing import Tuple

from aidial_sdk.chat_completion import FinishReason
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Message, Request, Role
from aidial_sdk.chat_completion.response import Response
from openai.lib.azure import AsyncAzureOpenAI
from openai.types.chat import ChatCompletionToolParam
from pydantic import BaseModel

from aidial_assistant.application.addons_dialogue_limiter import (
Expand All @@ -18,18 +21,29 @@
MAIN_BEST_EFFORT_TEMPLATE,
MAIN_SYSTEM_DIALOG_MESSAGE,
)
from aidial_assistant.chain.command_chain import CommandChain, CommandDict
from aidial_assistant.chain.command_chain import (
CommandChain,
CommandConstructor,
CommandDict,
)
from aidial_assistant.chain.history import History
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.commands.run_tool import RunTool
from aidial_assistant.model.model_client import (
ModelClient,
ReasonLengthException,
)
from aidial_assistant.tools_chain.tools_chain import (
CommandToolDict,
ToolsChain,
convert_commands_to_tools,
)
from aidial_assistant.utils.exceptions import (
RequestParameterValidationError,
unhandled_exception_handler,
)
from aidial_assistant.utils.open_ai import construct_tool
from aidial_assistant.utils.open_ai_plugin import (
AddonTokenSource,
get_open_ai_plugin_info,
Expand All @@ -49,8 +63,6 @@ def _get_request_args(request: Request) -> dict[str, str]:
args = {
"model": request.model,
"temperature": request.temperature,
"api_version": request.api_version,
"api_key": request.api_key,
"user": request.user,
}

Expand Down Expand Up @@ -83,68 +95,114 @@ def _validate_messages(messages: list[Message]) -> None:
)


def _construct_tool(name: str, description: str) -> ChatCompletionToolParam:
return construct_tool(
name,
description,
{
"query": {
"type": "string",
"description": "A task written in natural language",
}
},
["query"],
)


class AssistantApplication(ChatCompletion):
def __init__(self, config_dir: Path):
def __init__(
self, config_dir: Path, tools_supporting_deployments: set[str]
):
self.args = parse_args(config_dir)
self.tools_supporting_deployments = tools_supporting_deployments

@unhandled_exception_handler
async def chat_completion(
self, request: Request, response: Response
) -> None:
_validate_messages(request.messages)
addon_references = _validate_addons(request.addons)
chat_args = self.args.openai_conf.dict() | _get_request_args(request)
chat_args = _get_request_args(request)

model = ModelClient(
model_args=chat_args
| {
"deployment_id": chat_args["model"],
"api_type": "azure",
"stream": True,
},
buffer_size=self.args.chat_conf.buffer_size,
client=AsyncAzureOpenAI(
azure_endpoint=self.args.openai_conf.api_base,
api_key=request.api_key,
# 2023-12-01-preview is needed to support tools
api_version="2023-12-01-preview",
),
model_args=chat_args,
)

token_source = AddonTokenSource(
request.headers,
(addon_reference.url for addon_reference in addon_references),
)

addons: dict[str, PluginInfo] = {}
plugins: list[PluginInfo] = []
# DIAL Core has own names for addons, so in stages we need to map them to the names used by the user
addon_name_mapping: dict[str, str] = {}
for addon_reference in addon_references:
info = await get_open_ai_plugin_info(addon_reference.url)
addons[info.ai_plugin.name_for_model] = PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
plugins.append(
PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
)
)

if addon_reference.name:
addon_name_mapping[
info.ai_plugin.name_for_model
] = addon_reference.name

if request.model in self.tools_supporting_deployments:
await AssistantApplication._run_native_tools_chat(
model, plugins, addon_name_mapping, request, response
)
else:
await AssistantApplication._run_emulated_tools_chat(
model, plugins, addon_name_mapping, request, response
)

@staticmethod
async def _run_emulated_tools_chat(
model: ModelClient,
addons: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000

def create_command(addon: PluginInfo):
return lambda: RunPlugin(model, addon, max_addons_dialogue_tokens)

command_dict: CommandDict = {
RunPlugin.token(): lambda: RunPlugin(
model, addons, max_addons_dialogue_tokens
),
Reply.token(): Reply,
addon.info.ai_plugin.name_for_model: create_command(addon)
for addon in addons
}
if Reply.token() in command_dict:
RequestParameterValidationError(
f"Addon with name '{Reply.token()}' is not allowed for model {request.model}.",
param="addons",
)

command_dict[Reply.token()] = Reply

chain = CommandChain(
model_client=model, name="ASSISTANT", command_dict=command_dict
)
addon_descriptions = {
name: addon.info.open_api.info.description
addon.info.ai_plugin.name_for_model: addon.info.open_api.info.description
or addon.info.ai_plugin.description_for_human
for name, addon in addons.items()
for addon in addons
}
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
Expand Down Expand Up @@ -187,3 +245,44 @@ async def chat_completion(

if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)

@staticmethod
async def _run_native_tools_chat(
model: ModelClient,
plugins: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
def create_command_tool(
plugin: PluginInfo,
) -> Tuple[CommandConstructor, ChatCompletionToolParam]:
return lambda: RunTool(model, plugin), _construct_tool(
plugin.info.ai_plugin.name_for_model,
plugin.info.ai_plugin.description_for_human,
)

commands: CommandToolDict = {
plugin.info.ai_plugin.name_for_model: create_command_tool(plugin)
for plugin in plugins
}
chain = ToolsChain(model, commands)

choice = response.create_single_choice()
choice.open()

callback = AssistantChainCallback(choice, addon_name_mapping)
finish_reason = FinishReason.STOP
messages = convert_commands_to_tools(parse_history(request.messages))
try:
await chain.run_chat(messages, callback)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH

if callback.invocations:
choice.set_state(State(invocations=callback.invocations))
choice.close(finish_reason)

response.set_usage(
model.total_prompt_tokens, model.total_completion_tokens
)
Loading