Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support native model ability to invoke tools #50

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
b6ba03c
Support native tools (initial commit)
Oleksii-Klimov Dec 14, 2023
99e8322
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 2, 2024
d55b24d
Some intermediate fixes.
Oleksii-Klimov Jan 4, 2024
eb0d90a
Small fixes.
Oleksii-Klimov Jan 5, 2024
d6da3a9
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 5, 2024
cb5cbc9
More fixes.
Oleksii-Klimov Jan 8, 2024
f513d56
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 8, 2024
cad2dc1
Check for a reserved command name.
Oleksii-Klimov Jan 9, 2024
d012fee
Add extra line between commands.
Oleksii-Klimov Jan 10, 2024
41464dc
Update dial sdk to support httpx for opentelemetry.
Oleksii-Klimov Jan 10, 2024
5db6bb0
Remove unused import.
Oleksii-Klimov Jan 10, 2024
076c2ea
Clarify prompts.
Oleksii-Klimov Jan 10, 2024
eb88263
Minor prompt adjustments.
Oleksii-Klimov Jan 10, 2024
762cbc2
Improve prompt formatting for gpt-4-0314.
Oleksii-Klimov Jan 11, 2024
55886a1
Use latest openai api version to support tools.
Oleksii-Klimov Jan 11, 2024
ca0a1db
Address review comments.
Oleksii-Klimov Jan 12, 2024
4b33f4d
Rename method.
Oleksii-Klimov Jan 12, 2024
6581c9c
Remove redundant comment.
Oleksii-Klimov Jan 12, 2024
f7ef289
Fix tests.
Oleksii-Klimov Jan 12, 2024
c24d81e
Fix typo.
Oleksii-Klimov Jan 12, 2024
5663d6b
Remove redundant import.
Oleksii-Klimov Jan 12, 2024
10b613f
Add comment to clarify logic.
Oleksii-Klimov Jan 12, 2024
73d95eb
Prompt clarifications.
Oleksii-Klimov Jan 15, 2024
2dbf2fd
Address review comments.
Oleksii-Klimov Jan 15, 2024
bb11c20
Update .env.example.
Oleksii-Klimov Jan 15, 2024
94a93a0
Use official model name.
Oleksii-Klimov Jan 15, 2024
02844bd
Merge branch 'development' into 41-support-native-model-ability-to-in…
Oleksii-Klimov Jan 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions aidial_assistant/app.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
#!/usr/bin/env python3
import logging.config
import os
from pathlib import Path

from aidial_sdk import DIALApp
from aidial_sdk.telemetry.types import TelemetryConfig, TracingConfig
from starlette.responses import Response

from aidial_assistant.application.assistant_application import (
AssistantApplication,
)
from aidial_assistant.utils.log_config import get_log_config

log_level = os.getenv("LOG_LEVEL", "INFO")
Expand All @@ -21,9 +16,13 @@
service_name="aidial-assistant", tracing=TracingConfig()
)
app = DIALApp(telemetry_config=telemetry_config)
app.add_chat_completion("assistant", AssistantApplication(config_dir))

# A delayed import is necessary to set up the httpx hook before the openai client inherits from AsyncClient.
from aidial_assistant.application.assistant_application import ( # noqa: E402
AssistantApplication,
)

@app.get("/healthcheck/status200")
def status200() -> Response:
return Response("Service is running...", status_code=200)
app.add_chat_completion(
"assistant",
AssistantApplication(config_dir),
)
221 changes: 193 additions & 28 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import json
import logging
from pathlib import Path

from aidial_sdk.chat_completion import FinishReason
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Message, Request, Role
from aidial_sdk.chat_completion.request import Addon
from aidial_sdk.chat_completion.request import Message as SdkMessage
from aidial_sdk.chat_completion.request import Request, Role
from aidial_sdk.chat_completion.response import Response
from openai.lib.azure import AsyncAzureOpenAI
from pydantic import BaseModel

from aidial_assistant.application.addons_dialogue_limiter import (
Expand All @@ -19,17 +23,27 @@
MAIN_SYSTEM_DIALOG_MESSAGE,
)
from aidial_assistant.chain.command_chain import CommandChain, CommandDict
from aidial_assistant.chain.history import History
from aidial_assistant.chain.command_result import Commands, Responses
from aidial_assistant.chain.history import History, MessageScope, ScopedMessage
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.commands.run_tool import RunTool
from aidial_assistant.model.model_client import (
Message,
ModelClient,
ReasonLengthException,
ToolCall,
)
from aidial_assistant.tools_chain.tools_chain import ToolsChain
from aidial_assistant.utils.exceptions import (
RequestParameterValidationError,
unhandled_exception_handler,
)
from aidial_assistant.utils.open_ai import (
FunctionCall,
Tool,
construct_function,
)
from aidial_assistant.utils.open_ai_plugin import (
AddonTokenSource,
get_open_ai_plugin_info,
Expand All @@ -49,8 +63,6 @@ def _get_request_args(request: Request) -> dict[str, str]:
args = {
"model": request.model,
"temperature": request.temperature,
"api_version": request.api_version,
"api_key": request.api_key,
"user": request.user,
}

Expand All @@ -71,7 +83,7 @@ def _validate_addons(addons: list[Addon] | None) -> list[AddonReference]:
return addon_references


def _validate_messages(messages: list[Message]) -> None:
def _validate_messages(messages: list[SdkMessage]) -> None:
if not messages:
raise RequestParameterValidationError(
"Message list cannot be empty.", param="messages"
Expand All @@ -83,6 +95,87 @@ def _validate_messages(messages: list[Message]) -> None:
)


def _validate_request(request: Request) -> None:
_validate_messages(request.messages)
_validate_addons(request.addons)


def _construct_function(name: str, description: str) -> Tool:
return construct_function(
name,
description,
{
"query": {
"type": "string",
"description": "A task written in natural language",
}
},
["query"],
)


def _convert_commands_to_tools(
scoped_messages: list[ScopedMessage],
) -> list[Message]:
messages: list[Message] = []
next_tool_id: int = 0
last_call_count: int = 0
for scoped_message in scoped_messages:
message = scoped_message.message
if scoped_message.scope == MessageScope.INTERNAL:
if not message.content:
raise RequestParameterValidationError(
"State is broken. Content cannot be empty.",
param="messages",
)

if message.role == Role.ASSISTANT:
commands: Commands = json.loads(message.content)
messages.append(
Message(
role=Role.ASSISTANT,
tool_calls=[
ToolCall(
index=index,
id=str(next_tool_id + index),
function=FunctionCall(
name=command["command"],
arguments=json.dumps(command["arguments"]),
),
type="function",
)
for index, command in enumerate(
commands["commands"]
)
],
)
)
last_call_count = len(commands["commands"])
next_tool_id += last_call_count
elif message.role == Role.USER:
responses: Responses = json.loads(message.content)
response_count = len(responses["responses"])
if response_count != last_call_count:
raise RequestParameterValidationError(
f"Expected {last_call_count} responses, but got {response_count}.",
param="messages",
)
first_tool_id = next_tool_id - last_call_count
messages.extend(
[
Message(
role=Role.TOOL,
tool_call_id=str(first_tool_id + index),
content=response["response"],
)
for index, response in enumerate(responses["responses"])
]
)
else:
messages.append(scoped_message.message)
return messages


class AssistantApplication(ChatCompletion):
def __init__(self, config_dir: Path):
self.args = parse_args(config_dir)
Expand All @@ -93,58 +186,86 @@ async def chat_completion(
) -> None:
_validate_messages(request.messages)
addon_references = _validate_addons(request.addons)
chat_args = self.args.openai_conf.dict() | _get_request_args(request)
chat_args = _get_request_args(request)

model = ModelClient(
model_args=chat_args
| {
"deployment_id": chat_args["model"],
"api_type": "azure",
"stream": True,
},
buffer_size=self.args.chat_conf.buffer_size,
client=AsyncAzureOpenAI(
azure_endpoint=self.args.openai_conf.api_base,
api_key=request.api_key,
api_version=request.api_version,
),
model_args=chat_args,
)

token_source = AddonTokenSource(
request.headers,
(addon_reference.url for addon_reference in addon_references),
)

addons: dict[str, PluginInfo] = {}
addons: list[PluginInfo] = []
# DIAL Core has own names for addons, so in stages we need to map them to the names used by the user
addon_name_mapping: dict[str, str] = {}
for addon_reference in addon_references:
info = await get_open_ai_plugin_info(addon_reference.url)
addons[info.ai_plugin.name_for_model] = PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
addons.append(
PluginInfo(
info=info,
auth=get_plugin_auth(
info.ai_plugin.auth.type,
info.ai_plugin.auth.authorization_type,
addon_reference.url,
token_source,
),
)
)

if addon_reference.name:
addon_name_mapping[
info.ai_plugin.name_for_model
] = addon_reference.name

if request.model in {"gpt-4-turbo-1106", "anthropic.claude-v2-1"}:
await AssistantApplication._run_native_tools_chat(
model, addons, addon_name_mapping, request, response
)
else:
await AssistantApplication._run_emulated_tools_chat(
model, addons, addon_name_mapping, request, response
)

@staticmethod
async def _run_emulated_tools_chat(
model: ModelClient,
addons: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000

def create_command(addon: PluginInfo):
return lambda: RunPlugin(model, addon, max_addons_dialogue_tokens)

command_dict: CommandDict = {
RunPlugin.token(): lambda: RunPlugin(
model, addons, max_addons_dialogue_tokens
),
Reply.token(): Reply,
addon.info.ai_plugin.name_for_model: create_command(addon)
for addon in addons
}
if Reply.token() in command_dict:
RequestParameterValidationError(
f"Addon with name '{Reply.token()}' is not allowed for model {request.model}.",
param="addons",
)

command_dict[Reply.token()] = Reply

chain = CommandChain(
model_client=model, name="ASSISTANT", command_dict=command_dict
)
addon_descriptions = {
name: addon.info.open_api.info.description
addon.info.ai_plugin.name_for_model: addon.info.open_api.info.description
or addon.info.ai_plugin.description_for_human
for name, addon in addons.items()
for addon in addons
}
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
Expand Down Expand Up @@ -187,3 +308,47 @@ async def chat_completion(

if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)

@staticmethod
async def _run_native_tools_chat(
model: ModelClient,
addons: list[PluginInfo],
addon_name_mapping: dict[str, str],
request: Request,
response: Response,
):
tools: list[Tool] = [
_construct_function(
addon.info.ai_plugin.name_for_model,
addon.info.ai_plugin.description_for_human,
)
for addon in addons
]

def create_command(addon: PluginInfo):
return lambda: RunTool(model, addon)

command_dict: CommandDict = {
addon.info.ai_plugin.name_for_model: create_command(addon)
for addon in addons
}
chain = ToolsChain(model, tools, command_dict)

choice = response.create_single_choice()
choice.open()

callback = AssistantChainCallback(choice, addon_name_mapping)
finish_reason = FinishReason.STOP
messages = _convert_commands_to_tools(parse_history(request.messages))
try:
await chain.run_chat(messages, callback)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH

if callback.invocations:
choice.set_state(State(invocations=callback.invocations))
choice.close(finish_reason)

response.set_usage(
model.total_prompt_tokens, model.total_completion_tokens
)
Loading