From c19e5c2e4b77fb9fdc96812a6d7a09e3976aded4 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 09:38:24 -0600 Subject: [PATCH 01/27] Add layout traversal utilities for Dash component trees --- dash/layout.py | 228 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 228 insertions(+) create mode 100644 dash/layout.py diff --git a/dash/layout.py b/dash/layout.py new file mode 100644 index 0000000000..fdca86edca --- /dev/null +++ b/dash/layout.py @@ -0,0 +1,228 @@ +"""Reusable layout utilities for traversing and inspecting Dash component trees.""" + +from __future__ import annotations + +import json +from typing import Any, Generator + +from dash import get_app +from dash._pages import PAGE_REGISTRY +from dash.dependencies import Wildcard +from dash.development.base_component import Component + +_WILDCARD_VALUES = frozenset(w.value for w in Wildcard) + + +def traverse( + start: Component | None = None, +) -> Generator[tuple[Component, tuple[Component, ...]], None, None]: + """Yield ``(component, ancestors)`` for every Component in the tree. + + If ``start`` is ``None``, the full app layout is resolved via + ``dash.get_app()``, preferring ``validation_layout`` for completeness. + """ + if start is None: + app = get_app() + start = getattr(app, "validation_layout", None) or app.get_layout() + + yield from _walk(start, ()) + + +def _walk( + node: Any, + ancestors: tuple[Component, ...], +) -> Generator[tuple[Component, tuple[Component, ...]], None, None]: + if node is None: + return + if isinstance(node, (list, tuple)): + for item in node: + yield from _walk(item, ancestors) + return + if not isinstance(node, Component): + return + + yield node, ancestors + + child_ancestors = (*ancestors, node) + for _prop_name, child in iter_children(node): + yield from _walk(child, child_ancestors) + + +def iter_children( + component: Component, +) -> Generator[tuple[str, Component], None, None]: + """Yield ``(prop_name, child_component)`` for all component-valued props. + + Walks ``children`` plus any props declared in the component's + ``_children_props`` list. Supports nested path expressions like + ``control_groups[].children`` and ``insights.title``. + """ + props_to_walk = ["children"] + getattr(component, "_children_props", []) + for prop_path in props_to_walk: + for child in get_children(component, prop_path): + yield prop_path, child + + +def get_children(component: Any, prop_path: str) -> list[Component]: + """Resolve a ``_children_props`` path expression to child Components. + + Mirrors the dash-renderer's path parsing in ``DashWrapper.tsx``. + Supports: + - ``"children"`` — simple prop + - ``"control_groups[].children"`` — array, then sub-prop per element + - ``"insights.title"`` — nested object prop + """ + clean_path = prop_path.replace("[]", "").replace("{}", "") + + if "." not in prop_path: + return _collect_components(getattr(component, clean_path, None)) + + parts = prop_path.split(".") + array_idx = next((i for i, p in enumerate(parts) if "[]" in p), len(parts)) + front = [p.replace("[]", "").replace("{}", "") for p in parts[: array_idx + 1]] + back = [p.replace("{}", "") for p in parts[array_idx + 1 :]] + + node = _resolve_path(component, front) + if node is None: + return [] + + if back and isinstance(node, (list, tuple)): + results: list[Component] = [] + for element in node: + child = _resolve_path(element, back) + results.extend(_collect_components(child)) + return results + + return _collect_components(node) + + +def _resolve_path(node: Any, keys: list[str]) -> Any: + """Walk a chain of keys through Components and dicts.""" + for key in keys: + if isinstance(node, Component): + node = getattr(node, key, None) + elif isinstance(node, dict): + node = node.get(key) + else: + return None + if node is None: + return None + return node + + +def _collect_components(value: Any) -> list[Component]: + """Extract Components from a value (single, list, or None).""" + if value is None: + return [] + if isinstance(value, Component): + return [value] + if isinstance(value, (list, tuple)): + return [item for item in value if isinstance(item, (Component, list, tuple))] + return [] + + +def find_component( + component_id: str | dict, + layout: Component | None = None, + page: str | None = None, +) -> Component | None: + """Find a component by ID. + + If neither ``layout`` nor ``page`` is provided, searches the full + app layout (preferring ``validation_layout`` for completeness). + """ + if page is not None: + layout = _resolve_page_layout(page) + + if layout is None: + app = get_app() + layout = getattr(app, "validation_layout", None) or app.get_layout() + + for comp, _ in traverse(layout): + if getattr(comp, "id", None) == component_id: + return comp + return None + + +def parse_wildcard_id(pid: Any) -> dict | None: + """Parse a component ID and return it as a dict if it contains a wildcard. + + Accepts string (JSON-encoded) or dict IDs. Returns ``None`` + if the ID is not a wildcard pattern. + + Example:: + + >>> parse_wildcard_id('{"type":"input","index":["ALL"]}') + {"type": "input", "index": ["ALL"]} + >>> parse_wildcard_id("my-dropdown") + None + """ + if isinstance(pid, str) and pid.startswith("{"): + try: + pid = json.loads(pid) + except (json.JSONDecodeError, ValueError): + return None + if not isinstance(pid, dict): + return None + for v in pid.values(): + if isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES: + return pid + return None + + +def find_matching_components(pattern: dict) -> list[Component]: + """Find all components whose dict ID matches a wildcard pattern. + + Non-wildcard keys must match exactly. Wildcard keys are ignored. + """ + non_wildcard_keys = { + k: v + for k, v in pattern.items() + if not (isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES) + } + matches = [] + for comp, _ in traverse(): + comp_id = getattr(comp, "id", None) + if not isinstance(comp_id, dict): + continue + if all(comp_id.get(k) == v for k, v in non_wildcard_keys.items()): + matches.append(comp) + return matches + + +def extract_text(component: Component) -> str: + """Recursively extract plain text from a component's children tree. + + Mimics the browser's ``element.textContent``. + """ + children = getattr(component, "children", None) + if children is None: + return "" + if isinstance(children, str): + return children + if isinstance(children, Component): + return extract_text(children) + if isinstance(children, (list, tuple)): + parts: list[str] = [] + for child in children: + if isinstance(child, str): + parts.append(child) + elif isinstance(child, Component): + parts.append(extract_text(child)) + return "".join(parts).strip() + return "" + + +def _resolve_page_layout(page: str) -> Any | None: + if not PAGE_REGISTRY: + return None + for _module, page_info in PAGE_REGISTRY.items(): + if page_info.get("path") == page: + page_layout = page_info.get("layout") + if callable(page_layout): + try: + page_layout = page_layout() + except (TypeError, RuntimeError): + return None + return page_layout + return None From 9283b66ba9a30c8ba270edc35e5ad7eac0a81d0e Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:01:59 -0600 Subject: [PATCH 02/27] Make Dash components compatible with Pydantic types --- dash/development/_py_components_generation.py | 5 +- dash/development/base_component.py | 17 ++++ dash/types.py | 67 ++++++++++++++- requirements/install.txt | 1 + tests/unit/test_layout.py | 83 +++++++++++++++++++ tests/unit/test_pydantic_types.py | 36 ++++++++ 6 files changed, 204 insertions(+), 5 deletions(-) create mode 100644 tests/unit/test_layout.py create mode 100644 tests/unit/test_pydantic_types.py diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index 73545ea4a5..b597283a04 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -24,6 +24,7 @@ import typing # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args +from dash.types import NumberType # noqa: F401 {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ @@ -31,10 +32,6 @@ typing.Sequence[ComponentSingleType], ] -NumberType = typing.Union[ - typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex -] - """ diff --git a/dash/development/base_component.py b/dash/development/base_component.py index 02579ff2e2..5382c5aafc 100644 --- a/dash/development/base_component.py +++ b/dash/development/base_component.py @@ -117,6 +117,23 @@ class Component(metaclass=ComponentMeta): _valid_wildcard_attributes: typing.List[str] available_wildcard_properties: typing.List[str] + @classmethod + def __get_pydantic_core_schema__(cls, source_type, handler): + from pydantic_core import core_schema + return core_schema.any_schema() + + @classmethod + def __get_pydantic_json_schema__(cls, schema, handler): + namespaces = list(ComponentRegistry.namespace_to_package.keys()) + return { + "type": "object", + "properties": { + "type": {"type": "string"}, + "namespace": {"type": "string", "enum": namespaces} if namespaces else {"type": "string"}, + "props": {"type": "object"}, + }, + } + class _UNDEFINED: def __repr__(self): return "undefined" diff --git a/dash/types.py b/dash/types.py index 9a39adb43e..43bf16dc30 100644 --- a/dash/types.py +++ b/dash/types.py @@ -1,4 +1,29 @@ -from typing_extensions import TypedDict, NotRequired +import typing +from typing import Any, Dict, List, Union + +from pydantic import Field, GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic_core import core_schema +from typing_extensions import Annotated, TypedDict, NotRequired + + +class _NumberSchema: # pylint: disable=too-few-public-methods + @classmethod + def __get_pydantic_core_schema__( + cls, _source_type: Any, _handler: GetCoreSchemaHandler + ) -> Any: + return core_schema.float_schema() + + @classmethod + def __get_pydantic_json_schema__( + cls, _schema: Any, _handler: GetJsonSchemaHandler + ) -> dict: + return {"type": "number"} + + +NumberType = Annotated[ + Union[typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex], + _NumberSchema, +] class RendererHooks(TypedDict): # pylint: disable=too-many-ancestors @@ -8,3 +33,43 @@ class RendererHooks(TypedDict): # pylint: disable=too-many-ancestors request_post: NotRequired[str] callback_resolved: NotRequired[str] request_refresh_jwt: NotRequired[str] + + +class CallbackDependency(TypedDict): + id: Union[str, Dict[str, Any]] + property: str + + +class CallbackInput(TypedDict): + id: Union[str, Dict[str, Any]] + property: str + value: Any + + +class CallbackDispatchBody(TypedDict): + output: str + outputs: List[CallbackDependency] + inputs: List[CallbackInput] + state: List[CallbackInput] + changedPropIds: List[str] + + +CallbackOutput = Annotated[ + Dict[str, Any], + Field( + description="The return values of the callback. A mapping of component & property names to their updated values." + ), +] + +CallbackSideOutput = Annotated[ + Dict[str, Any], + Field( + description="Side-effect updates that the callback performed but did not declare ahead of time. A mapping of component & property names to their updated values." + ), +] + + +class CallbackDispatchResponse(TypedDict): + multi: NotRequired[bool] + response: NotRequired[Dict[str, CallbackOutput]] + sideUpdate: NotRequired[Dict[str, CallbackSideOutput]] diff --git a/requirements/install.txt b/requirements/install.txt index df0e1299e3..5b425cf5a9 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -7,3 +7,4 @@ requests retrying nest-asyncio setuptools +pydantic>=2.12.5 diff --git a/tests/unit/test_layout.py b/tests/unit/test_layout.py new file mode 100644 index 0000000000..76a72f7fb4 --- /dev/null +++ b/tests/unit/test_layout.py @@ -0,0 +1,83 @@ +"""Tests for dash.layout — layout traversal and component lookup utilities.""" + +import pytest + +from dash import html, dcc +from dash.layout import ( + traverse, + find_component, + extract_text, + parse_wildcard_id, +) + + +@pytest.fixture +def sample_layout(): + return html.Div( + [ + html.Label("Name:", htmlFor="name-input"), + " ", + dcc.Input(id="name-input", value="World"), + html.Div( + [html.Span(id="deep-child", children="deep text")], + id="inner", + ), + ], + id="root", + ) + + +class TestTraverse: + def test_yields_all_components_with_correct_ancestors(self, sample_layout): + results = { + getattr(c, "id", None): len(ancestors) + for c, ancestors in traverse(sample_layout) + } + assert results["root"] == 0 + assert results["name-input"] == 1 + assert results["deep-child"] == 2 + + def test_empty_layout(self): + results = list(traverse(html.Div())) + assert len(results) == 1 # just the Div itself + + +class TestFindComponent: + def test_finds_by_string_id(self, sample_layout): + comp = find_component("deep-child", layout=sample_layout) + assert comp is not None and comp.id == "deep-child" + + def test_returns_none_for_missing_id(self, sample_layout): + assert find_component("nope", layout=sample_layout) is None + + def test_finds_by_dict_id(self): + layout = html.Div([html.Div(id={"type": "item", "index": 0})]) + assert find_component({"type": "item", "index": 0}, layout=layout) is not None + + +class TestExtractText: + def test_extracts_all_text_content(self, sample_layout): + assert extract_text(sample_layout) == "Name: deep text" + + def test_none_children(self): + assert extract_text(html.Div()) == "" + + +class TestParseWildcardId: + @pytest.mark.parametrize("wildcard", ["ALL", "MATCH", "ALLSMALLER"]) + def test_returns_dict_for_wildcard(self, wildcard): + result = parse_wildcard_id({"type": "input", "index": [wildcard]}) + assert result == {"type": "input", "index": [wildcard]} + + def test_parses_json_string(self): + result = parse_wildcard_id('{"type":"input","index":["ALL"]}') + assert result == {"type": "input", "index": ["ALL"]} + + def test_returns_none_for_plain_string(self): + assert parse_wildcard_id("my-dropdown") is None + + def test_returns_none_for_non_wildcard_dict(self): + assert parse_wildcard_id({"type": "input", "index": 0}) is None + + def test_returns_none_for_invalid_json(self): + assert parse_wildcard_id("{not valid}") is None diff --git a/tests/unit/test_pydantic_types.py b/tests/unit/test_pydantic_types.py new file mode 100644 index 0000000000..75d1dc7f41 --- /dev/null +++ b/tests/unit/test_pydantic_types.py @@ -0,0 +1,36 @@ +"""Tests for dash.types — Pydantic-compatible types and schemas.""" + +from pydantic import TypeAdapter + +from dash.types import NumberType, CallbackDispatchBody, CallbackDispatchResponse +from dash.development.base_component import Component + + +class TestNumberType: + def test_json_schema_is_number(self): + schema = TypeAdapter(NumberType).json_schema() + assert schema["type"] == "number" + + +class TestComponentPydanticSchema: + def test_produces_object_schema(self): + schema = TypeAdapter(Component).json_schema() + assert schema["type"] == "object" + assert "properties" in schema + + def test_schema_has_type_and_props(self): + schema = TypeAdapter(Component).json_schema() + props = schema["properties"] + assert "type" in props + assert "props" in props + + +class TestCallbackDispatchTypes: + def test_dispatch_body_schema(self): + schema = TypeAdapter(CallbackDispatchBody).json_schema() + assert "output" in schema["properties"] + assert "inputs" in schema["properties"] + + def test_dispatch_response_schema(self): + schema = TypeAdapter(CallbackDispatchResponse).json_schema() + assert "response" in schema["properties"] From 402d8b96286307a2d25e7278309d8567c8fefafd Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:31:58 -0600 Subject: [PATCH 03/27] Extract get_layout() from serve_layout() --- dash/dash.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 122cf54dd6..1418483187 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -907,15 +907,22 @@ def index_string(self, value: str) -> None: self._index_string = value @with_app_context - def serve_layout(self): - layout = self._layout_value() + def get_layout(self): + """Return the resolved layout with all hooks applied. + This is the canonical way to obtain the app's layout — it + calls the layout function (if callable), includes extra + components, and runs layout hooks. + """ + layout = self._layout_value() for hook in self._hooks.get_hooks("layout"): layout = hook(layout) + return layout + def serve_layout(self): # TODO - Set browser cache limit - pass hash into frontend return flask.Response( - to_json(layout), + to_json(self.get_layout()), mimetype="application/json", ) From f82288da0a7253cfdd0f40485355a167e3f586e5 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:46:25 -0600 Subject: [PATCH 04/27] Fix build issues for dash-table and dash-core-components --- components/dash-core-components/package.json | 4 ++-- components/dash-table/package.json | 2 +- dash/dash-renderer/babel.config.js | 2 +- dash/dash-renderer/package.json | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/components/dash-core-components/package.json b/components/dash-core-components/package.json index ac9d88c80c..e430a00b6a 100644 --- a/components/dash-core-components/package.json +++ b/components/dash-core-components/package.json @@ -27,7 +27,7 @@ "build:js": "webpack --mode production", "build:backends": "dash-generate-components ./src/components dash_core_components -p package-info.json && cp dash_core_components_base/** dash_core_components/ && dash-generate-components ./src/components dash_core_components -p package-info.json -k RangeSlider,Slider,Dropdown,RadioItems,Checklist,DatePickerSingle,DatePickerRange,Input,Link --r-prefix 'dcc' --r-suggests 'dash,dashHtmlComponents,jsonlite,plotly' --jl-prefix 'dcc' && black dash_core_components", "build": "run-s prepublishOnly build:js build:backends", - "postbuild": "es-check es2015 dash_core_components/*.js", + "postbuild": "es-check es2017 dash_core_components/*.js", "build:watch": "watch 'npm run build' src", "format": "run-s private::format.*", "lint": "run-s private::lint.*" @@ -126,6 +126,6 @@ "react-dom": "16 - 19" }, "browserslist": [ - "last 9 years and not dead" + "last 11 years and not dead" ] } diff --git a/components/dash-table/package.json b/components/dash-table/package.json index b5d65499e9..fd905c6f40 100644 --- a/components/dash-table/package.json +++ b/components/dash-table/package.json @@ -119,6 +119,6 @@ "npm": ">=6.1.0" }, "browserslist": [ - "last 9 years and not dead" + "last 11 years and not dead" ] } diff --git a/dash/dash-renderer/babel.config.js b/dash/dash-renderer/babel.config.js index d7b0c89e8e..6e6cc5d957 100644 --- a/dash/dash-renderer/babel.config.js +++ b/dash/dash-renderer/babel.config.js @@ -3,7 +3,7 @@ module.exports = { '@babel/preset-typescript', ['@babel/preset-env', { "targets": { - "browsers": ["last 10 years and not dead"] + "browsers": ["last 11 years and not dead"] } }], '@babel/preset-react' diff --git a/dash/dash-renderer/package.json b/dash/dash-renderer/package.json index f92d22cfc5..ce0c26b4a6 100644 --- a/dash/dash-renderer/package.json +++ b/dash/dash-renderer/package.json @@ -89,6 +89,6 @@ ], "prettier": "@plotly/prettier-config-dash", "browserslist": [ - "last 10 years and not dead" + "last 11 years and not dead" ] } From 0efcec5edea9431d5337ab2c735f7ec77255322a Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 15:13:51 -0600 Subject: [PATCH 05/27] Add CallbackDispatchBody type hints to dispatch methods --- dash/dash.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dash/dash.py b/dash/dash.py index 1418483187..9fa9f1e8e6 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -81,7 +81,7 @@ _import_layouts_from_pages, ) from ._jupyter import jupyter_dash, JupyterDisplayMode -from .types import RendererHooks +from .types import CallbackDispatchBody, RendererHooks RouteCallable = Callable[..., Any] @@ -1472,7 +1472,7 @@ def callback(self, *_args, **_kwargs) -> Callable[..., Any]: ) # pylint: disable=R0915 - def _initialize_context(self, body): + def _initialize_context(self, body: CallbackDispatchBody): """Initialize the global context for the request.""" g = AttributeDict({}) g.inputs_list = body.get("inputs", []) @@ -1493,7 +1493,7 @@ def _initialize_context(self, body): g.updated_props = {} return g - def _prepare_callback(self, g, body): + def _prepare_callback(self, g, body: CallbackDispatchBody): """Prepare callback-related data.""" output = body["output"] try: From 200240c759a59406d4d9433f6771bbb7cda31251 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 16:28:56 -0600 Subject: [PATCH 06/27] Use python3.8 compatible pydantic --- requirements/install.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/install.txt b/requirements/install.txt index 5b425cf5a9..89bd8a5595 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -7,4 +7,4 @@ requests retrying nest-asyncio setuptools -pydantic>=2.12.5 +pydantic>=2.10 From a01a01622c77cd8c9ed8a768449cd88f072ea7f4 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 16:43:50 -0600 Subject: [PATCH 07/27] lint --- dash/development/base_component.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dash/development/base_component.py b/dash/development/base_component.py index 5382c5aafc..32357c94ea 100644 --- a/dash/development/base_component.py +++ b/dash/development/base_component.py @@ -120,6 +120,7 @@ class Component(metaclass=ComponentMeta): @classmethod def __get_pydantic_core_schema__(cls, source_type, handler): from pydantic_core import core_schema + return core_schema.any_schema() @classmethod @@ -129,7 +130,9 @@ def __get_pydantic_json_schema__(cls, schema, handler): "type": "object", "properties": { "type": {"type": "string"}, - "namespace": {"type": "string", "enum": namespaces} if namespaces else {"type": "string"}, + "namespace": {"type": "string", "enum": namespaces} + if namespaces + else {"type": "string"}, "props": {"type": "object"}, }, } From 26fc936fd9c9e34c8bd22f1b4f890426593dd7ad Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 17:03:03 -0600 Subject: [PATCH 08/27] Fix lint error on CI --- dash/development/base_component.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/dash/development/base_component.py b/dash/development/base_component.py index 32357c94ea..a7aec5a3f0 100644 --- a/dash/development/base_component.py +++ b/dash/development/base_component.py @@ -118,13 +118,13 @@ class Component(metaclass=ComponentMeta): available_wildcard_properties: typing.List[str] @classmethod - def __get_pydantic_core_schema__(cls, source_type, handler): - from pydantic_core import core_schema + def __get_pydantic_core_schema__(cls, _source_type, _handler): + from pydantic_core import core_schema # pylint: disable=import-outside-toplevel return core_schema.any_schema() @classmethod - def __get_pydantic_json_schema__(cls, schema, handler): + def __get_pydantic_json_schema__(cls, _schema, _handler): namespaces = list(ComponentRegistry.namespace_to_package.keys()) return { "type": "object", From 98ea9eea46fa50e4554a73b5b71d4e8f5e77bbd9 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 12:00:13 -0600 Subject: [PATCH 09/27] Rename layout.py to _layout_utils.py --- dash/{layout.py => _layout_utils.py} | 0 tests/unit/test_layout.py | 4 ++-- 2 files changed, 2 insertions(+), 2 deletions(-) rename dash/{layout.py => _layout_utils.py} (100%) diff --git a/dash/layout.py b/dash/_layout_utils.py similarity index 100% rename from dash/layout.py rename to dash/_layout_utils.py diff --git a/tests/unit/test_layout.py b/tests/unit/test_layout.py index 76a72f7fb4..64fff724a1 100644 --- a/tests/unit/test_layout.py +++ b/tests/unit/test_layout.py @@ -1,9 +1,9 @@ -"""Tests for dash.layout — layout traversal and component lookup utilities.""" +"""Tests for dash._layout_utils — layout traversal and component lookup utilities.""" import pytest from dash import html, dcc -from dash.layout import ( +from dash._layout_utils import ( traverse, find_component, extract_text, From f3a14f90ed695dcfe29ec343aaacbbed7680b3e4 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 13:41:38 -0600 Subject: [PATCH 10/27] Make NumberType import backwards-compatible in generated components --- dash/development/_py_components_generation.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index b597283a04..2fd6a6cdb7 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -24,7 +24,12 @@ import typing # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args -from dash.types import NumberType # noqa: F401 +try: + from dash.types import NumberType # noqa: F401 +except ImportError: + NumberType = typing.Union[ # noqa: F401 + typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex + ] {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ From d2d2be9cf63c7f606e78513428cc493a319d5406 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 14:34:19 -0600 Subject: [PATCH 11/27] Fix type checker for NumberType implementation --- dash/development/_py_components_generation.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index 2fd6a6cdb7..1690a1744e 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -22,14 +22,18 @@ import_string = """# AUTO GENERATED FILE - DO NOT EDIT import typing # noqa: F401 +from typing import TYPE_CHECKING # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args -try: +if TYPE_CHECKING: from dash.types import NumberType # noqa: F401 -except ImportError: - NumberType = typing.Union[ # noqa: F401 - typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex - ] +else: + try: + from dash.types import NumberType # noqa: F401 + except ImportError: + NumberType = typing.Union[ # noqa: F401 + typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex + ] {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ From 3f9fbcac6211dfdb68e4833061f29762d6697dc8 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Tue, 7 Apr 2026 10:30:13 -0600 Subject: [PATCH 12/27] Clean up import for pyright --- dash/development/_py_components_generation.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/dash/development/_py_components_generation.py b/dash/development/_py_components_generation.py index 1690a1744e..87211c7cc8 100644 --- a/dash/development/_py_components_generation.py +++ b/dash/development/_py_components_generation.py @@ -22,18 +22,17 @@ import_string = """# AUTO GENERATED FILE - DO NOT EDIT import typing # noqa: F401 -from typing import TYPE_CHECKING # noqa: F401 from typing_extensions import TypedDict, NotRequired, Literal # noqa: F401 from dash.development.base_component import Component, _explicitize_args -if TYPE_CHECKING: +try: from dash.types import NumberType # noqa: F401 -else: - try: - from dash.types import NumberType # noqa: F401 - except ImportError: - NumberType = typing.Union[ # noqa: F401 - typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex - ] +except ImportError: + # Backwards compatibility for dash<=4.1.0 + if typing.TYPE_CHECKING: + raise + NumberType = typing.Union[ # noqa: F401 + typing.SupportsFloat, typing.SupportsInt, typing.SupportsComplex + ] {custom_imports} ComponentSingleType = typing.Union[str, int, float, Component, None] ComponentType = typing.Union[ From 34e834391ec90072ed977db131d1c0c58cdcae40 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 10:37:02 -0600 Subject: [PATCH 13/27] Add callback adapter core for MCP tool generation --- dash/_callback.py | 12 +- dash/mcp/primitives/tools/callback_adapter.py | 461 ++++++++++++++++++ .../tools/callback_adapter_collection.py | 154 ++++++ dash/mcp/primitives/tools/callback_utils.py | 36 ++ .../primitives/tools/descriptions/__init__.py | 7 + .../tools/input_schemas/__init__.py | 5 + .../tools/output_schemas/__init__.py | 5 + dash/mcp/types/__init__.py | 26 + dash/mcp/types/callback_types.py | 33 ++ dash/mcp/types/component_types.py | 20 + dash/mcp/types/exceptions.py | 30 ++ dash/mcp/types/typing_utils.py | 28 ++ requirements/install.txt | 1 + tests/unit/mcp/conftest.py | 6 + tests/unit/mcp/tools/test_callback_adapter.py | 227 +++++++++ .../tools/test_callback_adapter_collection.py | 145 ++++++ 16 files changed, 1193 insertions(+), 3 deletions(-) create mode 100644 dash/mcp/primitives/tools/callback_adapter.py create mode 100644 dash/mcp/primitives/tools/callback_adapter_collection.py create mode 100644 dash/mcp/primitives/tools/callback_utils.py create mode 100644 dash/mcp/primitives/tools/descriptions/__init__.py create mode 100644 dash/mcp/primitives/tools/input_schemas/__init__.py create mode 100644 dash/mcp/primitives/tools/output_schemas/__init__.py create mode 100644 dash/mcp/types/__init__.py create mode 100644 dash/mcp/types/callback_types.py create mode 100644 dash/mcp/types/component_types.py create mode 100644 dash/mcp/types/exceptions.py create mode 100644 dash/mcp/types/typing_utils.py create mode 100644 tests/unit/mcp/conftest.py create mode 100644 tests/unit/mcp/tools/test_callback_adapter.py create mode 100644 tests/unit/mcp/tools/test_callback_adapter_collection.py diff --git a/dash/_callback.py b/dash/_callback.py index 3785df7166..dc73dd0792 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -41,6 +41,7 @@ from . import _validate from .background_callback.managers import BaseBackgroundCallbackManager from ._callback_context import context_value +from .types import CallbackDispatchResponse from ._no_update import NoUpdate @@ -80,6 +81,7 @@ def callback( api_endpoint: Optional[str] = None, optional: Optional[bool] = False, hidden: Optional[bool] = None, + mcp_enabled: bool = True, **_kwargs, ) -> Callable[..., Any]: """ @@ -231,6 +233,7 @@ def callback( api_endpoint=api_endpoint, optional=optional, hidden=hidden, + mcp_enabled=mcp_enabled, ) @@ -278,6 +281,7 @@ def insert_callback( no_output=False, optional=False, hidden=None, + mcp_enabled=True, ): if prevent_initial_call is None: prevent_initial_call = config_prevent_initial_callbacks @@ -318,6 +322,7 @@ def insert_callback( "manager": manager, "allow_dynamic_callbacks": dynamic_creator, "no_output": no_output, + "mcp_enabled": mcp_enabled, } callback_list.append(callback_spec) @@ -523,7 +528,7 @@ def _prepare_response( output_value, output_spec, multi, - response, + response: CallbackDispatchResponse, callback_ctx, app, original_packages, @@ -652,6 +657,7 @@ def register_callback( no_output=not has_output, optional=_kwargs.get("optional", False), hidden=_kwargs.get("hidden", None), + mcp_enabled=_kwargs.get("mcp_enabled", True), ) # pylint: disable=too-many-locals @@ -686,7 +692,7 @@ def add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response: dict = {"multi": True} # type: ignore + response: CallbackDispatchResponse = {"multi": True} jsonResponse = None try: @@ -758,7 +764,7 @@ async def async_add_context(*args, **kwargs): args, kwargs, inputs_state_indices, has_output, insert_output ) - response = {"multi": True} + response: CallbackDispatchResponse = {"multi": True} try: if background is not None: diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py new file mode 100644 index 0000000000..0f50d15c03 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -0,0 +1,461 @@ +"""Adapter: Dash callback → MCP tool interface. + +Wraps a raw ``callback_map`` entry and exposes MCP-facing +properties (tool name, params, outputs) lazily. +""" + +from __future__ import annotations + +import inspect +import json +import typing +from functools import cached_property +from typing import Any + +from mcp.types import Tool + +from dash import get_app +from dash.layout import ( + _WILDCARD_VALUES, + find_component, + find_matching_components, + parse_wildcard_id, +) +from dash.mcp.types import is_nullable +from dash._grouping import flatten_grouping +from dash._utils import clean_property_name, split_callback_id +from dash.mcp.types import MCPInput, MCPOutput +from .callback_utils import run_callback +from .descriptions import build_tool_description +from .input_schemas import get_input_schema +from .output_schemas import get_output_schema + + +class CallbackAdapter: + """Adapts a single Dash callback_map entry to the MCP tool interface.""" + + def __init__(self, callback_output_id: str): + self._output_id = callback_output_id + + # ------------------------------------------------------------------- + # Projections + # ------------------------------------------------------------------- + + @cached_property + def as_mcp_tool(self) -> Tool: + """Stub — will be implemented in a future PR.""" + raise NotImplementedError("as_mcp_tool will be implemented in a future PR.") + + def as_callback_body(self, kwargs: dict[str, Any]) -> dict[str, Any]: + """Transforms the given kwargs to a dict suitable for calling this callback. + + Mirrors how the Dash renderer assembles the callback payload — + see ``fillVals()`` in ``dash-renderer/src/actions/callbacks.ts``. + + For pattern-matching callbacks, wildcard deps are expanded into + nested arrays with concrete component IDs. + """ + coerced = {k: _coerce_value(v) for k, v in kwargs.items()} + + raw_inputs = self._cb_info.get("inputs", []) + raw_state = self._cb_info.get("state", []) + n_deps = len(raw_inputs) + len(raw_state) + + flat_values = [None] * n_deps + for i, name in enumerate(self._param_names): + if i < n_deps and name in coerced: + flat_values[i] = coerced[name] + + inputs_with_values = [ + _expand_dep(dep, flat_values[i]) for i, dep in enumerate(raw_inputs) + ] + state_with_values = [ + _expand_dep(dep, flat_values[len(raw_inputs) + i]) + for i, dep in enumerate(raw_state) + ] + + outputs_spec = _expand_output_spec( + self._output_id, self._cb_info, inputs_with_values + ) + + # changedPropIds: only inputs with non-None values. + # This determines ctx.triggered_id in the callback. + changed = [] + for entry in inputs_with_values: + if isinstance(entry, dict) and entry.get("value") is not None: + eid = entry.get("id") + if isinstance(eid, dict): + changed.append( + f"{json.dumps(eid, sort_keys=True)}.{entry['property']}" + ) + elif isinstance(eid, str): + changed.append(f"{eid}.{entry['property']}") + + return { + "output": self._output_id, + "outputs": outputs_spec, + "inputs": inputs_with_values, + "state": state_with_values, + "changedPropIds": changed, + } + + # ------------------------------------------------------------------- + # Public identity and metadata + # ------------------------------------------------------------------- + + @cached_property + def is_valid(self) -> bool: + """Whether all input components exist in the layout.""" + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + for dep in all_deps: + dep_id = str(dep.get("id", "")) + if dep_id.startswith("{"): + continue + if find_component(dep_id) is None: + return False + return True + + @property + def output_id(self) -> str: + return self._output_id + + @property + def tool_name(self) -> str: + return get_app().mcp_callback_map._tool_names_map[self._output_id] + + @cached_property + def prevents_initial_call(self) -> bool: + for cb in get_app()._callback_list: + if cb["output"] == self._output_id: + return cb.get("prevent_initial_call", False) + return False + + # ------------------------------------------------------------------- + # Private: computed fields for the MCP Tool + # ------------------------------------------------------------------- + + @cached_property + def _description(self) -> str: + return build_tool_description(self.outputs, self._docstring) + + @cached_property + def _input_schema(self) -> dict[str, Any]: + properties = {p["name"]: get_input_schema(p) for p in self.inputs} + required = [p["name"] for p in self.inputs if p["required"]] + + input_schema: dict[str, Any] = {"type": "object", "properties": properties} + if required: + input_schema["required"] = required + return input_schema + + @cached_property + def _output_schema(self) -> dict[str, Any]: + return get_output_schema() + + # ------------------------------------------------------------------- + # Private: callback metadata + # ------------------------------------------------------------------- + + @cached_property + def _docstring(self) -> str | None: + return getattr(self._original_func, "__doc__", None) + + @cached_property + def _initial_output(self) -> dict[str, dict[str, Any]]: + """Run this callback with initial input values. + + Returns the ``response`` portion of the dispatch result: + ``{component_id: {property: value}}``. + + Skipped for callbacks with ``prevent_initial_call=True``, + matching how the Dash renderer skips them on page load. + """ + if self.prevents_initial_call: + return {} + + callback_map = get_app().mcp_callback_map + kwargs = {} + for p in self.inputs: + upstream = callback_map.find_by_output(p["id_and_prop"]) + if upstream is self: + kwargs[p["name"]] = getattr( + find_component(p["component_id"]), p["property"], None + ) + else: + kwargs[p["name"]] = callback_map.get_initial_value(p["id_and_prop"]) + try: + result = run_callback(self, kwargs) + return result.get("response", {}) + except Exception: + return {} + + def initial_output_value(self, id_and_prop: str) -> Any: + """Return the initial value for a specific output ``"component_id.property"``.""" + component_id, prop = id_and_prop.rsplit(".", 1) + return self._initial_output.get(component_id, {}).get(prop) + + @cached_property + def outputs(self) -> list[MCPOutput]: + if self._cb_info.get("no_output"): + return [] + parsed = split_callback_id(self._output_id) + if isinstance(parsed, dict): + parsed = [parsed] + result: list[MCPOutput] = [] + for p in parsed: + comp_id = p["id"] + prop = clean_property_name(p["property"]) + id_and_prop = f"{comp_id}.{prop}" + comp = find_component(comp_id) + result.append( + { + "id_and_prop": id_and_prop, + "component_id": comp_id, + "property": prop, + "component_type": getattr(comp, "_type", None), + "initial_value": self.initial_output_value(id_and_prop), + "tool_name": self.tool_name, + } + ) + return result + + @cached_property + def inputs(self) -> list[MCPInput]: + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + callback_map = get_app().mcp_callback_map + + result: list[MCPInput] = [] + for dep, name, annotation in zip( + all_deps, self._param_names, self._param_annotations + ): + comp_id = str(dep.get("id", "unknown")) + comp = find_component(comp_id) + prop = dep.get("property", "unknown") + id_and_prop = f"{comp_id}.{prop}" + + upstream_cb = callback_map.find_by_output(id_and_prop) + upstream_output = None + if upstream_cb is not None and upstream_cb is not self: + if not upstream_cb.prevents_initial_call: + for out in upstream_cb.outputs: + if out["id_and_prop"] == id_and_prop: + upstream_output = out + break + + initial_value = ( + upstream_output["initial_value"] + if upstream_output is not None + else getattr(comp, prop, None) + ) + + if annotation is not None: + required = not is_nullable(annotation) + else: + required = initial_value is not None + + result.append( + { + "name": name, + "id_and_prop": id_and_prop, + "component_id": comp_id, + "property": prop, + "annotation": annotation, + "component_type": getattr(comp, "_type", None), + "component": comp, + "required": required, + "initial_value": initial_value, + "upstream_output": upstream_output, + } + ) + return result + + # ------------------------------------------------------------------- + # Helpers + # ------------------------------------------------------------------- + + @cached_property + def _cb_info(self) -> dict[str, Any]: + return get_app().callback_map[self._output_id] + + @cached_property + def _original_func(self) -> Any | None: + func = self._cb_info.get("callback") + return getattr(func, "__wrapped__", func) + + @cached_property + def _func_signature(self) -> inspect.Signature | None: + if self._original_func is None: + return None + try: + return inspect.signature(self._original_func) + except (ValueError, TypeError): + return None + + @cached_property + def _dep_param_map(self) -> list[tuple[str, str]]: + """(func_param_name, mcp_param_name) per dep, in dep order. + + Single source of truth for mapping deps to param names. + All dict-vs-list branching is confined here. + """ + all_deps = self._cb_info.get("inputs", []) + self._cb_info.get("state", []) + n_deps = len(all_deps) + indices = self._cb_info.get("inputs_state_indices") + + if isinstance(indices, dict): + entries: list[tuple[int, str, str]] = [] + for func_name, idx in indices.items(): + positions = flatten_grouping(idx) + if len(positions) == 1: + entries.append((positions[0], func_name, func_name)) + else: + for pos in positions: + dep = all_deps[pos] if pos < n_deps else {} + comp_id = str(dep.get("id", "unknown")).replace("-", "_") + prop = dep.get("property", "unknown") + entries.append( + (pos, func_name, f"{func_name}_{comp_id}__{prop}") + ) + entries.sort(key=lambda e: e[0]) + result = [(f, m) for _, f, m in entries] + elif self._func_signature is not None: + names = list(self._func_signature.parameters.keys()) + result = [(n, n) for n in names] + else: + result = [] + + while len(result) < n_deps: + fallback = f"param_{len(result)}" + result.append((fallback, fallback)) + return result + + @cached_property + def _param_names(self) -> list[str]: + """MCP param name per dep, in dep order.""" + return [mcp for _, mcp in self._dep_param_map] + + @cached_property + def _param_annotations(self) -> list[Any | None]: + """One annotation per dep, in dep order.""" + if self._func_signature is None: + return [None] * len(self._dep_param_map) + try: + hints = typing.get_type_hints(self._original_func) + except Exception: + hints = getattr(self._original_func, "__annotations__", {}) + return [hints.get(func_name) for func_name, _ in self._dep_param_map] + + +def _expand_dep(dep: dict, value: Any) -> Any: + """Expand a dependency into the dispatch format. + + For regular deps, returns ``{id, property, value}``. + For ALL/ALLSMALLER: passes through the list of ``{id, property, value}`` dicts. + For MATCH: passes through the single ``{id, property, value}`` dict. + """ + pattern = parse_wildcard_id(dep.get("id", "")) + if pattern is None: + return {**dep, "value": value} + + # LLM provides browser-like format + if isinstance(value, list): + return value + if isinstance(value, dict) and "id" in value: + return value + return {**dep, "value": value} + + +def _expand_output_spec(output_id: str, cb_info: dict, resolved_inputs: list) -> Any: + """Build the outputs spec, expanding wildcards to concrete IDs. + + For wildcard outputs, derives concrete IDs from the resolved inputs. + The browser does the same: input and output wildcards resolve against + the same set of matching components. + """ + if cb_info.get("no_output"): + return [] + + parsed = split_callback_id(output_id) + if isinstance(parsed, dict): + parsed = [parsed] + + results = [] + for p in parsed: + pid = p["id"] + prop = clean_property_name(p["property"]) + pattern = parse_wildcard_id(pid) + if pattern is not None: + concrete_ids = _derive_output_ids(pattern, resolved_inputs) + if not concrete_ids: + concrete_ids = [comp.id for comp in find_matching_components(pattern)] + expanded = [{"id": cid, "property": prop} for cid in concrete_ids] + # ALL/ALLSMALLER → nested list; MATCH → single dict + if len(expanded) == 1: + results.append(expanded[0]) + else: + results.append(expanded) + else: + results.append({"id": pid, "property": prop}) + + if len(results) == 1: + return results[0] + return results + + +def _derive_output_ids( + output_pattern: dict, resolved_inputs: list +) -> list[dict] | None: + """Derive concrete output IDs from the resolved input entries. + + Extracts the wildcard key values from the LLM-provided concrete + input IDs and substitutes them into the output pattern. + """ + wildcard_keys = [ + k + for k, v in output_pattern.items() + if isinstance(v, list) and len(v) == 1 and v[0] in _WILDCARD_VALUES + ] + if not wildcard_keys: + return None + + def _substitute(item_id: dict) -> dict | None: + if not isinstance(item_id, dict): + return None + output_id = dict(output_pattern) + for wk in wildcard_keys: + if wk in item_id: + output_id[wk] = item_id[wk] + return output_id + + for entry in resolved_inputs: + # ALL/ALLSMALLER: nested array of {id, property, value} dicts + if isinstance(entry, list) and entry: + concrete_ids = [] + for item in entry: + out = _substitute(item.get("id")) + if out: + concrete_ids.append(out) + if concrete_ids: + return concrete_ids + # MATCH: single {id, property, value} dict + elif isinstance(entry, dict) and isinstance(entry.get("id"), dict): + out = _substitute(entry["id"]) + if out: + return [out] + + return None + + +def _coerce_value(value: Any) -> Any: + """Parse JSON strings back to Python objects. + + MCP tool parameters arrive as strings. This recovers the + intended type (list, dict, number, bool, null) via json.loads. + Plain strings that aren't valid JSON pass through unchanged. + """ + if not isinstance(value, str): + return value + try: + return json.loads(value) + except (json.JSONDecodeError, ValueError): + return value diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py new file mode 100644 index 0000000000..60e9e2efe5 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -0,0 +1,154 @@ +"""Collection of CallbackAdapters with cross-adapter queries. + +Stored as a singleton on ``app.mcp_callback_map``. +""" + +from __future__ import annotations + +import hashlib +import re +from functools import cached_property +from typing import Any + +from mcp.types import Tool + +from dash import get_app +from dash._utils import clean_property_name, split_callback_id +from dash.layout import extract_text, find_component, traverse +from .callback_adapter import CallbackAdapter + + +class CallbackAdapterCollection: + def __init__(self, app): + callback_map = getattr(app, "callback_map", {}) + + raw: list[tuple[str, dict]] = [] + for output_id, cb_info in callback_map.items(): + if cb_info.get("mcp_enabled") is False: + continue + if "callback" not in cb_info: + continue + raw.append((output_id, cb_info)) + + self._tool_names_map = self._build_tool_names(raw) + self._callbacks = [ + CallbackAdapter(callback_output_id=output_id) + for output_id in self._tool_names_map + ] + # TODO: enable_mcp_server() will replace this with a direct assignment on app + app.mcp_callback_map = self + + @staticmethod + def _sanitize_name(name: str) -> str: + + max_len = 64 + sanitized = re.sub(r"[^a-zA-Z0-9_]", "_", name) + sanitized = re.sub(r"_+", "_", sanitized).strip("_") + if sanitized and sanitized[0].isdigit(): + sanitized = "cb_" + sanitized + full = sanitized or "unnamed_callback" + if len(full) <= max_len: + return full + hash_suffix = hashlib.sha256(full.encode()).hexdigest()[:8] + truncated = sanitized[: max_len - 9].rstrip("_") + return f"{truncated}_{hash_suffix}" + + @classmethod + def _build_tool_names(cls, raw: list[tuple[str, dict]]) -> dict[str, str]: + func_name_counts: dict[str, int] = {} + for _output_id, cb_info in raw: + func = cb_info.get("callback") + original = getattr(func, "__wrapped__", func) + fn = getattr(original, "__name__", "") or "" + func_name_counts[fn] = func_name_counts.get(fn, 0) + 1 + + name_map: dict[str, str] = {} + for output_id, cb_info in raw: + func = cb_info.get("callback") + original = getattr(func, "__wrapped__", func) + fn = getattr(original, "__name__", "") or "" + raw_name = fn if fn and func_name_counts[fn] == 1 else output_id + name_map[output_id] = cls._sanitize_name(raw_name) + return name_map + + def __iter__(self): + return iter(self._callbacks) + + def __len__(self): + return len(self._callbacks) + + def __getitem__(self, index): + return self._callbacks[index] + + def find_by_tool_name(self, name: str) -> CallbackAdapter | None: + for cb in self._callbacks: + if cb.tool_name == name: + return cb + return None + + def find_by_output(self, id_and_prop: str) -> CallbackAdapter | None: + """Find the adapter that outputs to ``id_and_prop`` (``"component_id.property"``).""" + for cb in self._callbacks: + try: + parsed = split_callback_id(cb.output_id) + except ValueError: + continue + if isinstance(parsed, dict): + parsed = [parsed] + for p in parsed: + if f"{p['id']}.{clean_property_name(p['property'])}" == id_and_prop: + return cb + return None + + def get_initial_value(self, id_and_prop: str) -> Any: + """Return the initial value for ``id_and_prop`` (``"component_id.property"``). + + If a callback outputs to this property, runs it (recursively + resolving its inputs). Otherwise returns the layout default. + """ + upstream_cb = self.find_by_output(id_and_prop) + if upstream_cb is not None: + return upstream_cb.initial_output_value(id_and_prop) + else: + component_id, prop = id_and_prop.rsplit(".", 1) + layout_component = find_component(component_id) + return getattr(layout_component, prop, None) + + def as_mcp_tools(self) -> list[Tool]: + """Stub — will be implemented in a future PR.""" + raise NotImplementedError("as_mcp_tools will be implemented in a future PR.") + + @property + def tool_names(self) -> set[str]: + return set(self._tool_names_map.values()) + + @cached_property + def component_label_map(self) -> dict[str, list[str]]: + """Map component ID → list of label texts from html.Label containers + and/or `htmlFor` associations. + """ + layout = get_app().get_layout() + if layout is None: + return {} + + labels: dict[str, list[str]] = {} + for comp, ancestors in traverse(layout): + if getattr(comp, "_type", None) == "Label": + html_for = getattr(comp, "htmlFor", None) + if html_for is not None: + text = extract_text(comp) + if text: + labels.setdefault(str(html_for), []).append(text) + + comp_id = getattr(comp, "id", None) + if comp_id is not None: + for ancestor in reversed(ancestors): + if getattr(ancestor, "_type", None) == "Label": + text = extract_text(ancestor) + if text: + sid = str(comp_id) + if text not in labels.get(sid, []): + labels.setdefault(sid, []).append(text) + break + + return labels diff --git a/dash/mcp/primitives/tools/callback_utils.py b/dash/mcp/primitives/tools/callback_utils.py new file mode 100644 index 0000000000..ec157b6037 --- /dev/null +++ b/dash/mcp/primitives/tools/callback_utils.py @@ -0,0 +1,36 @@ +"""Callback introspection utilities for MCP tools.""" + +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any + +from dash import get_app + +if TYPE_CHECKING: + from .callback_adapter import CallbackAdapter + + +def run_callback(callback: CallbackAdapter, kwargs: dict[str, Any]) -> dict[str, Any]: + """Execute a callback via Dash's dispatch pipeline.""" + from dash.mcp.types import CallbackExecutionError + + body = callback.as_callback_body(kwargs) + + app = get_app() + with app.server.test_request_context( + "/_dash-update-component", + method="POST", + data=json.dumps(body, default=str), + content_type="application/json", + ): + response = app.dispatch() + + response_text = response.get_data(as_text=True) + if response.status_code != 200: + raise CallbackExecutionError( + f"Callback {callback.output_id} failed " + f"(HTTP {response.status_code}): {response_text[:500]}" + ) + + return json.loads(response_text) diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py new file mode 100644 index 0000000000..67ec78c9ff --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -0,0 +1,7 @@ +"""Stub — real implementation in a later PR.""" + + +def build_tool_description(outputs, docstring=None): + if docstring: + return docstring.strip() + return "Dash callback" diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py new file mode 100644 index 0000000000..f306042a0c --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -0,0 +1,5 @@ +"""Stub — real implementation in a later PR.""" + + +def get_input_schema(param): + return {} diff --git a/dash/mcp/primitives/tools/output_schemas/__init__.py b/dash/mcp/primitives/tools/output_schemas/__init__.py new file mode 100644 index 0000000000..d2d70c3552 --- /dev/null +++ b/dash/mcp/primitives/tools/output_schemas/__init__.py @@ -0,0 +1,5 @@ +"""Stub — real implementation in a later PR.""" + + +def get_output_schema(): + return {} diff --git a/dash/mcp/types/__init__.py b/dash/mcp/types/__init__.py new file mode 100644 index 0000000000..af588e0808 --- /dev/null +++ b/dash/mcp/types/__init__.py @@ -0,0 +1,26 @@ +"""MCP types, exceptions, and typing utilities.""" + +from dash.mcp.types.callback_types import MCPInput, MCPOutput +from dash.mcp.types.component_types import ( + ComponentPropertyInfo, + ComponentQueryResult, +) +from dash.mcp.types.exceptions import ( + CallbackExecutionError, + InvalidParamsError, + MCPError, + ToolNotFoundError, +) +from dash.mcp.types.typing_utils import is_nullable + +__all__ = [ + "CallbackExecutionError", + "ComponentPropertyInfo", + "ComponentQueryResult", + "InvalidParamsError", + "MCPError", + "MCPInput", + "MCPOutput", + "ToolNotFoundError", + "is_nullable", +] diff --git a/dash/mcp/types/callback_types.py b/dash/mcp/types/callback_types.py new file mode 100644 index 0000000000..9c65dcb9d8 --- /dev/null +++ b/dash/mcp/types/callback_types.py @@ -0,0 +1,33 @@ +"""Typed dicts for MCP callback adapter data.""" + +from __future__ import annotations + +from typing import Any + +from typing_extensions import TypedDict + + +class MCPOutput(TypedDict): + """A single callback output, with component type and initial value resolved.""" + + id_and_prop: str + component_id: str + property: str + component_type: str | None + initial_value: Any + tool_name: str + + +class MCPInput(TypedDict): + """A single callback parameter (input or state), fully resolved.""" + + name: str + id_and_prop: str + component_id: str + property: str + annotation: Any | None + component_type: str | None + component: Any | None + required: bool + initial_value: Any + upstream_output: MCPOutput | None diff --git a/dash/mcp/types/component_types.py b/dash/mcp/types/component_types.py new file mode 100644 index 0000000000..0cac3ad689 --- /dev/null +++ b/dash/mcp/types/component_types.py @@ -0,0 +1,20 @@ +"""Typed dicts for component data in MCP.""" + +from __future__ import annotations + +from typing import Any + +from typing_extensions import NotRequired, TypedDict + + +class ComponentPropertyInfo(TypedDict): + initial_value: Any + modified_by_tool: list[str] + input_to_tool: list[str] + + +class ComponentQueryResult(TypedDict): + component_id: str + component_type: str + label: NotRequired[list[str] | None] + properties: dict[str, ComponentPropertyInfo] diff --git a/dash/mcp/types/exceptions.py b/dash/mcp/types/exceptions.py new file mode 100644 index 0000000000..7fb962db85 --- /dev/null +++ b/dash/mcp/types/exceptions.py @@ -0,0 +1,30 @@ +"""MCP error types with JSON-RPC error codes.""" + +from __future__ import annotations + + +class MCPError(Exception): + """Base MCP error carrying a JSON-RPC error code.""" + + code = -32603 + + def __init__(self, message: str): + super().__init__(message) + + +class ToolNotFoundError(MCPError): + """Tool name not found in the callback registry.""" + + code = -32601 + + +class InvalidParamsError(MCPError): + """Invalid or missing parameters for a tool call.""" + + code = -32602 + + +class CallbackExecutionError(MCPError): + """Callback raised an exception during execution.""" + + code = -32603 diff --git a/dash/mcp/types/typing_utils.py b/dash/mcp/types/typing_utils.py new file mode 100644 index 0000000000..9a96d4135d --- /dev/null +++ b/dash/mcp/types/typing_utils.py @@ -0,0 +1,28 @@ +"""Shared typing utilities for the MCP layer.""" + +from __future__ import annotations + +import typing +from typing import Any + + +def is_nullable(annotation: Any) -> bool: + """Check if a type annotation includes NoneType (is nullable/Optional).""" + origin = getattr(annotation, "__origin__", None) + args = getattr(annotation, "__args__", ()) + + _is_union = origin is typing.Union + if not _is_union: + try: + import types as _types + + if isinstance(annotation, _types.UnionType): + _is_union = True + args = annotation.__args__ + except AttributeError: + pass + + if _is_union and args: + return type(None) in args + + return False diff --git a/requirements/install.txt b/requirements/install.txt index 89bd8a5595..b813a6ce55 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -8,3 +8,4 @@ retrying nest-asyncio setuptools pydantic>=2.10 +mcp>=1.0.0; python_version>="3.10" diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py new file mode 100644 index 0000000000..437a71db5c --- /dev/null +++ b/tests/unit/mcp/conftest.py @@ -0,0 +1,6 @@ +import sys + +collect_ignore_glob = [] + +if sys.version_info < (3, 10): + collect_ignore_glob.append("*") diff --git a/tests/unit/mcp/tools/test_callback_adapter.py b/tests/unit/mcp/tools/test_callback_adapter.py new file mode 100644 index 0000000000..91808d304e --- /dev/null +++ b/tests/unit/mcp/tools/test_callback_adapter.py @@ -0,0 +1,227 @@ +"""Tests for CallbackAdapter.""" + +import pytest +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def simple_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp", type="text"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + """Update output.""" + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def duplicate_names_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): # noqa: F811 + return v + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestFromApp: + def test_returns_list(self, simple_app): + assert len(app_context.get().mcp_callback_map) == 1 + + def test_excludes_clientside(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="cs-out"), + html.Div(id="srv-out"), + ] + ) + app.clientside_callback( + "function(n) { return n; }", + Output("cs-out", "children"), + Input("btn", "n_clicks"), + ) + + @app.callback(Output("srv-out", "children"), Input("btn", "n_clicks")) + def server_cb(n): + return str(n) + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + names = [a.tool_name for a in app.mcp_callback_map] + assert names == ["server_cb"] + + def test_excludes_mcp_disabled(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("inp", "value")) + def visible(val): + return val + + @app.callback( + Output("out2", "children"), Input("inp", "value"), mcp_enabled=False + ) + def hidden(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + names = [a.tool_name for a in app.mcp_callback_map] + assert "visible" in names + assert "hidden" not in names + + +class TestToolName: + def test_uses_func_name(self, simple_app): + assert app_context.get().mcp_callback_map[0].tool_name == "update" + + def test_duplicates_get_unique_names(self, duplicate_names_app): + names = [a.tool_name for a in app_context.get().mcp_callback_map] + assert len(names) == 2 + assert names[0] != names[1] + + +class TestGetInitialValue: + def test_returns_layout_value(self, simple_app): + callback_map = app_context.get().mcp_callback_map + # Input with no value set — returns None (layout default for dcc.Input) + assert callback_map.get_initial_value("inp.value") is None + + def test_returns_set_value(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(selected): + return selected + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map.get_initial_value("dd.value") == "a" + + def test_initial_callback_makes_param_required(self): + """A param with None in layout but set by an initial callback is required.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="country", options=["France", "Germany"], value="France" + ), + dcc.Dropdown(id="city"), # value=None in layout + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_cities(country): + return [{"label": "Paris", "value": "Paris"}], "Paris" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"Selected: {city}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + # city.value is None in layout but "Paris" after initial callback + with app.server.test_request_context(): + show_city_cb = app.mcp_callback_map.find_by_tool_name("show_city") + city_param = show_city_cb.inputs[0] + assert city_param["name"] == "city" + assert city_param["required"] is True # not optional despite None in layout + + +class TestIsValid: + def test_valid_when_inputs_in_layout(self, simple_app): + assert app_context.get().mcp_callback_map[0].is_valid + + def test_invalid_when_input_not_in_layout(self): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("nonexistent", "value")) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert not app.mcp_callback_map[0].is_valid + + def test_pattern_matching_ids_always_valid(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + ) + def update(val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + assert app.mcp_callback_map[0].is_valid diff --git a/tests/unit/mcp/tools/test_callback_adapter_collection.py b/tests/unit/mcp/tools/test_callback_adapter_collection.py new file mode 100644 index 0000000000..c120a2df8b --- /dev/null +++ b/tests/unit/mcp/tools/test_callback_adapter_collection.py @@ -0,0 +1,145 @@ +"""Tests for CallbackAdapterCollection.""" + +from dash import Dash, Input, Output, dcc, html +from dash._get_app import app_context + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +def _setup(app): + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + +class TestToolNameCollisions: + @staticmethod + def _make_duplicate_cb_app(n=3): + ids = [f"dd{i + 1}" for i in range(n)] + app = Dash(__name__) + app.layout = html.Div( + [ + item + for i in ids + for item in [ + dcc.Dropdown( + id=i, options=[chr(97 + j) for j in range(1)], value="a" + ), + html.Div(id=f"{i}-output"), + ] + ] + ) + for idx, dd_id in enumerate(ids): + + @app.callback(Output(f"{dd_id}-output", "children"), Input(dd_id, "value")) + def cb(value, _id=dd_id): # noqa: F811 + return f"{_id}: {value}" + + return app + + def test_duplicate_func_names_get_unique_tools(self): + app = self._make_duplicate_cb_app(3) + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert len(tool_names) == 3 + assert len(set(tool_names)) == 3, f"Tool names are not unique: {tool_names}" + for name in tool_names: + assert "dd" in name, f"Expected output ID in tool name: {name}" + + def test_unique_func_names_use_func_name(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def alpha_handler(value): + return value + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def beta_handler(value): + return value + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "alpha_handler" in tool_names + assert "beta_handler" in tool_names + + def test_duplicate_func_names_use_output_id(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="out1"), + html.Div(id="out2"), + html.Div(id="out3"), + html.Div(id="in1"), + html.Div(id="in2"), + html.Div(id="in3"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def unique_func(v): + return v + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb(v): + return v + + @app.callback(Output("out3", "children"), Input("in3", "children")) + def cb(v): # noqa: F811 + return v + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "unique_func" in tool_names + non_unique = [n for n in tool_names if n != "unique_func"] + assert len(non_unique) == 2 + assert non_unique[0] != non_unique[1] + + +class TestAllCallbacksVisibleByDefault: + def test_all_callbacks_visible_by_default(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="in1"), + html.Div(id="out1"), + html.Div(id="in2"), + html.Div(id="out2"), + ] + ) + + @app.callback(Output("out1", "children"), Input("in1", "children")) + def cb_one(value): + return value + + @app.callback(Output("out2", "children"), Input("in2", "children")) + def cb_two(value): + return value + + _setup(app) + tool_names = [a.tool_name for a in app.mcp_callback_map] + assert "cb_one" in tool_names + assert "cb_two" in tool_names + + +class TestAdapterCollection: + def test_adapter_has_expected_properties(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + return val + + _setup(app) + adapter = app.mcp_callback_map[0] + assert adapter.tool_name == "update" + assert adapter.output_id == "out.children" From 33fb6df0fc0ca026fedde1f1be3ac9980a20a50d Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 16:59:57 -0600 Subject: [PATCH 14/27] Fix type errors --- dash/_callback.py | 2 +- dash/types.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/dash/_callback.py b/dash/_callback.py index dc73dd0792..5900bbe0fc 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -540,7 +540,7 @@ def _prepare_response( allow_dynamic_callbacks, ): """Prepare the response object based on the callback output.""" - component_ids = collections.defaultdict(dict) + component_ids: dict = collections.defaultdict(dict) if has_output: if not multi: diff --git a/dash/types.py b/dash/types.py index 43bf16dc30..e392a2d599 100644 --- a/dash/types.py +++ b/dash/types.py @@ -73,3 +73,4 @@ class CallbackDispatchResponse(TypedDict): multi: NotRequired[bool] response: NotRequired[Dict[str, CallbackOutput]] sideUpdate: NotRequired[Dict[str, CallbackSideOutput]] + dist: NotRequired[List[Any]] From 74cd47768ea65f6fdd3092d740fb5eddf10a401e Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 12:11:03 -0600 Subject: [PATCH 15/27] Fix import path --- dash/mcp/primitives/tools/callback_adapter.py | 2 +- dash/mcp/primitives/tools/callback_adapter_collection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 0f50d15c03..743453af10 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -15,7 +15,7 @@ from mcp.types import Tool from dash import get_app -from dash.layout import ( +from dash._layout_utils import ( _WILDCARD_VALUES, find_component, find_matching_components, diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index 60e9e2efe5..59c1a7ac47 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -14,7 +14,7 @@ from dash import get_app from dash._utils import clean_property_name, split_callback_id -from dash.layout import extract_text, find_component, traverse +from dash._layout_utils import extract_text, find_component, traverse from .callback_adapter import CallbackAdapter From dfabf376599046dc5d9aecd8cc97bfd8d4c8b7eb Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 12:46:29 -0600 Subject: [PATCH 16/27] Add MCP resource providers for layout, components, pages, and clientside callbacks --- dash/mcp/primitives/resources/__init__.py | 52 ++++++++++ .../resource_clientside_callbacks.py | 95 +++++++++++++++++++ .../resources/resource_components.py | 59 ++++++++++++ .../primitives/resources/resource_layout.py | 44 +++++++++ .../resources/resource_page_layout.py | 77 +++++++++++++++ .../primitives/resources/resource_pages.py | 76 +++++++++++++++ .../test_resource_clientside_callbacks.py | 54 +++++++++++ .../resources/test_resource_layout.py | 59 ++++++++++++ .../resources/test_resource_page_layout.py | 52 ++++++++++ .../resources/test_resource_pages.py | 78 +++++++++++++++ 10 files changed, 646 insertions(+) create mode 100644 dash/mcp/primitives/resources/__init__.py create mode 100644 dash/mcp/primitives/resources/resource_clientside_callbacks.py create mode 100644 dash/mcp/primitives/resources/resource_components.py create mode 100644 dash/mcp/primitives/resources/resource_layout.py create mode 100644 dash/mcp/primitives/resources/resource_page_layout.py create mode 100644 dash/mcp/primitives/resources/resource_pages.py create mode 100644 tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py create mode 100644 tests/unit/mcp/primitives/resources/test_resource_layout.py create mode 100644 tests/unit/mcp/primitives/resources/test_resource_page_layout.py create mode 100644 tests/unit/mcp/primitives/resources/test_resource_pages.py diff --git a/dash/mcp/primitives/resources/__init__.py b/dash/mcp/primitives/resources/__init__.py new file mode 100644 index 0000000000..da93feae04 --- /dev/null +++ b/dash/mcp/primitives/resources/__init__.py @@ -0,0 +1,52 @@ +"""MCP resource listing and read handling. + +Each resource module exports: +- ``URI`` — the URI prefix this module handles +- ``get_resource() -> Resource | None`` +- ``get_template() -> ResourceTemplate | None`` +- ``read_resource(uri) -> ReadResourceResult`` + +Dispatch is by prefix match: more specific prefixes must come first. +""" + +from __future__ import annotations + +from mcp.types import ( + ListResourcesResult, + ListResourceTemplatesResult, + ReadResourceResult, +) + +from . import ( + resource_clientside_callbacks as _clientside, + resource_components as _components, + resource_layout as _layout, + resource_page_layout as _page_layout, + resource_pages as _pages, +) + +_RESOURCE_MODULES = [_layout, _components, _pages, _clientside, _page_layout] + + +def list_resources() -> ListResourcesResult: + """Build the MCP resources/list response.""" + resources = [ + r for mod in _RESOURCE_MODULES for r in [mod.get_resource()] if r is not None + ] + return ListResourcesResult(resources=resources) + + +def list_resource_templates() -> ListResourceTemplatesResult: + """Build the MCP resources/templates/list response.""" + templates = [ + t for mod in _RESOURCE_MODULES for t in [mod.get_template()] if t is not None + ] + return ListResourceTemplatesResult(resourceTemplates=templates) + + +def read_resource(uri: str) -> ReadResourceResult: + """Dispatch a resources/read request by URI prefix match.""" + for mod in _RESOURCE_MODULES: + if uri.startswith(mod.URI): + return mod.read_resource(uri) + raise ValueError(f"Unknown resource URI: {uri}") diff --git a/dash/mcp/primitives/resources/resource_clientside_callbacks.py b/dash/mcp/primitives/resources/resource_clientside_callbacks.py new file mode 100644 index 0000000000..dbc3009edb --- /dev/null +++ b/dash/mcp/primitives/resources/resource_clientside_callbacks.py @@ -0,0 +1,95 @@ +"""Clientside callbacks resource.""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +from dash import get_app +from dash._utils import clean_property_name, split_callback_id + +URI = "dash://clientside-callbacks" + + +def get_resource() -> Resource | None: + if not _get_clientside_callbacks(): + return None + return Resource( + uri=URI, + name="dash_clientside_callbacks", + description=( + "Actions the user can take manually in the browser " + "to affect clientside state. Inputs describe the " + "components that can be changed to trigger an effect. " + "Outputs describe the components that will change " + "in response." + ), + mimeType="application/json", + ) + + +def get_template() -> ResourceTemplate | None: + return None + + +def read_resource(uri: str = "") -> ReadResourceResult: + data = { + "description": ( + "These are actions that the user can take manually in the " + "browser to affect the clientside state. Inputs describe " + "the components that can be changed to trigger an effect. " + "Outputs describe the components that will change in " + "response to the effect." + ), + "callbacks": _get_clientside_callbacks(), + } + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=URI, + mimeType="application/json", + text=json.dumps(data, default=str), + ) + ] + ) + + +def _get_clientside_callbacks() -> list[dict[str, Any]]: + """Get input/output mappings for clientside callbacks.""" + app = get_app() + callbacks = [] + callback_map = getattr(app, "callback_map", {}) + + for output_id, callback_info in callback_map.items(): + if "callback" in callback_info: + continue + normalize_deps = lambda deps: [ + { + "component_id": str(d.get("id", "unknown")), + "property": d.get("property", "unknown"), + } + for d in deps + ] + parsed = split_callback_id(output_id) + if isinstance(parsed, dict): + parsed = [parsed] + outputs = [ + {"component_id": p["id"], "property": clean_property_name(p["property"])} + for p in parsed + ] + callbacks.append( + { + "outputs": outputs, + "inputs": normalize_deps(callback_info.get("inputs", [])), + "state": normalize_deps(callback_info.get("state", [])), + } + ) + + return callbacks diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py new file mode 100644 index 0000000000..e6441d7aee --- /dev/null +++ b/dash/mcp/primitives/resources/resource_components.py @@ -0,0 +1,59 @@ +"""Component list resource.""" + +from __future__ import annotations + +import json + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +from dash import get_app +from dash.layout import traverse + +URI = "dash://components" + + +def get_resource() -> Resource | None: + return Resource( + uri=URI, + name="dash_components", + description=( + "All components with IDs in the app layout. " + "Use get_dash_component with any of these IDs " + "to inspect their properties and values. " + "See dash://layout for the tree structure showing " + "how these components are nested in the page." + ), + mimeType="application/json", + ) + + +def get_template() -> ResourceTemplate | None: + return None + + +def read_resource(uri: str = "") -> ReadResourceResult: + app = get_app() + layout = app.get_layout() + components = sorted( + [ + {"id": str(comp.id), "type": getattr(comp, "_type", type(comp).__name__)} + for comp, _ in traverse(layout) + if getattr(comp, "id", None) is not None + ], + key=lambda c: c["id"], + ) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=URI, + mimeType="application/json", + text=json.dumps(components), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_layout.py b/dash/mcp/primitives/resources/resource_layout.py new file mode 100644 index 0000000000..01d0be046d --- /dev/null +++ b/dash/mcp/primitives/resources/resource_layout.py @@ -0,0 +1,44 @@ +"""Layout tree resource for the whole app.""" + +from __future__ import annotations + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +from dash import get_app +from dash._utils import to_json + +URI = "dash://layout" + + +def get_resource() -> Resource | None: + return Resource( + uri=URI, + name="dash_app_layout", + description=( + "Full component tree of the Dash app. " + "See dash://components for a compact list of component IDs." + ), + mimeType="application/json", + ) + + +def get_template() -> ResourceTemplate | None: + return None + + +def read_resource(uri: str = "") -> ReadResourceResult: + app = get_app() + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=URI, + mimeType="application/json", + text=to_json(app.get_layout()), + ) + ] + ) diff --git a/dash/mcp/primitives/resources/resource_page_layout.py b/dash/mcp/primitives/resources/resource_page_layout.py new file mode 100644 index 0000000000..d82d366298 --- /dev/null +++ b/dash/mcp/primitives/resources/resource_page_layout.py @@ -0,0 +1,77 @@ +"""Per-page layout resource template for multi-page apps.""" + +from __future__ import annotations + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +from dash._utils import to_json + +URI = "dash://page-layout/" +_URI_TEMPLATE = "dash://page-layout/{path}" + + +def get_resource() -> Resource | None: + return None + + +def get_template() -> ResourceTemplate | None: + if not _has_pages(): + return None + return ResourceTemplate( + uriTemplate=_URI_TEMPLATE, + name="dash_page_layout", + description="Component tree for a specific page in the app.", + mimeType="application/json", + ) + + +def read_resource(uri: str) -> ReadResourceResult: + path = uri[len(URI) :] + if not path.startswith("/"): + path = "/" + path + + try: + from dash._pages import PAGE_REGISTRY + except ImportError: + raise ValueError("Dash Pages is not available.") + + page_layout = None + for _module, page in PAGE_REGISTRY.items(): + if page.get("path") == path: + page_layout = page.get("layout") + break + + if page_layout is None: + raise ValueError(f"Page not found: {path}") + + if callable(page_layout): + page_layout = page_layout() + + if isinstance(page_layout, (list, tuple)): + from dash import html + + page_layout = html.Div(list(page_layout)) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=uri, + mimeType="application/json", + text=to_json(page_layout), + ) + ] + ) + + +def _has_pages() -> bool: + try: + from dash._pages import PAGE_REGISTRY + + return bool(PAGE_REGISTRY) + except ImportError: + return False diff --git a/dash/mcp/primitives/resources/resource_pages.py b/dash/mcp/primitives/resources/resource_pages.py new file mode 100644 index 0000000000..51a61b9f00 --- /dev/null +++ b/dash/mcp/primitives/resources/resource_pages.py @@ -0,0 +1,76 @@ +"""Pages resource for multi-page apps.""" + +from __future__ import annotations + +import json + +from mcp.types import ( + ReadResourceResult, + Resource, + ResourceTemplate, + TextResourceContents, +) + +URI = "dash://pages" + + +def _has_pages() -> bool: + try: + from dash._pages import PAGE_REGISTRY + + return bool(PAGE_REGISTRY) + except ImportError: + return False + + +def get_resource() -> Resource | None: + if not _has_pages(): + return None + return Resource( + uri=URI, + name="dash_app_pages", + description=( + "List of all pages in this multi-page Dash app " + "with paths, names, titles, and descriptions." + ), + mimeType="application/json", + ) + + +def get_template() -> ResourceTemplate | None: + return None + + +def read_resource(uri: str = "") -> ReadResourceResult: + try: + from dash._pages import PAGE_REGISTRY + except ImportError: + return ReadResourceResult( + contents=[ + TextResourceContents(uri=URI, mimeType="application/json", text="[]") + ] + ) + + pages = [] + for module, page in PAGE_REGISTRY.items(): + title = page.get("title", "") + description = page.get("description", "") + pages.append( + { + "module": module, + "path": page.get("path", ""), + "name": page.get("name", ""), + "title": title if not callable(title) else page.get("name", ""), + "description": description if not callable(description) else "", + } + ) + + return ReadResourceResult( + contents=[ + TextResourceContents( + uri=URI, + mimeType="application/json", + text=json.dumps(pages, default=str), + ) + ] + ) diff --git a/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py b/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py new file mode 100644 index 0000000000..3ba2ce7996 --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_clientside_callbacks.py @@ -0,0 +1,54 @@ +"""Tests for the dash://clientside-callbacks resource.""" + +import json + +from dash import Dash, Input, Output, clientside_callback, html + +from dash.mcp.primitives.resources import list_resources, read_resource + + +class TestClientsideCallbacksResource: + @staticmethod + def _make_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn", children="Click"), + html.Div(id="out"), + html.Div(id="server-out"), + ] + ) + + clientside_callback( + "function(n) { return n; }", + Output("out", "children"), + Input("btn", "n_clicks"), + ) + + @app.callback(Output("server-out", "children"), Input("btn", "n_clicks")) + def server_cb(n): + return str(n) + + with app.server.test_request_context(): + app._setup_server() + + return app + + def test_resource_listed(self): + app = self._make_app() + with app.server.test_request_context(): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://clientside-callbacks" in uris + + def test_resource_read(self): + app = self._make_app() + with app.server.test_request_context(): + result = read_resource("dash://clientside-callbacks") + data = json.loads(result.contents[0].text) + assert "description" in data + callbacks = data["callbacks"] + assert len(callbacks) == 1 + assert callbacks[0]["inputs"][0]["component_id"] == "btn" + assert callbacks[0]["inputs"][0]["property"] == "n_clicks" + assert callbacks[0]["outputs"][0]["component_id"] == "out" diff --git a/tests/unit/mcp/primitives/resources/test_resource_layout.py b/tests/unit/mcp/primitives/resources/test_resource_layout.py new file mode 100644 index 0000000000..ade207b1f3 --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_layout.py @@ -0,0 +1,59 @@ +"""Tests for the dash://layout resource.""" + +import json +from unittest.mock import patch + +from dash import Dash, dcc, html + +from dash.mcp.primitives.resources import list_resources, read_resource + +EXPECTED_LAYOUT = { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": [ + { + "type": "Dropdown", + "namespace": "dash_core_components", + "props": { + "id": "test-dd", + "options": ["a", "b"], + "value": "a", + }, + }, + { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": None, + "id": "output", + }, + }, + ] + }, +} + + +class TestLayoutResource: + def test_listed_in_resources(self): + app = Dash(__name__) + app.layout = html.Div(id="main") + with app.server.test_request_context(): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://layout" in uris + + def test_read_returns_layout(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="test-dd", options=["a", "b"], value="a"), + html.Div(id="output"), + ] + ) + with app.server.test_request_context(): + with patch.object(app, "get_layout", wraps=app.get_layout) as mock: + result = read_resource("dash://layout") + mock.assert_called_once() + layout = json.loads(result.contents[0].text) + assert layout == EXPECTED_LAYOUT diff --git a/tests/unit/mcp/primitives/resources/test_resource_page_layout.py b/tests/unit/mcp/primitives/resources/test_resource_page_layout.py new file mode 100644 index 0000000000..88ffd82118 --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_page_layout.py @@ -0,0 +1,52 @@ +"""Tests for the dash://page-layout/{path} resource template.""" + +import json +from unittest.mock import patch + +from dash import Dash, dcc, html + +from dash.mcp.primitives.resources import read_resource + +EXPECTED_PAGE_LAYOUT = { + "type": "Div", + "namespace": "dash_html_components", + "props": { + "children": [ + { + "type": "Dropdown", + "namespace": "dash_core_components", + "props": { + "id": "page-dd", + "options": ["a", "b"], + "value": "a", + }, + } + ] + }, +} + + +class TestPageLayoutResource: + def test_read_page_layout(self): + app = Dash(__name__) + app.layout = html.Div(id="main") + + page_layout = html.Div( + [ + dcc.Dropdown(id="page-dd", options=["a", "b"], value="a"), + ] + ) + fake_registry = { + "pages.test": { + "path": "/test", + "name": "Test", + "title": "Test Page", + "description": "", + "layout": page_layout, + }, + } + with app.server.test_request_context(): + with patch("dash._pages.PAGE_REGISTRY", fake_registry): + result = read_resource("dash://page-layout/test") + layout = json.loads(result.contents[0].text) + assert layout == EXPECTED_PAGE_LAYOUT diff --git a/tests/unit/mcp/primitives/resources/test_resource_pages.py b/tests/unit/mcp/primitives/resources/test_resource_pages.py new file mode 100644 index 0000000000..b2307d6fef --- /dev/null +++ b/tests/unit/mcp/primitives/resources/test_resource_pages.py @@ -0,0 +1,78 @@ +"""Tests for the dash://pages resource.""" + +import json +from unittest.mock import patch + +from dash import Dash, html + +from dash.mcp.primitives.resources import list_resources, read_resource + +EXPECTED_PAGES = [ + { + "path": "/", + "name": "Home", + "title": "Home Page", + "description": "The landing page", + "module": "pages.home", + }, + { + "path": "/analytics", + "name": "Analytics", + "title": "Analytics Dashboard", + "description": "View analytics", + "module": "pages.analytics", + }, +] + + +class TestPagesResource: + @staticmethod + def _make_app(): + app = Dash(__name__) + app.layout = html.Div(id="main") + return app + + def test_listed_for_multi_page_app(self): + app = self._make_app() + fake_registry = { + "pages.home": { + "path": "/", + "name": "Home", + "title": "Home", + "description": "", + } + } + with app.server.test_request_context(): + with patch("dash._pages.PAGE_REGISTRY", fake_registry): + result = list_resources() + uris = [str(r.uri) for r in result.resources] + assert "dash://pages" in uris + + def test_returns_page_info(self): + app = self._make_app() + fake_registry = { + "pages.home": EXPECTED_PAGES[0], + "pages.analytics": EXPECTED_PAGES[1], + } + with app.server.test_request_context(): + with patch("dash._pages.PAGE_REGISTRY", fake_registry): + result = read_resource("dash://pages") + content = json.loads(result.contents[0].text) + assert content == EXPECTED_PAGES + + def test_callable_title_falls_back_to_name(self): + app = self._make_app() + fake_registry = { + "pages.dynamic": { + "path": "/item/", + "name": "Item Detail", + "title": lambda **kwargs: f"Item {kwargs.get('item_id', '')}", + "description": lambda **kwargs: f"Details for {kwargs.get('item_id', '')}", + }, + } + with app.server.test_request_context(): + with patch("dash._pages.PAGE_REGISTRY", fake_registry): + result = read_resource("dash://pages") + page = json.loads(result.contents[0].text)[0] + assert page["title"] == "Item Detail" + assert page["description"] == "" From c397a386a1523fb5da30eacde10e987a97f7f142 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Mon, 6 Apr 2026 12:13:28 -0600 Subject: [PATCH 17/27] Fix import path --- dash/mcp/primitives/resources/resource_components.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/mcp/primitives/resources/resource_components.py b/dash/mcp/primitives/resources/resource_components.py index e6441d7aee..8cf366f95c 100644 --- a/dash/mcp/primitives/resources/resource_components.py +++ b/dash/mcp/primitives/resources/resource_components.py @@ -12,7 +12,7 @@ ) from dash import get_app -from dash.layout import traverse +from dash._layout_utils import traverse URI = "dash://components" From 73ba233c3e6bf73cc2bde5625c999318632e731c Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 13:05:07 -0600 Subject: [PATCH 18/27] Implement callbacks as tools with rich input/output schema and description generation --- dash/mcp/primitives/tools/callback_adapter.py | 13 +- .../tools/callback_adapter_collection.py | 3 +- .../primitives/tools/descriptions/__init__.py | 35 +- .../descriptions/description_docstring.py | 15 + .../tools/descriptions/description_outputs.py | 56 +++ .../tools/input_schemas/__init__.py | 44 +- .../input_descriptions/__init__.py | 31 ++ .../description_component_props.py | 81 ++++ .../description_docstrings.py | 71 +++ .../description_html_labels.py | 23 + .../schema_callback_type_annotations.py | 67 +++ .../schema_component_proptypes.py | 32 ++ .../schema_component_proptypes_overrides.py | 70 +++ .../tools/output_schemas/__init__.py | 28 +- .../schema_callback_response.py | 16 + tests/unit/mcp/conftest.py | 81 ++++ .../unit/mcp/tools/input_schemas/__init__.py | 0 .../input_descriptions/__init__.py | 0 .../input_descriptions/test_descriptions.py | 424 ++++++++++++++++++ .../tools/input_schemas/test_input_schemas.py | 331 ++++++++++++++ .../test_schema_component_proptypes.py | 15 + tests/unit/mcp/tools/test_callback_adapter.py | 167 ++++++- tests/unit/mcp/tools/test_tool_schema.py | 64 +++ 23 files changed, 1652 insertions(+), 15 deletions(-) create mode 100644 dash/mcp/primitives/tools/descriptions/description_docstring.py create mode 100644 dash/mcp/primitives/tools/descriptions/description_outputs.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py create mode 100644 dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py create mode 100644 dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py create mode 100644 dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py create mode 100644 dash/mcp/primitives/tools/output_schemas/schema_callback_response.py create mode 100644 tests/unit/mcp/tools/input_schemas/__init__.py create mode 100644 tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py create mode 100644 tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py create mode 100644 tests/unit/mcp/tools/input_schemas/test_input_schemas.py create mode 100644 tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py create mode 100644 tests/unit/mcp/tools/test_tool_schema.py diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 743453af10..7fd528dd01 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -43,8 +43,17 @@ def __init__(self, callback_output_id: str): @cached_property def as_mcp_tool(self) -> Tool: - """Stub — will be implemented in a future PR.""" - raise NotImplementedError("as_mcp_tool will be implemented in a future PR.") + """Transforms the internal Dash callback to a structured MCP tool. + + This tool can be serialized for LLM consumption or used internally for + its computed data. + """ + return Tool( + name=self.tool_name, + description=self._description, + inputSchema=self._input_schema, + outputSchema=self._output_schema, + ) def as_callback_body(self, kwargs: dict[str, Any]) -> dict[str, Any]: """Transforms the given kwargs to a dict suitable for calling this callback. diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index 59c1a7ac47..b53cf53a9d 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -115,8 +115,7 @@ def get_initial_value(self, id_and_prop: str) -> Any: return getattr(layout_component, prop, None) def as_mcp_tools(self) -> list[Tool]: - """Stub — will be implemented in a future PR.""" - raise NotImplementedError("as_mcp_tools will be implemented in a future PR.") + return [cb.as_mcp_tool for cb in self._callbacks if cb.is_valid] @property def tool_names(self) -> set[str]: diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index 67ec78c9ff..d464677251 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -1,7 +1,32 @@ -"""Stub — real implementation in a later PR.""" +"""Tool-level description generation for MCP tools. +Each source shares the same signature: +``(outputs, docstring) -> list[str]`` -def build_tool_description(outputs, docstring=None): - if docstring: - return docstring.strip() - return "Dash callback" +This is distinct from per-parameter descriptions +(in ``input_schemas/input_descriptions/``) which populate +``inputSchema.properties.{param}.description``. +""" + +from __future__ import annotations + +from typing import Any + +from .description_docstring import callback_docstring +from .description_outputs import output_summary + +_SOURCES = [ + output_summary, + callback_docstring, +] + + +def build_tool_description( + outputs: list[dict[str, Any]], + docstring: str | None = None, +) -> str: + """Build a human-readable description for an MCP tool.""" + lines: list[str] = [] + for source in _SOURCES: + lines.extend(source(outputs, docstring)) + return "\n".join(lines) if lines else "Dash callback" diff --git a/dash/mcp/primitives/tools/descriptions/description_docstring.py b/dash/mcp/primitives/tools/descriptions/description_docstring.py new file mode 100644 index 0000000000..71cf4d3d5a --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_docstring.py @@ -0,0 +1,15 @@ +"""Callback docstring for tool descriptions.""" + +from __future__ import annotations + +from typing import Any + + +def callback_docstring( + outputs: list[dict[str, Any]], + docstring: str | None = None, +) -> list[str]: + """Return the callback's docstring as description lines.""" + if docstring: + return ["", docstring.strip()] + return [] diff --git a/dash/mcp/primitives/tools/descriptions/description_outputs.py b/dash/mcp/primitives/tools/descriptions/description_outputs.py new file mode 100644 index 0000000000..c174c177f4 --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_outputs.py @@ -0,0 +1,56 @@ +"""Output summary for tool descriptions.""" + +from __future__ import annotations + +from typing import Any + +_OUTPUT_SEMANTICS: dict[tuple[str | None, str], str] = { + ("Graph", "figure"): "Returns chart/visualization data", + ("DataTable", "data"): "Returns tabular data", + ("DataTable", "columns"): "Returns table column definitions", + ("Dropdown", "options"): "Returns selection options", + ("Dropdown", "value"): "Updates a selection value", + ("RadioItems", "options"): "Returns selection options", + ("Checklist", "options"): "Returns selection options", + ("Store", "data"): "Returns stored data", + ("Download", "data"): "Returns downloadable content", + ("Markdown", "children"): "Returns formatted text", + (None, "figure"): "Returns chart/visualization data", + (None, "data"): "Returns data", + (None, "options"): "Returns selection options", + (None, "columns"): "Returns column definitions", + (None, "children"): "Returns content", + (None, "value"): "Returns a value", + (None, "style"): "Updates styling", + (None, "disabled"): "Updates enabled/disabled state", +} + + +def output_summary( + outputs: list[dict[str, Any]], + docstring: str | None = None, +) -> list[str]: + """Produce a short summary of what the callback outputs represent.""" + if not outputs: + return ["Dash callback"] + + lines: list[str] = [] + for out in outputs: + comp_id = out["component_id"] + prop = out["property"] + comp_type = out.get("component_type") + + semantic = _OUTPUT_SEMANTICS.get((comp_type, prop)) + if semantic is None: + semantic = _OUTPUT_SEMANTICS.get((None, prop)) + + if semantic is not None: + lines.append(f"- {comp_id}.{prop}: {semantic}") + else: + lines.append(f"- {comp_id}.{prop}") + + n = len(outputs) + if n == 1: + return [lines[0].lstrip("- ")] + header = f"Returns {n} output{'s' if n > 1 else ''}:" + return [header] + lines diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py index f306042a0c..2c1646f56a 100644 --- a/dash/mcp/primitives/tools/input_schemas/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -1,5 +1,43 @@ -"""Stub — real implementation in a later PR.""" +"""Input schema generation for MCP tool inputSchema fields. +Mirrors ``output_schemas/`` which generates ``outputSchema``. -def get_input_schema(param): - return {} +Each source is tried in priority order. All share the same signature: +``(param: MCPInput) -> dict | None``. +""" + +from __future__ import annotations + +from typing import Any + +from dash.mcp.types import MCPInput +from .schema_callback_type_annotations import annotation_to_schema +from .schema_component_proptypes_overrides import get_override_schema +from .schema_component_proptypes import get_component_prop_schema +from .input_descriptions import get_property_description + +_SOURCES = [ + annotation_to_schema, + get_override_schema, + get_component_prop_schema, +] + + +def get_input_schema(param: MCPInput) -> dict[str, Any]: + """Return the complete JSON Schema for a callback input parameter. + + Type sources provide ``type``/``enum`` (first non-None wins). + Description is assembled by ``input_descriptions``. + """ + schema: dict[str, Any] = {} + for source in _SOURCES: + result = source(param) + if result is not None: + schema = result + break + + description = get_property_description(param) + if description: + schema = {**schema, "description": description} + + return schema diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py new file mode 100644 index 0000000000..e1d1e9f47c --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py @@ -0,0 +1,31 @@ +"""Per-property description generation for MCP tool input parameters. + +Each source shares the same signature: +``(param: MCPInput) -> list[str]`` + +Sources are tried in order from most generic to most instance-specific. +All sources that produce lines are combined. +""" + +from __future__ import annotations + +from dash.mcp.types import MCPInput +from .description_component_props import component_props_description +from .description_docstrings import docstring_prop_description +from .description_html_labels import label_description + +_SOURCES = [ + docstring_prop_description, + label_description, + component_props_description, +] + + +def get_property_description(param: MCPInput) -> str | None: + """Build a complete description string for a callback input parameter.""" + lines: list[str] = [] + if not param.get("required", True): + lines.append("Input is optional.") + for source in _SOURCES: + lines.extend(source(param)) + return "\n".join(lines) if lines else None diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py new file mode 100644 index 0000000000..6934918260 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_component_props.py @@ -0,0 +1,81 @@ +"""Generic component property descriptions. + +Generate a description for each component prop that has a value (either set +directly in the layout or by an upstream callback). +""" + +from __future__ import annotations + +from typing import Any + +from dash import get_app +from dash.mcp.types import MCPInput + +_MAX_VALUE_LENGTH = 200 + +_MCP_EXCLUDED_PROPS = {"id", "className", "style"} + +_PROP_TEMPLATES: dict[tuple[str | None, str], str] = { + ("Store", "storage_type"): ( + "storage_type: {value}. Describes how to store the value client-side" + "'memory' resets on page refresh. " + "'session' persists for the duration of this session. " + "'local' persists on disk until explicitly cleared." + ), +} + + +def component_props_description(param: MCPInput) -> list[str]: + component = param.get("component") + if component is None: + return [] + + component_id = param["component_id"] + cbmap = get_app().mcp_callback_map + prop_lines: list[str] = [] + + for prop_name in getattr(component, "_prop_names", []): + if prop_name in _MCP_EXCLUDED_PROPS: + continue + + upstream = cbmap.find_by_output(f"{component_id}.{prop_name}") + if upstream is not None and not upstream.prevents_initial_call: + value = upstream.initial_output_value(f"{component_id}.{prop_name}") + else: + value = getattr(component, prop_name, None) + tool_name = upstream.tool_name if upstream is not None else None + + if value is None and tool_name is None: + continue + + component_type = param.get("component_type") + template = _PROP_TEMPLATES.get((component_type, prop_name)) + formatted_value = ( + _truncate_large_values(value, component_id, prop_name) + if value is not None + else None + ) + + if template and formatted_value is not None: + line = template.format(value=formatted_value) + elif formatted_value is not None: + line = f"{prop_name}: {formatted_value}" + else: + line = prop_name + + if tool_name: + line += f" (can be updated by tool: `{tool_name}`)" + + prop_lines.append(line) + + if not prop_lines: + return [] + return [f"Component properties for {component_id}:"] + prop_lines + + +def _truncate_large_values(value: Any, component_id: str, prop_name: str) -> str: + text = repr(value) + if len(text) > _MAX_VALUE_LENGTH: + hint = f"Use get_dash_component('{component_id}', '{prop_name}') for the full value" + return f"{text[:_MAX_VALUE_LENGTH]}... ({hint})" + return text diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py new file mode 100644 index 0000000000..1f67c3c0f2 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_docstrings.py @@ -0,0 +1,71 @@ +"""Extract property descriptions from component class docstrings. + +Dash component classes have structured docstrings generated by +``dash-generate-components`` in the format:: + + Keyword arguments: + + - prop_name (type_string; optional): + Description text that may span + multiple lines. + +This module parses that format and returns the first sentence of the +description for a given property. +""" + +from __future__ import annotations + +import re + +from dash.mcp.types import MCPInput + +_PROP_RE = re.compile( + r"^[ ]*- (\w+) \([^)]+\):\s*\n((?:[ ]+.+\n)*)", + re.MULTILINE, +) + +_cache: dict[type, dict[str, str]] = {} + +_SENTENCE_END = re.compile(r"(?<=[.!?])\s") + + +def docstring_prop_description(param: MCPInput) -> list[str]: + component = param.get("component") + if component is None: + return [] + desc = _get_prop_description(type(component), param["property"]) + return [desc] if desc else [] + + +def _get_prop_description(cls: type, prop: str) -> str | None: + props = _parse_docstring(cls) + return props.get(prop) + + +def _parse_docstring(cls: type) -> dict[str, str]: + if cls in _cache: + return _cache[cls] + + doc = getattr(cls, "__doc__", None) + if not doc: + _cache[cls] = {} + return _cache[cls] + + props: dict[str, str] = {} + for match in _PROP_RE.finditer(doc): + prop_name = match.group(1) + raw_desc = match.group(2) + lines = [line.strip() for line in raw_desc.strip().splitlines()] + desc = " ".join(lines) + if desc: + props[prop_name] = _first_sentence(desc) + + _cache[cls] = props + return props + + +def _first_sentence(text: str) -> str: + m = _SENTENCE_END.search(text) + if m: + return text[: m.start() + 1].rstrip() + return text diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py new file mode 100644 index 0000000000..2c9cd8dea9 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_html_labels.py @@ -0,0 +1,23 @@ +"""Label-based property descriptions. + +Reads the label map from the ``CallbackAdapterCollection``, +which builds it from the layout using ``htmlFor`` and +containment associations. +""" + +from __future__ import annotations + +from dash import get_app +from dash.mcp.types import MCPInput + + +def label_description(param: MCPInput) -> list[str]: + """Return the label text for this component, if any.""" + component_id = param.get("component_id") + if not component_id: + return [] + label_map = get_app().mcp_callback_map.component_label_map + texts = label_map.get(component_id, []) + if texts: + return [f"Labeled with: {'; '.join(texts)}"] + return [] diff --git a/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py new file mode 100644 index 0000000000..aee5b17c6f --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_callback_type_annotations.py @@ -0,0 +1,67 @@ +"""Map callback function type annotations to JSON Schema. + +When a callback function has explicit type annotations, those take +priority over all other schema sources (static overrides, component +introspection). + +Unlike component annotations (where nullable means "not required"), +callback annotations preserve ``null`` in the schema type when the +user writes ``Optional[X]`` — the user is explicitly saying the +value can be null. + +Also provides ``annotation_to_json_schema``, the shared low-level +converter used by both callback and component annotation pipelines. +""" + +from __future__ import annotations + +import inspect +from typing import Any + +from pydantic import TypeAdapter + +from dash.development.base_component import Component +from dash.mcp.types import MCPInput, is_nullable + + +def annotation_to_json_schema(annotation: type) -> dict[str, Any] | None: + """Convert a Python type annotation to a JSON Schema dict. + + Returns ``None`` if the annotation cannot be translated. + """ + if annotation is inspect.Parameter.empty or annotation is type(None): + return None + + if isinstance(annotation, type) and issubclass(annotation, Component): + return {"type": "string"} + + try: + return TypeAdapter(annotation).json_schema() + except Exception: + return None + + +def annotation_to_schema(param: MCPInput) -> dict[str, Any] | None: + """Convert a callback parameter's type annotation to a JSON Schema dict. + + Returns ``None`` if the annotation is not recognised, meaning the + caller should fall through to the next schema source. + + ``Optional[X]`` produces ``{"type": ["X", "null"]}`` — the user + explicitly chose a nullable type. + """ + annotation = param.get("annotation") + if annotation is None: + return None + schema = annotation_to_json_schema(annotation) + if schema is None: + return None + + if is_nullable(annotation) and schema: + t = schema.get("type") + if isinstance(t, str): + schema = {**schema, "type": [t, "null"]} + elif isinstance(t, list) and "null" not in t: + schema = {**schema, "type": [*t, "null"]} + + return schema diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py new file mode 100644 index 0000000000..151e391cf4 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes.py @@ -0,0 +1,32 @@ +"""Derive JSON Schema from a component's ``__init__`` type annotations.""" + +from __future__ import annotations + +import inspect +from typing import Any + +from dash.mcp.types import MCPInput +from .schema_callback_type_annotations import annotation_to_json_schema + + +def get_component_prop_schema(param: MCPInput) -> dict[str, Any] | None: + """Return the JSON Schema for a component property. + + Inspects the ``__init__`` signature of the component's class. + Returns ``None`` if the prop has no annotation. + """ + component = param.get("component") + prop = param["property"] + if component is None: + return None + + try: + sig = inspect.signature(type(component).__init__) + except (ValueError, TypeError): + return None + + sig_param = sig.parameters.get(prop) + if sig_param is None or sig_param.annotation is inspect.Parameter.empty: + return None + + return annotation_to_json_schema(sig_param.annotation) diff --git a/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py new file mode 100644 index 0000000000..25086896e7 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_component_proptypes_overrides.py @@ -0,0 +1,70 @@ +"""A place to manually define Schemas that override component-defined prop types +where type generation produces insufficient results. +""" + +from __future__ import annotations + +from typing import Any + +from dash.mcp.types import MCPInput +from .schema_component_proptypes import get_component_prop_schema + +_DATE_SCHEMA = { + "type": "string", + "format": "date", + "pattern": r"^\d{4}-\d{2}-\d{2}$", +} + + +def _compute_dropdown_value_schema(param: MCPInput) -> dict[str, Any] | None: + """Dropdown values are an array if `multi=True`; scalar values otherwise.""" + schema = get_component_prop_schema(param) + if schema is None: + return None + + component = param.get("component") + t = schema.get("type") + if not isinstance(t, list): + return schema + + if getattr(component, "multi", False): + items_schema = schema.get("items", {}) + return ( + {"type": "array", "items": items_schema} + if items_schema + else {"type": "array"} + ) + + scalar_types = [x for x in t if x != "array"] + refined = dict(schema) + refined["type"] = scalar_types[0] if len(scalar_types) == 1 else scalar_types + refined.pop("items", None) + return refined + + +_OVERRIDES: dict[tuple[str, str], dict[str, Any] | callable] = { + ("DatePickerSingle", "date"): _DATE_SCHEMA, + ("DatePickerRange", "start_date"): _DATE_SCHEMA, + ("DatePickerRange", "end_date"): _DATE_SCHEMA, + # Graph — annotation says "object", we add structured properties. + ("Graph", "figure"): { + "type": "object", + "properties": { + "data": {"type": "array", "items": {"type": "object"}}, + "layout": {"type": "object"}, + "frames": {"type": "array", "items": {"type": "object"}}, + }, + }, + ("Dropdown", "value"): _compute_dropdown_value_schema, +} + + +def get_override_schema(param: MCPInput) -> dict[str, Any] | None: + """Return a schema override, or None to fall through to introspection.""" + key = (param.get("component_type"), param["property"]) + override = _OVERRIDES.get(key) + if override is None: + return None + if callable(override): + return override(param) + return dict(override) diff --git a/dash/mcp/primitives/tools/output_schemas/__init__.py b/dash/mcp/primitives/tools/output_schemas/__init__.py index d2d70c3552..41ddfd8d49 100644 --- a/dash/mcp/primitives/tools/output_schemas/__init__.py +++ b/dash/mcp/primitives/tools/output_schemas/__init__.py @@ -1,5 +1,29 @@ -"""Stub — real implementation in a later PR.""" +"""Output schema generation for MCP tool outputSchema fields. +Mirrors ``input_schemas/`` which generates ``inputSchema``. -def get_output_schema(): +Each source shares the same signature: ``() -> dict | None``. +""" + +from __future__ import annotations + +from typing import Any + +from .schema_callback_response import callback_response_schema + +_SOURCES = [ + callback_response_schema, +] + + +def get_output_schema() -> dict[str, Any]: + """Return the JSON Schema for a callback tool's output. + + Tries each source in order, returning the first non-None result. + Falls back to ``{}`` (any type). + """ + for source in _SOURCES: + schema = source() + if schema is not None: + return schema return {} diff --git a/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py b/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py new file mode 100644 index 0000000000..e61a482cba --- /dev/null +++ b/dash/mcp/primitives/tools/output_schemas/schema_callback_response.py @@ -0,0 +1,16 @@ +"""Output schema derived from CallbackDispatchResponse.""" + +from __future__ import annotations + +from typing import Any + +from pydantic import TypeAdapter + +from dash.types import CallbackDispatchResponse + +_schema = TypeAdapter(CallbackDispatchResponse).json_schema() + + +def callback_response_schema() -> dict[str, Any]: + """Return the JSON Schema for a callback dispatch response.""" + return _schema diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 437a71db5c..97b8d9c137 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -4,3 +4,84 @@ if sys.version_info < (3, 10): collect_ignore_glob.append("*") + +"""Shared helpers for MCP unit tests. + +These helpers work directly with Tool objects from CallbackAdapterCollection, +avoiding the MCP server so they can be used before the server is wired up. +""" + +from dash import Dash, Input, Output, html +from dash._get_app import app_context +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + +BUILTINS = {"get_dash_component"} + + +def _setup_mcp(app): + """Set up MCP for an app in tests.""" + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback(Output("my-output", "children"), Input("my-input", "children")) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return _setup_mcp(app) + + +def _tools_list(app): + """Return tools as Tool objects via as_mcp_tools().""" + _setup_mcp(app) + with app.server.test_request_context(): + return app.mcp_callback_map.as_mcp_tools() + + +def _user_tool(tools): + """Return the first tool that isn't a builtin.""" + return next(t for t in tools if t.name not in BUILTINS) + + +def _app_with_callback(component, input_prop="value", output_id="out"): + """Create a Dash app with one callback using ``component`` as Input.""" + app = Dash(__name__) + app.layout = html.Div([component, html.Div(id=output_id)]) + + @app.callback(Output(output_id, "children"), Input(component.id, input_prop)) + def update(val): + return f"got: {val}" + + return _setup_mcp(app) + + +def _schema_for(tool, param_name=None): + """Extract the JSON schema dict for a parameter, without description.""" + props = tool.inputSchema["properties"] + if param_name is None: + param_name = next(iter(props)) + schema = dict(props[param_name]) + schema.pop("description", None) + return schema + + +def _desc_for(tool, param_name=None): + """Extract the description string for a parameter, or ''.""" + props = tool.inputSchema["properties"] + if param_name is None: + param_name = next(iter(props)) + return props[param_name].get("description", "") diff --git a/tests/unit/mcp/tools/input_schemas/__init__.py b/tests/unit/mcp/tools/input_schemas/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py b/tests/unit/mcp/tools/input_schemas/input_descriptions/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py b/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py new file mode 100644 index 0000000000..bc6758d2d3 --- /dev/null +++ b/tests/unit/mcp/tools/input_schemas/input_descriptions/test_descriptions.py @@ -0,0 +1,424 @@ +"""Description tests — verifies per-property description generation. + +Tests are organized by description source: +- Labels (htmlFor, containment, text extraction) +- Component-specific (date pickers, sliders) +- Options (Dropdown, RadioItems, Checklist) +- Generic props (placeholder, default value, min/max/step) +- Chained callbacks (dynamic prop/options detection) +- Combinations (label + component-specific) +""" + +import pytest + +from dash import Dash, Input, Output, dcc, html + +from tests.unit.mcp.conftest import ( + _app_with_callback, + _desc_for, + _tools_list, + _user_tool, +) + + +def _app_with_layout(layout, *inputs): + app = Dash(__name__) + app.layout = layout + + @app.callback( + Output("out", "children"), + [Input(cid, prop) for cid, prop in inputs], + ) + def update(*args): + return str(args) + + return app + + +def _tool_for(component, input_prop="value"): + app = _app_with_callback(component, input_prop=input_prop) + return _user_tool(_tools_list(app)) + + +# --------------------------------------------------------------------------- +# Labels +# --------------------------------------------------------------------------- + + +class TestLabels: + def test_html_for(self): + app = _app_with_layout( + html.Div( + [ + html.Label("Your Name", htmlFor="inp"), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Your Name" in _desc_for(tool) + + def test_html_for_not_adjacent(self): + app = _app_with_layout( + html.Div( + [ + html.Div(html.Label("Remote Label", htmlFor="inp")), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Remote Label" in _desc_for(tool) + + def test_containment(self): + app = _app_with_layout( + html.Div( + [ + html.Label( + [ + "Pick a city", + dcc.Dropdown(id="city_dd", options=["NYC", "LA"]), + ] + ), + html.Div(id="out"), + ] + ), + ("city_dd", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Pick a city" in _desc_for(tool) + + def test_deeply_nested_containment(self): + app = _app_with_layout( + html.Div( + [ + html.Label( + [ + html.Span("Nested Label"), + html.Div(dcc.Input(id="nested_inp")), + ] + ), + html.Div(id="out"), + ] + ), + ("nested_inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Nested Label" in _desc_for(tool) + + def test_both_htmlfor_and_containment_captured(self): + app = _app_with_layout( + html.Div( + [ + html.Label(["Containment Label", dcc.Input(id="inp")]), + html.Label("HtmlFor Label", htmlFor="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "HtmlFor Label" in desc + assert "Containment Label" in desc + + def test_deep_text_extraction(self): + app = _app_with_layout( + html.Div( + [ + html.Label( + html.Div(html.Span(html.B("Deep Text"))), + htmlFor="inp", + ), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + assert "Deep Text" in _desc_for(tool) + + def test_multiple_text_nodes(self): + app = _app_with_layout( + html.Div( + [ + html.Label( + [html.B("First"), " ", html.I("Second")], + htmlFor="inp", + ), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ), + ("inp", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Labeled with: First Second" in desc + + def test_unrelated_label_excluded(self): + app = _app_with_layout( + html.Div( + [ + html.Label("Other Field", htmlFor="other"), + dcc.Input(id="other"), + dcc.Input(id="target"), + html.Div(id="out"), + ] + ), + ("target", "value"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Other Field" not in (desc or "") + + +# --------------------------------------------------------------------------- +# Component-specific: date pickers +# --------------------------------------------------------------------------- + + +class TestDatePickerDescriptions: + def test_single_full_range(self): + dp = dcc.DatePickerSingle( + id="dp", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "2020-01-01" in desc + assert "2025-12-31" in desc + + def test_single_min_only(self): + dp = dcc.DatePickerSingle(id="dp", min_date_allowed="2020-01-01") + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "min_date_allowed: '2020-01-01'" in desc + + def test_single_default_date(self): + dp = dcc.DatePickerSingle(id="dp", date="2024-06-15") + desc = _desc_for(_tool_for(dp, "date"), "val") + assert "date: '2024-06-15'" in desc + + def test_range_with_constraints(self): + dpr = dcc.DatePickerRange( + id="dpr", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + desc = _desc_for(_tool_for(dpr, "start_date"), "val") + assert "2020-01-01" in desc + + +# --------------------------------------------------------------------------- +# Component-specific: sliders +# --------------------------------------------------------------------------- + + +class TestSliderDescriptions: + def test_min_max(self): + sl = dcc.Slider(id="sl", min=0, max=100) + desc = _desc_for(_tool_for(sl), "val") + assert "min: 0" in desc + assert "max: 100" in desc + + def test_step(self): + sl = dcc.Slider(id="sl", min=0, max=100, step=5) + desc = _desc_for(_tool_for(sl), "val") + assert "step: 5" in desc + + def test_default_value(self): + sl = dcc.Slider(id="sl", min=0, max=100, value=50) + desc = _desc_for(_tool_for(sl), "val") + assert "value: 50" in desc + + def test_marks(self): + sl = dcc.Slider(id="sl", min=0, max=100, marks={0: "Low", 100: "High"}) + desc = _desc_for(_tool_for(sl), "val") + assert "marks: {0: 'Low', 100: 'High'}" in desc + + def test_range_slider_min_max(self): + rs = dcc.RangeSlider(id="rs", min=0, max=100) + desc = _desc_for(_tool_for(rs), "val") + assert "min: 0" in desc + assert "max: 100" in desc + + +# --------------------------------------------------------------------------- +# Options (parametrized across Dropdown, RadioItems, Checklist) +# --------------------------------------------------------------------------- + + +_OPTIONS_COMPONENTS = [ + ("Dropdown", lambda **kw: dcc.Dropdown(id="comp", **kw), "comp"), + ("RadioItems", lambda **kw: dcc.RadioItems(id="comp", **kw), "comp"), + ("Checklist", lambda **kw: dcc.Checklist(id="comp", **kw), "comp"), +] + + +class TestOptionsDescriptions: + @pytest.mark.parametrize( + "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] + ) + def test_options_shown(self, name, factory, cid): + comp = factory(options=["X", "Y", "Z"]) + desc = _desc_for(_tool_for(comp), "val") + assert "options: ['X', 'Y', 'Z']" in desc + + @pytest.mark.parametrize( + "name,factory,cid", _OPTIONS_COMPONENTS, ids=[c[0] for c in _OPTIONS_COMPONENTS] + ) + def test_default_shown(self, name, factory, cid): + value = ["a"] if name == "Checklist" else "a" + comp = factory(options=["a", "b"], value=value) + desc = _desc_for(_tool_for(comp), "val") + assert f"value: {value!r}" in desc + + def test_dropdown_dict_options(self): + dd = dcc.Dropdown( + id="dd", + options=[ + {"label": "New York", "value": "NYC"}, + ], + ) + assert "NYC" in _desc_for(_tool_for(dd), "val") + + def test_store_storage_type_template(self): + store = dcc.Store(id="store", storage_type="session") + app = _app_with_callback(store, input_prop="data") + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool, "val") + assert ( + "storage_type: 'session'. Describes how to store the value client-side" + in desc + ) + + def test_many_options_truncated(self): + dd = dcc.Dropdown(id="big", options=[str(i) for i in range(50)], value="0") + app = _app_with_callback(dd) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool, "val") + assert "options:" in desc + assert "Use get_dash_component('big', 'options') for the full value" in desc + + +# --------------------------------------------------------------------------- +# Generic props +# --------------------------------------------------------------------------- + + +class TestGenericDescriptions: + def test_placeholder(self): + inp = dcc.Input(id="inp", placeholder="Enter your name") + assert "placeholder: 'Enter your name'" in _desc_for(_tool_for(inp), "val") + + def test_numeric_min_max(self): + inp = dcc.Input(id="inp", type="number", min=0, max=999) + desc = _desc_for(_tool_for(inp), "val") + assert "min: 0" in desc + assert "max: 999" in desc + + def test_step(self): + inp = dcc.Input(id="inp", type="number", min=0, max=100, step=0.1) + assert "step: 0.1" in _desc_for(_tool_for(inp), "val") + + def test_default_value(self): + inp = dcc.Input(id="inp", value="hello") + desc = _desc_for(_tool_for(inp), "val") + assert "value: 'hello'" in desc + + def test_non_text_type(self): + inp = dcc.Input(id="inp", type="email") + assert "type: 'email'" in _desc_for(_tool_for(inp), "val") + + def test_store_default(self): + store = dcc.Store(id="store", data={"key": "value"}) + app = _app_with_callback(store, input_prop="data") + tool = _user_tool(_tools_list(app)) + assert "data: {'key': 'value'}" in _desc_for(tool, "val") + + +# --------------------------------------------------------------------------- +# Chained callbacks +# --------------------------------------------------------------------------- + + +class TestChainedCallbacks: + def test_options_set_by_upstream(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["US", "CA"], value="US"), + dcc.Dropdown(id="city", options=[], value=None), + html.Div(id="result"), + ] + ) + + @app.callback(Output("city", "options"), Input("country", "value")) + def update_cities(country): + return ["NYC", "LA"] if country == "US" else ["Toronto"] + + @app.callback(Output("result", "children"), Input("city", "value")) + def show_city(city): + return city + + tools = _tools_list(app) + tool = next(t for t in tools if "show_city" in t.name) + desc = _desc_for(tool, "city") + assert "can be updated by tool: `update_cities`" in desc + assert "options:" in desc + + def test_value_set_by_upstream(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="source", value=""), + html.Div(id="derived", children=""), + html.Div(id="result"), + ] + ) + + @app.callback(Output("derived", "children"), Input("source", "value")) + def compute_derived(val): + return f"derived: {val}" + + @app.callback(Output("result", "children"), Input("derived", "children")) + def use_derived(val): + return val + + tools = _tools_list(app) + tool = next(t for t in tools if "use_derived" in t.name) + desc = _desc_for(tool, "val") + assert "can be updated by tool: `compute_derived`" in desc + + +# --------------------------------------------------------------------------- +# Combinations +# --------------------------------------------------------------------------- + + +class TestCombinations: + def test_label_with_date_picker(self): + dp = dcc.DatePickerSingle( + id="dp", + min_date_allowed="2020-01-01", + max_date_allowed="2025-12-31", + ) + app = _app_with_layout( + html.Div( + [ + html.Label("Departure Date", htmlFor="dp"), + dp, + html.Div(id="out"), + ] + ), + ("dp", "date"), + ) + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Departure Date" in desc + assert "2020-01-01" in desc diff --git a/tests/unit/mcp/tools/input_schemas/test_input_schemas.py b/tests/unit/mcp/tools/input_schemas/test_input_schemas.py new file mode 100644 index 0000000000..5350bd955e --- /dev/null +++ b/tests/unit/mcp/tools/input_schemas/test_input_schemas.py @@ -0,0 +1,331 @@ +"""Input schema tests — verifies JSON Schema generation for component properties. + +Tests are organized by concern: +- Static overrides (date pickers, graph, interval, sliders) +- Component introspection (representative samples — full type coverage in test_json_prop_typing) +- Callback annotation overrides (highest priority) +- Required/nullable behavior +""" + +import pytest +from typing import Optional + +from dash import Dash, Input, Output, State, dcc, html + +from tests.unit.mcp.conftest import ( + _app_with_callback, + _schema_for, + _tools_list, + _user_tool, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_schema(component_type, prop): + _factories = { + "DatePickerSingle": lambda: dcc.DatePickerSingle(id="dp"), + "DatePickerRange": lambda: dcc.DatePickerRange(id="dpr"), + "Graph": lambda: dcc.Graph(id="graph"), + "Interval": lambda: dcc.Interval(id="intv"), + "Input": lambda: dcc.Input(id="inp"), + "Textarea": lambda: dcc.Textarea(id="ta"), + "Tabs": lambda: dcc.Tabs(id="tabs"), + "Dropdown": lambda: dcc.Dropdown(id="dd"), + "RadioItems": lambda: dcc.RadioItems(id="ri"), + "Checklist": lambda: dcc.Checklist(id="cl"), + "Store": lambda: dcc.Store(id="store"), + "Upload": lambda: dcc.Upload(id="upload"), + "Slider": lambda: dcc.Slider(id="sl"), + "RangeSlider": lambda: dcc.RangeSlider(id="rs"), + } + app = _app_with_callback(_factories[component_type](), input_prop=prop) + tool = _user_tool(_tools_list(app)) + return _schema_for(tool) + + +# --------------------------------------------------------------------------- +# Static overrides take priority over introspection +# --------------------------------------------------------------------------- + + +class TestStaticOverrides: + """Verify that overrides win over component introspection.""" + + def test_override_beats_introspection(self): + schema = _get_schema("DatePickerSingle", "date") + # Introspection would return None for this prop; + # override provides a date format with pattern + assert schema["type"] == "string" + assert schema["format"] == "date" + assert "pattern" in schema + + +# --------------------------------------------------------------------------- +# Introspection — representative samples (not exhaustive per-component) +# --------------------------------------------------------------------------- + +INTROSPECTION_CASES = [ + # (component_type, prop, expected_schema) — one per distinct type shape + ( + "Input", + "value", + {"anyOf": [{"type": "string"}, {"type": "number"}, {"type": "null"}]}, + ), + ( + "Input", + "disabled", + { + "anyOf": [ + {"type": "boolean"}, + {"const": "disabled", "type": "string"}, + {"const": "DISABLED", "type": "string"}, + {"type": "null"}, + ] + }, + ), + ("Input", "n_submit", {"anyOf": [{"type": "number"}, {"type": "null"}]}), + ( + "Dropdown", + "value", + { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + { + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + ] + }, + "type": "array", + }, + {"type": "null"}, + ] + }, + ), + ("Dropdown", "options", {"anyOf": [{}, {"type": "null"}]}), + ( + "Checklist", + "value", + { + "anyOf": [ + { + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "number"}, + {"type": "boolean"}, + ] + }, + "type": "array", + }, + {"type": "null"}, + ] + }, + ), + ( + "Store", + "data", + { + "anyOf": [ + {"additionalProperties": True, "type": "object"}, + {"items": {}, "type": "array"}, + {"type": "number"}, + {"type": "string"}, + {"type": "boolean"}, + {"type": "null"}, + ] + }, + ), + ( + "Upload", + "contents", + { + "anyOf": [ + {"type": "string"}, + {"items": {"type": "string"}, "type": "array"}, + {"type": "null"}, + ] + }, + ), + ( + "RangeSlider", + "value", + {"anyOf": [{"items": {"type": "number"}, "type": "array"}, {"type": "null"}]}, + ), + ("Tabs", "value", {"anyOf": [{"type": "string"}, {"type": "null"}]}), +] + + +class TestIntrospection: + """Representative introspection tests — full type coverage in test_json_prop_typing.""" + + @pytest.mark.parametrize( + "component_type,prop,expected", + INTROSPECTION_CASES, + ids=[f"{c}.{p}" for c, p, _ in INTROSPECTION_CASES], + ) + def test_introspected_schema(self, component_type, prop, expected): + assert _get_schema(component_type, prop) == expected + + +# --------------------------------------------------------------------------- +# Callback annotation overrides +# --------------------------------------------------------------------------- + + +def _app_with_annotated_callback(annotation_type, input_prop="disabled"): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + if annotation_type is None: + + @app.callback(Output("out", "children"), Input("inp", input_prop)) + def update(val): + return str(val) + + else: + + @app.callback(Output("out", "children"), Input("inp", input_prop)) + def update(val: annotation_type): + return str(val) + + return app + + +ANNOTATION_CASES = [ + (str, "disabled", {"type": "string"}), + (int, "value", {"type": "integer"}), + (float, "value", {"type": "number"}), + (bool, "value", {"type": "boolean"}), + (list, "value", {"items": {}, "type": "array"}), + (dict, "value", {"additionalProperties": True, "type": "object"}), + (Optional[int], "value", {"anyOf": [{"type": "integer"}, {"type": "null"}]}), + (Optional[str], "value", {"anyOf": [{"type": "string"}, {"type": "null"}]}), +] + + +class TestAnnotationOverrides: + """Callback type annotations override component schemas.""" + + @pytest.mark.parametrize( + "ann,prop,expected", + ANNOTATION_CASES, + ids=[ + f"{a.__name__ if hasattr(a, '__name__') else a}-{p}" + for a, p, _ in ANNOTATION_CASES + ], + ) + def test_annotation(self, ann, prop, expected): + app = _app_with_annotated_callback(ann, input_prop=prop) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == expected + + def test_no_annotation_uses_introspection(self): + app = _app_with_annotated_callback(None) + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == { + "anyOf": [ + {"type": "boolean"}, + {"const": "disabled", "type": "string"}, + {"const": "DISABLED", "type": "string"}, + {"type": "null"}, + ] + } + + +class TestAnnotationNullability: + """Annotations control nullable vs non-nullable schemas.""" + + def test_str_removes_null(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: str): + return val + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == {"type": "string"} + + def test_optional_preserves_null(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: Optional[str]): + return val or "" + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == { + "anyOf": [{"type": "string"}, {"type": "null"}] + } + + def test_optional_param_not_required(self): + app = Dash(__name__) + app.layout = html.Div([dcc.Dropdown(id="dd"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val: Optional[str]): + return val or "" + + tool = _user_tool(_tools_list(app)) + assert "val" not in tool.inputSchema.get("required", []) + + +class TestAnnotationWithState: + """Annotations work for State parameters too.""" + + def test_state_annotation_overrides(self): + app = Dash(__name__) + app.layout = html.Div( + [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] + ) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + State("store", "data"), + ) + def update(val: str, data: dict): + return str(val) + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == {"type": "string"} + assert _schema_for(tool, "data") == { + "additionalProperties": True, + "type": "object", + } + + def test_partial_annotations(self): + app = Dash(__name__) + app.layout = html.Div( + [dcc.Input(id="inp"), dcc.Store(id="store"), html.Div(id="out")] + ) + + @app.callback( + Output("out", "children"), + Input("inp", "value"), + State("store", "data"), + ) + def update(val: int, data): + return str(val) + + tool = _user_tool(_tools_list(app)) + assert _schema_for(tool, "val") == {"type": "integer"} + assert _schema_for(tool, "data") == { + "anyOf": [ + {"additionalProperties": True, "type": "object"}, + {"items": {}, "type": "array"}, + {"type": "number"}, + {"type": "string"}, + {"type": "boolean"}, + {"type": "null"}, + ] + } diff --git a/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py b/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py new file mode 100644 index 0000000000..10b6ae5543 --- /dev/null +++ b/tests/unit/mcp/tools/input_schemas/test_schema_component_proptypes.py @@ -0,0 +1,15 @@ +"""Tests for schema_component_proptypes. + +Only tests our custom logic — pydantic's type-to-schema conversion +is tested by pydantic itself. +""" + +from dash.development.base_component import Component +from dash.mcp.primitives.tools.input_schemas.schema_callback_type_annotations import ( + annotation_to_json_schema, +) + + +class TestComponentTypes: + def test_component_type_maps_to_string(self): + assert annotation_to_json_schema(Component) == {"type": "string"} diff --git a/tests/unit/mcp/tools/test_callback_adapter.py b/tests/unit/mcp/tools/test_callback_adapter.py index 91808d304e..dc3fc041fc 100644 --- a/tests/unit/mcp/tools/test_callback_adapter.py +++ b/tests/unit/mcp/tools/test_callback_adapter.py @@ -1,8 +1,9 @@ """Tests for CallbackAdapter.""" import pytest -from dash import Dash, Input, Output, dcc, html +from dash import Dash, Input, Output, State, dcc, html from dash._get_app import app_context +from mcp.types import Tool from dash.mcp.primitives.tools.callback_adapter_collection import ( CallbackAdapterCollection, @@ -35,6 +36,68 @@ def update(val): return app +@pytest.fixture +def multi_output_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [], val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def state_app(): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("btn", "n_clicks"), + State("inp", "value"), + ) + def update(clicks, val): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +@pytest.fixture +def typed_app(): + app = Dash(__name__) + app.layout = html.Div([dcc.Input(id="inp"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val: str): + return val + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + @pytest.fixture def duplicate_names_app(): app = Dash(__name__) @@ -131,6 +194,52 @@ def test_duplicates_get_unique_names(self, duplicate_names_app): assert names[0] != names[1] +class TestTool: + def test_returns_tool_instance(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert isinstance(tool, Tool) + assert tool.name == "update" + + def test_description_includes_docstring(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "Update output." in tool.description + + def test_description_includes_output_target(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "out.children" in tool.description + + def test_param_name_from_function_signature(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "val" in tool.inputSchema["properties"] + + def test_param_has_label_description(self, simple_app): + with simple_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + desc = tool.inputSchema["properties"]["val"].get("description", "") + assert "Your Name" in desc + + def test_state_params_included(self, state_app): + with state_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + props = tool.inputSchema["properties"] + assert set(props.keys()) == {"clicks", "val"} + + def test_multi_output_description(self, multi_output_app): + with multi_output_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert "dd2.options" in tool.description + assert "out.children" in tool.description + + def test_typed_annotation_narrows_schema(self, typed_app): + with typed_app.server.test_request_context(): + tool = app_context.get().mcp_callback_map[0].as_mcp_tool + assert tool.inputSchema["properties"]["val"]["type"] == "string" + + class TestGetInitialValue: def test_returns_layout_value(self, simple_app): callback_map = app_context.get().mcp_callback_map @@ -225,3 +334,59 @@ def update(val): app_context.set(app) app.mcp_callback_map = CallbackAdapterCollection(app) assert app.mcp_callback_map[0].is_valid + + +class TestNoInfiniteLoop: + @pytest.mark.timeout(5) + def test_initial_output_does_not_loop(self): + """Building a tool must not trigger infinite re-entry in _initial_output.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Slider(id="sl", min=0, max=10, value=5), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("sl", "value")) + def show(value): + return f"Value: {value}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + tool = app.mcp_callback_map[0].as_mcp_tool + assert tool.name == "show" + + @pytest.mark.timeout(5) + def test_chained_callbacks_do_not_loop(self): + """Chained callbacks with initial value resolution must not loop.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Slider(id="sl", min=0, max=10, value=5), + dcc.Slider(id="sl2", min=0, max=10), + html.Div(id="out"), + ] + ) + + @app.callback(Output("sl2", "value"), Input("sl", "value")) + def sync(v): + return v + + @app.callback( + Output("out", "children"), + Input("sl", "value"), + Input("sl2", "value"), + ) + def show(v1, v2): + return f"{v1} + {v2}" + + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + + with app.server.test_request_context(): + for cb in app.mcp_callback_map: + tool = cb.as_mcp_tool + assert tool.name is not None diff --git a/tests/unit/mcp/tools/test_tool_schema.py b/tests/unit/mcp/tools/test_tool_schema.py new file mode 100644 index 0000000000..49b639834c --- /dev/null +++ b/tests/unit/mcp/tools/test_tool_schema.py @@ -0,0 +1,64 @@ +"""Tool schema tests — what a Dash MCP tool looks like. + +The EXPECTED_TOOL dict below is the canonical reference for the shape of +a callback-generated MCP tool. It doubles as human-readable documentation +and as a test fixture. + +Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools +""" + +from tests.unit.mcp.conftest import ( + _make_app, + _tools_list, + _user_tool, +) + +from pydantic import TypeAdapter +from dash.development.base_component import Component +from dash.types import CallbackDispatchResponse + +_DASH_COMPONENT_SCHEMA = TypeAdapter(Component).json_schema() + +EXPECTED_TOOL = { + "name": "update_output", + "description": ( + "my-output.children: Returns content\n" "\n" "Test callback docstring." + ), + "inputSchema": { + "type": "object", + "properties": { + "value": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "number"}, + _DASH_COMPONENT_SCHEMA, + { + "items": { + "anyOf": [ + {"type": "string"}, + {"type": "integer"}, + {"type": "number"}, + _DASH_COMPONENT_SCHEMA, + {"type": "null"}, + ] + }, + "type": "array", + }, + {"type": "null"}, + ], + "description": "Input is optional.\nThe children of this component.", + }, + }, + }, + "outputSchema": TypeAdapter(CallbackDispatchResponse).json_schema(), +} + + +class TestToolSchema: + """Verify that the generated tool matches EXPECTED_TOOL exactly.""" + + def test_full_tool(self): + """The entire tool dict matches the expected shape.""" + tool = _user_tool(_tools_list(_make_app())) + assert tool.model_dump(exclude_none=True) == EXPECTED_TOOL From cac68555c37709e894d6f842414d343909a942e7 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 10:51:47 -0600 Subject: [PATCH 19/27] Refactor description sources to accept CallbackAdapter instances --- dash/mcp/primitives/tools/callback_adapter.py | 2 +- .../primitives/tools/descriptions/__init__.py | 16 +++++++++------- .../tools/descriptions/description_docstring.py | 11 ++++++----- .../tools/descriptions/description_outputs.py | 11 ++++++----- 4 files changed, 22 insertions(+), 18 deletions(-) diff --git a/dash/mcp/primitives/tools/callback_adapter.py b/dash/mcp/primitives/tools/callback_adapter.py index 7fd528dd01..98fa66aacb 100644 --- a/dash/mcp/primitives/tools/callback_adapter.py +++ b/dash/mcp/primitives/tools/callback_adapter.py @@ -145,7 +145,7 @@ def prevents_initial_call(self) -> bool: @cached_property def _description(self) -> str: - return build_tool_description(self.outputs, self._docstring) + return build_tool_description(self) @cached_property def _input_schema(self) -> dict[str, Any]: diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index d464677251..29cc2840d0 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -1,7 +1,7 @@ """Tool-level description generation for MCP tools. Each source shares the same signature: -``(outputs, docstring) -> list[str]`` +``(adapter: CallbackAdapter) -> list[str]`` This is distinct from per-parameter descriptions (in ``input_schemas/input_descriptions/``) which populate @@ -10,23 +10,25 @@ from __future__ import annotations -from typing import Any +from __future__ import annotations + +from typing import TYPE_CHECKING from .description_docstring import callback_docstring from .description_outputs import output_summary +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + _SOURCES = [ output_summary, callback_docstring, ] -def build_tool_description( - outputs: list[dict[str, Any]], - docstring: str | None = None, -) -> str: +def build_tool_description(adapter: CallbackAdapter) -> str: """Build a human-readable description for an MCP tool.""" lines: list[str] = [] for source in _SOURCES: - lines.extend(source(outputs, docstring)) + lines.extend(source(adapter)) return "\n".join(lines) if lines else "Dash callback" diff --git a/dash/mcp/primitives/tools/descriptions/description_docstring.py b/dash/mcp/primitives/tools/descriptions/description_docstring.py index 71cf4d3d5a..21dbeed804 100644 --- a/dash/mcp/primitives/tools/descriptions/description_docstring.py +++ b/dash/mcp/primitives/tools/descriptions/description_docstring.py @@ -2,14 +2,15 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter -def callback_docstring( - outputs: list[dict[str, Any]], - docstring: str | None = None, -) -> list[str]: + +def callback_docstring(adapter: CallbackAdapter) -> list[str]: """Return the callback's docstring as description lines.""" + docstring = adapter._docstring if docstring: return ["", docstring.strip()] return [] diff --git a/dash/mcp/primitives/tools/descriptions/description_outputs.py b/dash/mcp/primitives/tools/descriptions/description_outputs.py index c174c177f4..986344c75c 100644 --- a/dash/mcp/primitives/tools/descriptions/description_outputs.py +++ b/dash/mcp/primitives/tools/descriptions/description_outputs.py @@ -2,7 +2,10 @@ from __future__ import annotations -from typing import Any +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter _OUTPUT_SEMANTICS: dict[tuple[str | None, str], str] = { ("Graph", "figure"): "Returns chart/visualization data", @@ -26,11 +29,9 @@ } -def output_summary( - outputs: list[dict[str, Any]], - docstring: str | None = None, -) -> list[str]: +def output_summary(adapter: CallbackAdapter) -> list[str]: """Produce a short summary of what the callback outputs represent.""" + outputs = adapter.outputs if not outputs: return ["Dash callback"] From 37572429abc9595bd4b6e0898b3bda1440958a3e Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 12:12:17 -0600 Subject: [PATCH 20/27] Fix pylint error --- tests/unit/mcp/conftest.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 97b8d9c137..83f6e5378c 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -1,21 +1,17 @@ +"""Shared helpers for MCP unit tests.""" + import sys -collect_ignore_glob = [] +from dash import Dash, Input, Output, html +from dash._get_app import app_context +collect_ignore_glob = [] if sys.version_info < (3, 10): collect_ignore_glob.append("*") - -"""Shared helpers for MCP unit tests. - -These helpers work directly with Tool objects from CallbackAdapterCollection, -avoiding the MCP server so they can be used before the server is wired up. -""" - -from dash import Dash, Input, Output, html -from dash._get_app import app_context -from dash.mcp.primitives.tools.callback_adapter_collection import ( - CallbackAdapterCollection, -) +else: + from dash.mcp.primitives.tools.callback_adapter_collection import ( # pylint: disable=wrong-import-position + CallbackAdapterCollection, + ) BUILTINS = {"get_dash_component"} From 7be911dffd24252a01a1210f24b8e76f2ffe3afa Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 14:05:37 -0600 Subject: [PATCH 21/27] Add pattern-matching callback support to input schemas and descriptions --- .../tools/input_schemas/__init__.py | 2 + .../input_descriptions/__init__.py | 2 + .../description_pattern_matching.py | 69 ++++++++++ .../input_schemas/schema_pattern_matching.py | 84 ++++++++++++ .../input_schemas/test_pattern_matching.py | 121 ++++++++++++++++++ 5 files changed, 278 insertions(+) create mode 100644 dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py create mode 100644 dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py create mode 100644 tests/unit/mcp/tools/input_schemas/test_pattern_matching.py diff --git a/dash/mcp/primitives/tools/input_schemas/__init__.py b/dash/mcp/primitives/tools/input_schemas/__init__.py index 2c1646f56a..67a56cfcfd 100644 --- a/dash/mcp/primitives/tools/input_schemas/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/__init__.py @@ -14,9 +14,11 @@ from .schema_callback_type_annotations import annotation_to_schema from .schema_component_proptypes_overrides import get_override_schema from .schema_component_proptypes import get_component_prop_schema +from .schema_pattern_matching import get_pattern_matching_schema from .input_descriptions import get_property_description _SOURCES = [ + get_pattern_matching_schema, annotation_to_schema, get_override_schema, get_component_prop_schema, diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py index e1d1e9f47c..9e94737293 100644 --- a/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/__init__.py @@ -13,11 +13,13 @@ from .description_component_props import component_props_description from .description_docstrings import docstring_prop_description from .description_html_labels import label_description +from .description_pattern_matching import pattern_matching_description _SOURCES = [ docstring_prop_description, label_description, component_props_description, + pattern_matching_description, ] diff --git a/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py new file mode 100644 index 0000000000..53bd92d7a0 --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/input_descriptions/description_pattern_matching.py @@ -0,0 +1,69 @@ +"""Description for pattern-matching callback inputs. + +Explains that the input corresponds to a pattern-matching callback +(ALL, MATCH, ALLSMALLER) and describes the expected format. +See: https://dash.plotly.com/pattern-matching-callbacks +""" + +from __future__ import annotations + +import json + +from dash.dependencies import Wildcard +from dash.mcp.types import MCPInput + +_WILDCARD_VALUES = frozenset(w.value for w in Wildcard) + + +def pattern_matching_description(param: MCPInput) -> list[str]: + """Describe pattern-matching behavior for wildcard inputs.""" + dep_id = _parse_dep_id(param["component_id"]) + if dep_id is None: + return [] + + wildcard_key, wildcard_type = _find_wildcard(dep_id) + if wildcard_key is None: + return [] + + non_wildcard = {k: v for k, v in dep_id.items() if k != wildcard_key} + pattern_desc = ", ".join(f'{k}="{v}"' for k, v in non_wildcard.items()) + prop = param["property"] + + wildcard_descriptions = { + "ALL": ( + f"Pattern-matching input (ALL): provide an array of `{prop}` values, " + f"one per component matching {{{pattern_desc}}}. " + f"All matching components are included." + ), + "MATCH": ( + f"Pattern-matching input (MATCH): provide the `{prop}` value " + f"for the specific component matching {{{pattern_desc}}} " + f"that triggered this callback." + ), + "ALLSMALLER": ( + f"Pattern-matching input (ALLSMALLER): provide an array of `{prop}` values " + f"from components matching {{{pattern_desc}}} " + f"whose `{wildcard_key}` is smaller than the triggering component's `{wildcard_key}`." + ), + } + + desc = wildcard_descriptions.get(wildcard_type) + return [desc] if desc else [] + + +def _parse_dep_id(component_id: str) -> dict | None: + if not component_id.startswith("{"): + return None + try: + return json.loads(component_id) + except (json.JSONDecodeError, ValueError): + return None + + +def _find_wildcard(dep_id: dict) -> tuple[str | None, str | None]: + """Return (key, wildcard_type) for the first wildcard found.""" + for key, value in dep_id.items(): + if isinstance(value, list) and len(value) == 1: + if value[0] in _WILDCARD_VALUES: + return key, value[0] + return None, None diff --git a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py new file mode 100644 index 0000000000..68a06def3f --- /dev/null +++ b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py @@ -0,0 +1,84 @@ +"""Schema for pattern-matching callback inputs (ALL, MATCH, ALLSMALLER). + +When a callback input uses a wildcard ID, the callback receives a +list of values — one per matching component. This source detects +wildcard IDs and produces an array schema. If matching components +exist in the layout, the item type is inferred from a concrete match. +""" + +from __future__ import annotations + +import json +from typing import Any + +from dash.layout import find_matching_components, _WILDCARD_VALUES +from dash.mcp.types import MCPInput + + +def get_pattern_matching_schema(param: MCPInput) -> dict[str, Any] | None: + """Return a schema for pattern-matching inputs. + + For ALL/ALLSMALLER: array of ``{id, property, value}`` objects. + For MATCH: a single ``{id, property, value}`` object. + """ + dep_id = _parse_dep_id(param["component_id"]) + if dep_id is None: + return None + + wildcard_type = _get_wildcard_type(dep_id) + if wildcard_type is None: + return None + + value_schema = _infer_value_schema(param) + + item_schema: dict[str, Any] = { + "type": "object", + "properties": { + "id": {"type": "object"}, + "property": {"type": "string"}, + "value": value_schema or {}, + }, + "required": ["id", "property", "value"], + } + + if wildcard_type == "MATCH": + return item_schema + + return {"type": "array", "items": item_schema} + + +def _parse_dep_id(component_id: str) -> dict | None: + if not component_id.startswith("{"): + return None + try: + return json.loads(component_id) + except (json.JSONDecodeError, ValueError): + return None + + +def _get_wildcard_type(dep_id: dict) -> str | None: + """Return the wildcard type (ALL, MATCH, ALLSMALLER) or None.""" + for value in dep_id.values(): + if isinstance(value, list) and len(value) == 1: + if value[0] in _WILDCARD_VALUES: + return value[0] + return None + + +def _infer_value_schema(param: MCPInput) -> dict[str, Any] | None: + """Infer the JSON Schema for the ``value`` field from a matching component.""" + matches = find_matching_components(_parse_dep_id(param["component_id"])) + if not matches: + return None + + from . import get_input_schema + + concrete_param: MCPInput = { + **param, + "component": matches[0], + "component_id": str(getattr(matches[0], "id", "")), + "component_type": getattr(matches[0], "_type", None), + } + schema = get_input_schema(concrete_param) + schema.pop("description", None) + return schema or None diff --git a/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py b/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py new file mode 100644 index 0000000000..37b8a642ee --- /dev/null +++ b/tests/unit/mcp/tools/input_schemas/test_pattern_matching.py @@ -0,0 +1,121 @@ +"""Tests for pattern-matching schema and description generation.""" + +from dash import Dash, html, Input, Output, ALL, MATCH + +from tests.unit.mcp.conftest import _tools_list, _user_tool, _schema_for, _desc_for + + +class TestPatternMatchingSchema: + def test_all_produces_array_schema(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}, children="A"), + html.Div(id={"type": "item", "index": 1}, children="B"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": ALL}, "children"), + ) + def combine(values): + return ", ".join(values) + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "array" + assert schema["items"]["type"] == "object" + assert "value" in schema["items"]["properties"] + + def test_match_produces_object_schema(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}, children="A"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": MATCH}, "children"), + ) + def echo(value): + return value + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "object" + assert "value" in schema["properties"] + + def test_annotation_narrows_value_schema(self): + from dash import dcc + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id={"type": "filter", "index": 0}, options=["a", "b"]), + dcc.Dropdown(id={"type": "filter", "index": 1}, options=["c", "d"]), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "filter", "index": ALL}, "options"), + ) + def combine(options: list[str]): + return str(options) + + tool = _user_tool(_tools_list(app)) + schema = _schema_for(tool) + assert schema["type"] == "array" + value_schema = schema["items"]["properties"]["value"] + # Annotation narrows value to list[str] instead of the broad introspected type + assert value_schema == {"items": {"type": "string"}, "type": "array"} + + +class TestPatternMatchingDescription: + def test_all_description(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": ALL}, "children"), + ) + def combine(values): + return str(values) + + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Pattern-matching input (ALL)" in desc + assert 'type="item"' in desc + + def test_match_description(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id={"type": "item", "index": 0}), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input({"type": "item", "index": MATCH}, "children"), + ) + def echo(value): + return value + + tool = _user_tool(_tools_list(app)) + desc = _desc_for(tool) + assert "Pattern-matching input (MATCH)" in desc + assert 'type="item"' in desc From ed469f59751cfd70e4b51662534d8332a9ff2216 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 12:49:11 -0600 Subject: [PATCH 22/27] Fix import in tests --- .../primitives/tools/input_schemas/schema_pattern_matching.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py index 68a06def3f..bc910dca43 100644 --- a/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py +++ b/dash/mcp/primitives/tools/input_schemas/schema_pattern_matching.py @@ -11,7 +11,7 @@ import json from typing import Any -from dash.layout import find_matching_components, _WILDCARD_VALUES +from dash._layout_utils import find_matching_components, _WILDCARD_VALUES from dash.mcp.types import MCPInput From 323b3d6c05628fcbe8edd71522cdd48c5cd6de43 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 14:23:37 -0600 Subject: [PATCH 23/27] Add result formatters for Plotly figures and tabular data --- dash/mcp/primitives/tools/results/__init__.py | 52 ++++++++++ .../tools/results/result_dataframe.py | 55 +++++++++++ .../tools/results/result_plotly_figure.py | 52 ++++++++++ .../tools/results/test_callback_response.py | 98 +++++++++++++++++++ .../unit/mcp/tools/results/test_dataframe.py | 63 ++++++++++++ .../mcp/tools/results/test_plotly_figure.py | 55 +++++++++++ 6 files changed, 375 insertions(+) create mode 100644 dash/mcp/primitives/tools/results/__init__.py create mode 100644 dash/mcp/primitives/tools/results/result_dataframe.py create mode 100644 dash/mcp/primitives/tools/results/result_plotly_figure.py create mode 100644 tests/unit/mcp/tools/results/test_callback_response.py create mode 100644 tests/unit/mcp/tools/results/test_dataframe.py create mode 100644 tests/unit/mcp/tools/results/test_plotly_figure.py diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py new file mode 100644 index 0000000000..e2f91a67a8 --- /dev/null +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -0,0 +1,52 @@ +"""Tool result formatting for MCP tools/call responses. + +Each result formatter shares the same signature: +``(output: MCPOutput, value: Any) -> list[TextContent | ImageContent]`` + +Formatters decide for themselves whether they care about a given output. +The structuredContent is always the full dispatch response. +""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import CallToolResult, TextContent + +from dash.types import CallbackDispatchResponse +from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + +from .result_dataframe import dataframe_result +from .result_plotly_figure import plotly_figure_result + +_RESULT_FORMATTERS = [ + plotly_figure_result, + dataframe_result, +] + + +def format_callback_response( + response: CallbackDispatchResponse, + callback: CallbackAdapter, +) -> CallToolResult: + """Format a dispatch response as a CallToolResult. + + The response is always returned as structuredContent. Result + formatters are called per output property and may add additional + content items (images, markdown, etc.). + """ + content: list[Any] = [ + TextContent(type="text", text=json.dumps(response, default=str)), + ] + + resp = response.get("response") or {} + for callback_output in callback.outputs: + value = resp.get(callback_output["component_id"], {}).get(callback_output["property"]) + for result_fn in _RESULT_FORMATTERS: + content.extend(result_fn(callback_output, value)) + + return CallToolResult( + content=content, + structuredContent=response, + ) diff --git a/dash/mcp/primitives/tools/results/result_dataframe.py b/dash/mcp/primitives/tools/results/result_dataframe.py new file mode 100644 index 0000000000..652c31589b --- /dev/null +++ b/dash/mcp/primitives/tools/results/result_dataframe.py @@ -0,0 +1,55 @@ +"""Tabular data result: render as a markdown table. + +Detects tabular output by component type and prop name: +- DataTable.data +- AgGrid.rowData +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import TextContent + +from dash.mcp.types import MCPOutput + +MAX_ROWS = 50 + +_TABULAR_PROPS = { + ("DataTable", "data"), + ("AgGrid", "rowData"), +} + + +def _to_markdown_table(rows: list[dict], max_rows: int = MAX_ROWS) -> str: + """Render a list of row dicts as a markdown table.""" + columns = list(rows[0].keys()) + total = len(rows) + + lines: list[str] = [] + lines.append(f"*{total} rows \u00d7 {len(columns)} columns*") + lines.append("") + lines.append("| " + " | ".join(columns) + " |") + lines.append("| " + " | ".join("---" for _ in columns) + " |") + + for row in rows[:max_rows]: + cells = [ + str(row.get(col, "")).replace("|", "\\|").replace("\n", " ") + for col in columns + ] + lines.append("| " + " | ".join(cells) + " |") + + if total > max_rows: + lines.append(f"\n(\u2026 {total - max_rows} more rows)") + + return "\n".join(lines) + + +def dataframe_result(callback_output: MCPOutput, callback_output_value: Any) -> list: + """Produce a markdown table for tabular component output values.""" + key = (callback_output.get("component_type"), callback_output.get("property")) + if key not in _TABULAR_PROPS: + return [] + if not isinstance(callback_output_value, list) or not callback_output_value or not isinstance(callback_output_value[0], dict): + return [] + return [TextContent(type="text", text=_to_markdown_table(callback_output_value))] diff --git a/dash/mcp/primitives/tools/results/result_plotly_figure.py b/dash/mcp/primitives/tools/results/result_plotly_figure.py new file mode 100644 index 0000000000..d837e17eed --- /dev/null +++ b/dash/mcp/primitives/tools/results/result_plotly_figure.py @@ -0,0 +1,52 @@ +"""Plotly figure tool result: rendered image.""" + +from __future__ import annotations + +import base64 +import logging +from typing import Any + +from mcp.types import ImageContent + +from dash.mcp.types import MCPOutput + +logger = logging.getLogger(__name__) + +IMAGE_WIDTH = 700 +IMAGE_HEIGHT = 450 + + +def _render_image(figure: Any) -> ImageContent | None: + """Render the figure as a base64 PNG ImageContent. + + Returns None if kaleido is not installed. + """ + try: + img_bytes = figure.to_image( + format="png", + width=IMAGE_WIDTH, + height=IMAGE_HEIGHT, + ) + except (ValueError, ImportError): + logger.debug("MCP: kaleido not available, skipping image render") + return None + + b64 = base64.b64encode(img_bytes).decode("ascii") + return ImageContent(type="image", data=b64, mimeType="image/png") + + +def plotly_figure_result(callback_output: MCPOutput, callback_output_value: Any) -> list: + """Produce a rendered PNG for Graph.figure output values.""" + if callback_output.get("component_type") != "Graph" or callback_output.get("property") != "figure": + return [] + if not isinstance(callback_output_value, dict): + return [] + + try: + import plotly.graph_objects as go + except ImportError: + return [] + + fig = go.Figure(callback_output_value) + image = _render_image(fig) + return [image] if image is not None else [] diff --git a/tests/unit/mcp/tools/results/test_callback_response.py b/tests/unit/mcp/tools/results/test_callback_response.py new file mode 100644 index 0000000000..ff8cca5e20 --- /dev/null +++ b/tests/unit/mcp/tools/results/test_callback_response.py @@ -0,0 +1,98 @@ +"""Tests for the callback response formatter.""" + +from unittest.mock import Mock + +from dash.mcp.primitives.tools.results import format_callback_response + + +def _mock_callback(outputs=None): + cb = Mock() + cb.outputs = outputs or [] + return cb + + +class TestFormatCallbackResponse: + def test_wraps_as_structured_content(self): + response = { + "multi": True, + "response": {"out": {"children": "hello"}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent == response + + def test_content_has_json_text_fallback(self): + """Per MCP spec, structuredContent SHOULD include a TextContent fallback.""" + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert len(result.content) >= 1 + assert result.content[0].type == "text" + assert '"multi": true' in result.content[0].text + + def test_is_error_defaults_false(self): + response = {"multi": True, "response": {}} + result = format_callback_response(response, _mock_callback()) + assert result.isError is False + + def test_preserves_side_update(self): + response = { + "multi": True, + "response": {"out": {"children": "x"}}, + "sideUpdate": {"other": {"value": 42}}, + } + result = format_callback_response(response, _mock_callback()) + assert result.structuredContent["sideUpdate"] == {"other": {"value": 42}} + + def test_datatable_result_includes_markdown_table(self): + response = { + "multi": True, + "response": { + "my-table": {"data": [{"name": "Alice", "age": 30}]}, + }, + } + outputs = [ + { + "component_id": "my-table", + "component_type": "DataTable", + "property": "data", + "id_and_prop": "my-table.data", + "initial_value": None, + "tool_name": "update", + } + ] + result = format_callback_response(response, _mock_callback(outputs)) + texts = [c.text for c in result.content if c.type == "text"] + assert any("| name | age |" in t for t in texts) + + def test_plotly_figure_includes_image(self): + from unittest.mock import patch + + try: + import plotly.graph_objects as go + except ImportError: + return + + response = { + "multi": True, + "response": { + "my-graph": { + "figure": { + "data": [{"type": "bar", "x": ["A"], "y": [1]}], + "layout": {}, + } + } + }, + } + outputs = [ + { + "component_id": "my-graph", + "component_type": "Graph", + "property": "figure", + "id_and_prop": "my-graph.figure", + "initial_value": None, + "tool_name": "update", + } + ] + with patch.object(go.Figure, "to_image", return_value=b"\x89PNGfake"): + result = format_callback_response(response, _mock_callback(outputs)) + images = [c for c in result.content if c.type == "image"] + assert len(images) == 1 diff --git a/tests/unit/mcp/tools/results/test_dataframe.py b/tests/unit/mcp/tools/results/test_dataframe.py new file mode 100644 index 0000000000..a7f9e42fca --- /dev/null +++ b/tests/unit/mcp/tools/results/test_dataframe.py @@ -0,0 +1,63 @@ +"""Tests for the tabular data result formatter.""" + +from dash.mcp.primitives.tools.results.result_dataframe import ( + MAX_ROWS, + dataframe_result, +) + +EXPECTED_TABLE = ( + "*2 rows \u00d7 2 columns*\n" + "\n" + "| name | age |\n" + "| --- | --- |\n" + "| Alice | 30 |\n" + "| Bob | 25 |" +) + +SAMPLE_ROWS = [{"name": "Alice", "age": 30}, {"name": "Bob", "age": 25}] + +DATATABLE_OUTPUT = { + "component_type": "DataTable", + "property": "data", + "component_id": "t", + "id_and_prop": "t.data", + "initial_value": None, + "tool_name": "update", +} + +AGGRID_OUTPUT = { + "component_type": "AgGrid", + "property": "rowData", + "component_id": "g", + "id_and_prop": "g.rowData", + "initial_value": None, + "tool_name": "update", +} + + +class TestDataframeResult: + def test_datatable_data_renders_markdown(self): + result = dataframe_result(DATATABLE_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + def test_aggrid_rowdata_renders_markdown(self): + result = dataframe_result(AGGRID_OUTPUT, SAMPLE_ROWS) + assert len(result) == 1 + assert result[0].text == EXPECTED_TABLE + + def test_ignores_non_tabular_props(self): + non_tabular = {**DATATABLE_OUTPUT, "property": "columns"} + assert dataframe_result(non_tabular, SAMPLE_ROWS) == [] + + def test_ignores_empty_or_non_dict_rows(self): + assert dataframe_result(DATATABLE_OUTPUT, []) == [] + assert dataframe_result(DATATABLE_OUTPUT, ["a", "b"]) == [] + + def test_truncates_large_tables(self): + rows = [{"i": n} for n in range(MAX_ROWS + 50)] + result = dataframe_result(DATATABLE_OUTPUT, rows) + text = result[0].text + assert f"| {MAX_ROWS - 1} |" in text + assert f"| {MAX_ROWS} |" not in text + assert "50 more rows" in text diff --git a/tests/unit/mcp/tools/results/test_plotly_figure.py b/tests/unit/mcp/tools/results/test_plotly_figure.py new file mode 100644 index 0000000000..8e336ba687 --- /dev/null +++ b/tests/unit/mcp/tools/results/test_plotly_figure.py @@ -0,0 +1,55 @@ +"""Tests for the Plotly figure tool result formatter.""" + +import base64 +from unittest.mock import patch + +import pytest + +from dash.mcp.primitives.tools.results.result_plotly_figure import ( + plotly_figure_result, +) + +go = pytest.importorskip("plotly.graph_objects") + +FAKE_PNG = b"\x89PNG\r\n\x1a\nfakedata" +FAKE_B64 = base64.b64encode(FAKE_PNG).decode("ascii") + +GRAPH_FIGURE_OUTPUT = { + "component_type": "Graph", + "property": "figure", + "component_id": "g", + "id_and_prop": "g.figure", + "initial_value": None, + "tool_name": "update", +} + + +class TestPlotlyFigureResult: + def test_returns_image_when_kaleido_available(self): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", return_value=FAKE_PNG): + result = plotly_figure_result(GRAPH_FIGURE_OUTPUT, fig_dict) + assert len(result) == 1 + assert result[0].type == "image" + assert result[0].data == FAKE_B64 + + def test_returns_empty_when_kaleido_unavailable(self): + fig_dict = go.Figure(data=[go.Bar(x=["A", "B"], y=[1, 2])]).to_plotly_json() + with patch.object(go.Figure, "to_image", side_effect=ImportError): + result = plotly_figure_result(GRAPH_FIGURE_OUTPUT, fig_dict) + assert result == [] + + def test_ignores_non_graph_components(self): + output = { + **GRAPH_FIGURE_OUTPUT, + "component_type": "Div", + "property": "children", + } + assert plotly_figure_result(output, {}) == [] + + def test_ignores_non_figure_props(self): + output = {**GRAPH_FIGURE_OUTPUT, "property": "clickData"} + assert plotly_figure_result(output, {}) == [] + + def test_ignores_non_dict_values(self): + assert plotly_figure_result(GRAPH_FIGURE_OUTPUT, "not a dict") == [] From 1a307e7579ae4021f22edbdf8513f9665b679138 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 15:02:04 -0600 Subject: [PATCH 24/27] Add get_dash_component tool and callback tool dispatch pipeline --- dash/mcp/primitives/tools/__init__.py | 43 ++++++ .../tools/tool_get_dash_component.py | 123 ++++++++++++++++ dash/mcp/primitives/tools/tools_callbacks.py | 47 ++++++ tests/unit/mcp/conftest.py | 12 +- .../mcp/tools/test_tool_get_dash_component.py | 117 +++++++++++++++ tests/unit/mcp/tools/test_tools_callbacks.py | 137 ++++++++++++++++++ 6 files changed, 477 insertions(+), 2 deletions(-) create mode 100644 dash/mcp/primitives/tools/__init__.py create mode 100644 dash/mcp/primitives/tools/tool_get_dash_component.py create mode 100644 dash/mcp/primitives/tools/tools_callbacks.py create mode 100644 tests/unit/mcp/tools/test_tool_get_dash_component.py create mode 100644 tests/unit/mcp/tools/test_tools_callbacks.py diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py new file mode 100644 index 0000000000..64f89dc3d0 --- /dev/null +++ b/dash/mcp/primitives/tools/__init__.py @@ -0,0 +1,43 @@ +"""MCP tool listing and call handling. + +Each tool module exports: +- ``get_tool_names() -> set[str]`` +- ``get_tools() -> list[Tool]`` +- ``call_tool(tool_name, arguments) -> CallToolResult`` + +The __init__ assembles the list and dispatches calls by name. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, ListToolsResult + +from dash.mcp.types import ToolNotFoundError + +from . import tool_get_dash_component as _get_component +from . import tools_callbacks as _callbacks + +_TOOL_MODULES = [_callbacks, _get_component] + + +def list_tools() -> ListToolsResult: + """Build the MCP tools/list response.""" + tools = [] + for mod in _TOOL_MODULES: + tools.extend(mod.get_tools()) + return ListToolsResult(tools=tools) + + +def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + """Dispatch a tools/call request by tool name.""" + for mod in _TOOL_MODULES: + if tool_name in mod.get_tool_names(): + result = mod.call_tool(tool_name, arguments) + return result + raise ToolNotFoundError( + f"Tool not found: {tool_name}." + " The app's callbacks may have changed." + " Please call tools/list to refresh your tool list." + ) diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py new file mode 100644 index 0000000000..8584242333 --- /dev/null +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -0,0 +1,123 @@ +"""Built-in tool: get_dash_component.""" + +from __future__ import annotations + +import json +from typing import Any + +from mcp.types import CallToolResult, TextContent, Tool +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated, NotRequired, TypedDict + +from dash import get_app +from dash._layout_utils import find_component +from dash.mcp.types import ComponentPropertyInfo, ComponentQueryResult + + +class _ComponentQueryInput(TypedDict): + component_id: Annotated[str, Field(description="The component ID to query")] + property: NotRequired[ + Annotated[ + str, + Field( + description="The property name to read (e.g. 'options', 'value'). Omit to list all defined properties." + ), + ] + ] + + +_INPUT_SCHEMA = TypeAdapter(_ComponentQueryInput).json_schema() +_OUTPUT_SCHEMA = TypeAdapter(ComponentQueryResult).json_schema() + +NAME = "get_dash_component" + + +def get_tool_names() -> set[str]: + return {NAME} + + +def get_tools() -> list[Tool]: + return [_build_tool()] + + +def _build_tool() -> Tool: + return Tool( + name=NAME, + description=( + "Get a component's properties, values, and tool relationships. " + "If property is omitted, returns all defined properties. " + "If property is specified, returns only that property. " + "See the dash://components resource for available component IDs." + ), + inputSchema=_INPUT_SCHEMA, + outputSchema=_OUTPUT_SCHEMA, + ) + + +def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + comp_id = arguments.get("component_id", "") + if not comp_id: + raise ValueError("component_id is required") + + prop_filter = arguments.get("property", "") + component = find_component(comp_id) + + if component is None: + callback_map = get_app().mcp_callback_map + rendering_tools = [ + cb.tool_name + for cb in callback_map + if any(out["component_id"] == comp_id for out in cb.outputs) + ] + msg = f"Component '{comp_id}' not found in static layout." + if rendering_tools: + msg += f" However, the following tools would modify it: {rendering_tools}." + msg += " Use the dash://components resource to see statically available component IDs." + return CallToolResult( + content=[TextContent(type="text", text=msg)], + isError=True, + ) + + callback_map = get_app().mcp_callback_map + + properties: dict[str, ComponentPropertyInfo] = {} + for prop_name in getattr(component, "_prop_names", []): + if prop_filter and prop_name != prop_filter: + continue + + value = callback_map.get_initial_value(f"{comp_id}.{prop_name}") + if value is None: + value = getattr(component, prop_name, None) + if value is None: + continue + + modified_by: list[str] = [] + input_to: list[str] = [] + id_and_prop = f"{comp_id}.{prop_name}" + for cb in callback_map: + for out in cb.outputs: + if out["id_and_prop"] == id_and_prop: + modified_by.append(cb.tool_name) + for inp in cb.inputs: + if inp["id_and_prop"] == id_and_prop: + input_to.append(cb.tool_name) + + properties[prop_name] = ComponentPropertyInfo( + initial_value=value, + modified_by_tool=modified_by, + input_to_tool=input_to, + ) + + labels = callback_map.component_label_map.get(comp_id, []) + + structured: ComponentQueryResult = ComponentQueryResult( + component_id=comp_id, + component_type=type(component).__name__, + label=labels if labels else None, + properties=properties, + ) + + return CallToolResult( + content=[TextContent(type="text", text=json.dumps(structured, default=str))], + structuredContent=structured, + ) diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py new file mode 100644 index 0000000000..ba08795d35 --- /dev/null +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -0,0 +1,47 @@ +"""Dynamic callback tools for MCP. + +Handles listing, naming, and executing callback-based tools. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import CallToolResult, TextContent, Tool + +from dash import get_app +from dash.mcp.types import CallbackExecutionError, ToolNotFoundError + +from .results import format_callback_response + + +def get_tool_names() -> set[str]: + return get_app().mcp_callback_map.tool_names + + +def get_tools() -> list[Tool]: + """Return one Tool per server-callable callback.""" + return get_app().mcp_callback_map.as_mcp_tools() + + +def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: + """Execute a callback tool by name.""" + from .callback_utils import run_callback + + callback_map = get_app().mcp_callback_map + cb = callback_map.find_by_tool_name(tool_name) + if cb is None: + raise ToolNotFoundError( + f"Tool not found: {tool_name}." + " The app's callbacks may have changed." + " Please call tools/list to refresh your tool list." + ) + + try: + dispatch_response = run_callback(cb, arguments) + except CallbackExecutionError as e: + return CallToolResult( + content=[TextContent(type="text", text=str(e))], + isError=True, + ) + return format_callback_response(dispatch_response, cb) diff --git a/tests/unit/mcp/conftest.py b/tests/unit/mcp/conftest.py index 83f6e5378c..70d6e0f663 100644 --- a/tests/unit/mcp/conftest.py +++ b/tests/unit/mcp/conftest.py @@ -9,6 +9,7 @@ if sys.version_info < (3, 10): collect_ignore_glob.append("*") else: + from dash.mcp.primitives.tools import call_tool, list_tools # pylint: disable=wrong-import-position from dash.mcp.primitives.tools.callback_adapter_collection import ( # pylint: disable=wrong-import-position CallbackAdapterCollection, ) @@ -42,10 +43,10 @@ def update_output(value): def _tools_list(app): - """Return tools as Tool objects via as_mcp_tools().""" + """Return all tools (callbacks + builtins) as Tool objects.""" _setup_mcp(app) with app.server.test_request_context(): - return app.mcp_callback_map.as_mcp_tools() + return list_tools().tools def _user_tool(tools): @@ -81,3 +82,10 @@ def _desc_for(tool, param_name=None): if param_name is None: param_name = next(iter(props)) return props[param_name].get("description", "") + + +def _call_tool(app, tool_name, arguments=None): + """Call a tool via the dispatch pipeline and return the CallToolResult.""" + _setup_mcp(app) + with app.server.test_request_context(): + return call_tool(tool_name, arguments or {}) diff --git a/tests/unit/mcp/tools/test_tool_get_dash_component.py b/tests/unit/mcp/tools/test_tool_get_dash_component.py new file mode 100644 index 0000000000..5a8a454068 --- /dev/null +++ b/tests/unit/mcp/tools/test_tool_get_dash_component.py @@ -0,0 +1,117 @@ +"""Tests for the get_dash_component built-in tool.""" + +from dash import Dash, Input, Output, dcc, html + +from tests.unit.mcp.conftest import _call_tool, _make_app, _tools_list + + +class TestGetDashComponent: + def test_present_in_tools_list(self): + app = _make_app() + tool_names = [t.name for t in _tools_list(app)] + assert "get_dash_component" in tool_names + + def test_returns_structured_output_with_prop(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), + ] + ) + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + "property": "value", + }, + ) + sc = result.structuredContent + assert sc["component_id"] == "my-dd" + assert sc["component_type"] == "Dropdown" + assert "value" in sc["properties"] + assert sc["properties"]["value"]["initial_value"] == "b" + assert "options" not in sc["properties"] + + def test_returns_all_props_without_property(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="my-dd", options=["a", "b"], value="b"), + ] + ) + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + }, + ) + sc = result.structuredContent + assert "options" in sc["properties"] + assert "value" in sc["properties"] + assert sc["properties"]["value"]["initial_value"] == "b" + + def test_includes_label(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label("Pick one", htmlFor="my-dd"), + dcc.Dropdown(id="my-dd", options=["a", "b"], value="a"), + ] + ) + + @app.callback(Output("my-dd", "value"), Input("my-dd", "options")) + def noop(o): + return "a" + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "my-dd", + }, + ) + sc = result.structuredContent + assert sc["label"] == ["Pick one"] + + def test_includes_tool_references(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "dd", + "property": "value", + }, + ) + prop_info = result.structuredContent["properties"]["value"] + assert "update" in prop_info["input_to_tool"] + + def test_missing_id_returns_hint(self): + app = _make_app() + result = _call_tool( + app, + "get_dash_component", + { + "component_id": "nonexistent", + "property": "value", + }, + ) + text = result.content[0].text + assert "nonexistent" in text + assert "not found" in text + assert "dash://components" in text diff --git a/tests/unit/mcp/tools/test_tools_callbacks.py b/tests/unit/mcp/tools/test_tools_callbacks.py new file mode 100644 index 0000000000..751f85bedd --- /dev/null +++ b/tests/unit/mcp/tools/test_tools_callbacks.py @@ -0,0 +1,137 @@ +"""Tool definition tests — MCP spec compliance and Dash conventions. + +Verifies that generated tools conform to the MCP specification (2025-11-25) +and Dash-specific conventions. Focuses on shape/structure, not inputSchema +values (those are covered by input_schemas/). + +Reference: https://modelcontextprotocol.io/specification/2025-11-25/server/tools +""" + +import re + +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) +from dash.mcp.primitives.tools.descriptions import build_tool_description + +from tests.unit.mcp.conftest import ( + _make_app, + _tools_list, +) + +_TOOL_NAME_RE = re.compile(r"^[A-Za-z0-9_\-.]+$") + + +class TestToolSpecCompliance: + """Every tool must conform to the MCP 2025-11-25 specification.""" + + def test_all_tools_conform_to_mcp_spec(self): + tools = _tools_list(_make_app()) + names = [t.name for t in tools] + + assert len(names) == len(set(names)), f"Duplicate tool names: {names}" + + for tool in tools: + assert tool.name + assert tool.inputSchema + assert 1 <= len(tool.name) <= 128 + assert _TOOL_NAME_RE.match(tool.name), f"Invalid tool name: {tool.name}" + + schema = tool.inputSchema + assert isinstance(schema, dict) + assert schema.get("type") == "object" + assert isinstance(schema.get("properties", {}), dict) + + required = set(schema.get("required", [])) + props = set(schema.get("properties", {}).keys()) + assert ( + required <= props + ), f"{tool.name}: required {required - props} not in properties" + + +class TestBuiltinToolDefinitions: + def _tools(self): + return _tools_list(_make_app()) + + def _builtin(self, name): + return next(t for t in self._tools() if t.name == name) + + def test_query_component_always_present(self): + names = {t.name for t in self._tools()} + assert "get_dash_component" in names + + def test_query_component_has_required_params(self): + tool = self._builtin("get_dash_component") + assert "component_id" in tool.inputSchema["properties"] + assert "property" in tool.inputSchema["properties"] + assert set(tool.inputSchema.get("required", [])) == {"component_id"} + + +class TestSanitizeToolName: + def test_simple_name(self): + assert ( + CallbackAdapterCollection._sanitize_name("update_output") == "update_output" + ) + + def test_special_characters_replaced(self): + assert ( + CallbackAdapterCollection._sanitize_name("my-func.name") == "my_func_name" + ) + + def test_leading_digit(self): + assert CallbackAdapterCollection._sanitize_name("123func") == "cb_123func" + + def test_empty_name(self): + assert CallbackAdapterCollection._sanitize_name("") == "unnamed_callback" + + def test_consecutive_underscores_collapsed(self): + assert CallbackAdapterCollection._sanitize_name("a---b___c") == "a_b_c" + + def test_long_name_truncated_to_64_chars(self): + result = CallbackAdapterCollection._sanitize_name("a" * 200) + assert len(result) <= 64 + assert result[-8:].isalnum() + + def test_long_name_uniqueness(self): + result_a = CallbackAdapterCollection._sanitize_name("a" * 200) + result_b = CallbackAdapterCollection._sanitize_name("b" * 200) + assert result_a != result_b + + def test_short_name_not_truncated(self): + assert CallbackAdapterCollection._sanitize_name("short_name") == "short_name" + + +class TestOutputSemanticSummary: + """Test the _OUTPUT_SEMANTICS mapping in description_outputs.py. + + Other description tests (docstring, output target, multi-output) are + covered by TestTool in test_callback_adapter.py using real adapters. + """ + + @staticmethod + def _adapter_with_outputs(outputs, docstring=None): + from unittest.mock import Mock + adapter = Mock() + adapter.outputs = outputs + adapter._docstring = docstring + return adapter + + @staticmethod + def _out(comp_id, prop, comp_type=None): + return { + "id_and_prop": f"{comp_id}.{prop}", + "component_id": comp_id, + "property": prop, + "component_type": comp_type, + "initial_value": None, + } + + def test_semantic_summary_with_component_type(self): + adapter = self._adapter_with_outputs([self._out("my-graph", "figure", "Graph")]) + desc = build_tool_description(adapter) + assert "Returns chart/visualization data" in desc + + def test_semantic_summary_fallback_by_property(self): + adapter = self._adapter_with_outputs([self._out("unknown-id", "figure")]) + desc = build_tool_description(adapter) + assert "Returns chart/visualization data" in desc From 390f03974126485962aded56ba9a7fec30a32147 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 1 Apr 2026 15:35:13 -0600 Subject: [PATCH 25/27] Wire MCP server, SSE transport, and Dash app integration --- dash/_configs.py | 2 + dash/dash.py | 31 + dash/mcp/__init__.py | 7 + dash/mcp/_server.py | 277 +++++ dash/mcp/_sse.py | 67 ++ dash/mcp/notifications/__init__.py | 7 + .../notification_tools_changed.py | 30 + dash/mcp/primitives/__init__.py | 17 + .../tools/callback_adapter_collection.py | 2 - tests/integration/mcp/conftest.py | 53 + .../primitives/resources/test_resources.py | 51 + .../tools/test_callback_signatures.py | 958 ++++++++++++++++++ .../tools/test_duplicate_outputs.py | 128 +++ .../primitives/tools/test_input_schemas.py | 66 ++ .../tools/test_tool_get_dash_component.py | 54 + .../mcp/primitives/tools/test_tools_list.py | 118 +++ tests/integration/mcp/test_server.py | 304 ++++++ tests/unit/mcp/test_server.py | 92 ++ tests/unit/mcp/tools/test_run_callback.py | 246 +++++ 19 files changed, 2508 insertions(+), 2 deletions(-) create mode 100644 dash/mcp/__init__.py create mode 100644 dash/mcp/_server.py create mode 100644 dash/mcp/_sse.py create mode 100644 dash/mcp/notifications/__init__.py create mode 100644 dash/mcp/notifications/notification_tools_changed.py create mode 100644 dash/mcp/primitives/__init__.py create mode 100644 tests/integration/mcp/conftest.py create mode 100644 tests/integration/mcp/primitives/resources/test_resources.py create mode 100644 tests/integration/mcp/primitives/tools/test_callback_signatures.py create mode 100644 tests/integration/mcp/primitives/tools/test_duplicate_outputs.py create mode 100644 tests/integration/mcp/primitives/tools/test_input_schemas.py create mode 100644 tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py create mode 100644 tests/integration/mcp/primitives/tools/test_tools_list.py create mode 100644 tests/integration/mcp/test_server.py create mode 100644 tests/unit/mcp/test_server.py create mode 100644 tests/unit/mcp/tools/test_run_callback.py diff --git a/dash/_configs.py b/dash/_configs.py index edbf7b50d1..f6df4001f1 100644 --- a/dash/_configs.py +++ b/dash/_configs.py @@ -32,6 +32,8 @@ def load_dash_env_vars(): "DASH_DISABLE_VERSION_CHECK", "DASH_PRUNE_ERRORS", "DASH_COMPRESS", + "DASH_MCP_ENABLED", + "DASH_MCP_PATH", "HOST", "PORT", ) diff --git a/dash/dash.py b/dash/dash.py index 9fa9f1e8e6..c12662a4e9 100644 --- a/dash/dash.py +++ b/dash/dash.py @@ -472,6 +472,8 @@ def __init__( # pylint: disable=too-many-statements on_error: Optional[Callable[[Exception], Any]] = None, use_async: Optional[bool] = None, health_endpoint: Optional[str] = None, + enable_mcp: Optional[bool] = None, + mcp_path: Optional[str] = None, **obsolete, ): @@ -573,6 +575,13 @@ def __init__( # pylint: disable=too-many-statements # keep title as a class property for backwards compatibility self.title = title + # MCP (Model Context Protocol) configuration + self._enable_mcp = get_combined_config("mcp_enabled", enable_mcp, True) + _mcp_path = get_combined_config("mcp_path", mcp_path, "_mcp") + self._mcp_path = ( + _mcp_path.lstrip("/") if isinstance(_mcp_path, str) else _mcp_path + ) + # list of dependencies - this one is used by the back end for dispatching self.callback_map: dict = {} # same deps as a list to catch duplicate outputs, and to send to the front end @@ -793,6 +802,21 @@ def _setup_routes(self): hook.data["methods"], ) + if self._enable_mcp: + from .mcp import ( # pylint: disable=import-outside-toplevel + enable_mcp_server, + ) + + try: + enable_mcp_server(self, self._mcp_path) + except Exception as e: # pylint: disable=broad-exception-caught + self._enable_mcp = False + self.logger.warning( + "MCP server could not be started at '%s': %s", + self._mcp_path, + e, + ) + # catch-all for front-end routes, used by dcc.Location self._add_url("", self.index) @@ -2526,6 +2550,13 @@ def verify_url_part(served_part, url_part, part_name): if not jupyter_dash or not jupyter_dash.in_ipython: self.logger.info("Dash is running on %s://%s%s%s\n", *display_url) + if self._enable_mcp: + self.logger.info( + " * MCP available at %s://%s%s%s%s\n", + *display_url[:3], + self.config.routes_pathname_prefix, + self._mcp_path, + ) if self.config.extra_hot_reload_paths: extra_files = flask_run_options["extra_files"] = [] diff --git a/dash/mcp/__init__.py b/dash/mcp/__init__.py new file mode 100644 index 0000000000..2677ea141b --- /dev/null +++ b/dash/mcp/__init__.py @@ -0,0 +1,7 @@ +"""Dash MCP (Model Context Protocol) server integration.""" + +from dash.mcp._server import enable_mcp_server + +__all__ = [ + enable_mcp_server, +] diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py new file mode 100644 index 0000000000..1c6279290b --- /dev/null +++ b/dash/mcp/_server.py @@ -0,0 +1,277 @@ +"""Flask route setup, Streamable HTTP transport, and MCP message handling.""" + +from __future__ import annotations + +import atexit +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from flask import Response, request + +from dash.mcp.types import MCPError + +if TYPE_CHECKING: + from dash import Dash + +from dash import get_app + +from mcp.types import ( + LATEST_PROTOCOL_VERSION, + ErrorData, + Implementation, + InitializeResult, + JSONRPCError, + JSONRPCResponse, + ResourcesCapability, + ServerCapabilities, + ToolsCapability, +) + +from dash.version import __version__ +from dash.mcp._sse import ( + close_sse_stream, + create_sse_stream, + shutdown_all_streams, +) +from dash.mcp.primitives import ( + call_tool, + list_resource_templates, + list_resources, + list_tools, + read_resource, +) +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + +logger = logging.getLogger(__name__) + + +def enable_mcp_server(app: Dash, mcp_path: str) -> None: + """ + Add MCP routes to a Dash/Flask app. + + Registers a single Streamable HTTP endpoint for the MCP protocol. + Uses ``app._add_url()`` so that ``routes_pathname_prefix`` is applied + automatically. + + Args: + app: The Dash application instance. + mcp_path: Route prefix for MCP endpoints. + """ + # Session storage: session_id -> metadata + sessions: dict[str, dict[str, Any]] = {} + + def _create_session() -> str: + sid = str(uuid.uuid4()) + sessions[sid] = {} + return sid + + # -- Streamable HTTP endpoint -------------------------------------------- + + def mcp_handler() -> Response: + if request.method == "POST": + return _handle_post() + if request.method == "GET": + return _handle_get() + if request.method == "DELETE": + return _handle_delete() + return Response( + json.dumps({"error": "Method not allowed"}), + content_type="application/json", + status=405, + ) + + def _handle_get() -> Response: + session_id = request.headers.get("mcp-session-id") + if not session_id or session_id not in sessions: + return Response( + json.dumps({"error": "Session not found"}), + content_type="application/json", + status=404, + ) + return create_sse_stream(sessions, session_id) + + def _handle_post() -> Response: + content_type = request.content_type or "" + if "application/json" not in content_type: + return Response( + json.dumps({"error": "Content-Type must be application/json"}), + content_type="application/json", + status=415, + ) + + try: + data = request.get_json() + except Exception: + return Response( + json.dumps({"error": "Invalid JSON"}), + content_type="application/json", + status=400, + ) + + method = data.get("method", "") + request_id = data.get("id") + session_id = request.headers.get("mcp-session-id") + + stale_session = False + if method == "initialize": + session_id = _create_session() + elif session_id and session_id not in sessions: + stale_session = True + sessions[session_id] = {} + elif not session_id: + session_id = _create_session() + + response_data = _process_mcp_message(data) + + if response_data is None: + return Response("", status=202) + + if stale_session: + _inject_warning(response_data, _STALE_SESSION_WARNING) + + return Response( + json.dumps(response_data), + content_type="application/json", + status=200, + headers={"mcp-session-id": session_id}, + ) + + def _handle_delete() -> Response: + session_id = request.headers.get("mcp-session-id") + if not session_id or session_id not in sessions: + return Response( + json.dumps({"error": "Session not found"}), + content_type="application/json", + status=404, + ) + close_sse_stream(sessions[session_id]) + del sessions[session_id] + logger.info("MCP session terminated: %s", session_id) + return Response("", status=204) + + # -- Register routes ----------------------------------------------------- + + from dash._get_app import with_app_context_factory + + app._add_url( + mcp_path, with_app_context_factory(mcp_handler, app), ["GET", "POST", "DELETE"] + ) + + # Close all SSE streams on server shutdown so MCP clients see a + # clean stream end and can reconnect promptly. + atexit.register(shutdown_all_streams, sessions) + + logger.info( + "MCP routes registered at %s%s", + app.config.routes_pathname_prefix, + mcp_path, + ) + + +_STALE_SESSION_WARNING = ( + "[Warning: your session was not recognised" + " — the app may have restarted." + " Please call tools/list to refresh your tool list." + " Please ask the user to reconnect to the MCP server.]" +) + + +def _inject_warning(response_data: dict[str, Any], warning: str) -> None: + """Append a warning to a JSON-RPC response dict. + + For successful ``tools/call`` responses the warning is added as an + extra text content block so the agent sees it alongside the result. + For error responses the warning is appended to the error message. + Other responses (tools/list, resources/*) are left unchanged — the + JSON-RPC spec forbids extra top-level keys. + """ + # tools/call success: result has a "content" list + result = response_data.get("result") + if isinstance(result, dict) and isinstance(result.get("content"), list): + result["content"].append({"type": "text", "text": warning}) + return + + # Error response + error = response_data.get("error") + if isinstance(error, dict) and "message" in error: + error["message"] += " " + warning + + +def _handle_initialize() -> InitializeResult: + return InitializeResult( + protocolVersion=LATEST_PROTOCOL_VERSION, + capabilities=ServerCapabilities( + tools=ToolsCapability(listChanged=True), + resources=ResourcesCapability(), + ), + serverInfo=Implementation(name="Plotly Dash", version=__version__), + instructions=( + "This is a Dash web application. " + "Dash apps are stateless: calling a tool executes " + "a callback and returns its result to you, but does " + "NOT update the user's browser. " + "Use tool results to answer questions about what " + "the app would produce for given inputs." + ), + ) + + +def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: + """ + Process an MCP JSON-RPC message and return the response dict. + + Returns ``None`` for notifications (no ``id`` field). + """ + method = data.get("method", "") + params = data.get("params", {}) or {} + request_id = data.get("id") + + app = get_app() + if not hasattr(app, "mcp_callback_map"): + app.mcp_callback_map = CallbackAdapterCollection(app) + + mcp_methods = { + "initialize": _handle_initialize, + "tools/list": lambda: list_tools(), + "tools/call": lambda: call_tool( + params.get("name", ""), params.get("arguments", {}) + ), + "resources/list": lambda: list_resources(), + "resources/templates/list": lambda: list_resource_templates(), + "resources/read": lambda: read_resource(params.get("uri", "")), + } + + try: + handler = mcp_methods.get(method) + if handler is None: + if method.startswith("notifications/"): + return None + raise ValueError(f"Unknown method: {method}") + + result = handler() + + response = JSONRPCResponse( + jsonrpc="2.0", + id=request_id, + result=result.model_dump(exclude_none=True, mode="json"), + ) + return response.model_dump(exclude_none=True, mode="json") + + except MCPError as e: + logger.error("MCP error: %s", e) + return JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=e.code, message=str(e)), + ).model_dump(exclude_none=True) + except Exception as e: + logger.error("MCP error: %s", e, exc_info=True) + return JSONRPCError( + jsonrpc="2.0", + id=request_id, + error=ErrorData(code=-32603, message=f"{type(e).__name__}: {e}"), + ).model_dump(exclude_none=True) diff --git a/dash/mcp/_sse.py b/dash/mcp/_sse.py new file mode 100644 index 0000000000..4928dc68b2 --- /dev/null +++ b/dash/mcp/_sse.py @@ -0,0 +1,67 @@ +"""SSE stream generation and queue management.""" + +from __future__ import annotations + +import queue +from typing import Any + +from flask import Response + + +def create_sse_stream(sessions: dict[str, dict[str, Any]], session_id: str) -> Response: + """Create a Server-Sent Events stream for the given session. + + Stores a :class:`queue.Queue` in ``sessions[session_id]["sse_queue"]`` + and returns a Flask streaming ``Response``. The generator yields + events pushed to the queue, with keepalive comments every 30 seconds. + """ + event_queue: queue.Queue[str | None] = queue.Queue() + # Replace any prior SSE queue for this session (client reconnect). + sessions[session_id]["sse_queue"] = event_queue + + def _generate(): + try: + while True: + try: + event = event_queue.get(timeout=30) + if event is None: + return # Sentinel: server closing stream + yield f"event: message\ndata: {event}\n\n" + except queue.Empty: + yield ": keepalive\n\n" + except GeneratorExit: + pass + finally: + # Clean up queue reference if it's still ours. + if sessions.get(session_id, {}).get("sse_queue") is event_queue: + sessions[session_id].pop("sse_queue", None) + + return Response( + _generate(), + content_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "mcp-session-id": session_id, + }, + ) + + +def close_sse_stream(session_data: dict[str, Any]) -> None: + """Send a sentinel to shut down the session's SSE stream cleanly.""" + sse_queue = session_data.get("sse_queue") + if sse_queue is not None: + try: + sse_queue.put_nowait(None) + except queue.Full: + pass + + +def shutdown_all_streams(sessions: dict[str, dict[str, Any]]) -> None: + """Close all active SSE streams. + + Called during server shutdown (via ``atexit``) so that connected + MCP clients see a clean stream end and can reconnect promptly. + """ + for session_data in list(sessions.values()): + close_sse_stream(session_data) diff --git a/dash/mcp/notifications/__init__.py b/dash/mcp/notifications/__init__.py new file mode 100644 index 0000000000..b1fe9e8665 --- /dev/null +++ b/dash/mcp/notifications/__init__.py @@ -0,0 +1,7 @@ +"""Server-initiated MCP notifications.""" + +from .notification_tools_changed import broadcast_tools_changed + +__all__ = [ + "broadcast_tools_changed", +] diff --git a/dash/mcp/notifications/notification_tools_changed.py b/dash/mcp/notifications/notification_tools_changed.py new file mode 100644 index 0000000000..1970667d1a --- /dev/null +++ b/dash/mcp/notifications/notification_tools_changed.py @@ -0,0 +1,30 @@ +"""Tool list change notifications.""" + +from __future__ import annotations + +import json +import queue +from typing import Any + + +def broadcast_tools_changed( + sessions: dict[str, dict[str, Any]], +) -> None: + """Push a tools/list_changed notification to all active SSE streams. + + Not called automatically yet — available for future hot-reload + or dynamic callback registration. + """ + notification = json.dumps( + { + "jsonrpc": "2.0", + "method": "notifications/tools/list_changed", + } + ) + for data in sessions.values(): + sse_queue = data.get("sse_queue") + if sse_queue is not None: + try: + sse_queue.put_nowait(notification) + except queue.Full: + pass diff --git a/dash/mcp/primitives/__init__.py b/dash/mcp/primitives/__init__.py new file mode 100644 index 0000000000..b14839f1e1 --- /dev/null +++ b/dash/mcp/primitives/__init__.py @@ -0,0 +1,17 @@ +from .resources import ( + list_resources, + list_resource_templates, + read_resource, +) +from .tools import ( + call_tool, + list_tools, +) + +__all__ = [ + call_tool, + list_resources, + list_resource_templates, + list_tools, + read_resource, +] diff --git a/dash/mcp/primitives/tools/callback_adapter_collection.py b/dash/mcp/primitives/tools/callback_adapter_collection.py index b53cf53a9d..ed6dafa65e 100644 --- a/dash/mcp/primitives/tools/callback_adapter_collection.py +++ b/dash/mcp/primitives/tools/callback_adapter_collection.py @@ -35,8 +35,6 @@ def __init__(self, app): CallbackAdapter(callback_output_id=output_id) for output_id in self._tool_names_map ] - # TODO: enable_mcp_server() will replace this with a direct assignment on app - app.mcp_callback_map = self @staticmethod def _sanitize_name(name: str) -> str: diff --git a/tests/integration/mcp/conftest.py b/tests/integration/mcp/conftest.py new file mode 100644 index 0000000000..0f212d1763 --- /dev/null +++ b/tests/integration/mcp/conftest.py @@ -0,0 +1,53 @@ +"""Shared helpers for MCP integration tests.""" + +import requests + + +def _mcp_post(server_url, method, params=None, session_id=None, request_id=1): + headers = {"Content-Type": "application/json"} + if session_id: + headers["mcp-session-id"] = session_id + return requests.post( + f"{server_url}/_mcp", + json={ + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + }, + headers=headers, + timeout=5, + ) + + +def _mcp_session(server_url): + resp = _mcp_post(server_url, "initialize") + resp.raise_for_status() + return resp.headers["mcp-session-id"] + + +def _mcp_tools(server_url): + sid = _mcp_session(server_url) + resp = _mcp_post(server_url, "tools/list", session_id=sid, request_id=2) + resp.raise_for_status() + return resp.json()["result"]["tools"] + + +def _mcp_call_tool(server_url, tool_name, arguments=None): + sid = _mcp_session(server_url) + resp = _mcp_post( + server_url, + "tools/call", + {"name": tool_name, "arguments": arguments or {}}, + session_id=sid, + request_id=2, + ) + resp.raise_for_status() + return resp.json() + + +def _mcp_method(server_url, method, params=None): + sid = _mcp_session(server_url) + resp = _mcp_post(server_url, method, params, session_id=sid, request_id=2) + resp.raise_for_status() + return resp.json() diff --git a/tests/integration/mcp/primitives/resources/test_resources.py b/tests/integration/mcp/primitives/resources/test_resources.py new file mode 100644 index 0000000000..dfc1e09f9b --- /dev/null +++ b/tests/integration/mcp/primitives/resources/test_resources.py @@ -0,0 +1,51 @@ +"""Integration tests for MCP resources.""" + +import json + +from dash import Dash, dcc, html + +from tests.integration.mcp.conftest import _mcp_method + + +def test_resources_list_includes_layout(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + dash_duo.start_server(app) + result = _mcp_method(dash_duo.server.url, "resources/list") + + assert "result" in result + uris = [r["uri"] for r in result["result"]["resources"]] + assert "dash://layout" in uris + + +def test_read_layout_resource(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="res-dd", options=["x", "y"], value="x"), + html.Div(id="out"), + ] + ) + + dash_duo.start_server(app) + result = _mcp_method( + dash_duo.server.url, + "resources/read", + {"uri": "dash://layout"}, + ) + + assert "result" in result + layout = json.loads(result["result"]["contents"][0]["text"]) + assert layout["type"] == "Div" + children = layout["props"]["children"] + dd = next( + c for c in children if isinstance(c, dict) and c.get("type") == "Dropdown" + ) + assert dd["props"]["id"] == "res-dd" + assert dd["props"]["options"] == ["x", "y"] diff --git a/tests/integration/mcp/primitives/tools/test_callback_signatures.py b/tests/integration/mcp/primitives/tools/test_callback_signatures.py new file mode 100644 index 0000000000..db325f2046 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_callback_signatures.py @@ -0,0 +1,958 @@ +""" +Integration tests for all Dash callback signature types. + +Each test verifies that: +1. The MCP tool schema accurately reflects the callback's parameters +2. Calling the tool with those parameters produces the expected result + +Assertions are derived from the callback definition, not the implementation. + +See: https://dash.plotly.com/flexible-callback-signatures +""" + +from dash import Dash, Input, Output, State, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + + +def _find_tool(tools, name): + return next(t for t in tools if t["name"] == name) + + +def _get_response(result): + return result["result"]["structuredContent"]["response"] + + +def test_positional_callback(dash_duo): + """Standard positional: Input("fruit", "value") → param named 'value'.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="fruit", options=["apple", "banana"], value="apple"), + html.Div(id="out"), + ] + ) + + # Callback: 1 Input → 1 param named "value" (from function signature) + # Returns string → Output("out", "children") + @app.callback(Output("out", "children"), Input("fruit", "value")) + def show_fruit(value): + return f"Selected: {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Selected: apple") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_fruit") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"value"} + assert any(s.get("type") == "string" for s in props["value"]["anyOf"]) + + # Tool description reflects initial state + value_desc = props["value"].get("description", "") + assert "value: 'apple'" in value_desc + assert "options: ['apple', 'banana']" in value_desc + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "apple"}) + response = _get_response(result) + assert response["out"]["children"] == "Selected: apple" + + # MCP tool with different inputs + result = _mcp_call_tool(dash_duo.server.url, "show_fruit", {"value": "banana"}) + response = _get_response(result) + assert response["out"]["children"] == "Selected: banana" + + +def test_positional_with_state(dash_duo): + """Positional with State: Input + State both appear as params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + # Callback: 1 Input + 1 State → 2 params named "n_clicks" and "value" + @app.callback( + Output("out", "children"), + Input("btn", "n_clicks"), + State("inp", "value"), + ) + def update(n_clicks, value): + return f"Clicked {n_clicks} with {value}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Clicked None with hello") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"n_clicks", "value"} + assert any(s.get("type") == "number" for s in props["n_clicks"]["anyOf"]) + + # Tool description reflects initial state + assert "value: 'hello'" in props["value"].get("description", "") + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, "update", {"n_clicks": None, "value": "hello"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Clicked None with hello" + + result = _mcp_call_tool( + dash_duo.server.url, "update", {"n_clicks": 3, "value": "world"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Clicked 3 with world" + + +def test_multi_output_positional(dash_duo): + """Multi-output: returns tuple → both outputs updated in response.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="test"), + html.Div(id="out1"), + html.Div(id="out2"), + ] + ) + + # Callback: 1 Input → 2 Outputs via tuple return + @app.callback( + Output("out1", "children"), + Output("out2", "children"), + Input("inp", "value"), + ) + def split_case(value): + return value.upper(), value.lower() + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out1", "TEST") + dash_duo.wait_for_text_to_equal("#out2", "test") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "split_case") + props = tool["inputSchema"]["properties"] + assert set(props.keys()) == {"value"} + + # Tool description reflects initial state + assert "value: 'test'" in props["value"].get("description", "") + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "split_case", {"value": "test"}) + response = _get_response(result) + assert response["out1"]["children"] == "TEST" + assert response["out2"]["children"] == "test" + + +def test_dict_based_inputs_and_state(dash_duo): + """Dict-based: inputs=dict(trigger=...), state=dict(name=...) → dict keys are param names.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="name-input", value="world"), + html.Div(id="out"), + ] + ) + + # Callback: dict keys "trigger" and "name" become param names + @app.callback( + Output("out", "children"), + inputs=dict(trigger=Input("btn", "n_clicks")), + state=dict(name=State("name-input", "value")), + ) + def greet(trigger, name): + return f"Hello, {name}!" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Hello, world!") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"trigger", "name"} + assert any(s.get("type") == "number" for s in props["trigger"]["anyOf"]) + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, "greet", {"trigger": None, "name": "world"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Hello, world!" + + result = _mcp_call_tool( + dash_duo.server.url, "greet", {"trigger": 1, "name": "Dash"} + ) + response = _get_response(result) + assert response["out"]["children"] == "Hello, Dash!" + + +def test_dict_based_outputs(dash_duo): + """Dict-based outputs: output=dict(...) → callback returns dict, both outputs updated.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="upper-out"), + html.Div(id="lower-out"), + ] + ) + + # Callback: dict output keys "upper" and "lower" map to components + @app.callback( + output=dict( + upper=Output("upper-out", "children"), + lower=Output("lower-out", "children"), + ), + inputs=dict(val=Input("inp", "value")), + ) + def transform(val): + return dict(upper=val.upper(), lower=val.lower()) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#upper-out", "HELLO") + dash_duo.wait_for_text_to_equal("#lower-out", "hello") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "transform") + props = tool["inputSchema"]["properties"] + assert set(props.keys()) == {"val"} + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool(dash_duo.server.url, "transform", {"val": "hello"}) + response = _get_response(result) + assert response["upper-out"]["children"] == "HELLO" + assert response["lower-out"]["children"] == "hello" + + +def test_mixed_input_state_in_inputs(dash_duo): + """Mixed: State inside inputs=dict alongside Input → all appear as params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="first", value="Jane"), + dcc.Input(id="last", value="Doe"), + html.Div(id="out"), + ] + ) + + # Callback: Input and State mixed in same dict → all keys are params + @app.callback( + Output("out", "children"), + inputs=dict( + clicks=Input("btn", "n_clicks"), + first=State("first", "value"), + last=State("last", "value"), + ), + ) + def full_name(clicks, first, last): + return f"{first} {last}" + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#out", "Jane Doe") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "full_name") + props = tool["inputSchema"]["properties"] + + assert set(props.keys()) == {"clicks", "first", "last"} + assert any(s.get("type") == "number" for s in props["clicks"]["anyOf"]) + + # MCP tool with initial inputs matches browser + result = _mcp_call_tool( + dash_duo.server.url, + "full_name", + {"clicks": None, "first": "Jane", "last": "Doe"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "Jane Doe" + + result = _mcp_call_tool( + dash_duo.server.url, + "full_name", + {"clicks": 1, "first": "John", "last": "Smith"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "John Smith" + + +def test_tuple_grouped_inputs(dash_duo): + """Tuple grouping: pair=(Input("a",...), Input("b",...)) → expands to two named params.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a", value="1"), + dcc.Input(id="b", value="2"), + html.Div(id="out"), + ] + ) + + # Callback: tuple group "pair" maps to 2 deps → 2 params named pair___ + @app.callback( + Output("out", "children"), + inputs=dict(pair=(Input("a", "value"), Input("b", "value"))), + ) + def combine(pair): + return f"{pair[0]}+{pair[1]}" + + dash_duo.start_server(app) + tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") + props = tool["inputSchema"]["properties"] + + # Tuple expands: one param per dep, named with group prefix + component info + assert set(props.keys()) == {"pair_a__value", "pair_b__value"} + for schema in props.values(): + assert any(s.get("type") == "string" for s in schema["anyOf"]) + + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"pair_a__value": "x", "pair_b__value": "y"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "x+y" + + +def test_initial_values_from_chained_callbacks(dash_duo): + """Querying components reflects post-initial-callback values. + + 3-link chain: country (default "France") → update_states → + state (should become "Ile-de-France") → update_cities → + city (should become "Paris"). + """ + DATA = { + "France": { + "Ile-de-France": ["Paris", "Versailles"], + "Provence": ["Marseille", "Nice"], + }, + "Germany": { + "Bavaria": ["Munich", "Nuremberg"], + "Berlin": ["Berlin"], + }, + } + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=list(DATA.keys()), value="France"), + dcc.Dropdown(id="state"), + dcc.Dropdown(id="city"), + ] + ) + + @app.callback( + Output("state", "options"), + Output("state", "value"), + Input("country", "value"), + ) + def update_states(country): + if not country: + return [], None + states = list(DATA[country].keys()) + return [{"label": s, "value": s} for s in states], states[0] + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("state", "value"), + Input("country", "value"), + ) + def update_cities(state, country): + if not state or not country: + return [], None + cities = DATA[country][state] + return [{"label": c, "value": c} for c in cities], cities[0] + + dash_duo.start_server(app) + + # Tool descriptions should reflect post-initial-callback state + tools = _mcp_tools(dash_duo.server.url) + update_cities_tool = _find_tool(tools, "update_cities") + state_desc = update_cities_tool["inputSchema"]["properties"]["state"].get( + "description", "" + ) + # state.value was set to "Ile-de-France" by update_states initial callback + assert "Ile-de-France" in state_desc + + # state.value should be "Ile-de-France" (first state for France) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "state", "property": "value"}, + ) + state_props = result["result"]["structuredContent"]["properties"] + assert state_props["value"]["initial_value"] == "Ile-de-France" + + # city.value should be "Paris" (first city for Ile-de-France) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "city", "property": "value"}, + ) + city_props = result["result"]["structuredContent"]["properties"] + assert city_props["value"]["initial_value"] == "Paris" + + +def test_dict_based_reordered_state_input(dash_duo): + """Dict-based callback with State before Input: call works, schema types correct. + + State is listed before Input in the dict. The callback should still + work correctly via MCP, and the schema types should match the + function annotations (name: str, trigger: int), not be swapped. + """ + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + dcc.Input(id="inp", value="World"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + inputs=dict(name=State("inp", "value"), trigger=Input("btn", "n_clicks")), + ) + def greet(name: str, trigger: int): + return f"Hello {name}" + + dash_duo.start_server(app) + + # First: verify the callback actually works with these args + result = _mcp_call_tool( + dash_duo.server.url, + "greet", + {"name": "Dash", "trigger": 1}, + ) + assert _get_response(result)["out"]["children"] == "Hello Dash" + + # Second: verify schema types match annotations + tool = _find_tool(_mcp_tools(dash_duo.server.url), "greet") + props = tool["inputSchema"]["properties"] + assert props["trigger"]["type"] == "integer" + assert props["name"]["type"] == "string" + + # Third: verify each param describes the correct component + trigger_desc = props["trigger"].get("description", "") + assert "number of times that this element has been clicked on" in trigger_desc + name_desc = props["name"].get("description", "") + assert "The value of the input" in name_desc + + +def test_pattern_matching_callback(dash_duo): + """Pattern-matching dict IDs: tool works with correct params and results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="hello"), + dcc.Input(id={"type": "field", "index": 1}, value="world"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input({"type": "field", "index": 0}, "value"), + Input({"type": "field", "index": 1}, "value"), + ) + def combine(first, second): + return f"{first} {second}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "combine") + assert tool is not None + props = tool["inputSchema"]["properties"] + assert "first" in props + assert "second" in props + + # Verify initial output matches what the browser shows + dash_duo.wait_for_text_to_equal("#out", "hello world") + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"first": "hello", "second": "world"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "hello world" + + # Verify with different values + result = _mcp_call_tool( + dash_duo.server.url, + "combine", + {"first": "foo", "second": "bar"}, + ) + response = _get_response(result) + assert response["out"]["children"] == "foo bar" + + +def test_pattern_matching_with_all_wildcard(dash_duo): + """ALL wildcard: one callback receives values from all matching components.""" + from dash import ALL + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "input", "index": 0}, value="alpha"), + dcc.Input(id={"type": "input", "index": 1}, value="beta"), + html.Div(id="summary"), + ] + ) + + @app.callback( + Output("summary", "children"), + Input({"type": "input", "index": ALL}, "value"), + ) + def summarize(values): + return ", ".join(v for v in values if v) + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#summary", "alpha, beta") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "summarize") + assert tool is not None + + # Schema must describe values as an array of {id, property, value} objects + values_schema = tool["inputSchema"]["properties"]["values"] + assert ( + values_schema["type"] == "array" + ), f"ALL wildcard param should be typed as array, got: {values_schema}" + assert "items" in values_schema, "Array schema should include items definition" + items = values_schema["items"] + assert items["type"] == "object" + assert "id" in items["properties"] + assert "value" in items["properties"] + assert "Pattern-matching input (ALL)" in values_schema.get( + "description", "" + ), "ALL wildcard param description should explain the pattern-matching behavior" + + # MCP tool call with browser-like format: concrete IDs + values + result = _mcp_call_tool( + dash_duo.server.url, + "summarize", + { + "values": [ + { + "id": {"type": "input", "index": 0}, + "property": "value", + "value": "alpha", + }, + { + "id": {"type": "input", "index": 1}, + "property": "value", + "value": "beta", + }, + ] + }, + ) + response = _get_response(result) + assert response["summary"]["children"] == "alpha, beta" + + # Different values + result = _mcp_call_tool( + dash_duo.server.url, + "summarize", + { + "values": [ + { + "id": {"type": "input", "index": 0}, + "property": "value", + "value": "one", + }, + { + "id": {"type": "input", "index": 1}, + "property": "value", + "value": "two", + }, + ] + }, + ) + response = _get_response(result) + assert response["summary"]["children"] == "one, two" + + +def test_pattern_matching_mixed_outputs(dash_duo): + """Mixed outputs: one regular + one ALL wildcard in the same callback.""" + from dash import ALL + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id={"type": "field", "index": 0}, value="a"), + dcc.Input(id={"type": "field", "index": 1}, value="b"), + html.Div(id={"type": "echo", "index": 0}), + html.Div(id={"type": "echo", "index": 1}), + html.Div(id="total"), + ] + ) + + @app.callback( + Output({"type": "echo", "index": ALL}, "children"), + Output("total", "children"), + Input({"type": "field", "index": ALL}, "value"), + ) + def echo_and_total(values): + echoes = [f"Echo: {v}" for v in values] + total = f"Total: {len(values)} items" + return echoes, total + + dash_duo.start_server(app) + dash_duo.wait_for_text_to_equal("#total", "Total: 2 items") + + result = _mcp_call_tool( + dash_duo.server.url, + "echo_and_total", + { + "values": [ + { + "id": {"type": "field", "index": 0}, + "property": "value", + "value": "x", + }, + { + "id": {"type": "field", "index": 1}, + "property": "value", + "value": "y", + }, + ] + }, + ) + response = _get_response(result) + assert response["total"]["children"] == "Total: 2 items" + + +def test_pattern_matching_with_match_wildcard(dash_duo): + """MATCH wildcard: callback fires per-component with matching index. + + Based on https://dash.plotly.com/pattern-matching-callbacks + """ + from dash import MATCH + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + ["NYC", "MTL", "LA", "TOKYO"], + "NYC", + id={"type": "city-dd", "index": 0}, + ), + html.Div(id={"type": "city-out", "index": 0}), + dcc.Dropdown( + ["NYC", "MTL", "LA", "TOKYO"], + "LA", + id={"type": "city-dd", "index": 1}, + ), + html.Div(id={"type": "city-out", "index": 1}), + ] + ) + + @app.callback( + Output({"type": "city-out", "index": MATCH}, "children"), + Input({"type": "city-dd", "index": MATCH}, "value"), + ) + def show_city(value): + return f"Selected: {value}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") + assert tool is not None + + # Schema describes MATCH input + value_schema = tool["inputSchema"]["properties"]["value"] + assert "Pattern-matching input (MATCH)" in value_schema.get("description", "") + + # Call with concrete ID for index 0 (MATCH takes a single entry, not an array) + result = _mcp_call_tool( + dash_duo.server.url, + "show_city", + { + "value": { + "id": {"type": "city-dd", "index": 0}, + "property": "value", + "value": "MTL", + } + }, + ) + response = _get_response(result) + # Find the output key containing "city-out" (Dash may serialize dict IDs differently) + out_key = next(k for k in response if "city-out" in k) + assert response[out_key]["children"] == "Selected: MTL" + + +def test_pattern_matching_with_allsmaller_wildcard(dash_duo): + """ALLSMALLER wildcard: receives values from components with smaller index. + + Based on https://dash.plotly.com/pattern-matching-callbacks + """ + from dash import MATCH, ALLSMALLER + + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + ["France", "Germany", "Japan"], + "France", + id={"type": "country-dd", "index": 0}, + ), + html.Div(id={"type": "country-out", "index": 0}), + dcc.Dropdown( + ["France", "Germany", "Japan"], + "Germany", + id={"type": "country-dd", "index": 1}, + ), + html.Div(id={"type": "country-out", "index": 1}), + dcc.Dropdown( + ["France", "Germany", "Japan"], + "Japan", + id={"type": "country-dd", "index": 2}, + ), + html.Div(id={"type": "country-out", "index": 2}), + ] + ) + + @app.callback( + Output({"type": "country-out", "index": MATCH}, "children"), + Input({"type": "country-dd", "index": MATCH}, "value"), + Input({"type": "country-dd", "index": ALLSMALLER}, "value"), + ) + def show_countries(current, previous): + all_selected = [current] + list(reversed(previous)) + return f"All: {', '.join(all_selected)}" + + dash_duo.start_server(app) + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_countries") + assert tool is not None + + # Schema describes both MATCH and ALLSMALLER inputs + props = tool["inputSchema"]["properties"] + assert "Pattern-matching input (MATCH)" in props["current"].get("description", "") + assert "Pattern-matching input (ALLSMALLER)" in props["previous"].get( + "description", "" + ) + + # Call for index 2: MATCH is a single dict, ALLSMALLER is a list + result = _mcp_call_tool( + dash_duo.server.url, + "show_countries", + { + "current": { + "id": {"type": "country-dd", "index": 2}, + "property": "value", + "value": "Japan", + }, + "previous": [ + { + "id": {"type": "country-dd", "index": 0}, + "property": "value", + "value": "France", + }, + { + "id": {"type": "country-dd", "index": 1}, + "property": "value", + "value": "Germany", + }, + ], + }, + ) + response = _get_response(result) + out_key = next(k for k in response if "country-out" in k) + assert response[out_key]["children"] == "All: Japan, Germany, France" + + +def test_prevent_initial_call_uses_layout_default(dash_duo): + """prevent_initial_call=True: initial value stays as the layout default. + + The dropdown has value="original" in the layout. The callback has + prevent_initial_call=True so it doesn't run on page load. The MCP + tool description should show value: 'a' (layout default). + """ + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b", "c"], value="a"), + html.Div(id="out", children="not yet"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + prevent_initial_call=True, + ) + def update(val): + return f"Changed to: {val}" + + dash_duo.start_server(app) + # Browser shows layout default — callback hasn't fired + dash_duo.wait_for_text_to_equal("#out", "not yet") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "update") + val_desc = tool["inputSchema"]["properties"]["val"].get("description", "") + + # Tool description reflects layout default, not callback output + assert "value: 'a'" in val_desc + + +def test_initial_callback_overrides_layout_value(dash_duo): + """Initial callback overrides layout value in tool description. + + The city dropdown has value="default-city" in the layout. + update_city runs on page load (no prevent_initial_call) and + sets city.value to "Paris". The MCP tool should show "Paris" + as the default, not "default-city". + """ + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="country", options=["France", "Germany"], value="France"), + dcc.Dropdown(id="city", options=[], value="default-city"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("city", "options"), + Output("city", "value"), + Input("country", "value"), + ) + def update_city(country): + if country == "France": + return [{"label": "Paris", "value": "Paris"}], "Paris" + return [{"label": "Berlin", "value": "Berlin"}], "Berlin" + + @app.callback(Output("out", "children"), Input("city", "value")) + def show_city(city): + return f"City: {city}" + + dash_duo.start_server(app) + # Browser shows "Paris" — the initial callback overrode "default-city" + dash_duo.wait_for_text_to_equal("#out", "City: Paris") + + tool = _find_tool(_mcp_tools(dash_duo.server.url), "show_city") + city_desc = tool["inputSchema"]["properties"]["city"].get("description", "") + + # Tool description should show the post-initial-callback value + assert "value: 'Paris'" in city_desc + assert "default-city" not in city_desc + + +def test_callback_context_triggered_id(dash_duo): + """Callbacks using dash.ctx.triggered_id work via MCP. + + Based on https://dash.plotly.com/determining-which-callback-input-changed + """ + from dash import ctx + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Button 1", id="btn-1"), + html.Button("Button 2", id="btn-2"), + html.Button("Button 3", id="btn-3"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("btn-1", "n_clicks"), + Input("btn-2", "n_clicks"), + Input("btn-3", "n_clicks"), + ) + def display(btn1, btn2, btn3): + if not ctx.triggered_id: + return "No button clicked yet" + return f"Last clicked: {ctx.triggered_id}" + + dash_duo.start_server(app) + + # Browser initial state: no button clicked + dash_duo.wait_for_text_to_equal("#output", "No button clicked yet") + + # Tool should have all three button params + tool = _find_tool(_mcp_tools(dash_duo.server.url), "display") + props = tool["inputSchema"]["properties"] + assert "btn1" in props + assert "btn2" in props + assert "btn3" in props + + # Click btn-2 via MCP — ctx.triggered_id should be "btn-2" + result = _mcp_call_tool( + dash_duo.server.url, + "display", + {"btn1": None, "btn2": 1, "btn3": None}, + ) + response = _get_response(result) + assert response["output"]["children"] == "Last clicked: btn-2" + + # Click btn-3 via MCP + result = _mcp_call_tool( + dash_duo.server.url, + "display", + {"btn1": None, "btn2": None, "btn3": 5}, + ) + response = _get_response(result) + assert response["output"]["children"] == "Last clicked: btn-3" + + +def test_no_output_callback_does_not_crash_tools_list(dash_duo): + """A callback with no Output should not crash tools/list. + + No-output callbacks use set_props for side effects. They produce + a hash-only output_id with no dot separator. + """ + from dash import set_props + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button("Log", id="log-btn"), + dcc.Dropdown(id="picker", options=["a", "b"], value="a"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("log-btn", "n_clicks"), prevent_initial_call=True) + def log_click(n): + set_props("display", {"children": f"Logged {n} clicks"}) + + @app.callback(Output("display", "children"), Input("picker", "value")) + def show_selection(val): + return f"Selected: {val}" + + dash_duo.start_server(app) + + tools = _mcp_tools(dash_duo.server.url) + tool_names = [t["name"] for t in tools] + + # show_selection should appear as a tool + assert "show_selection" in tool_names + + # log_click has no declared output but uses set_props — still a valid tool + assert "log_click" in tool_names + + # Call log_click — sideUpdate should show the set_props effect + result = _mcp_call_tool( + dash_duo.server.url, + "log_click", + {"n": 3}, + ) + structured = result["result"]["structuredContent"] + assert "sideUpdate" in structured + assert structured["sideUpdate"]["display"]["children"] == "Logged 3 clicks" + + # get_dash_component shows show_selection as modifier (declared output). + # log_click uses set_props which bypasses the declarative graph — + # its effect is only visible via sideUpdate in tool call results. + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "display", "property": "children"}, + ) + prop_info = result["result"]["structuredContent"]["properties"]["children"] + assert "show_selection" in prop_info["modified_by_tool"] diff --git a/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py b/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py new file mode 100644 index 0000000000..4ad00641f8 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_duplicate_outputs.py @@ -0,0 +1,128 @@ +"""Integration test for duplicate callback outputs. + +Multiple callbacks can output to the same component.property +when using ``allow_duplicate=True``. The MCP server must handle +this correctly — both callbacks should appear as tools, and +calling either should work. +""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool, _mcp_tools + + +def _find_tool(tools, name): + return next((t for t in tools if t["name"] == name), None) + + +def _get_response(result): + return result["result"]["structuredContent"]["response"] + + +def test_duplicate_outputs_both_tools_listed(dash_duo): + """Both callbacks outputting to the same component appear as tools.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + first_tool = _find_tool(tools, "greet_by_first") + last_tool = _find_tool(tools, "greet_by_last") + + assert first_tool is not None, "greet_by_first should be listed" + assert last_tool is not None, "greet_by_last should be listed" + + +def test_duplicate_outputs_both_callable(dash_duo): + """Both callbacks can be called and produce correct results.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + result1 = _mcp_call_tool(dash_duo.server.url, "greet_by_first", {"first": "Alice"}) + assert _get_response(result1)["greeting"]["children"] == "Hello, Alice!" + + result2 = _mcp_call_tool(dash_duo.server.url, "greet_by_last", {"last": "Smith"}) + assert _get_response(result2)["greeting"]["children"] == "Hi, Smith!" + + +def test_duplicate_outputs_find_by_output_returns_primary(dash_duo): + """find_by_output returns the primary (non-duplicate) callback.""" + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="first-name", value="Jane"), + dcc.Input(id="last-name", value="Doe"), + html.Div(id="greeting"), + ] + ) + + @app.callback( + Output("greeting", "children"), + Input("first-name", "value"), + ) + def greet_by_first(first): + return f"Hello, {first}!" + + @app.callback( + Output("greeting", "children", allow_duplicate=True), + Input("last-name", "value"), + prevent_initial_call=True, + ) + def greet_by_last(last): + return f"Hi, {last}!" + + dash_duo.start_server(app) + + # Query the component — should reflect initial callback (greet_by_first) + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "greeting", "property": "children"}, + ) + structured = result["result"]["structuredContent"] + assert structured["properties"]["children"]["initial_value"] == "Hello, Jane!" diff --git a/tests/integration/mcp/primitives/tools/test_input_schemas.py b/tests/integration/mcp/primitives/tools/test_input_schemas.py new file mode 100644 index 0000000000..6ee3510ddd --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_input_schemas.py @@ -0,0 +1,66 @@ +""" +Integration tests for MCP tool schema generation. + +Starts a real Dash server via ``dash_duo`` and verifies that tools +are generated with correct inputSchema, descriptions, and labels. +""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_tools + + +def test_mcp_tool_with_label_and_date_picker_schema(dash_duo): + """Full assertion on a tool with an html.Label and DatePickerSingle constraints.""" + + # -- Test data: change these to update the test -- + label_text = "Departure Date" + component_id = "dp" + min_date = "2020-01-01" + max_date = "2025-12-31" + default_date = "2024-06-15" + func_name = "select_date" + param_name = "date" # function parameter name + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Label(label_text, htmlFor=component_id), + dcc.DatePickerSingle( + id=component_id, + min_date_allowed=min_date, + max_date_allowed=max_date, + date=default_date, + ), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input(component_id, "date")) + def select_date(date): + return f"Selected: {date}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + + # Find the callback tool + tool = next(t for t in tools if t["name"] not in ("get_dash_component",)) + + # -- Tool-level fields -- + assert func_name in tool["name"] + + # -- inputSchema structure -- + schema = tool["inputSchema"] + assert schema["type"] == "object" + assert param_name in schema["required"] + assert param_name in schema["properties"] + + # -- Property schema: type + format + description -- + prop = schema["properties"][param_name] + assert prop["type"] == "string" + assert prop["format"] == "date" + + # description includes all source values (label, constraints, default) + desc = prop["description"] + for expected in (label_text, min_date, max_date, default_date): + assert expected in desc, f"Expected {expected!r} in description: {desc!r}" diff --git a/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py b/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py new file mode 100644 index 0000000000..97472a16d7 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_tool_get_dash_component.py @@ -0,0 +1,54 @@ +"""Integration tests for the get_dash_component tool.""" + +from dash import Dash, dcc, html + +from tests.integration.mcp.conftest import _mcp_call_tool + +EXPECTED_DROPDOWN_OPTIONS = { + "component_id": "my-dropdown", + "component_type": "Dropdown", + "label": None, + "properties": { + "options": { + "initial_value": [ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + "modified_by_tool": [], + "input_to_tool": [], + }, + }, +} + + +def test_query_component_returns_structured_output(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown( + id="my-dropdown", + options=[ + {"label": "New York", "value": "NYC"}, + {"label": "Montreal", "value": "MTL"}, + ], + value="NYC", + ), + ] + ) + + dash_duo.start_server(app) + + result = _mcp_call_tool( + dash_duo.server.url, + "get_dash_component", + {"component_id": "my-dropdown", "property": "options"}, + ) + + assert "result" in result, f"Expected result in response: {result}" + structured = result["result"]["structuredContent"] + assert structured["component_id"] == EXPECTED_DROPDOWN_OPTIONS["component_id"] + assert structured["component_type"] == EXPECTED_DROPDOWN_OPTIONS["component_type"] + assert ( + structured["properties"]["options"] + == EXPECTED_DROPDOWN_OPTIONS["properties"]["options"] + ) diff --git a/tests/integration/mcp/primitives/tools/test_tools_list.py b/tests/integration/mcp/primitives/tools/test_tools_list.py new file mode 100644 index 0000000000..dc3d977146 --- /dev/null +++ b/tests/integration/mcp/primitives/tools/test_tools_list.py @@ -0,0 +1,118 @@ +"""Integration tests for tools/list — naming, dedup, and spec compliance.""" + +from dash import Dash, Input, Output, dcc, html + +from tests.integration.mcp.conftest import _mcp_tools + + +def test_tool_names_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"], value="a"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("dd", "value")) + def update(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + for param_name in tool.get("inputSchema", {}).get("properties", {}): + assert len(param_name) <= 64, f"Param name exceeds 64 chars: {param_name}" + + +def test_long_callback_ids_within_64_chars(dash_duo): + app = Dash(__name__) + long_id = "a" * 120 + app.layout = html.Div( + [ + dcc.Input(id=long_id, value="test"), + html.Div(id=f"{long_id}-output"), + ] + ) + + @app.callback(Output(f"{long_id}-output", "children"), Input(long_id, "value")) + def process(val): + return val + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_pattern_matching_ids_within_64_chars(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div( + [ + dcc.Input( + id={"type": "filter-input", "index": i, "category": "primary"}, + value=f"val-{i}", + ) + for i in range(3) + ] + ), + html.Div(id="pm-output"), + ] + ) + + @app.callback( + Output("pm-output", "children"), + Input({"type": "filter-input", "index": 0, "category": "primary"}, "value"), + ) + def filter_update(v0): + return str(v0) + + dash_duo.start_server(app) + for tool in _mcp_tools(dash_duo.server.url): + assert len(tool["name"]) <= 64, f"Tool name exceeds 64 chars: {tool['name']}" + + +def test_duplicate_func_names_produce_unique_tools(dash_duo): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd1", options=["a"], value="a"), + html.Div(id="dd1-output"), + dcc.Dropdown(id="dd2", options=["b"], value="b"), + html.Div(id="dd2-output"), + dcc.Dropdown(id="dd3", options=["c"], value="c"), + html.Div(id="dd3-output"), + ] + ) + + @app.callback(Output("dd1-output", "children"), Input("dd1", "value")) + def cb(value): + return f"first: {value}" + + @app.callback(Output("dd2-output", "children"), Input("dd2", "value")) + def cb(value): # noqa: F811 + return f"second: {value}" + + @app.callback(Output("dd3-output", "children"), Input("dd3", "value")) + def cb(value): # noqa: F811 + return f"third: {value}" + + dash_duo.start_server(app) + tools = _mcp_tools(dash_duo.server.url) + cb_tools = [t for t in tools if t["name"] not in ("get_dash_component",)] + tool_names = [t["name"] for t in cb_tools] + + assert ( + len(tool_names) == 3 + ), f"Expected 3 callback tools, got {len(tool_names)}: {tool_names}" + assert len(set(tool_names)) == 3, f"Tool names not unique: {tool_names}" + + +def test_builtin_tools_always_present(dash_duo): + app = Dash(__name__) + app.layout = html.Div(id="root") + + dash_duo.start_server(app) + tool_names = [t["name"] for t in _mcp_tools(dash_duo.server.url)] + assert "get_dash_component" in tool_names diff --git a/tests/integration/mcp/test_server.py b/tests/integration/mcp/test_server.py new file mode 100644 index 0000000000..7af88bfbff --- /dev/null +++ b/tests/integration/mcp/test_server.py @@ -0,0 +1,304 @@ +"""Integration tests for the MCP Streamable HTTP endpoint. + +These tests use Flask's test_client to exercise the HTTP transport layer +(POST/GET/DELETE at /_mcp), session management, content-type handling, +and route registration/configuration. +""" + +import json +import os + +from dash import Dash, Input, Output, html +from mcp.types import LATEST_PROTOCOL_VERSION + +MCP_PATH = "_mcp" + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(**kwargs): + """Create a minimal Dash app with a layout and one callback.""" + app = Dash(__name__, **kwargs) + app.layout = html.Div( + [ + html.Div(id="my-input"), + html.Div(id="my-output"), + ] + ) + + @app.callback(Output("my-output", "children"), Input("my-input", "children")) + def update_output(value): + """Test callback docstring.""" + return f"echo: {value}" + + return app + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestMCPEndpoint: + """Tests for the Streamable HTTP MCP endpoint at /_mcp.""" + + def test_post_initialize_creates_session(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert "mcp-session-id" in r.headers + data = json.loads(r.data) + assert data["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + + def test_post_without_session_auto_assigns(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert "mcp-session-id" in r.headers + data = json.loads(r.data) + assert "tools" in data["result"] + + def test_stale_session_error_includes_hint(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + { + "jsonrpc": "2.0", + "method": "tools/call", + "id": 1, + "params": {"name": "no_such_tool", "arguments": {}}, + } + ), + content_type="application/json", + headers={"mcp-session-id": "old-session-from-before-restart"}, + ) + assert r.status_code == 200 + data = json.loads(r.data) + assert "session was not recognised" in data["error"]["message"] + assert "tools/list" in data["error"]["message"] + + def test_post_with_valid_session(self): + app = _make_app() + client = app.server.test_client() + # Initialize to get session + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Use session for tools/list + r2 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} + ), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 200 + data = json.loads(r2.data) + assert "result" in data + assert "tools" in data["result"] + + def test_notification_returns_202(self): + app = _make_app() + client = app.server.test_client() + # Initialize to get session + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Send notification (no id field) + r2 = client.post( + f"/{MCP_PATH}", + data=json.dumps({"jsonrpc": "2.0", "method": "notifications/initialized"}), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 202 + + def test_delete_terminates_session(self): + app = _make_app() + client = app.server.test_client() + # Initialize + r1 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = r1.headers["mcp-session-id"] + # Delete + r2 = client.delete( + f"/{MCP_PATH}", + headers={"mcp-session-id": session_id}, + ) + assert r2.status_code == 204 + # Post-delete requests still succeed + r3 = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "tools/list", "id": 2, "params": {}} + ), + content_type="application/json", + headers={"mcp-session-id": session_id}, + ) + assert r3.status_code == 200 + + def test_delete_nonexistent_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.delete( + f"/{MCP_PATH}", + headers={"mcp-session-id": "nonexistent"}, + ) + assert r.status_code == 404 + + def test_get_without_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.get(f"/{MCP_PATH}") + assert r.status_code == 404 + + def test_get_with_stale_session_returns_404(self): + app = _make_app() + client = app.server.test_client() + r = client.get( + f"/{MCP_PATH}", + headers={"mcp-session-id": "nonexistent"}, + ) + assert r.status_code == 404 + + def test_get_returns_sse_stream(self): + app = _make_app() + client = app.server.test_client() + # First create a session via POST initialize + init = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + session_id = init.headers["mcp-session-id"] + # GET with valid session returns SSE stream + r = client.get( + f"/{MCP_PATH}", + headers={"mcp-session-id": session_id}, + ) + assert r.status_code == 200 + assert r.content_type == "text/event-stream" + assert r.headers.get("Cache-Control") == "no-cache" + + def test_post_rejects_wrong_content_type(self): + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data="not json", + content_type="text/plain", + ) + assert r.status_code == 415 + + def test_routes_not_registered_when_disabled(self): + app = _make_app(enable_mcp=False) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + # With MCP disabled, the route doesn't exist — response is HTML, not JSON + assert r.content_type != "application/json" + + def test_routes_respect_pathname_prefix(self): + app = _make_app(routes_pathname_prefix="/app/") + client = app.server.test_client() + + ok = client.post( + f"/app/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert ok.status_code == 200 + + miss = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert miss.status_code == 404 + + def test_enable_mcp_env_var_false(self): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app() + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.content_type != "application/json" + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old + + def test_constructor_overrides_env_var(self): + old = os.environ.get("DASH_MCP_ENABLED") + try: + os.environ["DASH_MCP_ENABLED"] = "false" + app = _make_app(enable_mcp=True) + client = app.server.test_client() + r = client.post( + f"/{MCP_PATH}", + data=json.dumps( + {"jsonrpc": "2.0", "method": "initialize", "id": 1, "params": {}} + ), + content_type="application/json", + ) + assert r.status_code == 200 + assert b"protocolVersion" in r.data + finally: + if old is None: + os.environ.pop("DASH_MCP_ENABLED", None) + else: + os.environ["DASH_MCP_ENABLED"] = old diff --git a/tests/unit/mcp/test_server.py b/tests/unit/mcp/test_server.py new file mode 100644 index 0000000000..93238faf19 --- /dev/null +++ b/tests/unit/mcp/test_server.py @@ -0,0 +1,92 @@ +"""Tests for MCP server (_server.py) — JSON-RPC message processing.""" + +from dash._get_app import app_context +from dash.mcp._server import _process_mcp_message +from mcp.types import LATEST_PROTOCOL_VERSION + +from tests.unit.mcp.conftest import _make_app, _setup_mcp + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool(app, tool_name, arguments=None, request_id=1): + return _mcp( + app, "tools/call", {"name": tool_name, "arguments": arguments or {}}, request_id + ) + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + result = _call_tool(app, tool_name, arguments) + structured = result["result"]["structuredContent"] + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +class TestProcessMCPMessage: + def test_initialize(self): + app = _make_app() + result = _mcp(app, "initialize") + + assert result is not None + assert result["id"] == 1 + assert result["jsonrpc"] == "2.0" + assert result["result"]["protocolVersion"] == LATEST_PROTOCOL_VERSION + assert "serverInfo" in result["result"] + + def test_initialize_advertises_list_changed(self): + app = _make_app() + result = _mcp(app, "initialize") + caps = result["result"]["capabilities"] + assert caps["tools"]["listChanged"] is True + + def test_tools_call(self): + app = _make_app() + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update_output" in t["name"]) + + result = _call_tool(app, tool_name, {"value": "hello"}, request_id=2) + + assert result is not None + assert result["id"] == 2 + assert _call_tool_output(app, tool_name, {"value": "hello"}) == "echo: hello" + + def test_tools_call_unknown_tool_returns_error(self): + app = _make_app() + result = _call_tool(app, "nonexistent_tool") + + assert result is not None + assert "error" in result + assert result["error"]["code"] == -32601 + + def test_unknown_method_returns_error(self): + app = _make_app() + result = _mcp(app, "unknown/method") + + assert result is not None + assert "error" in result + + def test_notification_returns_none(self): + app = _make_app() + data = {"jsonrpc": "2.0", "method": "notifications/initialized"} + with app.server.test_request_context(): + app_context.set(app) + result = _process_mcp_message(data) + assert result is None diff --git a/tests/unit/mcp/tools/test_run_callback.py b/tests/unit/mcp/tools/test_run_callback.py new file mode 100644 index 0000000000..00f4e5b7b1 --- /dev/null +++ b/tests/unit/mcp/tools/test_run_callback.py @@ -0,0 +1,246 @@ +"""Tests for callback dispatch execution via MCP tools.""" + +from dash import Dash, Input, Output, State, dcc, html +from dash.exceptions import PreventUpdate +from dash.mcp._server import _process_mcp_message + +from tests.unit.mcp.conftest import _setup_mcp + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _tools_list(app): + return _mcp(app, "tools/list")["result"]["tools"] + + +def _call_tool_structured(app, tool_name, arguments=None): + result = _mcp(app, "tools/call", {"name": tool_name, "arguments": arguments or {}}) + return result["result"]["structuredContent"] + + +def _call_tool_output( + app, tool_name, arguments=None, component_id=None, prop="children" +): + structured = _call_tool_structured(app, tool_name, arguments) + response = structured["response"] + if component_id is None: + component_id = next(iter(response)) + return response[component_id][prop] + + +class TestRunCallback: + def test_multi_output(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a", "b"], value="a"), + dcc.Dropdown(id="dd2"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("dd2", "options"), + Output("out", "children"), + Input("dd", "value"), + ) + def update(val): + return [{"label": val, "value": val}], f"selected: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + structured = _call_tool_structured(app, tool_name, {"val": "b"}) + assert structured["response"]["dd2"]["options"] == [ + {"label": "b", "value": "b"} + ] + assert structured["response"]["out"]["children"] == "selected: b" + + def test_omitted_kwargs_default_to_none(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Dropdown(id="dd", options=["a"]), + dcc.Input(id="inp"), + html.Div(id="out"), + ] + ) + + @app.callback( + Output("out", "children"), + Input("dd", "value"), + State("inp", "value"), + ) + def update(selected, text): + return f"{selected}-{text}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"selected": "a"}, "out") == "a-None" + + def test_no_output_callback(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Button(id="btn"), + html.Div(id="display"), + ] + ) + + @app.callback(Input("btn", "n_clicks")) + def server_cb(n): + from dash import set_props + + set_props("display", {"children": f"Clicked {n} times"}) + + tools = _tools_list(app) + tool_names = [t["name"] for t in tools] + assert "server_cb" in tool_names + + def test_prevent_update(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hello"), + html.Div(id="out"), + ] + ) + + @app.callback(Output("out", "children"), Input("inp", "value")) + def update(val): + if val == "block": + raise PreventUpdate + return f"got: {val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "update" in t["name"]) + assert _call_tool_output(app, tool_name, {"val": "test"}, "out") == "got: test" + + def test_with_state(self): + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="trigger"), + html.Div(id="store"), + html.Div(id="result"), + ] + ) + + @app.callback( + Output("result", "children"), + Input("trigger", "children"), + State("store", "children"), + ) + def with_state(trigger, store): + return f"{trigger}-{store}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "trigger": "click", + "store": "data", + }, + "result", + ) + == "click-data" + ) + + def test_dict_inputs(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="x-input", value="hello"), + dcc.Input(id="y-input", value="world"), + html.Div(id="dict-out"), + ] + ) + + @app.callback( + Output("dict-out", "children"), + inputs={ + "x_val": Input("x-input", "value"), + "y_val": Input("y-input", "value"), + }, + ) + def combine(**kwargs): + return f"{kwargs['x_val']}-{kwargs['y_val']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "combine" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "x_val": "foo", + "y_val": "bar", + }, + "dict-out", + ) + == "foo-bar" + ) + + def test_positional_inputs(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="a-input", value="A"), + html.Div(id="pos-out"), + ] + ) + + @app.callback(Output("pos-out", "children"), Input("a-input", "value")) + def echo(val): + return f"got:{val}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "echo" in t["name"]) + assert ( + _call_tool_output(app, tool_name, {"val": "test"}, "pos-out") == "got:test" + ) + + def test_dict_inputs_with_state(self): + app = Dash(__name__) + app.layout = html.Div( + [ + dcc.Input(id="inp", value="hi"), + html.Div(id="st", children="state-val"), + html.Div(id="ds-out"), + ] + ) + + @app.callback( + Output("ds-out", "children"), + inputs={"trigger": Input("inp", "value")}, + state={"kept": State("st", "children")}, + ) + def with_dict_state(**kwargs): + return f"{kwargs['trigger']}+{kwargs['kept']}" + + tools = _tools_list(app) + tool_name = next(t["name"] for t in tools if "with_dict_state" in t["name"]) + assert ( + _call_tool_output( + app, + tool_name, + { + "trigger": "hey", + "kept": "saved", + }, + "ds-out", + ) + == "hey+saved" + ) From 36c231d6efb7264c4a9ba6415e9adb1ac31befe1 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 17:06:29 -0600 Subject: [PATCH 26/27] Implement background callback support --- dash/_callback.py | 5 + dash/mcp/_server.py | 8 +- dash/mcp/primitives/tools/__init__.py | 10 +- .../primitives/tools/descriptions/__init__.py | 2 + .../description_background_callbacks.py | 25 ++ dash/mcp/primitives/tools/results/__init__.py | 27 +- .../primitives/tools/tool_background_tasks.py | 97 ++++++ .../tools/tool_get_dash_component.py | 2 +- dash/mcp/primitives/tools/tools_callbacks.py | 18 +- dash/mcp/tasks/__init__.py | 5 + dash/mcp/tasks/tasks.py | 147 +++++++++ requirements/install.txt | 2 +- .../mcp/test_background_callbacks.py | 133 ++++++++ .../mcp/tools/test_background_callbacks.py | 300 ++++++++++++++++++ 14 files changed, 770 insertions(+), 11 deletions(-) create mode 100644 dash/mcp/primitives/tools/descriptions/description_background_callbacks.py create mode 100644 dash/mcp/primitives/tools/tool_background_tasks.py create mode 100644 dash/mcp/tasks/__init__.py create mode 100644 dash/mcp/tasks/tasks.py create mode 100644 tests/integration/mcp/test_background_callbacks.py create mode 100644 tests/unit/mcp/tools/test_background_callbacks.py diff --git a/dash/_callback.py b/dash/_callback.py index 5900bbe0fc..321123c422 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -1,6 +1,7 @@ import collections import hashlib import inspect +from datetime import datetime, timezone from functools import wraps from typing import Callable, Optional, Any, List, Tuple, Union, Dict @@ -421,6 +422,10 @@ def _setup_background_callback( ctx_value, ) + callback_manager.handle.set( + f"{cache_key}-created_at", datetime.now(timezone.utc).isoformat() + ) + data = { "cacheKey": cache_key, "job": job, diff --git a/dash/mcp/_server.py b/dash/mcp/_server.py index 1c6279290b..4029982535 100644 --- a/dash/mcp/_server.py +++ b/dash/mcp/_server.py @@ -42,6 +42,7 @@ list_tools, read_resource, ) +from dash.mcp.tasks import get_task, get_task_result, cancel_task from dash.mcp.primitives.tools.callback_adapter_collection import ( CallbackAdapterCollection, ) @@ -238,11 +239,16 @@ def _process_mcp_message(data: dict[str, Any]) -> dict[str, Any] | None: "initialize": _handle_initialize, "tools/list": lambda: list_tools(), "tools/call": lambda: call_tool( - params.get("name", ""), params.get("arguments", {}) + tool_name=params.get("name", ""), + arguments=params.get("arguments", {}), + task=params.get("task"), ), "resources/list": lambda: list_resources(), "resources/templates/list": lambda: list_resource_templates(), "resources/read": lambda: read_resource(params.get("uri", "")), + "tasks/get": lambda: get_task(task_id=params.get("taskId", "")), + "tasks/result": lambda: get_task_result(task_id=params.get("taskId", "")), + "tasks/cancel": lambda: cancel_task(task_id=params.get("taskId", "")), } try: diff --git a/dash/mcp/primitives/tools/__init__.py b/dash/mcp/primitives/tools/__init__.py index 64f89dc3d0..45aec6559b 100644 --- a/dash/mcp/primitives/tools/__init__.py +++ b/dash/mcp/primitives/tools/__init__.py @@ -16,10 +16,11 @@ from dash.mcp.types import ToolNotFoundError +from . import tool_background_tasks as _background_tasks from . import tool_get_dash_component as _get_component from . import tools_callbacks as _callbacks -_TOOL_MODULES = [_callbacks, _get_component] +_TOOL_MODULES = [_callbacks, _get_component, _background_tasks] def list_tools() -> ListToolsResult: @@ -30,12 +31,13 @@ def list_tools() -> ListToolsResult: return ListToolsResult(tools=tools) -def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: +def call_tool( + tool_name: str, arguments: dict[str, Any], task: dict | None = None +) -> CallToolResult: """Dispatch a tools/call request by tool name.""" for mod in _TOOL_MODULES: if tool_name in mod.get_tool_names(): - result = mod.call_tool(tool_name, arguments) - return result + return mod.call_tool(tool_name, arguments, task=task) raise ToolNotFoundError( f"Tool not found: {tool_name}." " The app's callbacks may have changed." diff --git a/dash/mcp/primitives/tools/descriptions/__init__.py b/dash/mcp/primitives/tools/descriptions/__init__.py index 29cc2840d0..3529e79803 100644 --- a/dash/mcp/primitives/tools/descriptions/__init__.py +++ b/dash/mcp/primitives/tools/descriptions/__init__.py @@ -14,6 +14,7 @@ from typing import TYPE_CHECKING +from .description_background_callbacks import background_callback_description from .description_docstring import callback_docstring from .description_outputs import output_summary @@ -23,6 +24,7 @@ _SOURCES = [ output_summary, callback_docstring, + background_callback_description, ] diff --git a/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py new file mode 100644 index 0000000000..13129f5640 --- /dev/null +++ b/dash/mcp/primitives/tools/descriptions/description_background_callbacks.py @@ -0,0 +1,25 @@ +"""Description for background (long-running) callbacks. + +Informs the LLM that the tool returns a taskId immediately +and must be polled via get_background_task_result. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter + + +def background_callback_description(adapter: CallbackAdapter) -> list[str]: + """Add async polling instructions for background callbacks.""" + if not adapter._cb_info.get("background"): + return [] + + return [ + "", + "This is a long-running background operation. " + "It returns a taskId immediately. " + "Call tool `get_background_task_result` with the taskId to poll for the result.", + ] diff --git a/dash/mcp/primitives/tools/results/__init__.py b/dash/mcp/primitives/tools/results/__init__.py index e2f91a67a8..96a8b35d0a 100644 --- a/dash/mcp/primitives/tools/results/__init__.py +++ b/dash/mcp/primitives/tools/results/__init__.py @@ -12,7 +12,7 @@ import json from typing import Any -from mcp.types import CallToolResult, TextContent +from mcp.types import CallToolResult, CreateTaskResult, TextContent from dash.types import CallbackDispatchResponse from dash.mcp.primitives.tools.callback_adapter import CallbackAdapter @@ -50,3 +50,28 @@ def format_callback_response( content=content, structuredContent=response, ) + + +def task_result_to_tool_result(create_task_result: CreateTaskResult) -> CallToolResult: + """Wrap a CreateTaskResult as a CallToolResult with polling instructions. + + MCP Tasks are not yet supported by LLM clients, so this converts the + task metadata into a tool response that guides the LLM to poll via + the get_background_task_result tool. + """ + task = create_task_result.task + return CallToolResult( + content=[TextContent( + type="text", + text=json.dumps({ + "taskId": task.taskId, + "status": task.status, + "pollInterval": task.pollInterval, + "message": ( + "This is a long-running background callback. " + "Call the get_background_task_result tool with this taskId " + "to poll for the result." + ), + }), + )], + ) diff --git a/dash/mcp/primitives/tools/tool_background_tasks.py b/dash/mcp/primitives/tools/tool_background_tasks.py new file mode 100644 index 0000000000..53fd620cb2 --- /dev/null +++ b/dash/mcp/primitives/tools/tool_background_tasks.py @@ -0,0 +1,97 @@ +"""Built-in tools for background callback task lifecycle. + +Thin wrappers around the spec-aligned core in dash.mcp.tasks. +Only registered when the app has background callbacks. +""" + +from __future__ import annotations + +from typing import Any + +from mcp.types import ( + CallToolResult, + CancelTaskRequestParams, + GetTaskRequestParams, + TextContent, + Tool, +) + +from dash import get_app +from dash.mcp.tasks import get_task, get_task_result, cancel_task + + +def _input_schema_from(params_type, description: str) -> dict: + """Derive a clean tool inputSchema from an MCP request params type.""" + schema = params_type.model_json_schema() + return { + "type": "object", + "properties": { + "taskId": { + **schema["properties"]["taskId"], + "description": description, + }, + }, + "required": schema["required"], + } + + +_TOOL_NAMES = {"get_background_task_result", "cancel_background_task"} + +_GET_RESULT_TOOL = Tool( + name="get_background_task_result", + description=( + "Poll for the result of a long-running background callback. " + "Pass the taskId returned by the original tool call. " + "If the task is still running, call this tool again. " + "If complete, returns the callback result." + ), + inputSchema=_input_schema_from( + GetTaskRequestParams, + "The taskId returned by the background callback tool.", + ), +) + +_CANCEL_TASK_TOOL = Tool( + name="cancel_background_task", + description="Cancel a running background callback.", + inputSchema=_input_schema_from( + CancelTaskRequestParams, + "The taskId of the background task to cancel.", + ), +) + + +def _has_background_callbacks() -> bool: + app = get_app() + return any( + cb_info.get("background") + for cb_info in app.callback_map.values() + ) + + +def get_tool_names() -> set[str]: + return _TOOL_NAMES if _has_background_callbacks() else set() + + +def get_tools() -> list[Tool]: + return [_GET_RESULT_TOOL, _CANCEL_TASK_TOOL] if _has_background_callbacks() else [] + + +def call_tool(tool_name: str, arguments: dict[str, Any], task: dict | None = None) -> CallToolResult: + task_id = arguments.get("taskId", "") + + if tool_name == "get_background_task_result": + task_status = get_task(task_id) + if task_status.status == "completed": + return get_task_result(task_id) + return CallToolResult( + content=[TextContent(type="text", text=task_status.model_dump_json())], + ) + + if tool_name == "cancel_background_task": + result = cancel_task(task_id) + return CallToolResult( + content=[TextContent(type="text", text=result.model_dump_json())], + ) + + raise ValueError(f"Unknown tool: {tool_name}") diff --git a/dash/mcp/primitives/tools/tool_get_dash_component.py b/dash/mcp/primitives/tools/tool_get_dash_component.py index 8584242333..5ba95a8a54 100644 --- a/dash/mcp/primitives/tools/tool_get_dash_component.py +++ b/dash/mcp/primitives/tools/tool_get_dash_component.py @@ -54,7 +54,7 @@ def _build_tool() -> Tool: ) -def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: +def call_tool(tool_name: str, arguments: dict[str, Any], task: dict | None = None) -> CallToolResult: comp_id = arguments.get("component_id", "") if not comp_id: raise ValueError("component_id is required") diff --git a/dash/mcp/primitives/tools/tools_callbacks.py b/dash/mcp/primitives/tools/tools_callbacks.py index ba08795d35..17bcde0ff7 100644 --- a/dash/mcp/primitives/tools/tools_callbacks.py +++ b/dash/mcp/primitives/tools/tools_callbacks.py @@ -7,12 +7,13 @@ from typing import Any -from mcp.types import CallToolResult, TextContent, Tool +from mcp.types import CallToolResult, CreateTaskResult, TextContent, Tool from dash import get_app from dash.mcp.types import CallbackExecutionError, ToolNotFoundError -from .results import format_callback_response +from .results import format_callback_response, task_result_to_tool_result +from dash.mcp.tasks import create_task def get_tool_names() -> set[str]: @@ -24,7 +25,9 @@ def get_tools() -> list[Tool]: return get_app().mcp_callback_map.as_mcp_tools() -def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: +def call_tool( + tool_name: str, arguments: dict[str, Any], task: dict | None = None +) -> CallToolResult | CreateTaskResult: """Execute a callback tool by name.""" from .callback_utils import run_callback @@ -37,6 +40,8 @@ def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: " Please call tools/list to refresh your tool list." ) + is_background = bool(cb._cb_info.get("background")) + try: dispatch_response = run_callback(cb, arguments) except CallbackExecutionError as e: @@ -44,4 +49,11 @@ def call_tool(tool_name: str, arguments: dict[str, Any]) -> CallToolResult: content=[TextContent(type="text", text=str(e))], isError=True, ) + + if is_background: + task_result = create_task(dispatch_response, cb) + if task is not None: + return task_result + return task_result_to_tool_result(task_result) + return format_callback_response(dispatch_response, cb) diff --git a/dash/mcp/tasks/__init__.py b/dash/mcp/tasks/__init__.py new file mode 100644 index 0000000000..8b78741d60 --- /dev/null +++ b/dash/mcp/tasks/__init__.py @@ -0,0 +1,5 @@ +"""MCP Tasks — lifecycle management for background callback execution.""" + +from .tasks import create_task, get_task, get_task_result, cancel_task + +__all__ = ["create_task", "get_task", "get_task_result", "cancel_task"] diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py new file mode 100644 index 0000000000..6217fc7b1e --- /dev/null +++ b/dash/mcp/tasks/tasks.py @@ -0,0 +1,147 @@ +"""Handler functions for MCP tasks/* methods.""" + +from __future__ import annotations + +import json +from datetime import datetime, timezone +from typing import Any + +from mcp.types import CreateTaskResult, GetTaskResult, Task + +from dash import get_app +from dash.mcp.primitives.tools.results import format_callback_response +from dash.mcp.types import MCPError + + +def parse_task_id(task_id: str) -> tuple[str, str, str]: + """Parse a taskId into (tool_name, job_id, cache_key).""" + return task_id.split(":", 2) + + +def _get_callback_manager(): + """Get the background callback manager from the app's callback_map.""" + app = get_app() + for cb_info in app.callback_map.values(): + manager = cb_info.get("manager") + if manager is not None: + return manager + return None + + +def create_task(dispatch_response: dict[str, Any], callback) -> CreateTaskResult: + """Create a Task from a background callback's initial dispatch response.""" + cache_key = dispatch_response["cacheKey"] + job_id = str(dispatch_response["job"]) + task_id = f"{callback.tool_name}:{job_id}:{cache_key}" + interval = callback._cb_info.get("background", {}).get("interval", 1000) + now = datetime.now(timezone.utc) + return CreateTaskResult( + task=Task( + taskId=task_id, + status="working", + createdAt=now, + lastUpdatedAt=now, + ttl=None, + pollInterval=interval, + ), + ) + + +def get_task(task_id: str) -> GetTaskResult: + """Handle tasks/get — derive status from the callback manager.""" + tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + return GetTaskResult( + taskId=task_id, + status="failed", + statusMessage="No background callback manager configured.", + createdAt=datetime.now(timezone.utc), + lastUpdatedAt=datetime.now(timezone.utc), + ttl=None, + ) + + running = manager.job_running(job_id) + progress = manager.get_progress(cache_key) + + if running: + status = "working" + elif manager.result_ready(cache_key): + status = "completed" + else: + status = "failed" + + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + interval = None + if adapter is not None: + interval = adapter._cb_info.get("background", {}).get("interval", 1000) + + now = datetime.now(timezone.utc) + return GetTaskResult( + taskId=task_id, + status=status, + statusMessage=str(progress) if progress else None, + createdAt=datetime.fromisoformat(manager.handle.get(f"{cache_key}-created_at") or now.isoformat()), + lastUpdatedAt=now, + ttl=manager.expire * 1000 if manager.expire else None, + pollInterval=interval, + ) + + +def get_task_result(task_id: str) -> Any: + """Handle tasks/result — retrieve and format the callback result. + + Mirrors the Dash renderer: calls get_result() which clears from cache. + """ + tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + raise MCPError("No background callback manager configured.") + + # Mirror the renderer: dispatch with cacheKey/job query params. + # The framework handles result retrieval, wrapping, and cleanup. + adapter = get_app().mcp_callback_map.find_by_tool_name(tool_name) + body = adapter.as_callback_body({}) + app = get_app() + + with app.server.test_request_context( + f"/_dash-update-component?cacheKey={cache_key}&job={job_id}", + method="POST", + data=json.dumps(body, default=str), + content_type="application/json", + ): + response = app.dispatch() + + response_data = json.loads(response.get_data(as_text=True)) + + if "response" not in response_data: + raise MCPError("Task result not ready. Poll tasks/get until status is 'completed'.") + + return format_callback_response(response_data, adapter) + + +def cancel_task(task_id: str) -> Any: + """Handle tasks/cancel — terminate the background job. + + Same underlying mechanism as the renderer's cancelJob query param. + """ + from mcp.types import CancelTaskResult + + tool_name, job_id, cache_key = parse_task_id(task_id) + + manager = _get_callback_manager() + if manager is None: + raise MCPError("No background callback manager configured.") + + manager.terminate_job(job_id) + + now = datetime.now(timezone.utc) + return CancelTaskResult( + taskId=task_id, + status="cancelled", + createdAt=datetime.fromisoformat(manager.handle.get(f"{cache_key}-created_at") or now.isoformat()), + lastUpdatedAt=now, + ttl=manager.expire * 1000 if manager.expire else None, + ) diff --git a/requirements/install.txt b/requirements/install.txt index b813a6ce55..caf1e34d0d 100644 --- a/requirements/install.txt +++ b/requirements/install.txt @@ -8,4 +8,4 @@ retrying nest-asyncio setuptools pydantic>=2.10 -mcp>=1.0.0; python_version>="3.10" +mcp>=1.23.0; python_version>="3.10" diff --git a/tests/integration/mcp/test_background_callbacks.py b/tests/integration/mcp/test_background_callbacks.py new file mode 100644 index 0000000000..75ae009aeb --- /dev/null +++ b/tests/integration/mcp/test_background_callbacks.py @@ -0,0 +1,133 @@ +"""Integration tests for background callback support via MCP.""" + +import json +import time + +import diskcache +from dash import Dash, Input, Output, html +from dash.background_callback.managers.diskcache_manager import DiskcacheManager + +MCP_PATH = "_mcp" + + +def _make_background_app(): + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + time.sleep(0.5) + return f"done: {value}" + + return app + + +def _post(client, method, params=None, session_id=None, request_id=1): + headers = {"Content-Type": "application/json"} + if session_id: + headers["mcp-session-id"] = session_id + return client.post( + f"/{MCP_PATH}", + data=json.dumps( + { + "jsonrpc": "2.0", + "method": method, + "id": request_id, + "params": params or {}, + } + ), + headers=headers, + ) + + +def _init_session(client): + r = _post(client, "initialize") + return r.headers["mcp-session-id"] + + +class TestBackgroundCallbackLifecycle: + """Full lifecycle: trigger → poll → get result, over HTTP.""" + + def test_trigger_poll_and_retrieve(self): + app = _make_background_app() + client = app.server.test_client() + sid = _init_session(client) + + # Trigger + r = _post( + client, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + }, + session_id=sid, + ) + assert r.status_code == 200 + data = json.loads(r.data) + task_info = json.loads(data["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + assert task_info["status"] == "working" + + # Poll — should be working initially + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + session_id=sid, + request_id=2, + ) + assert r.status_code == 200 + + # Wait for completion + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 5 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # Get result + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + session_id=sid, + request_id=3, + ) + assert r.status_code == 200 + data = json.loads(r.data) + text = data["result"]["content"][0]["text"] + assert "done:" in text + + def test_background_tools_in_tools_list(self): + app = _make_background_app() + client = app.server.test_client() + sid = _init_session(client) + + r = _post(client, "tools/list", session_id=sid) + data = json.loads(r.data) + names = [t["name"] for t in data["result"]["tools"]] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + assert "slow_callback" in names diff --git a/tests/unit/mcp/tools/test_background_callbacks.py b/tests/unit/mcp/tools/test_background_callbacks.py new file mode 100644 index 0000000000..8dc99164f5 --- /dev/null +++ b/tests/unit/mcp/tools/test_background_callbacks.py @@ -0,0 +1,300 @@ +"""Tests for background callback support via MCP Tasks API.""" + +import time + +from dash import Dash, Input, Output, html +from dash._get_app import app_context +from dash.background_callback.managers.diskcache_manager import DiskcacheManager +from dash.mcp._server import _process_mcp_message +from dash.mcp.primitives.tools.callback_adapter_collection import ( + CallbackAdapterCollection, +) + + +def _setup_mcp(app): + app_context.set(app) + app.mcp_callback_map = CallbackAdapterCollection(app) + return app + + +def _msg(method, params=None, request_id=1): + d = {"jsonrpc": "2.0", "method": method, "id": request_id} + d["params"] = params if params is not None else {} + return d + + +def _mcp(app, method, params=None, request_id=1): + with app.server.test_request_context(): + _setup_mcp(app) + return _process_mcp_message(_msg(method, params, request_id)) + + +def _make_background_app(): + import diskcache + + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def slow_callback(value): + """A background callback.""" + time.sleep(0.3) + return f"done: {value}" + + return app + + +class TestCancelBackgroundTaskTool: + """cancel_background_task tool wrapper.""" + + def test_cancel_via_tool(self): + import json + + app = _make_background_app() + trigger = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + }, + ) + task_id = json.loads(trigger["result"]["content"][0]["text"])["taskId"] + + cancel = _mcp( + app, + "tools/call", + { + "name": "cancel_background_task", + "arguments": {"taskId": task_id}, + }, + ) + assert cancel["result"].get("isError") is not True + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +class TestBackgroundToolRegistration: + """Background task tools only appear when app has background callbacks.""" + + def test_present_with_background_callbacks(self): + app = _make_background_app() + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" in names + assert "cancel_background_task" in names + + def test_absent_without_background_callbacks(self): + app = Dash(__name__) + app.layout = html.Div([html.Div(id="in"), html.Div(id="out")]) + + @app.callback(Output("out", "children"), Input("in", "children")) + def normal_cb(v): + return v + + tools = _mcp(app, "tools/list")["result"]["tools"] + names = [t["name"] for t in tools] + assert "get_background_task_result" not in names + assert "cancel_background_task" not in names + + +class TestGetBackgroundTaskResult: + """get_background_task_result tool: poll and retrieve results.""" + + def _trigger(self, app): + import json + + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + }, + ) + return json.loads(result["result"]["content"][0]["text"])["taskId"] + + def test_returns_working_while_running(self): + app = _make_background_app() + task_id = self._trigger(app) + poll = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = poll["result"]["content"][0]["text"] + assert "working" in text.lower() + + def test_returns_result_when_complete(self): + app = _make_background_app() + task_id = self._trigger(app) + _, job_id, _ = task_id.split(":", 2) + + # Wait for completion + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 3 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + result = _mcp( + app, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + ) + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +class TestBackgroundCallbackTrigger: + """Calling a background callback tool returns taskId immediately.""" + + def test_returns_task_id(self): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + }, + ) + text = result["result"]["content"][0]["text"] + assert "taskId" in text + assert "slow_callback:" in text + + +class TestTasksGet: + """tasks/get derives status from the callback manager.""" + + def test_working_status_while_running(self): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + # Immediately poll — job should still be running + get_result = _mcp(app, "tasks/get", {"taskId": task_id}) + assert get_result["result"]["status"] == "working" + assert get_result["result"]["taskId"] == task_id + + +class TestTasksResult: + """tasks/result retrieves and formats the callback result.""" + + def test_returns_formatted_result_when_complete(self): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + tool_name, job_id, cache_key = task_id.split(":", 2) + + # Wait for the background job to finish + manager = app.callback_map["output.children"]["manager"] + deadline = time.time() + 3 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # Fetch the result + result = _mcp(app, "tasks/result", {"taskId": task_id}) + assert "content" in result["result"] + text = result["result"]["content"][0]["text"] + assert "done:" in text + + +class TestTasksCancel: + """tasks/cancel terminates the background job.""" + + def test_cancel_terminates_job(self): + app = _make_background_app() + create_result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = create_result["result"]["task"]["taskId"] + + cancel_result = _mcp(app, "tasks/cancel", {"taskId": task_id}) + assert "error" not in cancel_result + + _, job_id, _ = task_id.split(":", 2) + manager = app.callback_map["output.children"]["manager"] + assert not manager.job_running(job_id) + + +class TestBackgroundCallbackWithTask: + """When tools/call includes task metadata, return CreateTaskResult.""" + + def test_returns_create_task_result(self): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task = result["result"]["task"] + assert task["status"] == "working" + assert "taskId" in task + assert "pollInterval" in task + + def test_task_id_encodes_tool_name_job_id_cache_key(self): + app = _make_background_app() + result = _mcp( + app, + "tools/call", + { + "name": "slow_callback", + "arguments": {"value": "hello"}, + "task": {"ttl": 60000}, + }, + ) + task_id = result["result"]["task"]["taskId"] + tool_name, job_id, cache_key = task_id.split(":", 2) + assert tool_name == "slow_callback" + assert len(cache_key) == 64 # SHA256 hex From 4ead0dbee8ee78aeec576483a40d4dcea4d93308 Mon Sep 17 00:00:00 2001 From: Adrian Borrmann Date: Wed, 8 Apr 2026 17:43:51 -0600 Subject: [PATCH 27/27] Ensure that background callback expiry is communicated in MCP tool --- dash/_callback.py | 4 +- dash/mcp/tasks/tasks.py | 5 +- .../mcp/test_background_callbacks.py | 183 +++++++++++++++++- 3 files changed, 189 insertions(+), 3 deletions(-) diff --git a/dash/_callback.py b/dash/_callback.py index 321123c422..0407d5359e 100644 --- a/dash/_callback.py +++ b/dash/_callback.py @@ -423,7 +423,9 @@ def _setup_background_callback( ) callback_manager.handle.set( - f"{cache_key}-created_at", datetime.now(timezone.utc).isoformat() + f"{cache_key}-created_at", + datetime.now(timezone.utc).isoformat(), + expire=callback_manager.expire, ) data = { diff --git a/dash/mcp/tasks/tasks.py b/dash/mcp/tasks/tasks.py index 6217fc7b1e..877b84795b 100644 --- a/dash/mcp/tasks/tasks.py +++ b/dash/mcp/tasks/tasks.py @@ -138,10 +138,13 @@ def cancel_task(task_id: str) -> Any: manager.terminate_job(job_id) now = datetime.now(timezone.utc) + created_at = manager.handle.get(f"{cache_key}-created_at") + manager.handle.delete(f"{cache_key}-created_at") + return CancelTaskResult( taskId=task_id, status="cancelled", - createdAt=datetime.fromisoformat(manager.handle.get(f"{cache_key}-created_at") or now.isoformat()), + createdAt=datetime.fromisoformat(created_at) if created_at else now, lastUpdatedAt=now, ttl=manager.expire * 1000 if manager.expire else None, ) diff --git a/tests/integration/mcp/test_background_callbacks.py b/tests/integration/mcp/test_background_callbacks.py index 75ae009aeb..e0e3dead53 100644 --- a/tests/integration/mcp/test_background_callbacks.py +++ b/tests/integration/mcp/test_background_callbacks.py @@ -2,6 +2,7 @@ import json import time +from datetime import datetime import diskcache from dash import Dash, Input, Output, html @@ -82,7 +83,14 @@ def test_trigger_poll_and_retrieve(self): task_id = task_info["taskId"] assert task_info["status"] == "working" - # Poll — should be working initially + # Read createdAt from the callback manager directly + _, _, cache_key = task_id.split(":", 2) + stored_created_at = app.callback_map["output.children"]["manager"].handle.get( + f"{cache_key}-created_at" + ) + assert stored_created_at is not None + + # Poll — should be working, with createdAt matching the stored value r = _post( client, "tools/call", @@ -94,6 +102,10 @@ def test_trigger_poll_and_retrieve(self): request_id=2, ) assert r.status_code == 200 + poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + assert datetime.fromisoformat(poll_data["createdAt"]) == datetime.fromisoformat( + stored_created_at + ) # Wait for completion _, job_id, _ = task_id.split(":", 2) @@ -120,6 +132,175 @@ def test_trigger_poll_and_retrieve(self): text = data["result"]["content"][0]["text"] assert "done:" in text + def test_result_expires(self): + """Result and createdAt are available until the cache expires.""" + cache = diskcache.Cache() + manager = DiskcacheManager(cache, cache_by=[lambda: "fixed"], expire=2) + + app = Dash(__name__) + app.layout = html.Div([html.Div(id="input"), html.Div(id="output")]) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + background=True, + manager=manager, + ) + def fast_cb(value): + return f"done: {value}" + + client = app.server.test_client() + sid = _init_session(client) + + # Trigger + r = _post( + client, + "tools/call", + { + "name": "fast_cb", + "arguments": {"value": "hi"}, + }, + session_id=sid, + ) + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + _, job_id, cache_key = task_id.split(":", 2) + + # Wait for job to finish + deadline = time.time() + 3 + while time.time() < deadline: + if not manager.job_running(job_id): + break + time.sleep(0.1) + + # First retrieval — result and createdAt available + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + session_id=sid, + request_id=2, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + assert "done:" in text + created_at = manager.handle.get(f"{cache_key}-created_at") + assert created_at is not None + + # Second retrieval — still available (cache_by keeps it) + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + session_id=sid, + request_id=3, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + assert "done:" in text + assert manager.handle.get(f"{cache_key}-created_at") == created_at + + # Wait for expiry + time.sleep(2.5) + + # After expiry — tool reports failure, createdAt is fresh (stored value gone) + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + session_id=sid, + request_id=4, + ) + poll_data = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + assert poll_data["status"] == "failed" + assert datetime.fromisoformat(poll_data["createdAt"]) > datetime.fromisoformat( + created_at + ) + + def test_progress_in_poll_response(self): + """Progress reported via set_progress appears in poll statusMessage.""" + cache = diskcache.Cache() + manager = DiskcacheManager(cache) + + app = Dash(__name__) + app.layout = html.Div( + [ + html.Div(id="input"), + html.Div(id="status"), + html.Div(id="output"), + ] + ) + + @app.callback( + Output("output", "children"), + Input("input", "children"), + progress=Output("status", "children"), + background=True, + manager=manager, + interval=200, + ) + def progress_cb(set_progress, value): + for i in range(10): + set_progress(f"Step {i + 1} of 10") + time.sleep(0.2) + return f"done: {value}" + + client = app.server.test_client() + sid = _init_session(client) + + # Trigger + r = _post( + client, + "tools/call", + { + "name": "progress_cb", + "arguments": {"value": "hi"}, + }, + session_id=sid, + ) + task_info = json.loads(json.loads(r.data)["result"]["content"][0]["text"]) + task_id = task_info["taskId"] + + # Poll and collect all progress messages + import re + + progress_pattern = re.compile(r"Step \d+ of 10") + progress_messages = [] + deadline = time.time() + 10 + while time.time() < deadline: + r = _post( + client, + "tools/call", + { + "name": "get_background_task_result", + "arguments": {"taskId": task_id}, + }, + session_id=sid, + request_id=2, + ) + text = json.loads(r.data)["result"]["content"][0]["text"] + try: + poll_data = json.loads(text) + msg = poll_data.get("statusMessage") + if msg is not None: + progress_messages.append(msg) + if poll_data.get("status") == "completed": + break + except (json.JSONDecodeError, KeyError): + break + time.sleep(0.3) + + assert len(progress_messages) > 0, "Expected progress updates during polling" + for msg in progress_messages: + assert progress_pattern.search(msg), f"Unexpected progress format: {msg}" + def test_background_tools_in_tools_list(self): app = _make_background_app() client = app.server.test_client()