Skip to content

Commit c910961

Browse files
wuliang229copybara-github
authored andcommitted
feat: Add SSE streaming support to conformance tests
Co-authored-by: Liang Wu <wuliang@google.com> PiperOrigin-RevId: 882293566
1 parent 0847f51 commit c910961

11 files changed

Lines changed: 450 additions & 222 deletions

src/google/adk/cli/cli_tools_click.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from . import cli_create
3636
from . import cli_deploy
3737
from .. import version
38+
from ..agents.run_config import StreamingMode
3839
from ..evaluation.constants import MISSING_EVAL_DEPENDENCIES_MESSAGE
3940
from ..features import FeatureName
4041
from ..features import override_feature_enabled
@@ -230,10 +231,21 @@ def conformance():
230231
exists=True, dir_okay=True, file_okay=False, resolve_path=True
231232
),
232233
)
234+
@click.argument(
235+
"streaming-mode",
236+
type=click.Choice(
237+
[str(m.value) for m in StreamingMode], case_sensitive=False
238+
),
239+
callback=lambda ctx, param, value: next(
240+
(m for m in StreamingMode if str(m.value).lower() == value.lower()),
241+
value,
242+
),
243+
)
233244
@click.pass_context
234245
def cli_conformance_record(
235246
ctx,
236247
paths: tuple[str, ...],
248+
streaming_mode: StreamingMode,
237249
):
238250
"""Generate ADK conformance test YAML files from TestCaseInput specifications.
239251
@@ -273,7 +285,7 @@ def cli_conformance_record(
273285

274286
# Default to tests/ directory if no paths provided
275287
test_paths = [Path(p) for p in paths] if paths else [Path("tests").resolve()]
276-
asyncio.run(run_conformance_record(test_paths))
288+
asyncio.run(run_conformance_record(test_paths, streaming_mode))
277289

278290

279291
@conformance.command("test", cls=HelpfulCommand)
@@ -309,13 +321,28 @@ def cli_conformance_record(
309321
" directory."
310322
),
311323
)
324+
@click.option(
325+
"--streaming-mode",
326+
type=click.Choice(
327+
[str(m.value) for m in StreamingMode], case_sensitive=False
328+
),
329+
callback=lambda ctx, param, value: next(
330+
(m for m in StreamingMode if str(m.value).lower() == value.lower()),
331+
value,
332+
)
333+
if value is not None
334+
else None,
335+
required=False,
336+
default=None,
337+
)
312338
@click.pass_context
313339
def cli_conformance_test(
314340
ctx,
315341
paths: tuple[str, ...],
316342
mode: str,
317343
generate_report: bool,
318344
report_dir: Optional[str] = None,
345+
streaming_mode: Optional[StreamingMode] = None,
319346
):
320347
"""Run conformance tests to verify agent behavior consistency.
321348
@@ -342,9 +369,11 @@ def cli_conformance_test(
342369
\b
343370
category/
344371
test_name/
345-
spec.yaml # Test specification
346-
generated-recordings.yaml # Recorded interactions (replay mode)
347-
generated-session.yaml # Session data (replay mode)
372+
spec.yaml # Test specification
373+
generated-recordings.yaml # Recorded interactions (replay mode)
374+
generated-session.yaml # Session data (replay mode)
375+
generated-recordings-sse.yaml # Recorded SSE interactions (replay mode)
376+
generated-session-sse.yaml # SSE Session data (replay mode)
348377
349378
REPORT GENERATION:
350379
@@ -377,7 +406,6 @@ def cli_conformance_test(
377406
# Generate a test report in a specific directory
378407
adk conformance test --generate_report --report_dir=reports
379408
"""
380-
381409
try:
382410
from .conformance.cli_test import run_conformance_test
383411
except ImportError as e:
@@ -403,6 +431,7 @@ def cli_conformance_test(
403431
mode=mode.lower(),
404432
generate_report=generate_report,
405433
report_dir=report_dir,
434+
streaming_mode=streaming_mode,
406435
)
407436
)
408437

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright 2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from __future__ import annotations
17+
18+
import logging
19+
from typing import Any
20+
from typing import AsyncGenerator
21+
from typing import TYPE_CHECKING
22+
23+
from ...models.google_llm import Gemini
24+
25+
if TYPE_CHECKING:
26+
from ...models.llm_request import LlmRequest
27+
from ...models.llm_response import LlmResponse
28+
29+
logger = logging.getLogger('google_adk.' + __name__)
30+
31+
32+
class ReplayVerificationError(Exception):
33+
"""Exception raised when replay verification fails."""
34+
35+
36+
class _ConformanceTestGemini(Gemini):
37+
"""A mocked Gemini model for conformance test replay mode.
38+
39+
This class is used to mock the Gemini model in conformance test replay mode.
40+
It is a subclass of Gemini and overrides the `generate_content_async`` method to
41+
return a mocked response from the provided recordingss.
42+
"""
43+
44+
def __init__(
45+
self,
46+
*,
47+
config: dict[str, Any],
48+
**kwargs: Any,
49+
) -> None:
50+
super().__init__(**kwargs)
51+
recordings = config.get('_adk_replay_recordings')
52+
self._user_message_index = config.get('user_message_index')
53+
self._agent_name = config.get('agent_name')
54+
self._replay_index = config.get('current_replay_index')
55+
# Pre-filter LLM recordings for this agent and message index
56+
self._agent_llm_recordings = [
57+
recording.llm_recording
58+
for recording in recordings.recordings
59+
if recording.agent_name == self._agent_name
60+
and recording.user_message_index == self._user_message_index
61+
and recording.llm_recording
62+
]
63+
64+
async def generate_content_async(
65+
self, llm_request: LlmRequest, stream: bool = False
66+
) -> AsyncGenerator[LlmResponse, None]:
67+
"""Replay LLM response from recordings instead of making real call."""
68+
logger.debug(
69+
'Replaying LLM response for agent %s (index %d)',
70+
self._agent_name,
71+
self._replay_index,
72+
)
73+
74+
if self._replay_index >= len(self._agent_llm_recordings):
75+
raise ReplayVerificationError(
76+
'Runtime sent more LLM requests than expected for agent'
77+
f" '{self._agent_name}' at user_message_index"
78+
f' {self._user_message_index}. Expected'
79+
f' {len(self._agent_llm_recordings)}, but got request at index'
80+
f' {self._replay_index}'
81+
)
82+
83+
recording = self._agent_llm_recordings[self._replay_index]
84+
85+
# Verify request matches
86+
self._verify_llm_request_match(
87+
recording.llm_request, llm_request, self._replay_index
88+
)
89+
90+
for response in recording.llm_responses:
91+
yield response
92+
93+
def _verify_llm_request_match(
94+
self,
95+
recorded_request: LlmRequest,
96+
current_request: LlmRequest,
97+
replay_index: int,
98+
) -> None:
99+
"""Verify that the current LLM request exactly matches the recorded one."""
100+
# Comprehensive exclude dict for all fields that can differ between runs
101+
excluded_fields = {
102+
'live_connect_config': True,
103+
'config': { # some config fields can vary per run
104+
'http_options': True,
105+
'labels': True,
106+
},
107+
}
108+
109+
# Compare using model dumps with nested exclude dict
110+
recorded_dict = recorded_request.model_dump(
111+
exclude_none=True, exclude=excluded_fields, exclude_defaults=True
112+
)
113+
current_dict = current_request.model_dump(
114+
exclude_none=True, exclude=excluded_fields, exclude_defaults=True
115+
)
116+
117+
if recorded_dict != current_dict:
118+
raise ReplayVerificationError(
119+
f"""LLM request mismatch in turn {self._user_message_index} for agent '{self._agent_name}' (index {replay_index}):
120+
recorded: {recorded_dict}
121+
current: {current_dict}"""
122+
)

src/google/adk/cli/conformance/_generate_markdown_utils.py

Lines changed: 76 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
def generate_markdown_report(
3131
version_data: dict[str, Any],
32-
summary: _ConformanceTestSummary,
32+
summaries: list[_ConformanceTestSummary],
3333
report_dir: Optional[str],
3434
) -> None:
3535
"""Generates a Markdown report of the test results."""
@@ -44,46 +44,94 @@ def generate_markdown_report(
4444
report_path = Path(report_dir) / report_name
4545
report_path.parent.mkdir(parents=True, exist_ok=True)
4646

47+
# Collect all test results
48+
test_results = {}
49+
test_descriptions = {}
50+
streaming_modes = []
51+
52+
for summary in summaries:
53+
mode_name = (
54+
str(summary.streaming_mode.value)
55+
if summary.streaming_mode.value is not None
56+
else "none"
57+
)
58+
streaming_modes.append(mode_name)
59+
for result in summary.results:
60+
key = (result.category, result.name)
61+
if key not in test_results:
62+
test_results[key] = {}
63+
test_results[key][mode_name] = result
64+
if result.description:
65+
test_descriptions[key] = result.description
66+
67+
streaming_modes.sort()
68+
4769
with open(report_path, "w") as f:
4870
f.write("# ADK Python Conformance Test Report\n\n")
49-
50-
# Summary
5171
f.write("## Summary\n\n")
5272
f.write(f"- **ADK Version**: {server_version}\n")
53-
f.write(f"- **Language**: {language} {language_version}\n")
54-
f.write(f"- **Total Tests**: {summary.total_tests}\n")
55-
f.write(f"- **Passed**: {summary.passed_tests}\n")
56-
f.write(f"- **Failed**: {summary.failed_tests}\n")
57-
f.write(f"- **Success Rate**: {summary.success_rate:.1f}%\n\n")
73+
f.write(f"- **Language**: {language} {language_version}\n\n")
5874

59-
# Table
60-
f.write("## Test Results\n\n")
61-
f.write("| Status | Category | Test Name | Description |\n")
62-
f.write("| :--- | :--- | :--- | :--- |\n")
75+
f.write(
76+
"| Streaming Mode | Total Tests | Passed | Failed | Success Rate |\n"
77+
)
78+
f.write("| :--- | :--- | :--- | :--- | :--- |\n")
6379

64-
for result in summary.results:
65-
status_icon = "✅ PASS" if result.success else "❌ FAIL"
66-
description = (
67-
result.description.replace("\n", " ") if result.description else ""
80+
for summary in summaries:
81+
mode_name = (
82+
str(summary.streaming_mode.value)
83+
if summary.streaming_mode.value is not None
84+
else "none"
6885
)
6986
f.write(
70-
f"| {status_icon} | {result.category} | {result.name} |"
71-
f" {description} |\n"
87+
f"| {mode_name} | {summary.total_tests} |"
88+
f" {summary.passed_tests} | {summary.failed_tests} |"
89+
f" {summary.success_rate:.1f}% |\n"
90+
)
91+
f.write("\n")
92+
93+
# Table
94+
f.write("## Test Results\n\n")
95+
headers = ["Category", "Test Name", "Description"] + streaming_modes
96+
f.write("| " + " | ".join(headers) + " |\n")
97+
f.write("| " + " | ".join([":---"] * len(headers)) + " |\n")
98+
99+
sorted_keys = sorted(test_results.keys())
100+
for category, name in sorted_keys:
101+
description = test_descriptions.get((category, name), "").replace(
102+
"\n", " "
72103
)
104+
row = [category, name, description]
105+
for mode in streaming_modes:
106+
result = test_results[(category, name)].get(mode)
107+
if result:
108+
status_icon = "✅ PASS" if result.success else "❌ FAIL"
109+
else:
110+
status_icon = "N/A"
111+
row.append(status_icon)
112+
f.write("| " + " | ".join(row) + " |\n")
73113

74114
f.write("\n")
75115

76116
# Failed Tests Details
77-
if summary.failed_tests > 0:
117+
has_failures = any(s.failed_tests > 0 for s in summaries)
118+
if has_failures:
78119
f.write("## Failed Tests Details\n\n")
79-
for result in summary.results:
80-
if not result.success:
81-
f.write(f"### {result.category}/{result.name}\n\n")
82-
if result.description:
83-
f.write(f"**Description**: {result.description}\n\n")
84-
f.write("**Error**:\n")
85-
f.write("```\n")
86-
f.write(f"{result.error_message}\n")
87-
f.write("```\n\n")
120+
for summary in summaries:
121+
if summary.failed_tests > 0:
122+
mode_name = (
123+
str(summary.streaming_mode.value)
124+
if summary.streaming_mode.value is not None
125+
else "none"
126+
)
127+
for result in summary.results:
128+
if not result.success:
129+
f.write(f"### {result.category}/{result.name} ({mode_name})\n\n")
130+
if result.description:
131+
f.write(f"**Description**: {result.description}\n\n")
132+
f.write("**Error**:\n")
133+
f.write("```\n")
134+
f.write(f"{result.error_message}\n")
135+
f.write("```\n\n")
88136

89137
click.secho(f"\nReport generated at: {report_path.resolve()}", fg="blue")

src/google/adk/cli/conformance/_generated_file_utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import click
2424
import yaml
2525

26+
from ...agents.run_config import StreamingMode
2627
from ...sessions.session import Session
2728
from .test_case import TestSpec
2829

@@ -35,9 +36,17 @@ def load_test_case(test_case_dir: Path) -> TestSpec:
3536
return TestSpec.model_validate(data)
3637

3738

38-
def load_recorded_session(test_case_dir: Path) -> Optional[Session]:
39-
"""Load recorded session data from generated-session.yaml file."""
40-
session_file = test_case_dir / "generated-session.yaml"
39+
def load_recorded_session(
40+
test_case_dir: Path, streaming_mode: StreamingMode
41+
) -> Optional[Session]:
42+
"""Load recorded session data from YAML file."""
43+
if streaming_mode == StreamingMode.SSE:
44+
session_file = test_case_dir / "generated-session-sse.yaml"
45+
elif streaming_mode == StreamingMode.NONE:
46+
session_file = test_case_dir / "generated-session.yaml"
47+
else:
48+
raise ValueError(f"Unsupported streaming mode: {streaming_mode}")
49+
4150
if not session_file.exists():
4251
return None
4352

0 commit comments

Comments
 (0)