Skip to content

Commit 1cf889d

Browse files
shukladivyanshcopybara-github
authored andcommitted
ADK changes
PiperOrigin-RevId: 763947585
1 parent f213016 commit 1cf889d

5 files changed

Lines changed: 239 additions & 25 deletions

File tree

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
import java.util.Optional;
5454
import java.util.concurrent.Executor;
5555
import java.util.stream.Collectors;
56+
import javax.annotation.Nullable;
5657
import org.slf4j.Logger;
5758
import org.slf4j.LoggerFactory;
5859

@@ -83,8 +84,8 @@ public enum IncludeContents {
8384
private final boolean disallowTransferToPeers;
8485
private final Optional<List<BeforeModelCallback>> beforeModelCallback;
8586
private final Optional<List<AfterModelCallback>> afterModelCallback;
86-
private final Optional<BeforeToolCallback> beforeToolCallback;
87-
private final Optional<AfterToolCallback> afterToolCallback;
87+
private final Optional<List<BeforeToolCallback>> beforeToolCallback;
88+
private final Optional<List<AfterToolCallback>> afterToolCallback;
8889
private final Optional<Schema> inputSchema;
8990
private final Optional<Schema> outputSchema;
9091
private final Optional<Executor> executor;
@@ -152,8 +153,8 @@ public static class Builder {
152153
private List<AfterModelCallback> afterModelCallback;
153154
private List<BeforeAgentCallback> beforeAgentCallback;
154155
private List<AfterAgentCallback> afterAgentCallback;
155-
private BeforeToolCallback beforeToolCallback;
156-
private AfterToolCallback afterToolCallback;
156+
private ImmutableList<BeforeToolCallback> beforeToolCallback;
157+
private ImmutableList<AfterToolCallback> afterToolCallback;
157158
private Schema inputSchema;
158159
private Schema outputSchema;
159160
private Executor executor;
@@ -411,32 +412,93 @@ public Builder afterAgentCallbackSync(AfterAgentCallbackSync afterAgentCallbackS
411412

412413
@CanIgnoreReturnValue
413414
public Builder beforeToolCallback(BeforeToolCallback beforeToolCallback) {
414-
this.beforeToolCallback = beforeToolCallback;
415+
this.beforeToolCallback = ImmutableList.of(beforeToolCallback);
416+
return this;
417+
}
418+
419+
// TODO: b/416794047 - Use a unified interface for callback instead of using
420+
// Object.
421+
@CanIgnoreReturnValue
422+
public Builder beforeToolCallback(@Nullable List<Object> beforeToolCallbacks) {
423+
if (beforeToolCallbacks == null) {
424+
this.beforeToolCallback = null;
425+
} else if (beforeToolCallbacks.isEmpty()) {
426+
this.beforeToolCallback = ImmutableList.of();
427+
} else {
428+
ImmutableList.Builder<BeforeToolCallback> builder = ImmutableList.builder();
429+
for (Object callback : beforeToolCallbacks) {
430+
if (callback instanceof BeforeToolCallback beforeToolCallbackInstance) {
431+
builder.add(beforeToolCallbackInstance);
432+
} else if (callback instanceof BeforeToolCallbackSync beforeToolCallbackSyncInstance) {
433+
builder.add(
434+
(invocationContext, baseTool, input, toolContext) ->
435+
Maybe.fromOptional(
436+
beforeToolCallbackSyncInstance.call(
437+
invocationContext, baseTool, input, toolContext)));
438+
} else {
439+
logger.warn(
440+
"Invalid beforeToolCallback callback type: {}. Ignoring this callback.",
441+
callback.getClass().getName());
442+
}
443+
}
444+
this.beforeToolCallback = builder.build();
445+
}
415446
return this;
416447
}
417448

418449
@CanIgnoreReturnValue
419450
public Builder beforeToolCallbackSync(BeforeToolCallbackSync beforeToolCallbackSync) {
420451
this.beforeToolCallback =
421-
(invocationContext, baseTool, input, toolContext) ->
422-
Maybe.fromOptional(
423-
beforeToolCallbackSync.call(invocationContext, baseTool, input, toolContext));
452+
ImmutableList.of(
453+
(invocationContext, baseTool, input, toolContext) ->
454+
Maybe.fromOptional(
455+
beforeToolCallbackSync.call(
456+
invocationContext, baseTool, input, toolContext)));
424457
return this;
425458
}
426459

427460
@CanIgnoreReturnValue
428461
public Builder afterToolCallback(AfterToolCallback afterToolCallback) {
429-
this.afterToolCallback = afterToolCallback;
462+
this.afterToolCallback = ImmutableList.of(afterToolCallback);
463+
return this;
464+
}
465+
466+
@CanIgnoreReturnValue
467+
public Builder afterToolCallback(@Nullable List<Object> afterToolCallbacks) {
468+
if (afterToolCallbacks == null) {
469+
this.afterToolCallback = null;
470+
} else if (afterToolCallbacks.isEmpty()) {
471+
this.afterToolCallback = ImmutableList.of();
472+
} else {
473+
ImmutableList.Builder<AfterToolCallback> builder = ImmutableList.builder();
474+
for (Object callback : afterToolCallbacks) {
475+
if (callback instanceof AfterToolCallback afterToolCallbackInstance) {
476+
builder.add(afterToolCallbackInstance);
477+
} else if (callback instanceof AfterToolCallbackSync afterToolCallbackSyncInstance) {
478+
builder.add(
479+
(invocationContext, baseTool, input, toolContext, response) ->
480+
Maybe.fromOptional(
481+
afterToolCallbackSyncInstance.call(
482+
invocationContext, baseTool, input, toolContext, response)));
483+
} else {
484+
logger.warn(
485+
"Invalid afterToolCallback callback type: {}. Ignoring this callback.",
486+
callback.getClass().getName());
487+
}
488+
}
489+
this.afterToolCallback = builder.build();
490+
}
430491
return this;
431492
}
432493

433494
@CanIgnoreReturnValue
434495
public Builder afterToolCallbackSync(AfterToolCallbackSync afterToolCallbackSync) {
435496
this.afterToolCallback =
436-
(invocationContext, baseTool, input, toolContext, response) ->
437-
Maybe.fromOptional(
438-
afterToolCallbackSync.call(
439-
invocationContext, baseTool, input, toolContext, response));
497+
ImmutableList.of(
498+
(invocationContext, baseTool, input, toolContext, response) ->
499+
Maybe.fromOptional(
500+
afterToolCallbackSync.call(
501+
invocationContext, baseTool, input, toolContext, response)));
440502
return this;
441503
}
442504

@@ -606,11 +668,11 @@ public Optional<List<AfterModelCallback>> afterModelCallback() {
606668
return afterModelCallback;
607669
}
608670

609-
public Optional<BeforeToolCallback> beforeToolCallback() {
671+
public Optional<List<BeforeToolCallback>> beforeToolCallback() {
610672
return beforeToolCallback;
611673
}
612674

613-
public Optional<AfterToolCallback> afterToolCallback() {
675+
public Optional<List<AfterToolCallback>> afterToolCallback() {
614676
return afterToolCallback;
615677
}
616678

core/src/main/java/com/google/adk/flows/llmflows/Functions.java

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
package com.google.adk.flows.llmflows;
1818

1919
import com.google.adk.Telemetry;
20+
import com.google.adk.agents.Callbacks.AfterToolCallback;
21+
import com.google.adk.agents.Callbacks.BeforeToolCallback;
2022
import com.google.adk.agents.InvocationContext;
2123
import com.google.adk.agents.LlmAgent;
2224
import com.google.adk.events.Event;
@@ -31,6 +33,7 @@
3133
import io.opentelemetry.api.trace.Span;
3234
import io.opentelemetry.api.trace.Tracer;
3335
import io.opentelemetry.context.Scope;
36+
import io.reactivex.rxjava3.core.Flowable;
3437
import io.reactivex.rxjava3.core.Maybe;
3538
import java.util.ArrayList;
3639
import java.util.Collections;
@@ -242,10 +245,17 @@ private static Maybe<Map<String, Object>> maybeInvokeBeforeToolCall(
242245
ToolContext toolContext) {
243246
if (invocationContext.agent() instanceof LlmAgent) {
244247
LlmAgent agent = (LlmAgent) invocationContext.agent();
245-
return agent
246-
.beforeToolCallback()
247-
.map(callback -> callback.call(invocationContext, tool, functionArgs, toolContext))
248-
.orElse(Maybe.empty());
248+
249+
Optional<List<BeforeToolCallback>> callbacksOpt = agent.beforeToolCallback();
250+
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
251+
return Maybe.empty();
252+
}
253+
List<BeforeToolCallback> callbacks = callbacksOpt.get();
254+
255+
return Flowable.fromIterable(callbacks)
256+
.concatMapMaybe(
257+
callback -> callback.call(invocationContext, tool, functionArgs, toolContext))
258+
.firstElement();
249259
}
250260
return Maybe.empty();
251261
}
@@ -258,12 +268,17 @@ private static Maybe<Map<String, Object>> maybeInvokeAfterToolCall(
258268
Map<String, Object> functionResult) {
259269
if (invocationContext.agent() instanceof LlmAgent) {
260270
LlmAgent agent = (LlmAgent) invocationContext.agent();
261-
return agent
262-
.afterToolCallback()
263-
.map(
271+
Optional<List<AfterToolCallback>> callbacksOpt = agent.afterToolCallback();
272+
if (callbacksOpt.isEmpty() || callbacksOpt.get().isEmpty()) {
273+
return Maybe.empty();
274+
}
275+
List<AfterToolCallback> callbacks = callbacksOpt.get();
276+
277+
return Flowable.fromIterable(callbacks)
278+
.concatMapMaybe(
264279
callback ->
265280
callback.call(invocationContext, tool, functionArgs, toolContext, functionResult))
266-
.orElse(Maybe.empty());
281+
.firstElement();
267282
}
268283
return Maybe.empty();
269284
}

core/src/test/java/com/google/adk/agents/CallbacksTest.java

Lines changed: 133 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,6 @@ private static LlmResponse addPartToResponse(LlmResponse response, Part part) {
540540
.build();
541541
}
542542

543-
// Tool callback tests moved from FunctionsTest
544543
@Test
545544
public void handleFunctionCalls_withBeforeToolCallback_returnsBeforeToolCallbackResult() {
546545
ImmutableMap<String, Object> beforeToolCallbackResult =
@@ -944,4 +943,137 @@ public void handleFunctionCalls_withAfterToolCallbackSyncThatReturnsNull_returns
944943
.build())
945944
.build());
946945
}
946+
947+
@Test
948+
public void handleFunctionCalls_withChainedToolCallbacks_overridesResultAndPassesContext() {
949+
ImmutableMap<String, Object> originalToolInputArgs =
950+
ImmutableMap.of("input_key", "input_value");
951+
ImmutableMap<String, Object> stateAddedByBc2 =
952+
ImmutableMap.of("bc2_state_key", "bc2_state_value");
953+
ImmutableMap<String, Object> responseFromAc2 =
954+
ImmutableMap.of("ac2_response_key", "ac2_response_value");
955+
956+
Callbacks.BeforeToolCallbackSync bc1 =
957+
(invCtx, toolName, args, currentToolCtx) -> Optional.empty();
958+
959+
Callbacks.BeforeToolCallbackSync bc2 =
960+
(invCtx, toolName, args, currentToolCtx) -> {
961+
currentToolCtx.state().putAll(stateAddedByBc2);
962+
return Optional.empty();
963+
};
964+
965+
TestUtils.EchoTool echoTool = new TestUtils.EchoTool();
966+
967+
Callbacks.AfterToolCallbackSync ac1 =
968+
(invCtx, toolName, args, currentToolCtx, responseFromTool) -> Optional.empty();
969+
970+
Callbacks.AfterToolCallbackSync ac2 =
971+
(invCtx, toolName, args, currentToolCtx, responseFromTool) -> Optional.of(responseFromAc2);
972+
973+
InvocationContext invocationContext =
974+
createInvocationContext(
975+
createTestAgentBuilder(createTestLlm(LlmResponse.builder().build()))
976+
.beforeToolCallback(ImmutableList.<Object>of(bc1, bc2))
977+
.afterToolCallback(ImmutableList.<Object>of(ac1, ac2))
978+
.build());
979+
980+
Event eventWithFunctionCall =
981+
createEvent("event").toBuilder()
982+
.content(createFunctionCallContent("fc_id_minimal", "echo_tool", originalToolInputArgs))
983+
.build();
984+
985+
Event functionResponseEvent =
986+
Functions.handleFunctionCalls(
987+
invocationContext, eventWithFunctionCall, ImmutableMap.of("echo_tool", echoTool))
988+
.blockingGet();
989+
990+
assertThat(getFunctionResponse(functionResponseEvent)).isEqualTo(responseFromAc2);
991+
assertThat(invocationContext.session().state()).containsExactlyEntriesIn(stateAddedByBc2);
992+
}
993+
994+
@Test
995+
public void agentRunAsync_withToolCallbacks_inspectsArgsAndReturnsResponse() {
996+
TestUtils.EchoTool echoTool = new TestUtils.EchoTool();
997+
String toolName = echoTool.declaration().get().name().get();
998+
ImmutableMap<String, Object> functionArgs = ImmutableMap.of("message", "hello");
999+
1000+
Content llmFunctionCallContent =
1001+
Content.builder()
1002+
.role("model")
1003+
.parts(ImmutableList.of(Part.fromFunctionCall(toolName, functionArgs)))
1004+
.build();
1005+
Content llmTextContent =
1006+
Content.builder().role("model").parts(ImmutableList.of(Part.fromText("hi there"))).build();
1007+
TestLlm testLlm =
1008+
createTestLlm(createLlmResponse(llmFunctionCallContent), createLlmResponse(llmTextContent));
1009+
1010+
ImmutableMap<String, Object> responseFromAfterToolCallback =
1011+
ImmutableMap.of("final_wrapper", "wrapped_value_from_after_callback");
1012+
1013+
Callbacks.BeforeToolCallback beforeToolCb =
1014+
(invCtx, tName, args, toolCtx) -> {
1015+
assertThat(args).isEqualTo(functionArgs);
1016+
return Maybe.empty();
1017+
};
1018+
1019+
Callbacks.AfterToolCallback afterToolCb =
1020+
(invCtx, tName, args, toolCtx, toolResponse) -> {
1021+
assertThat(args).isEqualTo(functionArgs);
1022+
assertThat(toolResponse).isEqualTo(ImmutableMap.of("result", functionArgs));
1023+
return Maybe.just(responseFromAfterToolCallback);
1024+
};
1025+
1026+
LlmAgent agent =
1027+
createTestAgentBuilder(testLlm)
1028+
.tools(ImmutableList.of(echoTool))
1029+
.beforeToolCallback(beforeToolCb)
1030+
.afterToolCallback(afterToolCb)
1031+
.build();
1032+
1033+
InvocationContext invocationContext = createInvocationContext(agent);
1034+
1035+
List<Event> events = agent.runAsync(invocationContext).toList().blockingGet();
1036+
1037+
assertThat(testLlm.getRequests()).hasSize(2);
1038+
assertThat(events).hasSize(3);
1039+
1040+
var functionCall = getFunctionCall(events.get(0));
1041+
assertThat(functionCall.args().get()).isEqualTo(functionArgs);
1042+
assertThat(functionCall.name()).hasValue(toolName);
1043+
1044+
var functionResponse = getFunctionResponse(events.get(1));
1045+
assertThat(functionResponse).isEqualTo(responseFromAfterToolCallback);
1046+
1047+
assertThat(events.get(2).content()).hasValue(llmTextContent);
1048+
}
1049+
1050+
private static Content createFunctionCallContent(
1051+
String functionCallId, String toolName, Map<String, Object> args) {
1052+
return Content.builder()
1053+
.role("model")
1054+
.parts(
1055+
ImmutableList.of(
1056+
Part.builder()
1057+
.functionCall(
1058+
FunctionCall.builder().name(toolName).id(functionCallId).args(args).build())
1059+
.build()))
1060+
.build();
1061+
}
1062+
1063+
private static Map<String, Object> getFunctionResponse(Event functionResponseEvent) {
1064+
return functionResponseEvent
1065+
.content()
1066+
.get()
1067+
.parts()
1068+
.get()
1069+
.get(0)
1070+
.functionResponse()
1071+
.get()
1072+
.response()
1073+
.get();
1074+
}
1075+
1076+
private static FunctionCall getFunctionCall(Event functionCallEvent) {
1077+
return functionCallEvent.content().get().parts().get().get(0).functionCall().get();
1078+
}
9471079
}

core/src/test/java/com/google/adk/flows/llmflows/BasicTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import com.google.adk.events.Event;
3030
import com.google.adk.flows.llmflows.RequestProcessor.RequestProcessingResult;
3131
import com.google.adk.models.LlmRequest;
32+
import com.google.adk.models.LlmResponse;
3233
import com.google.adk.testing.TestLlm;
3334
import com.google.common.collect.ImmutableList;
3435
import com.google.genai.types.AudioTranscriptionConfig;
@@ -69,7 +70,7 @@ public final class BasicTest {
6970
@Before
7071
public void setUp() {
7172
basicProcessor = new Basic();
72-
testLlm = createTestLlm();
73+
testLlm = createTestLlm(new LlmResponse[] {});
7374
testAgent = createTestAgent(testLlm);
7475
testContext = createInvocationContext(testAgent);
7576
initialRequest = LlmRequest.builder().build();

core/src/test/java/com/google/adk/testing/TestUtils.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ public static TestLlm createTestLlm(Flowable<LlmResponse>... responses) {
184184
return createTestLlm(Arrays.asList(responses).iterator()::next);
185185
}
186186

187+
public static TestLlm createTestLlm(LlmResponse... responses) {
188+
return new TestLlm(Arrays.asList(responses));
189+
}
190+
187191
public static TestLlm createTestLlm(Supplier<Flowable<LlmResponse>> responsesSupplier) {
188192
return new TestLlm(responsesSupplier);
189193
}

0 commit comments

Comments
 (0)