Skip to content

Commit 27c0172

Browse files
Doris26copybara-github
authored andcommitted
feat: Add support for configuring agent callbacks in YAML
This change enables developers to define various lifecycle callbacks for `LlmAgent` directly within the agent's YAML configuration file. It introduces new sections like `before_agent_callbacks`, `after_agent_callbacks`, `before_model_callbacks`, `after_model_callbacks`, `before_tool_callbacks`, and `after_tool_callbacks` in the YAML schema. The `LlmAgentConfig` and `LlmAgent` classes are updated to parse these sections, and the `ComponentRegistry` is enhanced to resolve the named callback implementations provided in the application. This allows for more flexible and declarative configuration of agent behavior. PiperOrigin-RevId: 805098543
1 parent bd2553e commit 27c0172

4 files changed

Lines changed: 431 additions & 1 deletion

File tree

core/src/main/java/com/google/adk/agents/LlmAgent.java

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,9 @@ public static LlmAgent fromConfig(LlmAgentConfig config, String configAbsPath)
976976
builder.generateContentConfig(config.generateContentConfig());
977977
}
978978

979+
// Resolve callbacks if configured
980+
setCallbacksFromConfig(config, builder);
981+
979982
// Build and return the agent
980983
LlmAgent agent = builder.build();
981984
logger.info(
@@ -986,6 +989,93 @@ public static LlmAgent fromConfig(LlmAgentConfig config, String configAbsPath)
986989
return agent;
987990
}
988991

992+
private static void setCallbacksFromConfig(LlmAgentConfig config, Builder builder)
993+
throws ConfigurationException {
994+
var beforeAgentCallbacks = config.beforeAgentCallbacks();
995+
if (beforeAgentCallbacks != null) {
996+
ImmutableList.Builder<Callbacks.BeforeAgentCallbackBase> list = ImmutableList.builder();
997+
for (LlmAgentConfig.CallbackRef ref : beforeAgentCallbacks) {
998+
var callback = ComponentRegistry.resolveBeforeAgentCallback(ref.name());
999+
if (callback.isPresent()) {
1000+
list.add(callback.get());
1001+
continue;
1002+
}
1003+
throw new ConfigurationException("Invalid before_agent_callback: " + ref.name());
1004+
}
1005+
builder.beforeAgentCallback(list.build());
1006+
}
1007+
1008+
var afterAgentCallbacks = config.afterAgentCallbacks();
1009+
if (afterAgentCallbacks != null) {
1010+
ImmutableList.Builder<Callbacks.AfterAgentCallbackBase> list = ImmutableList.builder();
1011+
for (LlmAgentConfig.CallbackRef ref : afterAgentCallbacks) {
1012+
var callback = ComponentRegistry.resolveAfterAgentCallback(ref.name());
1013+
if (callback.isPresent()) {
1014+
list.add(callback.get());
1015+
continue;
1016+
}
1017+
throw new ConfigurationException("Invalid after_agent_callback: " + ref.name());
1018+
}
1019+
builder.afterAgentCallback(list.build());
1020+
}
1021+
1022+
var beforeModelCallbacks = config.beforeModelCallbacks();
1023+
if (beforeModelCallbacks != null) {
1024+
ImmutableList.Builder<Callbacks.BeforeModelCallbackBase> list = ImmutableList.builder();
1025+
for (LlmAgentConfig.CallbackRef ref : beforeModelCallbacks) {
1026+
var callback = ComponentRegistry.resolveBeforeModelCallback(ref.name());
1027+
if (callback.isPresent()) {
1028+
list.add(callback.get());
1029+
continue;
1030+
}
1031+
throw new ConfigurationException("Invalid before_model_callback: " + ref.name());
1032+
}
1033+
builder.beforeModelCallback(list.build());
1034+
}
1035+
1036+
var afterModelCallbacks = config.afterModelCallbacks();
1037+
if (afterModelCallbacks != null) {
1038+
ImmutableList.Builder<Callbacks.AfterModelCallbackBase> list = ImmutableList.builder();
1039+
for (LlmAgentConfig.CallbackRef ref : afterModelCallbacks) {
1040+
var callback = ComponentRegistry.resolveAfterModelCallback(ref.name());
1041+
if (callback.isPresent()) {
1042+
list.add(callback.get());
1043+
continue;
1044+
}
1045+
throw new ConfigurationException("Invalid after_model_callback: " + ref.name());
1046+
}
1047+
builder.afterModelCallback(list.build());
1048+
}
1049+
1050+
var beforeToolCallbacks = config.beforeToolCallbacks();
1051+
if (beforeToolCallbacks != null) {
1052+
ImmutableList.Builder<Callbacks.BeforeToolCallbackBase> list = ImmutableList.builder();
1053+
for (LlmAgentConfig.CallbackRef ref : beforeToolCallbacks) {
1054+
var callback = ComponentRegistry.resolveBeforeToolCallback(ref.name());
1055+
if (callback.isPresent()) {
1056+
list.add(callback.get());
1057+
continue;
1058+
}
1059+
throw new ConfigurationException("Invalid before_tool_callback: " + ref.name());
1060+
}
1061+
builder.beforeToolCallback(list.build());
1062+
}
1063+
1064+
var afterToolCallbacks = config.afterToolCallbacks();
1065+
if (afterToolCallbacks != null) {
1066+
ImmutableList.Builder<Callbacks.AfterToolCallbackBase> list = ImmutableList.builder();
1067+
for (LlmAgentConfig.CallbackRef ref : afterToolCallbacks) {
1068+
var callback = ComponentRegistry.resolveAfterToolCallback(ref.name());
1069+
if (callback.isPresent()) {
1070+
list.add(callback.get());
1071+
continue;
1072+
}
1073+
throw new ConfigurationException("Invalid after_tool_callback: " + ref.name());
1074+
}
1075+
builder.afterToolCallback(list.build());
1076+
}
1077+
}
1078+
9891079
/**
9901080
* Resolves a list of tool configurations into both {@link BaseTool} and {@link BaseToolset}
9911081
* instances.

core/src/main/java/com/google/adk/agents/LlmAgentConfig.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,33 @@ public class LlmAgentConfig extends BaseAgentConfig {
3636
private IncludeContents includeContents;
3737
private GenerateContentConfig generateContentConfig;
3838

39+
// Callback configuration (names resolved via ComponentRegistry)
40+
private List<CallbackRef> beforeAgentCallbacks;
41+
private List<CallbackRef> afterAgentCallbacks;
42+
private List<CallbackRef> beforeModelCallbacks;
43+
private List<CallbackRef> afterModelCallbacks;
44+
private List<CallbackRef> beforeToolCallbacks;
45+
private List<CallbackRef> afterToolCallbacks;
46+
47+
/** Reference to a callback stored in the ComponentRegistry. */
48+
public static class CallbackRef {
49+
private String name;
50+
51+
public CallbackRef() {}
52+
53+
public CallbackRef(String name) {
54+
this.name = name;
55+
}
56+
57+
public String name() {
58+
return name;
59+
}
60+
61+
public void setName(String name) {
62+
this.name = name;
63+
}
64+
}
65+
3966
public LlmAgentConfig() {
4067
super();
4168
setAgentClass("LlmAgent");
@@ -105,4 +132,52 @@ public GenerateContentConfig generateContentConfig() {
105132
public void setGenerateContentConfig(GenerateContentConfig generateContentConfig) {
106133
this.generateContentConfig = generateContentConfig;
107134
}
135+
136+
public List<CallbackRef> beforeAgentCallbacks() {
137+
return beforeAgentCallbacks;
138+
}
139+
140+
public void setBeforeAgentCallbacks(List<CallbackRef> beforeAgentCallbacks) {
141+
this.beforeAgentCallbacks = beforeAgentCallbacks;
142+
}
143+
144+
public List<CallbackRef> afterAgentCallbacks() {
145+
return afterAgentCallbacks;
146+
}
147+
148+
public void setAfterAgentCallbacks(List<CallbackRef> afterAgentCallbacks) {
149+
this.afterAgentCallbacks = afterAgentCallbacks;
150+
}
151+
152+
public List<CallbackRef> beforeModelCallbacks() {
153+
return beforeModelCallbacks;
154+
}
155+
156+
public void setBeforeModelCallbacks(List<CallbackRef> beforeModelCallbacks) {
157+
this.beforeModelCallbacks = beforeModelCallbacks;
158+
}
159+
160+
public List<CallbackRef> afterModelCallbacks() {
161+
return afterModelCallbacks;
162+
}
163+
164+
public void setAfterModelCallbacks(List<CallbackRef> afterModelCallbacks) {
165+
this.afterModelCallbacks = afterModelCallbacks;
166+
}
167+
168+
public List<CallbackRef> beforeToolCallbacks() {
169+
return beforeToolCallbacks;
170+
}
171+
172+
public void setBeforeToolCallbacks(List<CallbackRef> beforeToolCallbacks) {
173+
this.beforeToolCallbacks = beforeToolCallbacks;
174+
}
175+
176+
public List<CallbackRef> afterToolCallbacks() {
177+
return afterToolCallbacks;
178+
}
179+
180+
public void setAfterToolCallbacks(List<CallbackRef> afterToolCallbacks) {
181+
this.afterToolCallbacks = afterToolCallbacks;
182+
}
108183
}

core/src/main/java/com/google/adk/utils/ComponentRegistry.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import static com.google.common.collect.ImmutableSet.toImmutableSet;
2121

2222
import com.google.adk.agents.BaseAgent;
23+
import com.google.adk.agents.Callbacks;
2324
import com.google.adk.agents.LlmAgent;
2425
import com.google.adk.agents.LoopAgent;
2526
import com.google.adk.agents.ParallelAgent;
@@ -153,7 +154,7 @@ private void registerAdkToolsetClass(@Nonnull Class<? extends BaseToolset> tools
153154
* @throws IllegalArgumentException if name is null or empty, or if value is null
154155
*/
155156
public void register(String name, Object value) {
156-
if (name == null || name.trim().isEmpty()) {
157+
if (isNullOrEmpty(name) || name.trim().isEmpty()) {
157158
throw new IllegalArgumentException("Name cannot be null or empty");
158159
}
159160
if (value == null) {
@@ -461,4 +462,36 @@ public Set<String> getToolNamesWithPrefix(String prefix) {
461462
.filter(name -> name.startsWith(prefix))
462463
.collect(toImmutableSet());
463464
}
465+
466+
public static Optional<Callbacks.BeforeAgentCallback> resolveBeforeAgentCallback(String name) {
467+
return resolveCallback(name, Callbacks.BeforeAgentCallback.class);
468+
}
469+
470+
public static Optional<Callbacks.AfterAgentCallback> resolveAfterAgentCallback(String name) {
471+
return resolveCallback(name, Callbacks.AfterAgentCallback.class);
472+
}
473+
474+
public static Optional<Callbacks.BeforeModelCallback> resolveBeforeModelCallback(String name) {
475+
return resolveCallback(name, Callbacks.BeforeModelCallback.class);
476+
}
477+
478+
public static Optional<Callbacks.AfterModelCallback> resolveAfterModelCallback(String name) {
479+
return resolveCallback(name, Callbacks.AfterModelCallback.class);
480+
}
481+
482+
public static Optional<Callbacks.BeforeToolCallback> resolveBeforeToolCallback(String name) {
483+
return resolveCallback(name, Callbacks.BeforeToolCallback.class);
484+
}
485+
486+
public static Optional<Callbacks.AfterToolCallback> resolveAfterToolCallback(String name) {
487+
return resolveCallback(name, Callbacks.AfterToolCallback.class);
488+
}
489+
490+
private static <T> Optional<T> resolveCallback(String name, Class<T> type) {
491+
if (isNullOrEmpty(name)) {
492+
return Optional.empty();
493+
}
494+
ComponentRegistry registry = getInstance();
495+
return registry.get(name, type);
496+
}
464497
}

0 commit comments

Comments
 (0)