Skip to content

Commit b0c3bdd

Browse files
committed
feat(run_config): add model_input_context for transient context in LLM requests
This update introduces a new attribute, model_input_context, to the RunConfig class, allowing callers to provide transient context for each invocation without altering the conversation history. Additionally, the LLM request processing has been updated to incorporate this context appropriately. Unit tests have been added to verify the correct behavior of this feature.
1 parent 03d6208 commit b0c3bdd

4 files changed

Lines changed: 194 additions & 0 deletions

File tree

src/google/adk/agents/run_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,14 @@ class RunConfig(BaseModel):
344344
)
345345
"""
346346

347+
model_input_context: Optional[list[types.Content]] = None
348+
"""Transient context to include in the model input for this invocation.
349+
350+
The Runner does not persist these contents to the session. They are only
351+
added to the LLM request assembled for the current invocation, which lets
352+
callers provide per-turn context without changing the conversation history.
353+
"""
354+
347355
@model_validator(mode='before')
348356
@classmethod
349357
def check_for_deprecated_save_live_audio(cls, data: Any) -> Any:

src/google/adk/flows/llm_flows/contents.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,16 @@ async def run_async(
8585
preserve_function_call_ids=preserve_function_call_ids,
8686
)
8787

88+
if (
89+
invocation_context.run_config
90+
and invocation_context.run_config.model_input_context
91+
):
92+
_add_model_input_context_to_user_content(
93+
invocation_context,
94+
llm_request,
95+
copy.deepcopy(invocation_context.run_config.model_input_context),
96+
)
97+
8898
# Add instruction-related contents to proper position in conversation
8999
await _add_instructions_to_user_content(
90100
invocation_context, llm_request, instruction_related_contents
@@ -845,6 +855,26 @@ def _content_contains_function_response(content: types.Content) -> bool:
845855
return False
846856

847857

858+
def _add_model_input_context_to_user_content(
859+
invocation_context: InvocationContext,
860+
llm_request: LlmRequest,
861+
model_input_context: list[types.Content],
862+
) -> None:
863+
"""Insert transient model input context before the invocation user content."""
864+
if not model_input_context:
865+
return
866+
867+
insert_index = 0
868+
user_content = invocation_context.user_content
869+
if user_content:
870+
for i in range(len(llm_request.contents) - 1, -1, -1):
871+
if llm_request.contents[i] == user_content:
872+
insert_index = i
873+
break
874+
875+
llm_request.contents[insert_index:insert_index] = model_input_context
876+
877+
848878
async def _add_instructions_to_user_content(
849879
invocation_context: InvocationContext,
850880
llm_request: LlmRequest,

tests/unittests/agents/test_llm_agent_include_contents.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Unit tests for LlmAgent include_contents field behavior."""
1616

1717
from google.adk.agents.llm_agent import LlmAgent
18+
from google.adk.agents.run_config import RunConfig
1819
from google.adk.agents.sequential_agent import SequentialAgent
1920
from google.genai import types
2021
import pytest
@@ -189,6 +190,153 @@ def simple_tool(message: str) -> dict:
189190
assert len(mock_model.requests[0].config.tools) > 0
190191

191192

193+
def test_model_input_context_is_sent_to_model_without_persisting_to_session():
194+
mock_model = testing_utils.MockModel.create(responses=["Answer"])
195+
agent = LlmAgent(name="test_agent", model=mock_model)
196+
runner = testing_utils.InMemoryRunner(agent)
197+
session = runner.session
198+
199+
list(
200+
runner.runner.run(
201+
user_id=session.user_id,
202+
session_id=session.id,
203+
new_message=testing_utils.get_user_content("Question"),
204+
run_config=RunConfig(
205+
model_input_context=[
206+
types.UserContent("Relevant context for this turn")
207+
]
208+
),
209+
)
210+
)
211+
212+
assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [
213+
("user", "Relevant context for this turn"),
214+
("user", "Question"),
215+
]
216+
assert testing_utils.simplify_events(runner.session.events) == [
217+
("user", "Question"),
218+
("test_agent", "Answer"),
219+
]
220+
221+
222+
def test_model_input_context_stays_before_user_message_after_tool_call():
223+
def simple_tool(message: str) -> dict:
224+
return {"result": f"Tool processed: {message}"}
225+
226+
mock_model = testing_utils.MockModel.create(
227+
responses=[
228+
types.Part.from_function_call(
229+
name="simple_tool", args={"message": "payload"}
230+
),
231+
"Answer",
232+
]
233+
)
234+
agent = LlmAgent(name="test_agent", model=mock_model, tools=[simple_tool])
235+
runner = testing_utils.InMemoryRunner(agent)
236+
session = runner.session
237+
238+
list(
239+
runner.runner.run(
240+
user_id=session.user_id,
241+
session_id=session.id,
242+
new_message=testing_utils.get_user_content("Question"),
243+
run_config=RunConfig(
244+
model_input_context=[
245+
types.UserContent("Relevant context for this turn")
246+
]
247+
),
248+
)
249+
)
250+
251+
assert testing_utils.simplify_contents(mock_model.requests[0].contents) == [
252+
("user", "Relevant context for this turn"),
253+
("user", "Question"),
254+
]
255+
assert testing_utils.simplify_contents(mock_model.requests[1].contents) == [
256+
("user", "Relevant context for this turn"),
257+
("user", "Question"),
258+
(
259+
"model",
260+
types.Part.from_function_call(
261+
name="simple_tool", args={"message": "payload"}
262+
),
263+
),
264+
(
265+
"user",
266+
types.Part.from_function_response(
267+
name="simple_tool",
268+
response={"result": "Tool processed: payload"},
269+
),
270+
),
271+
]
272+
assert testing_utils.simplify_events(runner.session.events) == [
273+
("user", "Question"),
274+
(
275+
"test_agent",
276+
types.Part.from_function_call(
277+
name="simple_tool", args={"message": "payload"}
278+
),
279+
),
280+
(
281+
"test_agent",
282+
types.Part.from_function_response(
283+
name="simple_tool",
284+
response={"result": "Tool processed: payload"},
285+
),
286+
),
287+
("test_agent", "Answer"),
288+
]
289+
290+
291+
def test_model_input_context_with_include_contents_none_sub_agent():
292+
agent1_model = testing_utils.MockModel.create(
293+
responses=["Agent1 response: XYZ"]
294+
)
295+
agent1 = LlmAgent(name="agent1", model=agent1_model)
296+
297+
agent2_model = testing_utils.MockModel.create(
298+
responses=["Agent2 final response"]
299+
)
300+
agent2 = LlmAgent(
301+
name="agent2",
302+
model=agent2_model,
303+
include_contents="none",
304+
)
305+
sequential_agent = SequentialAgent(
306+
name="sequential_test_agent", sub_agents=[agent1, agent2]
307+
)
308+
runner = testing_utils.InMemoryRunner(sequential_agent)
309+
session = runner.session
310+
311+
list(
312+
runner.runner.run(
313+
user_id=session.user_id,
314+
session_id=session.id,
315+
new_message=testing_utils.get_user_content("Original user request"),
316+
run_config=RunConfig(
317+
model_input_context=[
318+
types.UserContent("Relevant context for this turn")
319+
]
320+
),
321+
)
322+
)
323+
324+
assert testing_utils.simplify_contents(agent1_model.requests[0].contents) == [
325+
("user", "Relevant context for this turn"),
326+
("user", "Original user request"),
327+
]
328+
assert testing_utils.simplify_contents(agent2_model.requests[0].contents) == [
329+
("user", "Relevant context for this turn"),
330+
(
331+
"user",
332+
[
333+
types.Part(text="For context:"),
334+
types.Part(text="[agent1] said: Agent1 response: XYZ"),
335+
],
336+
),
337+
]
338+
339+
192340
@pytest.mark.asyncio
193341
async def test_include_contents_none_sequential_agents():
194342
"""Test include_contents='none' with sequential agents."""

tests/unittests/agents/test_run_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,3 +97,11 @@ def test_avatar_config_with_name():
9797
assert run_config.avatar_config == avatar_config
9898
assert run_config.avatar_config.avatar_name == "test_avatar"
9999
assert run_config.avatar_config.customized_avatar is None
100+
101+
102+
def test_model_input_context_accepts_transient_contents():
103+
context_content = types.UserContent("Relevant context for this turn")
104+
105+
run_config = RunConfig(model_input_context=[context_content])
106+
107+
assert run_config.model_input_context == [context_content]

0 commit comments

Comments
 (0)