Skip to content

Commit 2de03a8

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: adding resume / event management primitives
This is a step towards implementing pause/resume/rewind. This change introduce several features related to resumability and event management within the Google ADK core. Here's a summary of the changes: 1. **`InvocationContext.java`**: * A new public method `resumabilityConfig()` is added to provide access to the invocation's `ResumabilityConfig`. * A method `populateAgentStates(ImmutableList<Event> events)` is introduced to initialize or update the `agentStates` and `endOfAgents` maps within the `InvocationContext` by processing events associated with the current invocation ID. 2. **`EventActions.java`**: * The `EventActions` class now extends `JsonBaseModel`. * A new field `deletedArtifactIds` (a `Set<String>`) is added to track artifacts that should be deleted. This field is included in JSON serialization/deserialization, equality checks, and the `EventActions.Builder`'s merge logic. 3. **`Event.java`**: * The `finalResponse()` logic is updated. Previously, an event with `longRunningToolIds` was always considered a final response. This check has been removed, meaning the presence of `longRunningToolIds` alone no longer makes an event a `finalResponse`. PiperOrigin-RevId: 866611243
1 parent ed736cd commit 2de03a8

6 files changed

Lines changed: 198 additions & 5 deletions

File tree

core/src/main/java/com/google/adk/agents/InvocationContext.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
import com.google.errorprone.annotations.InlineMe;
3232
import com.google.genai.types.Content;
3333
import com.google.genai.types.FunctionCall;
34+
import java.util.List;
3435
import java.util.Map;
3536
import java.util.Objects;
3637
import java.util.Optional;
@@ -369,6 +370,31 @@ public boolean isResumable() {
369370
return resumabilityConfig.isResumable();
370371
}
371372

373+
/** Returns ResumabilityConfig for this invocation. */
374+
public ResumabilityConfig resumabilityConfig() {
375+
return resumabilityConfig;
376+
}
377+
378+
/**
379+
* Populates agentStates and endOfAgents maps by reading session events for this invocation id.
380+
*/
381+
public void populateAgentStates(List<Event> events) {
382+
events.stream()
383+
.filter(event -> invocationId().equals(event.invocationId()))
384+
.forEach(
385+
event -> {
386+
if (event.actions() != null) {
387+
if (event.actions().agentState() != null
388+
&& !event.actions().agentState().isEmpty()) {
389+
agentStates.putAll(event.actions().agentState());
390+
}
391+
if (event.actions().endOfAgent()) {
392+
endOfAgents.put(event.author(), true);
393+
}
394+
}
395+
});
396+
}
397+
372398
/** Returns the events compaction configuration for the current agent run. */
373399
public Optional<EventsCompactionConfig> eventsCompactionConfig() {
374400
return Optional.ofNullable(eventsCompactionConfig);

core/src/main/java/com/google/adk/events/Event.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,8 +294,7 @@ public final boolean hasTrailingCodeExecutionResult() {
294294
/** Returns true if this is a final response. */
295295
@JsonIgnore
296296
public final boolean finalResponse() {
297-
if (actions().skipSummarization().orElse(false)
298-
|| (longRunningToolIds().isPresent() && !longRunningToolIds().get().isEmpty())) {
297+
if (actions().skipSummarization().orElse(false)) {
299298
return true;
300299
}
301300
return functionCalls().isEmpty()

core/src/main/java/com/google/adk/events/EventActions.java

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,28 @@
1818
import com.fasterxml.jackson.annotation.JsonInclude;
1919
import com.fasterxml.jackson.annotation.JsonProperty;
2020
import com.fasterxml.jackson.databind.annotation.JsonDeserialize;
21+
import com.google.adk.JsonBaseModel;
2122
import com.google.adk.agents.BaseAgentState;
2223
import com.google.adk.sessions.State;
2324
import com.google.errorprone.annotations.CanIgnoreReturnValue;
2425
import com.google.genai.types.Part;
26+
import java.util.HashSet;
2527
import java.util.Objects;
2628
import java.util.Optional;
29+
import java.util.Set;
2730
import java.util.concurrent.ConcurrentHashMap;
2831
import java.util.concurrent.ConcurrentMap;
2932
import javax.annotation.Nullable;
3033

3134
/** Represents the actions attached to an event. */
3235
// TODO - b/414081262 make json wire camelCase
3336
@JsonDeserialize(builder = EventActions.Builder.class)
34-
public class EventActions {
37+
public class EventActions extends JsonBaseModel {
3538

3639
private Optional<Boolean> skipSummarization;
3740
private ConcurrentMap<String, Object> stateDelta;
3841
private ConcurrentMap<String, Part> artifactDelta;
42+
private Set<String> deletedArtifactIds;
3943
private Optional<String> transferToAgent;
4044
private Optional<Boolean> escalate;
4145
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
@@ -51,6 +55,7 @@ public EventActions() {
5155
this.skipSummarization = Optional.empty();
5256
this.stateDelta = new ConcurrentHashMap<>();
5357
this.artifactDelta = new ConcurrentHashMap<>();
58+
this.deletedArtifactIds = new HashSet<>();
5459
this.transferToAgent = Optional.empty();
5560
this.escalate = Optional.empty();
5661
this.requestedAuthConfigs = new ConcurrentHashMap<>();
@@ -66,6 +71,7 @@ private EventActions(Builder builder) {
6671
this.skipSummarization = builder.skipSummarization;
6772
this.stateDelta = builder.stateDelta;
6873
this.artifactDelta = builder.artifactDelta;
74+
this.deletedArtifactIds = builder.deletedArtifactIds;
6975
this.transferToAgent = builder.transferToAgent;
7076
this.escalate = builder.escalate;
7177
this.requestedAuthConfigs = builder.requestedAuthConfigs;
@@ -122,6 +128,16 @@ public void setArtifactDelta(ConcurrentMap<String, Part> artifactDelta) {
122128
this.artifactDelta = artifactDelta;
123129
}
124130

131+
@JsonProperty("deletedArtifactIds")
132+
@JsonInclude(JsonInclude.Include.NON_EMPTY)
133+
public Set<String> deletedArtifactIds() {
134+
return deletedArtifactIds;
135+
}
136+
137+
public void setDeletedArtifactIds(Set<String> deletedArtifactIds) {
138+
this.deletedArtifactIds = deletedArtifactIds;
139+
}
140+
125141
@JsonProperty("transferToAgent")
126142
public Optional<String> transferToAgent() {
127143
return transferToAgent;
@@ -238,6 +254,7 @@ public boolean equals(Object o) {
238254
return Objects.equals(skipSummarization, that.skipSummarization)
239255
&& Objects.equals(stateDelta, that.stateDelta)
240256
&& Objects.equals(artifactDelta, that.artifactDelta)
257+
&& Objects.equals(deletedArtifactIds, that.deletedArtifactIds)
241258
&& Objects.equals(transferToAgent, that.transferToAgent)
242259
&& Objects.equals(escalate, that.escalate)
243260
&& Objects.equals(requestedAuthConfigs, that.requestedAuthConfigs)
@@ -255,6 +272,7 @@ public int hashCode() {
255272
skipSummarization,
256273
stateDelta,
257274
artifactDelta,
275+
deletedArtifactIds,
258276
transferToAgent,
259277
escalate,
260278
requestedAuthConfigs,
@@ -271,6 +289,7 @@ public static class Builder {
271289
private Optional<Boolean> skipSummarization;
272290
private ConcurrentMap<String, Object> stateDelta;
273291
private ConcurrentMap<String, Part> artifactDelta;
292+
private Set<String> deletedArtifactIds;
274293
private Optional<String> transferToAgent;
275294
private Optional<Boolean> escalate;
276295
private ConcurrentMap<String, ConcurrentMap<String, Object>> requestedAuthConfigs;
@@ -285,6 +304,7 @@ public Builder() {
285304
this.skipSummarization = Optional.empty();
286305
this.stateDelta = new ConcurrentHashMap<>();
287306
this.artifactDelta = new ConcurrentHashMap<>();
307+
this.deletedArtifactIds = new HashSet<>();
288308
this.transferToAgent = Optional.empty();
289309
this.escalate = Optional.empty();
290310
this.requestedAuthConfigs = new ConcurrentHashMap<>();
@@ -299,6 +319,7 @@ private Builder(EventActions eventActions) {
299319
this.skipSummarization = eventActions.skipSummarization();
300320
this.stateDelta = new ConcurrentHashMap<>(eventActions.stateDelta());
301321
this.artifactDelta = new ConcurrentHashMap<>(eventActions.artifactDelta());
322+
this.deletedArtifactIds = new HashSet<>(eventActions.deletedArtifactIds());
302323
this.transferToAgent = eventActions.transferToAgent();
303324
this.escalate = eventActions.escalate();
304325
this.requestedAuthConfigs = new ConcurrentHashMap<>(eventActions.requestedAuthConfigs());
@@ -332,6 +353,13 @@ public Builder artifactDelta(ConcurrentMap<String, Part> value) {
332353
return this;
333354
}
334355

356+
@CanIgnoreReturnValue
357+
@JsonProperty("deletedArtifactIds")
358+
public Builder deletedArtifactIds(Set<String> value) {
359+
this.deletedArtifactIds = value;
360+
return this;
361+
}
362+
335363
@CanIgnoreReturnValue
336364
@JsonProperty("transferToAgent")
337365
public Builder transferToAgent(String agentId) {
@@ -401,6 +429,7 @@ public Builder merge(EventActions other) {
401429
other.skipSummarization().ifPresent(this::skipSummarization);
402430
this.stateDelta.putAll(other.stateDelta());
403431
this.artifactDelta.putAll(other.artifactDelta());
432+
this.deletedArtifactIds.addAll(other.deletedArtifactIds());
404433
other.transferToAgent().ifPresent(this::transferToAgent);
405434
other.escalate().ifPresent(this::escalate);
406435
this.requestedAuthConfigs.putAll(other.requestedAuthConfigs());

core/src/test/java/com/google/adk/agents/InvocationContextTest.java

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.google.adk.apps.ResumabilityConfig;
2323
import com.google.adk.artifacts.BaseArtifactService;
2424
import com.google.adk.events.Event;
25+
import com.google.adk.events.EventActions;
2526
import com.google.adk.memory.BaseMemoryService;
2627
import com.google.adk.models.LlmCallsLimitExceededException;
2728
import com.google.adk.plugins.PluginManager;
@@ -150,7 +151,7 @@ public void testBuildWithLiveRequestQueue() {
150151
}
151152

152153
@Test
153-
public void testCopyOf() {
154+
public void testToBuilder() {
154155
InvocationContext originalContext =
155156
InvocationContext.builder()
156157
.sessionService(mockSessionService)
@@ -933,4 +934,56 @@ public void testDeprecatedConstructor_11params() {
933934
assertThat(context.runConfig()).isEqualTo(runConfig);
934935
assertThat(context.endInvocation()).isTrue();
935936
}
937+
938+
@Test
939+
public void populateAgentStates_populatesAgentStatesAndEndOfAgents() {
940+
InvocationContext context =
941+
InvocationContext.builder()
942+
.sessionService(mockSessionService)
943+
.artifactService(mockArtifactService)
944+
.agent(mockAgent)
945+
.session(session)
946+
.invocationId(testInvocationId)
947+
.build();
948+
949+
BaseAgentState agent1State = mock(BaseAgentState.class);
950+
ConcurrentHashMap<String, BaseAgentState> agent1StateMap = new ConcurrentHashMap<>();
951+
agent1StateMap.put("agent1", agent1State);
952+
Event event1 =
953+
Event.builder()
954+
.invocationId(testInvocationId)
955+
.author("agent1")
956+
.actions(EventActions.builder().agentState(agent1StateMap).endOfAgent(true).build())
957+
.build();
958+
Event event2 =
959+
Event.builder()
960+
.invocationId("other-invocation-id")
961+
.author("agent2")
962+
.actions(EventActions.builder().endOfAgent(true).build())
963+
.build();
964+
Event event3 =
965+
Event.builder()
966+
.invocationId(testInvocationId)
967+
.author("agent3")
968+
.actions(EventActions.builder().endOfAgent(false).build())
969+
.build();
970+
BaseAgentState agent4State = mock(BaseAgentState.class);
971+
ConcurrentHashMap<String, BaseAgentState> agent4StateMap = new ConcurrentHashMap<>();
972+
agent4StateMap.put("agent4", agent4State);
973+
Event event4 =
974+
Event.builder()
975+
.invocationId(testInvocationId)
976+
.author("agent4")
977+
.actions(EventActions.builder().agentState(agent4StateMap).endOfAgent(false).build())
978+
.build();
979+
Event event5 = Event.builder().invocationId(testInvocationId).author("agent5").build();
980+
981+
context.populateAgentStates(ImmutableList.of(event1, event2, event3, event4, event5));
982+
983+
assertThat(context.agentStates()).hasSize(2);
984+
assertThat(context.agentStates()).containsEntry("agent1", agent1State);
985+
assertThat(context.agentStates()).containsEntry("agent4", agent4State);
986+
assertThat(context.endOfAgents()).hasSize(1);
987+
assertThat(context.endOfAgents()).containsEntry("agent1", true);
988+
}
936989
}

core/src/test/java/com/google/adk/events/EventActionsTest.java

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import com.google.adk.sessions.State;
2222
import com.google.common.collect.ImmutableMap;
23+
import com.google.common.collect.ImmutableSet;
2324
import com.google.genai.types.Content;
2425
import com.google.genai.types.Part;
2526
import java.util.concurrent.ConcurrentHashMap;
@@ -44,7 +45,11 @@ public final class EventActionsTest {
4445
@Test
4546
public void toBuilder_createsBuilderWithSameValues() {
4647
EventActions eventActionsWithSkipSummarization =
47-
EventActions.builder().skipSummarization(true).compaction(COMPACTION).build();
48+
EventActions.builder()
49+
.skipSummarization(true)
50+
.compaction(COMPACTION)
51+
.deletedArtifactIds(ImmutableSet.of("d1"))
52+
.build();
4853

4954
EventActions eventActionsAfterRebuild = eventActionsWithSkipSummarization.toBuilder().build();
5055

@@ -59,6 +64,7 @@ public void merge_mergesAllFields() {
5964
.skipSummarization(true)
6065
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1")))
6166
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact1", PART)))
67+
.deletedArtifactIds(ImmutableSet.of("deleted1"))
6268
.requestedAuthConfigs(
6369
new ConcurrentHashMap<>(
6470
ImmutableMap.of("config1", new ConcurrentHashMap<>(ImmutableMap.of("k", "v")))))
@@ -70,6 +76,7 @@ public void merge_mergesAllFields() {
7076
EventActions.builder()
7177
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("key2", "value2")))
7278
.artifactDelta(new ConcurrentHashMap<>(ImmutableMap.of("artifact2", PART)))
79+
.deletedArtifactIds(ImmutableSet.of("deleted2"))
7380
.transferToAgent("agentId")
7481
.escalate(true)
7582
.requestedAuthConfigs(
@@ -85,6 +92,7 @@ public void merge_mergesAllFields() {
8592
assertThat(merged.skipSummarization()).hasValue(true);
8693
assertThat(merged.stateDelta()).containsExactly("key1", "value1", "key2", "value2");
8794
assertThat(merged.artifactDelta()).containsExactly("artifact1", PART, "artifact2", PART);
95+
assertThat(merged.deletedArtifactIds()).containsExactly("deleted1", "deleted2");
8896
assertThat(merged.transferToAgent()).hasValue("agentId");
8997
assertThat(merged.escalate()).hasValue(true);
9098
assertThat(merged.requestedAuthConfigs())
@@ -107,4 +115,19 @@ public void removeStateByKey_marksKeyAsRemoved() {
107115

108116
assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED);
109117
}
118+
119+
@Test
120+
public void jsonSerialization_works() throws Exception {
121+
EventActions eventActions =
122+
EventActions.builder()
123+
.deletedArtifactIds(ImmutableSet.of("d1", "d2"))
124+
.stateDelta(new ConcurrentHashMap<>(ImmutableMap.of("k", "v")))
125+
.build();
126+
127+
String json = eventActions.toJson();
128+
EventActions deserialized = EventActions.fromJsonString(json, EventActions.class);
129+
130+
assertThat(deserialized).isEqualTo(eventActions);
131+
assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2");
132+
}
110133
}

core/src/test/java/com/google/adk/events/EventTest.java

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,4 +191,67 @@ public void event_json_serialization_works() throws Exception {
191191
Event deserializedEvent = Event.fromJson(json);
192192
assertThat(deserializedEvent).isEqualTo(EVENT);
193193
}
194+
195+
@Test
196+
public void finalResponse_returnsTrueIfNoToolCalls() {
197+
Event event =
198+
Event.builder()
199+
.id("e1")
200+
.invocationId("i1")
201+
.author("agent")
202+
.content(Content.fromParts(Part.fromText("hello")))
203+
.build();
204+
assertThat(event.finalResponse()).isTrue();
205+
}
206+
207+
@Test
208+
public void finalResponse_returnsFalseIfToolCalls() {
209+
Event event =
210+
Event.builder()
211+
.id("e1")
212+
.invocationId("i1")
213+
.author("agent")
214+
.content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v"))))
215+
.build();
216+
assertThat(event.finalResponse()).isFalse();
217+
}
218+
219+
@Test
220+
public void finalResponse_isTrueForEventWithTextContent() {
221+
Event event =
222+
Event.builder()
223+
.id("e1")
224+
.invocationId("i1")
225+
.author("agent")
226+
.content(Content.fromParts(Part.fromText("hello")))
227+
.longRunningToolIds(ImmutableSet.of("tool1"))
228+
.build();
229+
assertThat(event.finalResponse()).isTrue();
230+
}
231+
232+
@Test
233+
public void finalResponse_isFalseForEventWithToolCallAndLongRunningToolId() {
234+
Event event =
235+
Event.builder()
236+
.id("e1")
237+
.invocationId("i1")
238+
.author("agent")
239+
.content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v"))))
240+
.longRunningToolIds(ImmutableSet.of("tool1"))
241+
.build();
242+
assertThat(event.finalResponse()).isFalse();
243+
}
244+
245+
@Test
246+
public void finalResponse_returnsTrueIfSkipSummarization() {
247+
Event event =
248+
Event.builder()
249+
.id("e1")
250+
.invocationId("i1")
251+
.author("agent")
252+
.content(Content.fromParts(Part.fromFunctionCall("tool", ImmutableMap.of("k", "v"))))
253+
.actions(EventActions.builder().skipSummarization(true).build())
254+
.build();
255+
assertThat(event.finalResponse()).isTrue();
256+
}
194257
}

0 commit comments

Comments
 (0)