diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index db897637c3..186dd64ae1 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -507,14 +507,19 @@ async def run_live( attempt += 1 if not llm_request.live_connect_config: llm_request.live_connect_config = types.LiveConnectConfig() - if not llm_request.live_connect_config.session_resumption: + session_resumption = ( + llm_request.live_connect_config.session_resumption + ) + if not session_resumption: + session_resumption = types.SessionResumptionConfig() llm_request.live_connect_config.session_resumption = ( - types.SessionResumptionConfig() + session_resumption ) - llm_request.live_connect_config.session_resumption.handle = ( + session_resumption.handle = ( invocation_context.live_session_resumption_handle ) - llm_request.live_connect_config.session_resumption.transparent = True + if session_resumption.transparent is None: + session_resumption.transparent = True logger.info( 'Establishing live connection for agent: %s', diff --git a/tests/unittests/flows/llm_flows/test_base_llm_flow.py b/tests/unittests/flows/llm_flows/test_base_llm_flow.py index ce2e83b6f7..a16c00dcd1 100644 --- a/tests/unittests/flows/llm_flows/test_base_llm_flow.py +++ b/tests/unittests/flows/llm_flows/test_base_llm_flow.py @@ -623,6 +623,78 @@ async def mock_receive_2(): assert invocation_context.live_session_resumption_handle == 'test_handle' +@pytest.mark.asyncio +async def test_run_live_reconnect_preserves_nontransparent_resumption(): + """Test that reconnect does not force transparent resumption.""" + from google.adk.agents.live_request_queue import LiveRequestQueue + from websockets.exceptions import ConnectionClosed + + real_model = Gemini() + mock_connection = mock.AsyncMock() + + async def mock_receive(): + yield LlmResponse( + live_session_resumption_update=types.LiveServerSessionResumptionUpdate( + new_handle='test_handle' + ) + ) + raise ConnectionClosed(None, None) + + mock_connection.receive = mock.Mock(side_effect=mock_receive) + + agent = Agent(name='test_agent', model=real_model) + invocation_context = await testing_utils.create_invocation_context( + agent=agent + ) + invocation_context.live_request_queue = LiveRequestQueue() + + flow = BaseLlmFlowForTesting() + + async def mock_preprocess(ctx, req): + req.live_connect_config.session_resumption = types.SessionResumptionConfig( + transparent=False + ) + if False: + yield + + with mock.patch.object( + flow, '_preprocess_async', side_effect=mock_preprocess + ): + with mock.patch.object(flow, '_send_to_model', new_callable=AsyncMock): + mock_connection_2 = mock.AsyncMock() + + class StopError(Exception): + pass + + async def mock_receive_2(): + yield LlmResponse( + content=types.Content(parts=[types.Part.from_text(text='hi')]) + ) + raise StopError('stop') + + mock_connection_2.receive = mock.Mock(side_effect=mock_receive_2) + + mock_aenter = mock.AsyncMock() + mock_aenter.side_effect = [mock_connection, mock_connection_2] + + with mock.patch( + 'google.adk.models.google_llm.Gemini.connect' + ) as mock_connect: + mock_connect.return_value.__aenter__ = mock_aenter + + try: + async for _ in flow.run_live(invocation_context): + pass + except StopError: + pass + + reconnect_request = mock_connect.call_args_list[1].args[0] + assert ( + reconnect_request.live_connect_config.session_resumption.transparent + is False + ) + + @pytest.mark.asyncio async def test_run_live_skips_send_history_on_resumption(): """Test that run_live skips send_history when resuming a session."""