Skip to content

Commit 1909869

Browse files
Poggeccicopybara-github
authored andcommitted
Add endInvocation field to Event Actions to facilitate interrupting the agent loop after a tool call
PiperOrigin-RevId: 773760597
1 parent ec31bac commit 1909869

3 files changed

Lines changed: 172 additions & 11 deletions

File tree

core/src/main/java/com/google/adk/events/EventActions.java

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ public class EventActions {
3737
private Optional<Boolean> escalate = Optional.empty();
3838
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs =
3939
new ConcurrentHashMap<>();
40+
private Optional<Boolean> endInvocation = Optional.empty();
4041

4142
/** Default constructor for Jackson. */
4243
public EventActions() {}
@@ -112,6 +113,19 @@ public void setRequestedAuthConfigs(
112113
this.requestedAuthConfigs = requestedAuthConfigs;
113114
}
114115

116+
@JsonProperty("endInvocation")
117+
public Optional<Boolean> endInvocation() {
118+
return endInvocation;
119+
}
120+
121+
public void setEndInvocation(Optional<Boolean> endInvocation) {
122+
this.endInvocation = endInvocation;
123+
}
124+
125+
public void setEndInvocation(boolean endInvocation) {
126+
this.endInvocation = Optional.of(endInvocation);
127+
}
128+
115129
public static Builder builder() {
116130
return new Builder();
117131
}
@@ -133,7 +147,8 @@ public boolean equals(Object o) {
133147
&& Objects.equals(artifactDelta, that.artifactDelta)
134148
&& Objects.equals(transferToAgent, that.transferToAgent)
135149
&& Objects.equals(escalate, that.escalate)
136-
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs);
150+
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs)
151+
&& Objects.equals(endInvocation, that.endInvocation);
137152
}
138153

139154
@Override
@@ -144,7 +159,8 @@ public int hashCode() {
144159
artifactDelta,
145160
transferToAgent,
146161
escalate,
147-
requestedAuthConfigs);
162+
requestedAuthConfigs,
163+
endInvocation);
148164
}
149165

150166
/** Builder for {@link EventActions}. */
@@ -156,6 +172,7 @@ public static class Builder {
156172
private Optional<Boolean> escalate = Optional.empty();
157173
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs =
158174
new ConcurrentHashMap<>();
175+
private Optional<Boolean> endInvocation = Optional.empty();
159176

160177
public Builder() {}
161178

@@ -166,6 +183,7 @@ private Builder(EventActions eventActions) {
166183
this.transferToAgent = eventActions.transferToAgent();
167184
this.escalate = eventActions.escalate();
168185
this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs());
186+
this.endInvocation = eventActions.endInvocation();
169187
}
170188

171189
@CanIgnoreReturnValue
@@ -211,6 +229,13 @@ public Builder requestedAuthConfigs(
211229
return this;
212230
}
213231

232+
@CanIgnoreReturnValue
233+
@JsonProperty("endInvocation")
234+
public Builder endInvocation(boolean endInvocation) {
235+
this.endInvocation = Optional.of(endInvocation);
236+
return this;
237+
}
238+
214239
@CanIgnoreReturnValue
215240
public Builder merge(EventActions other) {
216241
if (other.skipSummarization().isPresent()) {
@@ -231,6 +256,9 @@ public Builder merge(EventActions other) {
231256
if (other.requestedAuthConfigs() != null) {
232257
this.requestedAuthConfigs.putAll(other.requestedAuthConfigs());
233258
}
259+
if (other.endInvocation().isPresent()) {
260+
this.endInvocation = other.endInvocation();
261+
}
234262
return this;
235263
}
236264

@@ -242,6 +270,7 @@ public EventActions build() {
242270
eventActions.setTransferToAgent(this.transferToAgent);
243271
eventActions.setEscalate(this.escalate);
244272
eventActions.setRequestedAuthConfigs(this.requestedAuthConfigs);
273+
eventActions.setEndInvocation(this.endInvocation);
245274
return eventActions;
246275
}
247276
}

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -383,9 +383,12 @@ public Flowable<Event> run(InvocationContext invocationContext) {
383383
.toList()
384384
.flatMapPublisher(
385385
eventList -> {
386-
if (eventList.isEmpty() || Iterables.getLast(eventList).finalResponse()) {
386+
if (eventList.isEmpty()
387+
|| Iterables.getLast(eventList).finalResponse()
388+
|| Iterables.getLast(eventList).actions().endInvocation().orElse(false)) {
387389
logger.debug(
388-
"Ending flow execution based on final response or empty event list.");
390+
"Ending flow execution based on final response, endInvocation action or"
391+
+ " empty event list.");
389392
return Flowable.empty();
390393
} else {
391394
logger.debug("Continuing to next step of the flow.");
@@ -524,18 +527,21 @@ public void onError(Throwable e) {
524527
.content(event.content().get());
525528
}
526529
if (functionResponses.stream()
527-
.anyMatch(
528-
functionResponse ->
529-
functionResponse
530-
.name()
531-
.orElse("")
532-
.equals("transferToAgent"))) {
530+
.anyMatch(
531+
functionResponse ->
532+
functionResponse
533+
.name()
534+
.orElse("")
535+
.equals("transferToAgent"))
536+
|| event.actions().endInvocation().orElse(false)) {
533537
sendTask.dispose();
534538
connection.close();
535539
}
536540
});
537541

538-
return receiveFlow.startWithIterable(preResult.events());
542+
return receiveFlow
543+
.takeWhile(event -> !event.actions().endInvocation().orElse(false))
544+
.startWithIterable(preResult.events());
539545
});
540546
}
541547

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
/*
2+
* Copyright 2025 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.adk.flows.llmflows;
18+
19+
import static com.google.adk.testing.TestUtils.createInvocationContext;
20+
import static com.google.adk.testing.TestUtils.createLlmResponse;
21+
import static com.google.adk.testing.TestUtils.createTestAgentBuilder;
22+
import static com.google.adk.testing.TestUtils.createTestLlm;
23+
import static com.google.adk.testing.TestUtils.simplifyEvents;
24+
import static com.google.common.truth.Truth.assertThat;
25+
26+
import com.google.adk.agents.InvocationContext;
27+
import com.google.adk.agents.LlmAgent;
28+
import com.google.adk.events.Event;
29+
import com.google.adk.runner.InMemoryRunner;
30+
import com.google.adk.runner.Runner;
31+
import com.google.adk.sessions.Session;
32+
import com.google.adk.tools.BaseTool;
33+
import com.google.adk.tools.ToolContext;
34+
import com.google.common.collect.ImmutableList;
35+
import com.google.common.collect.ImmutableMap;
36+
import com.google.genai.types.Content;
37+
import com.google.genai.types.FunctionDeclaration;
38+
import com.google.genai.types.Part;
39+
import com.google.genai.types.Schema;
40+
import io.reactivex.rxjava3.core.Flowable;
41+
import io.reactivex.rxjava3.core.Single;
42+
import java.util.ArrayList;
43+
import java.util.List;
44+
import java.util.Map;
45+
import java.util.Optional;
46+
import org.junit.Test;
47+
import org.junit.runner.RunWith;
48+
import org.junit.runners.JUnit4;
49+
50+
@RunWith(JUnit4.class)
51+
public final class EndInvocationActionTest {
52+
53+
private static class EndInvocationTool extends BaseTool {
54+
public EndInvocationTool() {
55+
super("end_invocation", "Ends the current invocation.");
56+
}
57+
58+
@Override
59+
public Optional<FunctionDeclaration> declaration() {
60+
return Optional.of(
61+
FunctionDeclaration.builder()
62+
.name(name())
63+
.description(description())
64+
.parameters(Schema.builder().type("OBJECT").build()) // No parameters needed
65+
.build());
66+
}
67+
68+
@Override
69+
public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContext toolContext) {
70+
toolContext.setActions(toolContext.actions().toBuilder().endInvocation(true).build());
71+
return Single.just(ImmutableMap.of());
72+
}
73+
}
74+
75+
@Test
76+
public void endInvocationTool_stopsFlow() {
77+
Content endInvocationCallContent =
78+
Content.fromParts(Part.fromFunctionCall("end_invocation", ImmutableMap.of()));
79+
Content response1 = Content.fromParts(Part.fromText("response1"));
80+
Content response2 = Content.fromParts(Part.fromText("response2"));
81+
82+
var testLlm =
83+
createTestLlm(
84+
Flowable.just(createLlmResponse(endInvocationCallContent)),
85+
Flowable.just(createLlmResponse(response1)),
86+
Flowable.just(createLlmResponse(response2)));
87+
88+
LlmAgent rootAgent =
89+
createTestAgentBuilder(testLlm)
90+
.name("root_agent")
91+
.tools(ImmutableList.of(new EndInvocationTool()))
92+
.build();
93+
InvocationContext invocationContext = createInvocationContext(rootAgent);
94+
95+
Runner runner = getRunnerAndCreateSession(rootAgent, invocationContext.session());
96+
97+
List<Event> actualEvents = new ArrayList<>();
98+
runRunner(runner, invocationContext, actualEvents);
99+
100+
assertThat(simplifyEvents(actualEvents))
101+
.containsExactly(
102+
"root_agent: FunctionCall(name=end_invocation, args={})",
103+
"root_agent: FunctionResponse(name=end_invocation, response={})")
104+
.inOrder();
105+
}
106+
107+
private Runner getRunnerAndCreateSession(LlmAgent agent, Session session) {
108+
Runner runner = new InMemoryRunner(agent, session.appName());
109+
110+
var unused =
111+
runner
112+
.sessionService()
113+
.createSession(session.appName(), session.userId(), session.state(), session.id())
114+
.blockingGet();
115+
116+
return runner;
117+
}
118+
119+
private void runRunner(
120+
Runner runner, InvocationContext invocationContext, List<Event> actualEvents) {
121+
Session session = invocationContext.session();
122+
runner
123+
.runAsync(session.userId(), session.id(), invocationContext.userContent().orElse(null))
124+
.blockingForEach(actualEvents::add);
125+
}
126+
}

0 commit comments

Comments
 (0)