Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sanitize function name for langchain tools #7

Merged
merged 5 commits into from
Jan 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ build-backend = "hatchling.build"
dev = [
"jupyter>=1.1.1",
"pre-commit>=4.0.1",
"pytest-asyncio>=0.25.2",
"pytest>=8.3.4",
"ruff>=0.9.1",
]

[tool.pytest.ini_options]
asyncio_default_fixture_loop_scope = "function"
asyncio_mode = "auto"
48 changes: 44 additions & 4 deletions src/mcpadapt/langchain_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
>>> print(tools)
"""

import keyword
import logging
import re
from functools import partial
from typing import Any, Callable, Coroutine

Expand All @@ -18,6 +21,8 @@

from mcpadapt.core import ToolAdapter

log = logging.getLogger(__name__)

JSON_SCHEMA_TO_PYTHON_TYPES = {
"string": "str",
"number": "float",
Expand All @@ -29,6 +34,28 @@
}


def _sanitize_function_name(name):
"""
A function to sanitize function names to be used as a tool name.
Prevent the use of dashes or other python keywords as function names by tool.
"""
# Replace dashes with underscores
name = name.replace("-", "_")

# Remove any characters that aren't alphanumeric or underscore
name = re.sub(r"[^\w_]", "", name)

# Ensure it doesn't start with a number
if name[0].isdigit():
name = f"_{name}"

# Check if it's a Python keyword
if keyword.iskeyword(name):
name = f"{name}_"

return name


def _generate_tool_class(
name: str,
description: str,
Expand Down Expand Up @@ -99,6 +126,7 @@ def _instanciate_tool(
Returns:
the instanciated langchain tool
"""

# Create namespace and execute the class definition
namespace = {"tool": langchain_core.tools.tool, "func": func}
try:
Expand All @@ -117,35 +145,47 @@ class LangChainAdapter(ToolAdapter):

Note that `langchain` support both sync and async tools so we
write adapt for both methods.

Warning: if the mcp tool name is a python keyword, starts with digits or contains
dashes, the tool name will be sanitized to become a valid python function name.

"""

def adapt(
self,
func: Callable[[dict | None], mcp.types.CallToolResult],
mcp_tool: mcp.types.Tool,
) -> BaseTool:
mcp_tool_name = _sanitize_function_name(mcp_tool.name)
if mcp_tool_name != mcp_tool.name:
log.warning(f"MCP tool name {mcp_tool.name} sanitized to {mcp_tool_name}")

generate_class_template = partial(
_generate_tool_class,
mcp_tool.name,
mcp_tool_name,
mcp_tool.description,
mcp_tool.inputSchema,
False,
)
return _instanciate_tool(mcp_tool.name, generate_class_template, func)
return _instanciate_tool(mcp_tool_name, generate_class_template, func)

def async_adapt(
self,
afunc: Callable[[dict | None], Coroutine[Any, Any, mcp.types.CallToolResult]],
mcp_tool: mcp.types.Tool,
) -> BaseTool:
mcp_tool_name = _sanitize_function_name(mcp_tool.name)
if mcp_tool_name != mcp_tool.name:
log.warning(f"MCP tool name {mcp_tool.name} sanitized to {mcp_tool_name}")

generate_class_template = partial(
_generate_tool_class,
mcp_tool.name,
mcp_tool_name,
mcp_tool.description,
mcp_tool.inputSchema,
True,
)
return _instanciate_tool(mcp_tool.name, generate_class_template, afunc)
return _instanciate_tool(mcp_tool_name, generate_class_template, afunc)


if __name__ == "__main__":
Expand Down
103 changes: 103 additions & 0 deletions tests/test_langchain_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
from textwrap import dedent

import pytest
from mcp import StdioServerParameters

from mcpadapt.core import MCPAdapt
from mcpadapt.langchain_adapter import LangChainAdapter


@pytest.fixture
def echo_server_script():
return dedent(
'''
from mcp.server.fastmcp import FastMCP

mcp = FastMCP("Echo Server")

@mcp.tool()
def echo_tool(text: str) -> str:
"""Echo the input text"""
return f"Echo: {text}"

mcp.run()
'''
)


@pytest.mark.asyncio
async def test_basic_async(echo_server_script):
async with MCPAdapt(
StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
),
LangChainAdapter(),
) as tools:
assert len(tools) == 1 # we expect one tool as defined above
assert tools[0].name == "echo_tool" # we expect the tool to be named echo_tool
response = await tools[0].ainvoke("hello")
assert response == "Echo: hello" # we expect the tool to return "Echo: hello"


def test_basic_sync(echo_server_script):
with MCPAdapt(
StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
),
LangChainAdapter(),
) as tools:
assert len(tools) == 1
assert tools[0].name == "echo_tool"
assert tools[0].invoke("hello") == "Echo: hello"


def test_tool_name_with_dashes():
mcp_server_script = dedent(
'''
from mcp.server.fastmcp import FastMCP

mcp = FastMCP("Echo Server")

@mcp.tool(name="echo-tool")
def echo_tool(text: str) -> str:
"""Echo the input text"""
return f"Echo: {text}"

mcp.run()
'''
)
with MCPAdapt(
StdioServerParameters(
command="uv", args=["run", "python", "-c", mcp_server_script]
),
LangChainAdapter(),
) as tools:
assert len(tools) == 1
assert tools[0].name == "echo_tool"
assert tools[0].invoke("hello") == "Echo: hello"


def test_tool_name_with_keyword():
mcp_server_script = dedent(
'''
from mcp.server.fastmcp import FastMCP

mcp = FastMCP("Echo Server")

@mcp.tool(name="def")
def echo_tool(text: str) -> str:
"""Echo the input text"""
return f"Echo: {text}"

mcp.run()
'''
)
with MCPAdapt(
StdioServerParameters(
command="uv", args=["run", "python", "-c", mcp_server_script]
),
LangChainAdapter(),
) as tools:
assert len(tools) == 1
assert tools[0].name == "def_"
assert tools[0].invoke("hello") == "Echo: hello"
22 changes: 18 additions & 4 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading