diff --git a/src/parlant/core/services/tools/openapi.py b/src/parlant/core/services/tools/openapi.py index 973d525f8..c4f1970c2 100644 --- a/src/parlant/core/services/tools/openapi.py +++ b/src/parlant/core/services/tools/openapi.py @@ -34,6 +34,7 @@ ToolParameter, ToolParameterType, ToolContext, + validate_tool_arguments, ) from parlant.core.common import ItemNotFoundError, JSONSerializable, UniqueId from parlant.core.tools import ToolService @@ -211,5 +212,6 @@ async def call_tool( arguments: Mapping[str, JSONSerializable], ) -> ToolResult: _ = context - await self.read_tool(name) + tool = await self.read_tool(name) + validate_tool_arguments(tool, arguments) return await self._tools[name].func(**arguments) diff --git a/src/parlant/core/services/tools/plugins.py b/src/parlant/core/services/tools/plugins.py index ac23561c7..88410aa3d 100644 --- a/src/parlant/core/services/tools/plugins.py +++ b/src/parlant/core/services/tools/plugins.py @@ -55,6 +55,7 @@ EnumValueType, ToolResultError, normalize_tool_arguments, + validate_tool_arguments, ) from parlant.core.common import DefaultBaseModel, ItemNotFoundError, JSONSerializable, UniqueId from parlant.core.contextual_correlator import ContextualCorrelator @@ -561,6 +562,9 @@ async def call_tool( arguments: Mapping[str, JSONSerializable], ) -> ToolResult: try: + tool = await self.read_tool(name) + validate_tool_arguments(tool, arguments) + async with self._http_client.stream( method="post", url=self._get_url(f"/tools/{name}/calls"), diff --git a/src/parlant/core/tools.py b/src/parlant/core/tools.py index c3104d90e..b11b871b2 100644 --- a/src/parlant/core/tools.py +++ b/src/parlant/core/tools.py @@ -265,11 +265,16 @@ async def call_tool( raise ToolImportError(name) from e try: + tool = await self.read_tool(name) + validate_tool_arguments(tool, arguments) + func_params = inspect.signature(func).parameters result: ToolResult = func(**normalize_tool_arguments(func_params, arguments)) if inspect.isawaitable(result): result = await result + except ToolError as e: + raise e except Exception as e: raise ToolExecutionError(name) from e @@ -279,6 +284,22 @@ async def call_tool( return result +def validate_tool_arguments( + tool: Tool, + arguments: Mapping[str, Any], +) -> None: + expected = set(tool.parameters.keys()) + received = set(arguments.keys()) + + extra_args = received - expected + + missing_required = [p for p in tool.required if p not in arguments] + + if extra_args or missing_required: + message = "Argument mismatch.\n" f" - Expected parameters: {sorted(expected)}\n" + raise ToolError(message) + + def normalize_tool_arguments( parameters: Mapping[str, inspect.Parameter], arguments: Mapping[str, Any], diff --git a/tests/core/stable/services/tools/test_openapi.py b/tests/core/stable/services/tools/test_openapi.py index d696fb185..b6c3e54e1 100644 --- a/tests/core/stable/services/tools/test_openapi.py +++ b/tests/core/stable/services/tools/test_openapi.py @@ -13,9 +13,9 @@ # limitations under the License. from typing import Any -from pytest import mark +from pytest import mark, raises -from parlant.core.tools import ToolContext +from parlant.core.tools import ToolContext, ToolError from parlant.core.services.tools.openapi import OpenAPIClient from tests.test_utilities import ( @@ -101,3 +101,31 @@ async def test_that_a_tool_can_be_called_via_an_openapi_server( ) result = await client.call_tool(tool_name, stub_context, tool_args) assert result.data == expected_result + + +@mark.parametrize( + "tool_name,arguments", + [ + (one_required_query_param.__name__, {}), + (one_required_query_param.__name__, {"query_param": 123, "bogus": 999}), + ], +) +async def test_that_openapi_client_raises_tool_error_on_argument_mismatch( + tool_name: str, + arguments: dict[str, Any], +) -> None: + async with run_openapi_server(rng_app()): + openapi_json = await get_openapi_spec(OPENAPI_SERVER_URL) + + async with OpenAPIClient(OPENAPI_SERVER_URL, openapi_json) as client: + stub_context = ToolContext( + agent_id="test-agent", + session_id="test_session", + customer_id="test_customer", + ) + + with raises(ToolError) as exc_info: + await client.call_tool(tool_name, stub_context, arguments) + + error_msg = str(exc_info.value) + assert "Expected parameters" in error_msg diff --git a/tests/core/stable/services/tools/test_plugin_client.py b/tests/core/stable/services/tools/test_plugin_client.py index a97501e67..e112e8cd8 100644 --- a/tests/core/stable/services/tools/test_plugin_client.py +++ b/tests/core/stable/services/tools/test_plugin_client.py @@ -14,12 +14,12 @@ import asyncio import enum -from typing import Mapping, Optional, cast +from typing import Any, Mapping, Optional, cast from lagom import Container from pytest import fixture, raises import pytest -from parlant.core.tools import ToolContext, ToolResult, ToolResultError +from parlant.core.tools import ToolContext, ToolError, ToolResult, ToolResultError from parlant.core.services.tools.plugins import PluginServer, tool from parlant.core.agents import Agent, AgentId, AgentStore from parlant.core.contextual_correlator import ContextualCorrelator @@ -353,3 +353,32 @@ def huge_payload_tool(context: ToolContext) -> ToolResult: await client.call_tool(huge_payload_tool.tool.name, tool_context, arguments={}) assert "Response exceeds 16KB limit" in str(exc.value) + + +@pytest.mark.parametrize( + "arguments", + [ + {}, + {"paramA": 123, "paramX": 999}, + ], +) +async def test_that_a_plugin_raises_tool_error_for_argument_mismatch( + tool_context: ToolContext, + container: Container, + arguments: dict[str, Any], +) -> None: + @tool + def mismatch_tool(context: ToolContext, paramA: int) -> ToolResult: + return ToolResult(paramA) + + async with run_service_server([mismatch_tool]) as server: + async with create_client(server, container[EventBufferFactory]) as client: + with pytest.raises(ToolError) as exc_info: + await client.call_tool( + mismatch_tool.tool.name, + tool_context, + arguments=arguments, + ) + + error_msg = str(exc_info.value) + assert "Expected parameters" in error_msg