-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathregistry.py
More file actions
109 lines (76 loc) · 3.68 KB
/
registry.py
File metadata and controls
109 lines (76 loc) · 3.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""Generic tool registry — dispatcher for typed agent tools.
Each tool is a function that takes a single ``StrictModel`` input and returns
a single ``StrictModel`` output. The registry holds a mapping from tool name
to ``(input_schema, callable)`` and provides:
- ``register(name, input_schema)`` — decorator to register a callable.
- ``dispatch(name, raw_input)`` — validate the dict-shaped ``raw_input``
against the input schema, call the
tool, return the typed output.
Layer-wise the registry sits below ``agent`` / ``api`` / ``eval`` (it doesn't
import from them) and above ``models``. Verified by the import-linter
contract in ``pyproject.toml``.
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
from src.models._base import StrictModel
ToolFn = Callable[[StrictModel], StrictModel]
class UnknownToolError(KeyError):
"""Raised when ``dispatch`` is called with an unregistered tool name."""
class Registry:
"""Maps a tool name to its input schema and callable implementation."""
def __init__(self) -> None:
self._tools: dict[str, tuple[type[StrictModel], ToolFn]] = {}
def register(
self,
name: str,
input_schema: type[StrictModel],
) -> Callable[[ToolFn], ToolFn]:
"""Register a tool implementation.
Returns a decorator so callers can use either of:
@registry.register("echo", EchoToolInput)
def echo_tool(payload: EchoToolInput) -> EchoToolOutput: ...
registry.register("echo", EchoToolInput)(echo_tool)
"""
def decorator(fn: ToolFn) -> ToolFn:
if name in self._tools:
msg = f"Tool {name!r} is already registered."
raise ValueError(msg)
self._tools[name] = (input_schema, fn)
return fn
return decorator
def dispatch(self, name: str, raw_input: dict[str, Any]) -> StrictModel:
"""Validate ``raw_input`` and call the tool.
Raises ``UnknownToolError`` when *name* isn't registered. Pydantic's
``ValidationError`` propagates when ``raw_input`` doesn't match the
registered input schema.
"""
if name not in self._tools:
registered = sorted(self._tools)
msg = f"Unknown tool {name!r}. Registered: {registered}"
raise UnknownToolError(msg)
input_schema, fn = self._tools[name]
payload = input_schema.model_validate(raw_input)
return fn(payload)
def names(self) -> list[str]:
"""Return the sorted list of registered tool names."""
return sorted(self._tools)
# Module-global singleton — agent / eval consumers import this directly so
# tools self-register at module load via the decorator below.
registry = Registry()
# ---------------------------------------------------------------------------
# Example tool: echo — exercises the layer + demonstrates the contract shape.
# ---------------------------------------------------------------------------
class EchoToolInput(StrictModel, strict=True):
"""Input contract for the example echo tool."""
msg: str
class EchoToolOutput(StrictModel, strict=True):
"""Output contract for the example echo tool."""
echoed: str
@registry.register("echo", EchoToolInput)
def echo_tool(payload: StrictModel) -> StrictModel:
"""Return the input string wrapped in ``EchoToolOutput``."""
if not isinstance(payload, EchoToolInput): # pragma: no cover — defensive
msg = f"echo_tool got unexpected payload type: {type(payload)!r}"
raise TypeError(msg)
return EchoToolOutput(echoed=payload.msg)