Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 31 additions & 26 deletions src/mcp/client/session_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections.abc import Callable
from dataclasses import dataclass
from types import TracebackType
from typing import Any, TypeAlias
from typing import Any, TypeAlias, cast

import anyio
import httpx
Expand Down Expand Up @@ -332,6 +332,8 @@ async def _establish_session(
async def _aggregate_components(self, server_info: types.Implementation, session: mcp.ClientSession) -> None:
"""Aggregates prompts, resources, and tools from a given session."""

capabilities = cast(types.InitializeResult, session.initialize_result).capabilities

# Create a reverse index so we can find all prompts, resources, and
# tools belonging to this session. Used for removing components from
# the session group via self.disconnect_from_server.
Expand All @@ -345,35 +347,38 @@ async def _aggregate_components(self, server_info: types.Implementation, session
tool_to_session_temp: dict[str, mcp.ClientSession] = {}

# Query the server for its prompts and aggregate to list.
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")
if capabilities.prompts is not None:
try:
prompts = (await session.list_prompts()).prompts
for prompt in prompts:
name = self._component_name(prompt.name, server_info)
prompts_temp[name] = prompt
component_names.prompts.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch prompts: {err}")

# Query the server for its resources and aggregate to list.
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")
if capabilities.resources is not None:
try:
resources = (await session.list_resources()).resources
for resource in resources:
name = self._component_name(resource.name, server_info)
resources_temp[name] = resource
component_names.resources.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch resources: {err}")

# Query the server for its tools and aggregate to list.
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")
if capabilities.tools is not None:
try:
tools = (await session.list_tools()).tools
for tool in tools:
name = self._component_name(tool.name, server_info)
tools_temp[name] = tool
tool_to_session_temp[name] = session
component_names.tools.add(name)
except MCPError as err: # pragma: no cover
logging.warning(f"Could not fetch tools: {err}")

# Clean up exit stack for session if we couldn't retrieve anything
# from the server.
Expand Down
78 changes: 78 additions & 0 deletions tests/client/test_session_group.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import logging
from unittest import mock

import httpx
Expand Down Expand Up @@ -125,6 +126,83 @@ async def test_client_session_group_connect_to_server(mock_exit_stack: contextli
mock_session.list_prompts.assert_awaited_once()


@pytest.mark.anyio
async def test_client_session_group_connect_with_session_respects_negotiated_capabilities(
caplog: pytest.LogCaptureFixture,
):
from mcp import Client
from mcp.server import Server, ServerRequestContext

async def handle_list_tools(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> types.ListToolsResult:
return types.ListToolsResult(
tools=[
types.Tool(
name="ping",
description="Ping",
input_schema={"type": "object", "properties": {}},
)
]
)

async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult:
return types.CallToolResult(content=[types.TextContent(type="text", text="pong")])

server = Server(
"tools-only-server",
on_list_tools=handle_list_tools,
on_call_tool=handle_call_tool,
)

group = ClientSessionGroup()

with caplog.at_level(logging.WARNING):
async with Client(server) as client:
assert client.initialize_result.capabilities.prompts is None
assert client.initialize_result.capabilities.resources is None

client.session.list_prompts = mock.AsyncMock(side_effect=AssertionError("list_prompts() was called"))
client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called"))

await group.connect_with_session(client.initialize_result.server_info, client.session)
await group.call_tool("ping")

assert not caplog.records


@pytest.mark.anyio
async def test_client_session_group_skips_unadvertised_tools_and_resources(
caplog: pytest.LogCaptureFixture,
):
from mcp import Client
from mcp.server import Server, ServerRequestContext

async def handle_list_prompts(
ctx: ServerRequestContext, params: types.PaginatedRequestParams | None
) -> types.ListPromptsResult:
return types.ListPromptsResult(prompts=[types.Prompt(name="hello", description="Hello", arguments=[])])

server = Server(
"prompts-only-server",
on_list_prompts=handle_list_prompts,
)

group = ClientSessionGroup()

with caplog.at_level(logging.WARNING):
async with Client(server) as client:
assert client.initialize_result.capabilities.tools is None
assert client.initialize_result.capabilities.resources is None

client.session.list_tools = mock.AsyncMock(side_effect=AssertionError("list_tools() was called"))
client.session.list_resources = mock.AsyncMock(side_effect=AssertionError("list_resources() was called"))

await group.connect_with_session(client.initialize_result.server_info, client.session)

assert not caplog.records


@pytest.mark.anyio
async def test_client_session_group_connect_to_server_with_name_hook(mock_exit_stack: contextlib.AsyncExitStack):
"""Test connecting with a component name hook."""
Expand Down
Loading