Skip to content

Commit 342b59d

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: propagate model_version and other metadata in streaming responses
PiperOrigin-RevId: 928881287
1 parent fd0a11d commit 342b59d

2 files changed

Lines changed: 103 additions & 0 deletions

File tree

src/google/adk/utils/streaming_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,10 @@ async def process_response(
336336
yield LlmResponse(
337337
content=types.ModelContent(parts=parts),
338338
usage_metadata=llm_response.usage_metadata,
339+
grounding_metadata=llm_response.grounding_metadata,
340+
citation_metadata=llm_response.citation_metadata,
341+
finish_reason=llm_response.finish_reason,
342+
model_version=llm_response.model_version,
339343
)
340344
self._thought_text = ''
341345
self._text = ''
@@ -386,6 +390,7 @@ def close(self) -> Optional[LlmResponse]:
386390
usage_metadata=self._usage_metadata,
387391
finish_reason=finish_reason,
388392
partial=False,
393+
model_version=self._response.model_version,
389394
)
390395

391396
# ========== Non-Progressive SSE Streaming (old behavior) ==========
@@ -405,4 +410,5 @@ def close(self) -> Optional[LlmResponse]:
405410
usage_metadata=self._usage_metadata,
406411
finish_reason=finish_reason,
407412
partial=False,
413+
model_version=self._response.model_version,
408414
)

tests/unittests/utils/test_streaming_utils.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,103 @@ async def run_test():
386386
else:
387387
await run_test()
388388

389+
@pytest.mark.asyncio
390+
@pytest.mark.parametrize("use_progressive_sse", [False, True])
391+
async def test_close_propagates_model_version(self, use_progressive_sse):
392+
"""close() should carry model_version into the aggregated response."""
393+
aggregator = streaming_utils.StreamingResponseAggregator()
394+
response1 = types.GenerateContentResponse(
395+
candidates=[
396+
types.Candidate(
397+
content=types.Content(parts=[types.Part(text="Hello ")]),
398+
)
399+
],
400+
model_version="gemini-test-1.0",
401+
)
402+
response2 = types.GenerateContentResponse(
403+
candidates=[
404+
types.Candidate(
405+
content=types.Content(parts=[types.Part(text="World!")]),
406+
finish_reason=types.FinishReason.STOP,
407+
)
408+
],
409+
model_version="gemini-test-1.0",
410+
)
411+
412+
async def run_test():
413+
async for _ in aggregator.process_response(response1):
414+
pass
415+
async for _ in aggregator.process_response(response2):
416+
pass
417+
418+
closed_response = aggregator.close()
419+
assert closed_response is not None
420+
assert closed_response.model_version == "gemini-test-1.0"
421+
422+
if use_progressive_sse:
423+
with temporary_feature_override(
424+
FeatureName.PROGRESSIVE_SSE_STREAMING, True
425+
):
426+
await run_test()
427+
else:
428+
await run_test()
429+
430+
@pytest.mark.asyncio
431+
async def test_non_progressive_merged_yield_propagates_model_version(self):
432+
"""The mid-stream merged text event should carry model_version forward.
433+
434+
In non-progressive mode, when a new non-text response arrives after buffered
435+
text, the aggregator yields a synthesized merged-text LlmResponse before
436+
yielding the current partial. That merged event must preserve fields from
437+
the source response (model_version, grounding_metadata, citation_metadata,
438+
finish_reason).
439+
"""
440+
# PROGRESSIVE_SSE_STREAMING defaults to on; explicitly disable it to
441+
# exercise the non-progressive merged-yield code path under test.
442+
with temporary_feature_override(
443+
FeatureName.PROGRESSIVE_SSE_STREAMING, False
444+
):
445+
aggregator = streaming_utils.StreamingResponseAggregator()
446+
# First: buffer some text.
447+
response1 = types.GenerateContentResponse(
448+
candidates=[
449+
types.Candidate(
450+
content=types.Content(
451+
parts=[types.Part(text="Hello World!")]
452+
),
453+
)
454+
],
455+
model_version="gemini-test-2.0",
456+
)
457+
# Second: a response without text triggers the merged yield path.
458+
response2 = types.GenerateContentResponse(
459+
candidates=[
460+
types.Candidate(
461+
content=types.Content(parts=[]),
462+
finish_reason=types.FinishReason.STOP,
463+
)
464+
],
465+
model_version="gemini-test-2.0",
466+
)
467+
468+
results = []
469+
async for r in aggregator.process_response(response1):
470+
results.append(r)
471+
async for r in aggregator.process_response(response2):
472+
results.append(r)
473+
474+
# The synthesized merged-text event should carry model_version.
475+
merged_events = [
476+
r
477+
for r in results
478+
if r.content
479+
and r.content.parts
480+
and r.content.parts[0].text == "Hello World!"
481+
and not r.partial
482+
]
483+
assert merged_events, "expected a merged non-partial text event"
484+
assert merged_events[0].model_version == "gemini-test-2.0"
485+
389486

390487
class TestFunctionCallIdGeneration:
391488
"""Tests for function call ID generation in streaming mode.

0 commit comments

Comments
 (0)