Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/mcp/server/mcpserver/utilities/context_injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ def find_context_parameter(fn: Callable[..., Any]) -> str | None:
Returns:
The name of the context parameter, or None if not found
"""
target = fn
if not inspect.isfunction(fn) and not inspect.ismethod(fn):
target = fn.__call__

# Get type hints to properly resolve string annotations
try:
hints = typing.get_type_hints(fn)
hints = typing.get_type_hints(target)
except Exception: # pragma: lax no cover
# If we can't resolve type hints, we can't find the context parameter
return None
Expand Down
34 changes: 34 additions & 0 deletions tests/server/mcpserver/test_tool_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,40 @@ def something(a: int, ctx: Context) -> int: # pragma: no cover
assert "ctx" not in tool.fn_metadata.arg_model.model_fields


def test_context_arg_excluded_from_callable_object_schema():
class MyTool:
def __init__(self):
self.__name__ = "MyTool"

async def __call__(self, query: str, ctx: Context) -> str: # pragma: no cover
return query

manager = ToolManager()
tool = manager.add_tool(MyTool())

assert tool.context_kwarg == "ctx"
assert "ctx" not in json.dumps(tool.parameters)
assert "Context" not in json.dumps(tool.parameters)
assert "ctx" not in tool.fn_metadata.arg_model.model_fields


@pytest.mark.anyio
async def test_context_injected_into_callable_object():
class MyTool:
def __init__(self):
self.__name__ = "MyTool"

async def __call__(self, query: str, ctx: Context) -> str:
assert isinstance(ctx, Context)
return query

manager = ToolManager()
manager.add_tool(MyTool())

result = await manager.call_tool("MyTool", {"query": "hello"}, context=Context())
assert result == "hello"


class TestContextHandling:
"""Test context handling in the tool manager."""

Expand Down
Loading