Skip to content

Commit

Permalink
sanitize function name for langchain tools (#7)
Browse files Browse the repository at this point in the history
* add pytest asyncio with setup

* add a basic async tests for langchain

* add sync version test as well

* add tests for keyword toolname and dashes in tool name

* add sanitizing function and passing test
  • Loading branch information
grll authored Jan 25, 2025
1 parent 320b696 commit e73120c
Show file tree
Hide file tree
Showing 4 changed files with 170 additions and 8 deletions.
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ build-backend = "hatchling.build"
dev = [
"jupyter>=1.1.1",
"pre-commit>=4.0.1",
"pytest-asyncio>=0.25.2",
"pytest>=8.3.4",
"ruff>=0.9.1",
]

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"
48 changes: 44 additions & 4 deletions src/mcpadapt/langchain_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
>>> print(tools)
"""

import keyword
import logging
import re
from functools import partial
from typing import Any, Callable, Coroutine

Expand All @@ -18,6 +21,8 @@

from mcpadapt.core import ToolAdapter

log = logging.getLogger(__name__)

JSON_SCHEMA_TO_PYTHON_TYPES = {
"string": "str",
"number": "float",
Expand All @@ -29,6 +34,28 @@
}


def _sanitize_function_name(name):
"""
A function to sanitize function names to be used as a tool name.
Prevent the use of dashes or other python keywords as function names by tool.
"""
# Replace dashes with underscores
name = name.replace("-", "_")

# Remove any characters that aren't alphanumeric or underscore
name = re.sub(r"[^\w_]", "", name)

# Ensure it doesn't start with a number
if name[0].isdigit():
name = f"_{name}"

# Check if it's a Python keyword
if keyword.iskeyword(name):
name = f"{name}_"

return name


def _generate_tool_class(
name: str,
description: str,
Expand Down Expand Up @@ -99,6 +126,7 @@ def _instanciate_tool(
Returns:
the instanciated langchain tool
"""

# Create namespace and execute the class definition
namespace = {"tool": langchain_core.tools.tool, "func": func}
try:
Expand All @@ -117,35 +145,47 @@ class LangChainAdapter(ToolAdapter):
Note that `langchain` support both sync and async tools so we
write adapt for both methods.
Warning: if the mcp tool name is a python keyword, starts with digits or contains
dashes, the tool name will be sanitized to become a valid python function name.
"""

def adapt(
self,
func: Callable[[dict | None], mcp.types.CallToolResult],
mcp_tool: mcp.types.Tool,
) -> BaseTool:
mcp_tool_name = _sanitize_function_name(mcp_tool.name)
if mcp_tool_name != mcp_tool.name:
log.warning(f"MCP tool name {mcp_tool.name} sanitized to {mcp_tool_name}")

generate_class_template = partial(
_generate_tool_class,
mcp_tool.name,
mcp_tool_name,
mcp_tool.description,
mcp_tool.inputSchema,
False,
)
return _instanciate_tool(mcp_tool.name, generate_class_template, func)
return _instanciate_tool(mcp_tool_name, generate_class_template, func)

def async_adapt(
self,
afunc: Callable[[dict | None], Coroutine[Any, Any, mcp.types.CallToolResult]],
mcp_tool: mcp.types.Tool,
) -> BaseTool:
mcp_tool_name = _sanitize_function_name(mcp_tool.name)
if mcp_tool_name != mcp_tool.name:
log.warning(f"MCP tool name {mcp_tool.name} sanitized to {mcp_tool_name}")

generate_class_template = partial(
_generate_tool_class,
mcp_tool.name,
mcp_tool_name,
mcp_tool.description,
mcp_tool.inputSchema,
True,
)
return _instanciate_tool(mcp_tool.name, generate_class_template, afunc)
return _instanciate_tool(mcp_tool_name, generate_class_template, afunc)


if __name__ == "__main__":
Expand Down
103 changes: 103 additions & 0 deletions tests/test_langchain_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from textwrap import dedent

import pytest
from mcp import StdioServerParameters

from mcpadapt.core import MCPAdapt
from mcpadapt.langchain_adapter import LangChainAdapter


@pytest.fixture
def echo_server_script():
return dedent(
'''
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Echo Server")
@mcp.tool()
def echo_tool(text: str) -> str:
"""Echo the input text"""
return f"Echo: {text}"
mcp.run()
'''
)


@pytest.mark.asyncio
async def test_basic_async(echo_server_script):
async with MCPAdapt(
StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
),
LangChainAdapter(),
) as tools:
assert len(tools) == 1 # we expect one tool as defined above
assert tools[0].name == "echo_tool" # we expect the tool to be named echo_tool
response = await tools[0].ainvoke("hello")
assert response == "Echo: hello" # we expect the tool to return "Echo: hello"


def test_basic_sync(echo_server_script):
with MCPAdapt(
StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
),
LangChainAdapter(),
) as tools:
assert len(tools) == 1
assert tools[0].name == "echo_tool"
assert tools[0].invoke("hello") == "Echo: hello"


def test_tool_name_with_dashes():
mcp_server_script = dedent(
'''
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Echo Server")
@mcp.tool(name="echo-tool")
def echo_tool(text: str) -> str:
"""Echo the input text"""
return f"Echo: {text}"
mcp.run()
'''
)
with MCPAdapt(
StdioServerParameters(
command="uv", args=["run", "python", "-c", mcp_server_script]
),
LangChainAdapter(),
) as tools:
assert len(tools) == 1
assert tools[0].name == "echo_tool"
assert tools[0].invoke("hello") == "Echo: hello"


def test_tool_name_with_keyword():
mcp_server_script = dedent(
'''
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Echo Server")
@mcp.tool(name="def")
def echo_tool(text: str) -> str:
"""Echo the input text"""
return f"Echo: {text}"
mcp.run()
'''
)
with MCPAdapt(
StdioServerParameters(
command="uv", args=["run", "python", "-c", mcp_server_script]
),
LangChainAdapter(),
) as tools:
assert len(tools) == 1
assert tools[0].name == "def_"
assert tools[0].invoke("hello") == "Echo: hello"
22 changes: 18 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e73120c

Please sign in to comment.