Skip to content

Commit

Permalink
Add ToolError with log error-message that shows what arguments were r…
Browse files Browse the repository at this point in the history
…eceived, and what arguments were expected.
  • Loading branch information
mc-dorzo committed Jan 5, 2025
1 parent 3b6dda6 commit 00ba6e3
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 5 deletions.
4 changes: 3 additions & 1 deletion src/parlant/core/services/tools/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
ToolParameter,
ToolParameterType,
ToolContext,
validate_tool_arguments,
)
from parlant.core.common import ItemNotFoundError, JSONSerializable, UniqueId
from parlant.core.tools import ToolService
Expand Down Expand Up @@ -211,5 +212,6 @@ async def call_tool(
arguments: Mapping[str, JSONSerializable],
) -> ToolResult:
_ = context
await self.read_tool(name)
tool = await self.read_tool(name)
validate_tool_arguments(tool, arguments)
return await self._tools[name].func(**arguments)
4 changes: 4 additions & 0 deletions src/parlant/core/services/tools/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
EnumValueType,
ToolResultError,
normalize_tool_arguments,
validate_tool_arguments,
)
from parlant.core.common import DefaultBaseModel, ItemNotFoundError, JSONSerializable, UniqueId
from parlant.core.contextual_correlator import ContextualCorrelator
Expand Down Expand Up @@ -561,6 +562,9 @@ async def call_tool(
arguments: Mapping[str, JSONSerializable],
) -> ToolResult:
try:
tool = await self.read_tool(name)
validate_tool_arguments(tool, arguments)

async with self._http_client.stream(
method="post",
url=self._get_url(f"/tools/{name}/calls"),
Expand Down
21 changes: 21 additions & 0 deletions src/parlant/core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,11 +265,16 @@ async def call_tool(
raise ToolImportError(name) from e

try:
tool = await self.read_tool(name)
validate_tool_arguments(tool, arguments)

func_params = inspect.signature(func).parameters
result: ToolResult = func(**normalize_tool_arguments(func_params, arguments))

if inspect.isawaitable(result):
result = await result
except ToolError as e:
raise e
except Exception as e:
raise ToolExecutionError(name) from e

Expand All @@ -279,6 +284,22 @@ async def call_tool(
return result


def validate_tool_arguments(
tool: Tool,
arguments: Mapping[str, Any],
) -> None:
expected = set(tool.parameters.keys())
received = set(arguments.keys())

extra_args = received - expected

missing_required = [p for p in tool.required if p not in arguments]

if extra_args or missing_required:
message = "Argument mismatch.\n" f" - Expected parameters: {sorted(expected)}\n"
raise ToolError(message)


def normalize_tool_arguments(
parameters: Mapping[str, inspect.Parameter],
arguments: Mapping[str, Any],
Expand Down
32 changes: 30 additions & 2 deletions tests/core/stable/services/tools/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

from typing import Any
from pytest import mark
from pytest import mark, raises

from parlant.core.tools import ToolContext
from parlant.core.tools import ToolContext, ToolError
from parlant.core.services.tools.openapi import OpenAPIClient

from tests.test_utilities import (
Expand Down Expand Up @@ -101,3 +101,31 @@ async def test_that_a_tool_can_be_called_via_an_openapi_server(
)
result = await client.call_tool(tool_name, stub_context, tool_args)
assert result.data == expected_result


@mark.parametrize(
"tool_name,arguments",
[
(one_required_query_param.__name__, {}),
(one_required_query_param.__name__, {"query_param": 123, "bogus": 999}),
],
)
async def test_that_openapi_client_raises_tool_error_on_argument_mismatch(
tool_name: str,
arguments: dict[str, Any],
) -> None:
async with run_openapi_server(rng_app()):
openapi_json = await get_openapi_spec(OPENAPI_SERVER_URL)

async with OpenAPIClient(OPENAPI_SERVER_URL, openapi_json) as client:
stub_context = ToolContext(
agent_id="test-agent",
session_id="test_session",
customer_id="test_customer",
)

with raises(ToolError) as exc_info:
await client.call_tool(tool_name, stub_context, arguments)

error_msg = str(exc_info.value)
assert "Expected parameters" in error_msg
33 changes: 31 additions & 2 deletions tests/core/stable/services/tools/test_plugin_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@

import asyncio
import enum
from typing import Mapping, Optional, cast
from typing import Any, Mapping, Optional, cast
from lagom import Container
from pytest import fixture, raises
import pytest

from parlant.core.tools import ToolContext, ToolResult, ToolResultError
from parlant.core.tools import ToolContext, ToolError, ToolResult, ToolResultError
from parlant.core.services.tools.plugins import PluginServer, tool
from parlant.core.agents import Agent, AgentId, AgentStore
from parlant.core.contextual_correlator import ContextualCorrelator
Expand Down Expand Up @@ -353,3 +353,32 @@ def huge_payload_tool(context: ToolContext) -> ToolResult:
await client.call_tool(huge_payload_tool.tool.name, tool_context, arguments={})

assert "Response exceeds 16KB limit" in str(exc.value)


@pytest.mark.parametrize(
"arguments",
[
{},
{"paramA": 123, "paramX": 999},
],
)
async def test_that_a_plugin_raises_tool_error_for_argument_mismatch(
tool_context: ToolContext,
container: Container,
arguments: dict[str, Any],
) -> None:
@tool
def mismatch_tool(context: ToolContext, paramA: int) -> ToolResult:
return ToolResult(paramA)

async with run_service_server([mismatch_tool]) as server:
async with create_client(server, container[EventBufferFactory]) as client:
with pytest.raises(ToolError) as exc_info:
await client.call_tool(
mismatch_tool.tool.name,
tool_context,
arguments=arguments,
)

error_msg = str(exc_info.value)
assert "Expected parameters" in error_msg

0 comments on commit 00ba6e3

Please sign in to comment.