From b3eac058bdb25b52daaaace37dff98ad7bdff310 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 15 May 2026 08:28:56 +0000 Subject: [PATCH] fix(a2a): Support to_a2a(Workflow) --- .../adk/a2a/utils/agent_card_builder.py | 96 ++++++++--- src/google/adk/a2a/utils/agent_to_a2a.py | 10 +- .../a2a/utils/test_agent_card_builder.py | 159 +++++++++++++++++- .../unittests/a2a/utils/test_agent_to_a2a.py | 56 ++++-- 4 files changed, 278 insertions(+), 43 deletions(-) diff --git a/src/google/adk/a2a/utils/agent_card_builder.py b/src/google/adk/a2a/utils/agent_card_builder.py index 1e8cecad79..733a5c8d2d 100644 --- a/src/google/adk/a2a/utils/agent_card_builder.py +++ b/src/google/adk/a2a/utils/agent_card_builder.py @@ -32,6 +32,9 @@ from ...agents.parallel_agent import ParallelAgent from ...agents.sequential_agent import SequentialAgent from ...tools.example_tool import ExampleTool +from ...workflow._base_node import BaseNode +from ...workflow._base_node import START +from ...workflow._workflow import Workflow from ..experimental import a2a_experimental logger = logging.getLogger('google_adk.' + __name__) @@ -39,17 +42,17 @@ @a2a_experimental class AgentCardBuilder: - """Builder class for creating agent cards from ADK agents. + """Builder class for creating agent cards from ADK agents or workflows. - This class provides functionality to convert ADK agents into A2A agent cards, - including extracting skills, capabilities, and metadata from various agent - types. + This class provides functionality to convert an ADK BaseAgent (e.g. LlmAgent) + or a Workflow into an A2A agent card, including extracting skills, + capabilities, and metadata. """ def __init__( self, *, - agent: BaseAgent, + agent: BaseAgent | Workflow, rpc_url: Optional[str] = None, capabilities: Optional[AgentCapabilities] = None, doc_url: Optional[str] = None, @@ -59,6 +62,11 @@ def __init__( ): if not agent: raise ValueError('Agent cannot be None or empty.') + if not isinstance(agent, (BaseAgent, Workflow)): + raise TypeError( + 'AgentCardBuilder requires a BaseAgent or Workflow, got ' + f'{type(agent).__name__}.' + ) self._agent = agent self._rpc_url = rpc_url or 'http://localhost:80/a2a' @@ -96,8 +104,17 @@ async def build(self) -> AgentCard: # Module-level helper functions -async def _build_primary_skills(agent: BaseAgent) -> List[AgentSkill]: - """Build skills for any agent type.""" +def _iter_child_nodes(agent: BaseNode) -> List[BaseNode]: + """Returns the immediate child nodes of an agent or a workflow.""" + if isinstance(agent, BaseAgent): + return list(agent.sub_agents) + if isinstance(agent, Workflow) and agent.graph is not None: + return [n for n in agent.graph.nodes if n.name != START.name] + return [] + + +async def _build_primary_skills(agent: BaseNode) -> List[AgentSkill]: + """Build skills for any node type.""" if isinstance(agent, LlmAgent): return await _build_llm_agent_skills(agent) else: @@ -140,10 +157,10 @@ async def _build_llm_agent_skills(agent: LlmAgent) -> List[AgentSkill]: return skills -async def _build_sub_agent_skills(agent: BaseAgent) -> List[AgentSkill]: - """Build skills for all sub-agents.""" +async def _build_sub_agent_skills(agent: BaseNode) -> List[AgentSkill]: + """Build skills for all child nodes (sub-agents or workflow nodes).""" sub_agent_skills = [] - for sub_agent in agent.sub_agents: + for sub_agent in _iter_child_nodes(agent): try: sub_skills = await _build_primary_skills(sub_agent) for skill in sub_skills: @@ -225,8 +242,8 @@ def _build_code_executor_skill(agent: LlmAgent) -> AgentSkill: ) -async def _build_non_llm_agent_skills(agent: BaseAgent) -> List[AgentSkill]: - """Build skills for non-LLM agents.""" +async def _build_non_llm_agent_skills(agent: BaseNode) -> List[AgentSkill]: + """Build skills for non-LLM agents and workflow nodes.""" skills = [] # 1. Agent skill (main agent skill) @@ -249,8 +266,8 @@ async def _build_non_llm_agent_skills(agent: BaseAgent) -> List[AgentSkill]: ) ) - # 2. Sub-agent orchestration skill (for agents with sub-agents) - if agent.sub_agents: + # 2. Orchestration skill (for agents/workflows with child nodes) + if _iter_child_nodes(agent): orchestration_skill = _build_orchestration_skill(agent, agent_type) if orchestration_skill: skills.append(orchestration_skill) @@ -259,11 +276,11 @@ async def _build_non_llm_agent_skills(agent: BaseAgent) -> List[AgentSkill]: def _build_orchestration_skill( - agent: BaseAgent, agent_type: str + agent: BaseNode, agent_type: str ) -> Optional[AgentSkill]: - """Build orchestration skill for agents with sub-agents.""" + """Build orchestration skill for agents/workflows with child nodes.""" sub_agent_descriptions = [] - for sub_agent in agent.sub_agents: + for sub_agent in _iter_child_nodes(agent): description = sub_agent.description or 'No description' sub_agent_descriptions.append(f'{sub_agent.name}: {description}') @@ -281,7 +298,7 @@ def _build_orchestration_skill( ) -def _get_agent_type(agent: BaseAgent) -> str: +def _get_agent_type(agent: BaseNode) -> str: """Get the agent type for tagging.""" if isinstance(agent, LlmAgent): return 'llm' @@ -291,21 +308,23 @@ def _get_agent_type(agent: BaseAgent) -> str: return 'parallel_workflow' elif isinstance(agent, LoopAgent): return 'loop_workflow' + elif isinstance(agent, Workflow): + return 'graph_workflow' else: return 'custom_agent' -def _get_agent_skill_name(agent: BaseAgent) -> str: +def _get_agent_skill_name(agent: BaseNode) -> str: """Get the skill name based on agent type.""" if isinstance(agent, LlmAgent): return 'model' - elif isinstance(agent, (SequentialAgent, ParallelAgent, LoopAgent)): + elif isinstance(agent, (SequentialAgent, ParallelAgent, LoopAgent, Workflow)): return 'workflow' else: return 'custom' -def _build_agent_description(agent: BaseAgent) -> str: +def _build_agent_description(agent: BaseNode) -> str: """Build agent description from agent.description and workflow-specific descriptions.""" description_parts = [] @@ -382,9 +401,9 @@ def _replace_pronouns(text: str) -> str: ) -def _get_workflow_description(agent: BaseAgent) -> Optional[str]: - """Get workflow-specific description for non-LLM agents.""" - if not agent.sub_agents: +def _get_workflow_description(agent: BaseNode) -> Optional[str]: + """Get workflow-specific description for non-LLM agents and workflows.""" + if not _iter_child_nodes(agent): return None if isinstance(agent, SequentialAgent): @@ -393,6 +412,8 @@ def _get_workflow_description(agent: BaseAgent) -> Optional[str]: return _build_parallel_description(agent) elif isinstance(agent, LoopAgent): return _build_loop_description(agent) + elif isinstance(agent, Workflow): + return _build_graph_workflow_description(agent) return None @@ -448,13 +469,32 @@ def _build_loop_description(agent: LoopAgent) -> str: ) -def _get_default_description(agent: BaseAgent) -> str: +def _build_graph_workflow_description(workflow: Workflow) -> str: + """Build description for a graph-based Workflow.""" + child_nodes = _iter_child_nodes(workflow) + descriptions = [] + for node in child_nodes: + node_description = ( + node.description.rstrip('.') + if node.description + else f'execute the {node.name} node' + ) + descriptions.append(f'{node.name}: {node_description}') + return ( + 'This workflow orchestrates the following nodes: ' + + '; '.join(descriptions) + + '.' + ) + + +def _get_default_description(agent: BaseNode) -> str: """Get default description based on agent type.""" agent_type_descriptions = { LlmAgent: 'An LLM-based agent', SequentialAgent: 'A sequential workflow agent', ParallelAgent: 'A parallel workflow agent', LoopAgent: 'A loop workflow agent', + Workflow: 'A graph-based workflow agent', } for agent_type, description in agent_type_descriptions.items(): @@ -492,7 +532,7 @@ def _extract_inputs_from_examples(examples: Optional[list[dict]]) -> list[str]: async def _extract_examples_from_agent( - agent: BaseAgent, + agent: BaseNode, ) -> Optional[List[Dict]]: """Extract examples from example_tool if configured; otherwise, from agent instruction.""" if not isinstance(agent, LlmAgent): @@ -558,7 +598,7 @@ def _extract_examples_from_instruction( return examples if examples else None -def _get_input_modes(agent: BaseAgent) -> Optional[List[str]]: +def _get_input_modes(agent: BaseNode) -> Optional[List[str]]: """Get input modes based on agent model.""" if not isinstance(agent, LlmAgent): return None @@ -568,7 +608,7 @@ def _get_input_modes(agent: BaseAgent) -> Optional[List[str]]: return None -def _get_output_modes(agent: BaseAgent) -> Optional[List[str]]: +def _get_output_modes(agent: BaseNode) -> Optional[List[str]]: """Get output modes from Agent.generate_content_config.response_modalities.""" if not isinstance(agent, LlmAgent): return None diff --git a/src/google/adk/a2a/utils/agent_to_a2a.py b/src/google/adk/a2a/utils/agent_to_a2a.py index 3e8ed461e2..3cef1c8215 100644 --- a/src/google/adk/a2a/utils/agent_to_a2a.py +++ b/src/google/adk/a2a/utils/agent_to_a2a.py @@ -35,6 +35,7 @@ from ...memory.in_memory_memory_service import InMemoryMemoryService from ...runners import Runner from ...sessions.in_memory_session_service import InMemorySessionService +from ...workflow._workflow import Workflow from ..executor.a2a_agent_executor import A2aAgentExecutor from ..experimental import a2a_experimental from .agent_card_builder import AgentCardBuilder @@ -77,7 +78,7 @@ def _load_agent_card( @a2a_experimental def to_a2a( - agent: BaseAgent, + agent: Union[BaseAgent, Workflow], *, host: str = "localhost", port: int = 8000, @@ -87,10 +88,11 @@ def to_a2a( runner: Optional[Runner] = None, lifespan: Optional[Callable[[Starlette], AsyncIterator[None]]] = None, ) -> Starlette: - """Convert an ADK agent to a A2A Starlette application. + """Convert an ADK BaseAgent or Workflow to an A2A Starlette application. Args: - agent: The ADK agent to convert + agent: The ADK BaseAgent (e.g. LlmAgent) or Workflow to + convert. host: The host for the A2A RPC URL (default: "localhost") port: The port for the A2A RPC URL (default: 8000) protocol: The protocol for the A2A RPC URL (default: "http") @@ -106,7 +108,7 @@ def to_a2a( events. Use this to run startup/shutdown logic (e.g. initializing database connections or loading resources). The context manager receives the Starlette app instance and can set state on - ``app.state``. + app.state. Returns: A Starlette application that can be run with uvicorn diff --git a/tests/unittests/a2a/utils/test_agent_card_builder.py b/tests/unittests/a2a/utils/test_agent_card_builder.py index 8549c16ec8..c979ad5307 100644 --- a/tests/unittests/a2a/utils/test_agent_card_builder.py +++ b/tests/unittests/a2a/utils/test_agent_card_builder.py @@ -42,8 +42,11 @@ from google.adk.agents.loop_agent import LoopAgent from google.adk.agents.parallel_agent import ParallelAgent from google.adk.agents.sequential_agent import SequentialAgent -from google.adk.examples import Example from google.adk.tools.example_tool import ExampleTool +from google.adk.workflow import FunctionNode +from google.adk.workflow import START +from google.adk.workflow import Workflow +from pydantic import BaseModel import pytest @@ -112,6 +115,31 @@ def test_init_with_empty_agent(self): with pytest.raises(ValueError, match="Agent cannot be None or empty."): AgentCardBuilder(agent=mock_agent) + def test_init_rejects_function_node(self): + """__init__ raises TypeError for a bare FunctionNode. + + FunctionNode is a BaseNode but is intended for use inside a + Workflow, not as a standalone A2A root. Without this guard the + builder would silently produce a degenerate "custom agent" card. + """ + + async def my_fn(node_input): + return f"echo: {node_input}" + + fn_node = FunctionNode(func=my_fn, name="echo_fn") + + with pytest.raises( + TypeError, match="requires a BaseAgent or Workflow, got FunctionNode" + ): + AgentCardBuilder(agent=fn_node) + + def test_init_rejects_arbitrary_object(self): + """__init__ raises TypeError for non-BaseNode objects.""" + with pytest.raises( + TypeError, match="requires a BaseAgent or Workflow, got str" + ): + AgentCardBuilder(agent="not an agent") + @patch("google.adk.a2a.utils.agent_card_builder._build_primary_skills") @patch("google.adk.a2a.utils.agent_card_builder._build_sub_agent_skills") async def test_build_success( @@ -211,6 +239,89 @@ async def test_build_raises_runtime_error_on_failure( ): await builder.build() + async def test_build_succeeds_for_llm_agent(self): + """AgentCardBuilder.build succeeds for a standalone LlmAgent. + + Regression coverage for the type-narrowing to BaseAgent | Workflow: + LlmAgent (a BaseAgent subclass) must continue to work end-to-end. + """ + agent = LlmAgent( + name="writer", + model="gemini-2.5-flash", + description="Writes a short reply.", + instruction="Write a short reply.", + ) + builder = AgentCardBuilder(agent=agent, rpc_url="http://localhost:8000/") + + card = await builder.build() + + assert isinstance(card, AgentCard) + assert card.name == "writer" + assert card.description == "Writes a short reply." + skill_ids = [skill.id for skill in card.skills] + assert "writer" in skill_ids + + async def test_build_succeeds_for_workflow_with_llm_agent_node(self): + """AgentCardBuilder.build succeeds for a Workflow (no sub_agents).""" + writer = LlmAgent( + name="writer", + model="gemini-2.5-flash", + description="Writes the reply.", + instruction="Write a short reply.", + ) + workflow = Workflow( + name="pipe", + description="A simple pipeline.", + edges=[(START, writer)], + ) + builder = AgentCardBuilder(agent=workflow, rpc_url="http://localhost:8000/") + + card = await builder.build() + + assert isinstance(card, AgentCard) + assert card.name == "pipe" + skill_ids = [skill.id for skill in card.skills] + assert "pipe" in skill_ids # primary workflow skill + assert any("writer" in sid for sid in skill_ids) # child node skill + + async def test_build_succeeds_for_workflow_with_output_schema_node(self): + """AgentCardBuilder.build succeeds for a Workflow whose LlmAgent has output_schema. + + Mirrors the exact repro from + https://github.com/google/adk-python/issues/5487. + """ + + class _Out(BaseModel): + text: str + + writer = LlmAgent( + name="writer", + model="gemini-2.5-flash", + instruction="Write a short reply.", + output_schema=_Out, + ) + workflow = Workflow(name="pipe", edges=[(START, writer)]) + builder = AgentCardBuilder(agent=workflow, rpc_url="http://localhost:8000/") + + card = await builder.build() + + assert card.name == "pipe" + primary_skill = next(s for s in card.skills if s.id == "pipe") + assert "graph_workflow" in primary_skill.tags + + async def test_build_succeeds_for_empty_workflow(self): + """AgentCardBuilder.build succeeds for a Workflow with no edges.""" + workflow = Workflow(name="empty_wf", description="An empty workflow.") + builder = AgentCardBuilder(agent=workflow, rpc_url="http://localhost:8000/") + + card = await builder.build() + + assert card.name == "empty_wf" + assert card.description == "An empty workflow." + # Only the primary skill, no orchestration skill since no child nodes. + assert len(card.skills) == 1 + assert "graph_workflow" in card.skills[0].tags + class TestHelperFunctions: """Test suite for helper functions.""" @@ -304,6 +415,22 @@ def test_get_agent_skill_name_custom_agent(self): # Assert assert result == "custom" + def test_get_agent_type_workflow(self): + """Test _get_agent_type for the v2 graph-based Workflow.""" + workflow = Workflow(name="wf") + + result = _get_agent_type(workflow) + + assert result == "graph_workflow" + + def test_get_agent_skill_name_workflow(self): + """Test _get_agent_skill_name for the v2 graph-based Workflow.""" + workflow = Workflow(name="wf") + + result = _get_agent_skill_name(workflow) + + assert result == "workflow" + def test_replace_pronouns_basic(self): """Test _replace_pronouns with basic pronoun replacement.""" # Arrange @@ -698,6 +825,36 @@ def test_get_workflow_description_custom_agent(self): # Assert assert result is None + def test_get_workflow_description_workflow_with_nodes(self): + """_get_workflow_description lists graph nodes for a Workflow.""" + writer = LlmAgent( + name="writer", + model="gemini-2.5-flash", + description="Writes the reply", + ) + reviewer = LlmAgent( + name="reviewer", + model="gemini-2.5-flash", + description="Reviews the reply", + ) + workflow = Workflow( + name="pipe", edges=[(START, writer), (writer, reviewer)] + ) + + result = _get_workflow_description(workflow) + + assert result is not None + assert "writer: Writes the reply" in result + assert "reviewer: Reviews the reply" in result + + def test_get_workflow_description_empty_workflow(self): + """_get_workflow_description returns None for a workflow with no nodes.""" + workflow = Workflow(name="empty_wf") + + result = _get_workflow_description(workflow) + + assert result is None + def test_build_sequential_description_single_agent(self): """Test _build_sequential_description with single sub-agent.""" # Arrange diff --git a/tests/unittests/a2a/utils/test_agent_to_a2a.py b/tests/unittests/a2a/utils/test_agent_to_a2a.py index a9e2458ebd..feb9b6149b 100644 --- a/tests/unittests/a2a/utils/test_agent_to_a2a.py +++ b/tests/unittests/a2a/utils/test_agent_to_a2a.py @@ -26,11 +26,15 @@ from google.adk.a2a.utils.agent_card_builder import AgentCardBuilder from google.adk.a2a.utils.agent_to_a2a import to_a2a from google.adk.agents.base_agent import BaseAgent +from google.adk.agents.llm_agent import LlmAgent from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService from google.adk.auth.credential_service.in_memory_credential_service import InMemoryCredentialService from google.adk.memory.in_memory_memory_service import InMemoryMemoryService from google.adk.runners import Runner from google.adk.sessions.in_memory_session_service import InMemorySessionService +from google.adk.workflow import FunctionNode +from google.adk.workflow import START +from google.adk.workflow import Workflow import pytest from starlette.applications import Starlette @@ -514,19 +518,18 @@ def test_to_a2a_with_none_agent(self): with pytest.raises(ValueError, match="Agent cannot be None or empty."): to_a2a(None) - async def test_to_a2a_with_invalid_agent_type(self): - """Test that to_a2a raises error when agent is not a BaseAgent.""" - # Arrange - invalid_agent = "not an agent" + def test_to_a2a_rejects_non_agent_non_workflow(self): + """to_a2a raises TypeError immediately for unsupported types. - # Act & Assert - # The error occurs during lifespan startup when building the agent card - app = to_a2a(invalid_agent) + Only BaseAgent (e.g. LlmAgent) and Workflow are valid + A2A roots. Other BaseNode subclasses (e.g. FunctionNode) and + arbitrary objects must be rejected at call time, not silently served + as a degenerate "custom agent". + """ with pytest.raises( - AttributeError, match="'str' object has no attribute 'name'" + TypeError, match="requires a BaseAgent or Workflow, got str" ): - async with app.router.lifespan_context(app): - pass + to_a2a("not an agent") @patch("google.adk.a2a.utils.agent_to_a2a.A2aAgentExecutor") @patch("google.adk.a2a.utils.agent_to_a2a.DefaultRequestHandler") @@ -1051,3 +1054,36 @@ async def custom_lifespan(app): "user_startup", "user_shutdown", ] + + async def test_to_a2a_succeeds_for_workflow(self): + """to_a2a accepts a Workflow and the Starlette lifespan completes.""" + writer = LlmAgent( + name="writer", + model="gemini-2.5-flash", + instruction="Write a short reply.", + ) + workflow = Workflow(name="pipe", edges=[(START, writer)]) + + app = to_a2a(workflow, port=8001) + + async with app.router.lifespan_context(app): + pass + + def test_to_a2a_rejects_function_node(self): + """to_a2a raises TypeError for a bare FunctionNode. + + FunctionNode is a BaseNode but is intended for use inside a + Workflow, not as a standalone A2A root. Passing one directly used + to silently produce a degenerate "custom agent" card; it now fails + fast at to_a2a() call time. + """ + + async def my_fn(node_input): + return f"echo: {node_input}" + + fn_node = FunctionNode(func=my_fn, name="echo_fn") + + with pytest.raises( + TypeError, match="requires a BaseAgent or Workflow, got FunctionNode" + ): + to_a2a(fn_node)