From ff904b035ed9dc37c01017d0824ef7a043508cc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=86=AF=E5=9F=BA=E9=AD=81?= <1412414664@qq.com> Date: Mon, 8 Jun 2026 12:39:50 +0800 Subject: [PATCH] fix: detect context on callable tool objects --- .../mcpserver/utilities/context_injection.py | 6 +++- tests/server/mcpserver/test_tool_manager.py | 34 +++++++++++++++++++ 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/mcp/server/mcpserver/utilities/context_injection.py b/src/mcp/server/mcpserver/utilities/context_injection.py index ac7ab82d05..cb1e744d08 100644 --- a/src/mcp/server/mcpserver/utilities/context_injection.py +++ b/src/mcp/server/mcpserver/utilities/context_injection.py @@ -22,9 +22,13 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None: Returns: The name of the context parameter, or None if not found """ + target = fn + if not inspect.isfunction(fn) and not inspect.ismethod(fn): + target = fn.__call__ + # Get type hints to properly resolve string annotations try: - hints = typing.get_type_hints(fn) + hints = typing.get_type_hints(target) except Exception: # pragma: lax no cover # If we can't resolve type hints, we can't find the context parameter return None diff --git a/tests/server/mcpserver/test_tool_manager.py b/tests/server/mcpserver/test_tool_manager.py index e4dfd4ff9b..8a98d3c4be 100644 --- a/tests/server/mcpserver/test_tool_manager.py +++ b/tests/server/mcpserver/test_tool_manager.py @@ -328,6 +328,40 @@ def something(a: int, ctx: Context) -> int: # pragma: no cover assert "ctx" not in tool.fn_metadata.arg_model.model_fields +def test_context_arg_excluded_from_callable_object_schema(): + class MyTool: + def __init__(self): + self.__name__ = "MyTool" + + async def __call__(self, query: str, ctx: Context) -> str: # pragma: no cover + return query + + manager = ToolManager() + tool = manager.add_tool(MyTool()) + + assert tool.context_kwarg == "ctx" + assert "ctx" not in json.dumps(tool.parameters) + assert "Context" not in json.dumps(tool.parameters) + assert "ctx" not in tool.fn_metadata.arg_model.model_fields + + +@pytest.mark.anyio +async def test_context_injected_into_callable_object(): + class MyTool: + def __init__(self): + self.__name__ = "MyTool" + + async def __call__(self, query: str, ctx: Context) -> str: + assert isinstance(ctx, Context) + return query + + manager = ToolManager() + manager.add_tool(MyTool()) + + result = await manager.call_tool("MyTool", {"query": "hello"}, context=Context()) + assert result == "hello" + + class TestContextHandling: """Test context handling in the tool manager."""