Skip to content

Commit 733527f

Browse files
authored
Implement tool name restrictions from the MCP spec (#97)
1 parent b85dd7b commit 733527f

2 files changed

Lines changed: 60 additions & 1 deletion

File tree

splunklib/ai/registry.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
1212
# License for the specific language governing permissions and limitations
1313
# under the License.
14+
1415
import asyncio
1516
import inspect
1617
import logging
18+
import string
1719
from collections.abc import Callable, Sequence
1820
from dataclasses import asdict, dataclass
1921
from logging import Logger
@@ -416,6 +418,12 @@ def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]:
416418
if name is None:
417419
name = func.__name__
418420

421+
if not is_tool_name_valid(name):
422+
raise ToolRegistryRuntimeError(
423+
f"Tool name {name} doesn't conform to MCP spec, see: "
424+
+ "https://modelcontextprotocol.io/specification/latest/server/tools#tool-names"
425+
)
426+
419427
if self._executing:
420428
raise ToolRegistryRuntimeError(
421429
"ToolRegistry is already running, cannot define new tools"
@@ -497,3 +505,16 @@ def _drop_type_annotations_of(
497505
new_func.__annotations__ = new_annotations
498506

499507
return new_func
508+
509+
510+
MCP_ALLOWED_CHARS = string.ascii_letters + string.digits + "_-."
511+
512+
513+
def is_tool_name_valid(name: str) -> bool:
514+
"""Checks compliance with the MCP spec restrictions, see:
515+
https://modelcontextprotocol.io/specification/latest/server/tools#tool-names
516+
"""
517+
if not (1 <= len(name) <= 128):
518+
return False
519+
520+
return set(name).issubset(MCP_ALLOWED_CHARS)
Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# pyright: reportPrivateUsage=false, reportUnusedFunction=false, reportUnusedParameter=false
1616

1717
import os
18+
import string
1819
import sys
1920
import unittest
2021
from collections.abc import AsyncGenerator
@@ -27,7 +28,12 @@
2728
from mcp.client.stdio import stdio_client
2829
from mcp.types import TextContent
2930

30-
from splunklib.ai.registry import ToolContext, ToolRegistry, ToolRegistryRuntimeError
31+
from splunklib.ai.registry import (
32+
ToolContext,
33+
ToolRegistry,
34+
ToolRegistryRuntimeError,
35+
is_tool_name_valid,
36+
)
3137

3238

3339
class TestJSONSchemaInference(unittest.TestCase):
@@ -407,6 +413,38 @@ def tool(foo: int) -> int:
407413
register_name(r)
408414

409415

416+
@pytest.mark.parametrize(
417+
argnames="name",
418+
argvalues=[
419+
".",
420+
"." * 128,
421+
"func.tool-name_v2",
422+
string.ascii_letters + string.digits,
423+
],
424+
)
425+
def test_valid_name_passes(name: str) -> None:
426+
assert is_tool_name_valid(name)
427+
428+
429+
@pytest.mark.parametrize(
430+
argnames="name",
431+
argvalues=[
432+
"",
433+
"—",
434+
"." * 129,
435+
"tool^name+=|/",
436+
string.punctuation,
437+
],
438+
)
439+
def test_tool_decorator_raises_on_invalid_name(name: str) -> None:
440+
reg = ToolRegistry()
441+
442+
with pytest.raises(ToolRegistryRuntimeError, match=r"Tool name .*"):
443+
444+
@reg.tool(name)
445+
def mock_tool() -> None: ...
446+
447+
410448
class TestRegistryTestCase(unittest.IsolatedAsyncioTestCase):
411449
@asynccontextmanager
412450
async def connect(self, name: str) -> AsyncGenerator[ClientSession, Any]:

0 commit comments

Comments
 (0)