Skip to content

Commit 8190ed3

Browse files
google-genai-botcopybara-github
authored andcommitted
fix: Fixing a problem with serializing sessions that broke integration with Vertex AI Session Service
PiperOrigin-RevId: 866447373
1 parent 3dab101 commit 8190ed3

2 files changed

Lines changed: 258 additions & 81 deletions

File tree

core/src/main/java/com/google/adk/sessions/SessionJsonConverter.java

Lines changed: 98 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import com.fasterxml.jackson.core.JsonProcessingException;
2020
import com.fasterxml.jackson.databind.ObjectMapper;
2121
import com.google.adk.JsonBaseModel;
22-
import com.google.adk.agents.BaseAgentState;
2322
import com.google.adk.events.Event;
2423
import com.google.adk.events.EventActions;
2524
import com.google.adk.events.ToolConfirmation;
@@ -28,10 +27,12 @@
2827
import com.google.common.collect.Iterables;
2928
import com.google.genai.types.Content;
3029
import com.google.genai.types.FinishReason;
30+
import com.google.genai.types.GenerateContentResponseUsageMetadata;
3131
import com.google.genai.types.GroundingMetadata;
3232
import com.google.genai.types.Part;
3333
import java.io.UncheckedIOException;
3434
import java.time.Instant;
35+
import java.util.Collection;
3536
import java.util.HashMap;
3637
import java.util.HashSet;
3738
import java.util.List;
@@ -57,63 +58,64 @@ private SessionJsonConverter() {}
5758
* @throws UncheckedIOException if serialization fails.
5859
*/
5960
static String convertEventToJson(Event event) {
60-
Map<String, Object> metadataJson = new HashMap<>();
61-
metadataJson.put("partial", event.partial());
62-
metadataJson.put("turnComplete", event.turnComplete());
63-
metadataJson.put("interrupted", event.interrupted());
64-
metadataJson.put("branch", event.branch().orElse(null));
65-
metadataJson.put(
66-
"long_running_tool_ids",
67-
event.longRunningToolIds() != null ? event.longRunningToolIds().orElse(null) : null);
68-
if (event.groundingMetadata() != null) {
69-
metadataJson.put("grounding_metadata", event.groundingMetadata());
70-
}
61+
return convertEventToJson(event, false);
62+
}
7163

64+
/**
65+
* Converts an {@link Event} to its JSON string representation for API transmission.
66+
*
67+
* @param useIsoString if true, use ISO-8601 string for timestamp; otherwise use object format.
68+
* @return JSON string of the event.
69+
* @throws UncheckedIOException if serialization fails.
70+
*/
71+
static String convertEventToJson(Event event, boolean useIsoString) {
72+
Map<String, Object> metadataJson = new HashMap<>();
73+
event.partial().ifPresent(v -> metadataJson.put("partial", v));
74+
event.turnComplete().ifPresent(v -> metadataJson.put("turnComplete", v));
75+
event.interrupted().ifPresent(v -> metadataJson.put("interrupted", v));
76+
event.branch().ifPresent(v -> metadataJson.put("branch", v));
77+
putIfNotEmpty(metadataJson, "longRunningToolIds", event.longRunningToolIds());
78+
event.groundingMetadata().ifPresent(v -> metadataJson.put("groundingMetadata", v));
79+
event.usageMetadata().ifPresent(v -> metadataJson.put("usageMetadata", v));
7280
Map<String, Object> eventJson = new HashMap<>();
7381
eventJson.put("author", event.author());
7482
eventJson.put("invocationId", event.invocationId());
75-
eventJson.put(
76-
"timestamp",
77-
new HashMap<>(
78-
ImmutableMap.of(
79-
"seconds",
80-
event.timestamp() / 1000,
81-
"nanos",
82-
(event.timestamp() % 1000) * 1000000)));
83-
if (event.errorCode().isPresent()) {
84-
eventJson.put("errorCode", event.errorCode());
85-
}
86-
if (event.errorMessage().isPresent()) {
87-
eventJson.put("errorMessage", event.errorMessage());
83+
if (useIsoString) {
84+
eventJson.put("timestamp", Instant.ofEpochMilli(event.timestamp()).toString());
85+
} else {
86+
eventJson.put(
87+
"timestamp",
88+
new HashMap<>(
89+
ImmutableMap.of(
90+
"seconds",
91+
event.timestamp() / 1000,
92+
"nanos",
93+
(event.timestamp() % 1000) * 1000000)));
8894
}
95+
event.errorCode().ifPresent(errorCode -> eventJson.put("errorCode", errorCode));
96+
event.errorMessage().ifPresent(errorMessage -> eventJson.put("errorMessage", errorMessage));
8997
eventJson.put("eventMetadata", metadataJson);
9098

9199
if (event.actions() != null) {
92100
Map<String, Object> actionsJson = new HashMap<>();
93-
actionsJson.put("skipSummarization", event.actions().skipSummarization());
94-
actionsJson.put("stateDelta", stateDeltaToJson(event.actions().stateDelta()));
95-
actionsJson.put("artifactDelta", event.actions().artifactDelta());
96-
actionsJson.put("transferAgent", event.actions().transferToAgent());
97-
actionsJson.put("escalate", event.actions().escalate());
98-
actionsJson.put("endInvocation", event.actions().endInvocation());
99-
actionsJson.put("requestedAuthConfigs", event.actions().requestedAuthConfigs());
100-
actionsJson.put("requestedToolConfirmations", event.actions().requestedToolConfirmations());
101-
actionsJson.put("compaction", event.actions().compaction());
102-
if (!event.actions().agentState().isEmpty()) {
103-
actionsJson.put("agentState", event.actions().agentState());
104-
}
105-
actionsJson.put("rewindBeforeInvocationId", event.actions().rewindBeforeInvocationId());
101+
EventActions actions = event.actions();
102+
actions.skipSummarization().ifPresent(v -> actionsJson.put("skipSummarization", v));
103+
actionsJson.put("stateDelta", stateDeltaToJson(actions.stateDelta()));
104+
putIfNotEmpty(actionsJson, "artifactDelta", actions.artifactDelta());
105+
actions
106+
.transferToAgent()
107+
.ifPresent(
108+
v -> {
109+
actionsJson.put("transferAgent", v);
110+
});
111+
actions.escalate().ifPresent(v -> actionsJson.put("escalate", v));
112+
actions.endInvocation().ifPresent(v -> actionsJson.put("endOfAgent", v));
113+
putIfNotEmpty(actionsJson, "requestedAuthConfigs", actions.requestedAuthConfigs());
114+
putIfNotEmpty(
115+
actionsJson, "requestedToolConfirmations", actions.requestedToolConfirmations());
106116
eventJson.put("actions", actionsJson);
107117
}
108-
if (event.content().isPresent()) {
109-
eventJson.put("content", SessionUtils.encodeContent(event.content().get()));
110-
}
111-
if (event.errorCode().isPresent()) {
112-
eventJson.put("errorCode", event.errorCode().get());
113-
}
114-
if (event.errorMessage().isPresent()) {
115-
eventJson.put("errorMessage", event.errorMessage().get());
116-
}
118+
event.content().ifPresent(c -> eventJson.put("content", SessionUtils.encodeContent(c)));
117119
try {
118120
return objectMapper.writeValueAsString(eventJson);
119121
} catch (JsonProcessingException e) {
@@ -156,19 +158,31 @@ private static Content convertMapToContent(Object rawContentValue) {
156158
@SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema.
157159
static Event fromApiEvent(Map<String, Object> apiEvent) {
158160
EventActions.Builder eventActionsBuilder = EventActions.builder();
159-
if (apiEvent.get("actions") != null) {
160-
Map<String, Object> actionsMap = (Map<String, Object>) apiEvent.get("actions");
161-
if (actionsMap.get("skipSummarization") != null) {
162-
eventActionsBuilder.skipSummarization((Boolean) actionsMap.get("skipSummarization"));
161+
Map<String, Object> actionsMap = (Map<String, Object>) apiEvent.get("actions");
162+
if (actionsMap != null) {
163+
Boolean skipSummarization = (Boolean) actionsMap.get("skipSummarization");
164+
if (skipSummarization != null) {
165+
eventActionsBuilder.skipSummarization(skipSummarization);
163166
}
164167
eventActionsBuilder.stateDelta(stateDeltaFromJson(actionsMap.get("stateDelta")));
168+
Object artifactDelta = actionsMap.get("artifactDelta");
165169
eventActionsBuilder.artifactDelta(
166-
actionsMap.get("artifactDelta") != null
167-
? convertToArtifactDeltaMap(actionsMap.get("artifactDelta"))
170+
artifactDelta != null
171+
? convertToArtifactDeltaMap(artifactDelta)
168172
: new ConcurrentHashMap<>());
169-
eventActionsBuilder.transferToAgent((String) actionsMap.get("transferAgent"));
170-
if (actionsMap.get("escalate") != null) {
171-
eventActionsBuilder.escalate((Boolean) actionsMap.get("escalate"));
173+
String transferAgent = (String) actionsMap.get("transferAgent");
174+
if (transferAgent == null) {
175+
transferAgent = (String) actionsMap.get("transferToAgent");
176+
}
177+
eventActionsBuilder.transferToAgent(transferAgent);
178+
Boolean escalate = (Boolean) actionsMap.get("escalate");
179+
if (escalate != null) {
180+
eventActionsBuilder.escalate(escalate);
181+
}
182+
Boolean endOfAgent = (Boolean) actionsMap.get("endOfAgent");
183+
if (endOfAgent != null) {
184+
eventActionsBuilder.endOfAgent(endOfAgent);
185+
eventActionsBuilder.endInvocation(endOfAgent);
172186
}
173187
eventActionsBuilder.requestedAuthConfigs(
174188
Optional.ofNullable(actionsMap.get("requestedAuthConfigs"))
@@ -178,13 +192,6 @@ static Event fromApiEvent(Map<String, Object> apiEvent) {
178192
Optional.ofNullable(actionsMap.get("requestedToolConfirmations"))
179193
.map(SessionJsonConverter::asConcurrentMapOfToolConfirmations)
180194
.orElse(new ConcurrentHashMap<>()));
181-
if (actionsMap.get("agentState") != null) {
182-
eventActionsBuilder.agentState(asConcurrentMapOfAgentState(actionsMap.get("agentState")));
183-
}
184-
if (actionsMap.get("rewindBeforeInvocationId") != null) {
185-
eventActionsBuilder.rewindBeforeInvocationId(
186-
(String) actionsMap.get("rewindBeforeInvocationId"));
187-
}
188195
}
189196

190197
Event event =
@@ -204,11 +211,9 @@ static Event fromApiEvent(Map<String, Object> apiEvent) {
204211
.map(value -> new FinishReason((String) value)))
205212
.errorMessage(
206213
Optional.ofNullable(apiEvent.get("errorMessage")).map(value -> (String) value))
207-
.branch(Optional.ofNullable(apiEvent.get("branch")).map(value -> (String) value))
208214
.build();
209-
// TODO(b/414263934): Add Event branch and grounding metadata for python parity.
210-
if (apiEvent.get("eventMetadata") != null) {
211-
Map<String, Object> eventMetadata = (Map<String, Object>) apiEvent.get("eventMetadata");
215+
Map<String, Object> eventMetadata = (Map<String, Object>) apiEvent.get("eventMetadata");
216+
if (eventMetadata != null) {
212217
List<String> longRunningToolIdsList = (List<String>) eventMetadata.get("longRunningToolIds");
213218

214219
GroundingMetadata groundingMetadata = null;
@@ -217,6 +222,12 @@ static Event fromApiEvent(Map<String, Object> apiEvent) {
217222
groundingMetadata =
218223
objectMapper.convertValue(rawGroundingMetadata, GroundingMetadata.class);
219224
}
225+
GenerateContentResponseUsageMetadata usageMetadata = null;
226+
Object rawUsageMetadata = eventMetadata.get("usageMetadata");
227+
if (rawUsageMetadata != null) {
228+
usageMetadata =
229+
objectMapper.convertValue(rawUsageMetadata, GenerateContentResponseUsageMetadata.class);
230+
}
220231

221232
event =
222233
event.toBuilder()
@@ -227,6 +238,7 @@ static Event fromApiEvent(Map<String, Object> apiEvent) {
227238
Optional.ofNullable((Boolean) eventMetadata.get("interrupted")).orElse(false))
228239
.branch(Optional.ofNullable((String) eventMetadata.get("branch")))
229240
.groundingMetadata(groundingMetadata)
241+
.usageMetadata(usageMetadata)
230242
.longRunningToolIds(
231243
longRunningToolIdsList != null ? new HashSet<>(longRunningToolIdsList) : null)
232244
.build();
@@ -285,7 +297,7 @@ private static Instant convertToInstant(Object timestampObj) {
285297
* @param artifactDeltaObj The raw object from which to parse the artifact delta.
286298
* @return A {@link ConcurrentMap} representing the artifact delta.
287299
*/
288-
@SuppressWarnings("unchecked") // Safe because we check instanceof Map before casting.
300+
@SuppressWarnings("unchecked")
289301
private static ConcurrentMap<String, Part> convertToArtifactDeltaMap(Object artifactDeltaObj) {
290302
if (!(artifactDeltaObj instanceof Map)) {
291303
return new ConcurrentHashMap<>();
@@ -319,19 +331,6 @@ private static ConcurrentMap<String, Part> convertToArtifactDeltaMap(Object arti
319331
ConcurrentHashMap::putAll);
320332
}
321333

322-
@SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema.
323-
private static ConcurrentMap<String, BaseAgentState> asConcurrentMapOfAgentState(Object value) {
324-
return ((Map<String, Object>) value)
325-
.entrySet().stream()
326-
.collect(
327-
ConcurrentHashMap::new,
328-
(map, entry) ->
329-
map.put(
330-
entry.getKey(),
331-
objectMapper.convertValue(entry.getValue(), BaseAgentState.class)),
332-
ConcurrentHashMap::putAll);
333-
}
334-
335334
@SuppressWarnings("unchecked") // Parsing raw Map from JSON following a known schema.
336335
private static ConcurrentMap<String, ToolConfirmation> asConcurrentMapOfToolConfirmations(
337336
Object value) {
@@ -345,4 +344,22 @@ private static ConcurrentMap<String, ToolConfirmation> asConcurrentMapOfToolConf
345344
objectMapper.convertValue(entry.getValue(), ToolConfirmation.class)),
346345
ConcurrentHashMap::putAll);
347346
}
347+
348+
private static void putIfNotEmpty(Map<String, Object> map, String key, Map<?, ?> values) {
349+
if (values != null && !values.isEmpty()) {
350+
map.put(key, values);
351+
}
352+
}
353+
354+
private static void putIfNotEmpty(
355+
Map<String, Object> map, String key, Optional<? extends Collection<?>> values) {
356+
values.ifPresent(v -> putIfNotEmpty(map, key, v));
357+
}
358+
359+
private static void putIfNotEmpty(
360+
Map<String, Object> map, String key, @Nullable Collection<?> values) {
361+
if (values != null && !values.isEmpty()) {
362+
map.put(key, values);
363+
}
364+
}
348365
}

0 commit comments

Comments
 (0)