From e73120c19aa25ec277458a9e8fce82d5cfbeab8c Mon Sep 17 00:00:00 2001 From: Guillaume Raille Date: Sat, 25 Jan 2025 19:01:44 +0100 Subject: [PATCH] sanitize function name for langchain tools (#7) * add pytest asyncio with setup * add a basic async tests for langchain * add sync version test as well * add tests for keyword toolname and dashes in tool name * add sanitizing function and passing test --- pyproject.toml | 5 ++ src/mcpadapt/langchain_adapter.py | 48 ++++++++++++-- tests/test_langchain_adapter.py | 103 ++++++++++++++++++++++++++++++ uv.lock | 22 +++++-- 4 files changed, 170 insertions(+), 8 deletions(-) create mode 100644 tests/test_langchain_adapter.py diff --git a/pyproject.toml b/pyproject.toml index b7a1caf..9a30764 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,11 @@ build-backend = "hatchling.build" dev = [ "jupyter>=1.1.1", "pre-commit>=4.0.1", + "pytest-asyncio>=0.25.2", "pytest>=8.3.4", "ruff>=0.9.1", ] + +[tool.pytest.ini_options] +asyncio_default_fixture_loop_scope = "function" +asyncio_mode = "auto" diff --git a/src/mcpadapt/langchain_adapter.py b/src/mcpadapt/langchain_adapter.py index e2e3ef2..b64f3fb 100644 --- a/src/mcpadapt/langchain_adapter.py +++ b/src/mcpadapt/langchain_adapter.py @@ -8,6 +8,9 @@ >>> print(tools) """ +import keyword +import logging +import re from functools import partial from typing import Any, Callable, Coroutine @@ -18,6 +21,8 @@ from mcpadapt.core import ToolAdapter +log = logging.getLogger(__name__) + JSON_SCHEMA_TO_PYTHON_TYPES = { "string": "str", "number": "float", @@ -29,6 +34,28 @@ } +def _sanitize_function_name(name): + """ + A function to sanitize function names to be used as a tool name. + Prevent the use of dashes or other python keywords as function names by tool. + """ + # Replace dashes with underscores + name = name.replace("-", "_") + + # Remove any characters that aren't alphanumeric or underscore + name = re.sub(r"[^\w_]", "", name) + + # Ensure it doesn't start with a number + if name[0].isdigit(): + name = f"_{name}" + + # Check if it's a Python keyword + if keyword.iskeyword(name): + name = f"{name}_" + + return name + + def _generate_tool_class( name: str, description: str, @@ -99,6 +126,7 @@ def _instanciate_tool( Returns: the instanciated langchain tool """ + # Create namespace and execute the class definition namespace = {"tool": langchain_core.tools.tool, "func": func} try: @@ -117,6 +145,10 @@ class LangChainAdapter(ToolAdapter): Note that `langchain` support both sync and async tools so we write adapt for both methods. + + Warning: if the mcp tool name is a python keyword, starts with digits or contains + dashes, the tool name will be sanitized to become a valid python function name. + """ def adapt( @@ -124,28 +156,36 @@ def adapt( func: Callable[[dict | None], mcp.types.CallToolResult], mcp_tool: mcp.types.Tool, ) -> BaseTool: + mcp_tool_name = _sanitize_function_name(mcp_tool.name) + if mcp_tool_name != mcp_tool.name: + log.warning(f"MCP tool name {mcp_tool.name} sanitized to {mcp_tool_name}") + generate_class_template = partial( _generate_tool_class, - mcp_tool.name, + mcp_tool_name, mcp_tool.description, mcp_tool.inputSchema, False, ) - return _instanciate_tool(mcp_tool.name, generate_class_template, func) + return _instanciate_tool(mcp_tool_name, generate_class_template, func) def async_adapt( self, afunc: Callable[[dict | None], Coroutine[Any, Any, mcp.types.CallToolResult]], mcp_tool: mcp.types.Tool, ) -> BaseTool: + mcp_tool_name = _sanitize_function_name(mcp_tool.name) + if mcp_tool_name != mcp_tool.name: + log.warning(f"MCP tool name {mcp_tool.name} sanitized to {mcp_tool_name}") + generate_class_template = partial( _generate_tool_class, - mcp_tool.name, + mcp_tool_name, mcp_tool.description, mcp_tool.inputSchema, True, ) - return _instanciate_tool(mcp_tool.name, generate_class_template, afunc) + return _instanciate_tool(mcp_tool_name, generate_class_template, afunc) if __name__ == "__main__": diff --git a/tests/test_langchain_adapter.py b/tests/test_langchain_adapter.py new file mode 100644 index 0000000..1dbffdd --- /dev/null +++ b/tests/test_langchain_adapter.py @@ -0,0 +1,103 @@ +from textwrap import dedent + +import pytest +from mcp import StdioServerParameters + +from mcpadapt.core import MCPAdapt +from mcpadapt.langchain_adapter import LangChainAdapter + + +@pytest.fixture +def echo_server_script(): + return dedent( + ''' + from mcp.server.fastmcp import FastMCP + + mcp = FastMCP("Echo Server") + + @mcp.tool() + def echo_tool(text: str) -> str: + """Echo the input text""" + return f"Echo: {text}" + + mcp.run() + ''' + ) + + +@pytest.mark.asyncio +async def test_basic_async(echo_server_script): + async with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + LangChainAdapter(), + ) as tools: + assert len(tools) == 1 # we expect one tool as defined above + assert tools[0].name == "echo_tool" # we expect the tool to be named echo_tool + response = await tools[0].ainvoke("hello") + assert response == "Echo: hello" # we expect the tool to return "Echo: hello" + + +def test_basic_sync(echo_server_script): + with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + LangChainAdapter(), + ) as tools: + assert len(tools) == 1 + assert tools[0].name == "echo_tool" + assert tools[0].invoke("hello") == "Echo: hello" + + +def test_tool_name_with_dashes(): + mcp_server_script = dedent( + ''' + from mcp.server.fastmcp import FastMCP + + mcp = FastMCP("Echo Server") + + @mcp.tool(name="echo-tool") + def echo_tool(text: str) -> str: + """Echo the input text""" + return f"Echo: {text}" + + mcp.run() + ''' + ) + with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", mcp_server_script] + ), + LangChainAdapter(), + ) as tools: + assert len(tools) == 1 + assert tools[0].name == "echo_tool" + assert tools[0].invoke("hello") == "Echo: hello" + + +def test_tool_name_with_keyword(): + mcp_server_script = dedent( + ''' + from mcp.server.fastmcp import FastMCP + + mcp = FastMCP("Echo Server") + + @mcp.tool(name="def") + def echo_tool(text: str) -> str: + """Echo the input text""" + return f"Echo: {text}" + + mcp.run() + ''' + ) + with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", mcp_server_script] + ), + LangChainAdapter(), + ) as tools: + assert len(tools) == 1 + assert tools[0].name == "def_" + assert tools[0].invoke("hello") == "Echo: hello" diff --git a/uv.lock b/uv.lock index f7f4637..8659fc2 100644 --- a/uv.lock +++ b/uv.lock @@ -471,7 +471,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -955,7 +955,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "sys_platform == 'darwin'" }, + { name = "appnope", marker = "platform_system == 'Darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1687,7 +1687,7 @@ wheels = [ [[package]] name = "mcpadapt" -version = "0.0.6" +version = "0.0.9" source = { editable = "." } dependencies = [ { name = "jsonref" }, @@ -1710,6 +1710,7 @@ dev = [ { name = "jupyter" }, { name = "pre-commit" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "ruff" }, ] @@ -1729,6 +1730,7 @@ dev = [ { name = "jupyter", specifier = ">=1.1.1" }, { name = "pre-commit", specifier = ">=4.0.1" }, { name = "pytest", specifier = ">=8.3.4" }, + { name = "pytest-asyncio", specifier = ">=0.25.2" }, { name = "ruff", specifier = ">=0.9.1" }, ] @@ -2657,6 +2659,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/92/76a1c94d3afee238333bc0a42b82935dd8f9cf8ce9e336ff87ee14d9e1cf/pytest-8.3.4-py3-none-any.whl", hash = "sha256:50e16d954148559c9a74109af1eaf0c945ba2d8f30f0a3d3335edde19788b6f6", size = 343083 }, ] +[[package]] +name = "pytest-asyncio" +version = "0.25.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pytest" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/df/adcc0d60f1053d74717d21d58c0048479e9cab51464ce0d2965b086bd0e2/pytest_asyncio-0.25.2.tar.gz", hash = "sha256:3f8ef9a98f45948ea91a0ed3dc4268b5326c0e7bce73892acc654df4262ad45f", size = 53950 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/61/d8/defa05ae50dcd6019a95527200d3b3980043df5aa445d40cb0ef9f7f98ab/pytest_asyncio-0.25.2-py3-none-any.whl", hash = "sha256:0d0bb693f7b99da304a0634afc0a4b19e49d5e0de2d670f38dc4bfa5727c5075", size = 19400 }, +] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -3443,7 +3457,7 @@ name = "tqdm" version = "4.67.1" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "sys_platform == 'win32'" }, + { name = "colorama", marker = "platform_system == 'Windows'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/a8/4b/29b4ef32e036bb34e4ab51796dd745cdba7ed47ad142a9f4a1eb8e0c744d/tqdm-4.67.1.tar.gz", hash = "sha256:f8aef9c52c08c13a65f30ea34f4e5aac3fd1a34959879d7e59e63027286627f2", size = 169737 } wheels = [