diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index 65c4156744..a1575bdce6 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -40,6 +40,7 @@ from .llm_response import LlmResponse if TYPE_CHECKING: + from google.auth.credentials import Credentials from google.genai import Client from .llm_request import LlmRequest @@ -92,6 +93,7 @@ def __init__( custom_headers: dict[str, str] | None = None, retry_options: Optional[types.HttpRetryOptions] = None, api_type: ApiType | str = ApiType.UNKNOWN, + credentials: Credentials | None = None, ): """Initializes the Apigee LLM backend. @@ -123,6 +125,11 @@ def __init__( authorization headers in Vertex AI and Gemini API calls. retry_options: Allow google-genai to retry failed responses. api_type: The type of API to use. One of `ApiType` or string. + credentials: Optional google-auth credentials passed through to the + underlying `genai.Client`. Use this when the Apigee proxy requires + additional OAuth scopes (e.g., `userinfo.email` for tokeninfo-based + caller identification). When omitted, the default `genai.Client` + authentication flow is used. """ # fmt: skip super().__init__(model=model, retry_options=retry_options) @@ -165,6 +172,7 @@ def __init__( ) self._custom_headers = custom_headers or {} self._user_agent = f'google-adk/{adk_version.__version__}' + self._credentials = credentials @classmethod @override @@ -239,6 +247,8 @@ def api_client(self) -> Client: if self._isvertexai: kwargs_for_client['project'] = self._project kwargs_for_client['location'] = self._location + if self._credentials is not None: + kwargs_for_client['credentials'] = self._credentials return Client( http_options=http_options, diff --git a/src/google/adk/sessions/in_memory_session_service.py b/src/google/adk/sessions/in_memory_session_service.py index 1bef516086..b8f6cfab46 100644 --- a/src/google/adk/sessions/in_memory_session_service.py +++ b/src/google/adk/sessions/in_memory_session_service.py @@ -342,8 +342,9 @@ def _warning(message: str) -> None: # Update the storage session storage_session = self.sessions[app_name][user_id].get(session_id) - storage_session.events.append(event) - storage_session.last_update_time = event.timestamp + if storage_session is not session: + storage_session.events.append(event) + storage_session.last_update_time = event.timestamp if event.actions and event.actions.state_delta: state_deltas = _session_util.extract_state_delta( diff --git a/tests/unittests/models/test_apigee_llm.py b/tests/unittests/models/test_apigee_llm.py index ecbb61d18f..1e371e8aa1 100644 --- a/tests/unittests/models/test_apigee_llm.py +++ b/tests/unittests/models/test_apigee_llm.py @@ -651,6 +651,71 @@ def test_parse_response_usage_metadata(): assert llm_response.usage_metadata.thoughts_token_count == 4 +@pytest.mark.asyncio +@mock.patch('google.genai.Client') +async def test_api_client_passes_credentials_when_provided( + mock_client_constructor, llm_request +): + """Tests that credentials passed to __init__ are forwarded to genai.Client.""" + mock_credentials = mock.Mock() + + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + apigee_llm = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + credentials=mock_credentials, + ) + _ = [resp async for resp in apigee_llm.generate_content_async(llm_request)] + + _, kwargs = mock_client_constructor.call_args + assert kwargs['credentials'] is mock_credentials + + +@pytest.mark.asyncio +@mock.patch('google.genai.Client') +async def test_api_client_omits_credentials_when_not_provided( + mock_client_constructor, llm_request +): + """Tests that credentials kwarg is not forwarded when not supplied.""" + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + apigee_llm = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + ) + _ = [resp async for resp in apigee_llm.generate_content_async(llm_request)] + + _, kwargs = mock_client_constructor.call_args + assert 'credentials' not in kwargs + + def test_parse_response_with_refusal(): """Tests that CompletionsHTTPClient parses refusal correctly.""" client = CompletionsHTTPClient(base_url='http://test') diff --git a/tests/unittests/sessions/test_session_service.py b/tests/unittests/sessions/test_session_service.py index 2d7d89f15f..02f5159a45 100644 --- a/tests/unittests/sessions/test_session_service.py +++ b/tests/unittests/sessions/test_session_service.py @@ -1013,6 +1013,30 @@ async def test_append_event_allows_markerless_current_session(): await service.close() +@pytest.mark.asyncio +async def test_append_event_when_session_is_same_ref_as_storage_session(): + """Tests that appending an event to a session only appends it once if the user-passed session and the underlying storage session are the same object.""" + service = InMemorySessionService() + app_name = 'my_app' + user_id = 'test_user' + + # Create a session + session = await service.create_session(app_name=app_name, user_id=user_id) + + # Get the actual storage event object from the underlying storage + storage_session = service.sessions[app_name][user_id][session.id] + + # Append the event to the storage session directly + event = Event(invocation_id='inv1', author='user') + await service.append_event(session=storage_session, event=event) + + # Verify that the storage session has only one event + final_session = await service.get_session( + app_name=app_name, user_id=user_id, session_id=session.id + ) + assert len(final_session.events) == 1 + + @pytest.mark.asyncio async def test_get_session_with_config(session_service): app_name = 'my_app'