Skip to content

Commit 1a593a9

Browse files
tilgalascopybara-github
authored andcommitted
feat: remove model restrictions in BuiltInCodeExecutionTool
PiperOrigin-RevId: 875242195
1 parent 43e042a commit 1a593a9

4 files changed

Lines changed: 203 additions & 6 deletions

File tree

core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616

1717
package com.google.adk.tools;
1818

19+
import com.google.adk.agents.LlmAgent;
20+
import com.google.adk.models.BaseLlm;
1921
import com.google.adk.models.LlmRequest;
22+
import com.google.adk.utils.ModelNameUtils;
2023
import com.google.common.collect.ImmutableList;
2124
import com.google.genai.types.GenerateContentConfig;
2225
import com.google.genai.types.Tool;
2326
import com.google.genai.types.ToolCodeExecution;
2427
import io.reactivex.rxjava3.core.Completable;
2528
import java.util.List;
29+
import java.util.Optional;
30+
import org.slf4j.Logger;
31+
import org.slf4j.LoggerFactory;
2632

2733
/**
2834
* A built-in code execution tool that is automatically invoked by Gemini 2 models.
@@ -32,6 +38,7 @@
3238
*/
3339
public final class BuiltInCodeExecutionTool extends BaseTool {
3440
public static final BuiltInCodeExecutionTool INSTANCE = new BuiltInCodeExecutionTool();
41+
private static final Logger LOG = LoggerFactory.getLogger(BuiltInCodeExecutionTool.class);
3542

3643
public BuiltInCodeExecutionTool() {
3744
super("code_execution", "code_execution");
@@ -41,10 +48,28 @@ public BuiltInCodeExecutionTool() {
4148
public Completable processLlmRequest(
4249
LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) {
4350

44-
String model = llmRequestBuilder.build().model().get();
45-
if (model.isEmpty() || !model.startsWith("gemini-2")) {
46-
return Completable.error(
47-
new IllegalArgumentException("Code execution tool is not supported for model " + model));
51+
Optional<BaseLlm> model =
52+
Optional.ofNullable(toolContext)
53+
.flatMap(tCtx -> Optional.ofNullable(tCtx.invocationContext()))
54+
.flatMap(
55+
iCtx -> {
56+
if (iCtx.agent() instanceof LlmAgent llmAgent) {
57+
return Optional.of(llmAgent);
58+
} else {
59+
return Optional.empty();
60+
}
61+
})
62+
.flatMap(llmAgent -> llmAgent.resolvedModel().model());
63+
64+
String modelName = llmRequestBuilder.build().model().get();
65+
if (!ModelNameUtils.isGeminiModel(modelName)
66+
|| model.filter(ModelNameUtils::isInstanceOfGemini).isEmpty()) {
67+
// model name is not a gemini model, or the model isn't an instance of Gemini class (eg.
68+
// LangChain case).
69+
LOG.warn(
70+
"Code execution tool is not supported for model: {} ({}).",
71+
modelName,
72+
model.map(Object::getClass).map(Class::toString).orElse("<unknown class>"));
4873
}
4974
GenerateContentConfig.Builder configBuilder =
5075
llmRequestBuilder

core/src/main/java/com/google/adk/utils/ModelNameUtils.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,24 @@
1616

1717
package com.google.adk.utils;
1818

19+
import com.google.common.base.Strings;
20+
import java.util.Objects;
1921
import java.util.regex.Matcher;
2022
import java.util.regex.Pattern;
2123

2224
public final class ModelNameUtils {
25+
private static final String GEMINI_PREFIX = "gemini-";
2326
private static final Pattern GEMINI_2_PATTERN = Pattern.compile("^gemini-2\\..*");
27+
private static final String GEMINI_CLASS = "com.google.adk.models.Gemini";
2428
private static final Pattern PATH_PATTERN =
2529
Pattern.compile("^projects/[^/]+/locations/[^/]+/publishers/[^/]+/models/(.+)$");
2630
private static final Pattern APIGEE_PATTERN =
2731
Pattern.compile("^apigee/(?:[^/]+/)?(?:[^/]+/)?(.+)$");
2832

33+
public static boolean isGeminiModel(String modelString) {
34+
return extractModelName(Strings.nullToEmpty(modelString)).startsWith(GEMINI_PREFIX);
35+
}
36+
2937
public static boolean isGemini2Model(String modelString) {
3038
if (modelString == null) {
3139
return false;
@@ -34,6 +42,29 @@ public static boolean isGemini2Model(String modelString) {
3442
return GEMINI_2_PATTERN.matcher(modelName).matches();
3543
}
3644

45+
/**
46+
* Checks whether an object is an instance of {@link com.google.adk.models.Gemini}, by searching
47+
* through its class hierarchy for a class whose name equals the hardcoded String name of Gemini
48+
* class.
49+
*
50+
* <p>This method can be used where the "real" instanceof check is not possible because the Gemini
51+
* type is not known at compile time.
52+
*
53+
* @param o The object to check.
54+
* @return true if object's class is {@link com.google.adk.models.Gemini}, false otherwise.
55+
*/
56+
public static boolean isInstanceOfGemini(Object o) {
57+
if (o == null) {
58+
return false;
59+
}
60+
for (Class<?> clazz = o.getClass(); clazz != null; clazz = clazz.getSuperclass()) {
61+
if (Objects.equals(clazz.getName(), GEMINI_CLASS)) {
62+
return true;
63+
}
64+
}
65+
return false;
66+
}
67+
3768
/**
3869
* Extract the actual model name from either simple or path-based format.
3970
*

core/src/test/java/com/google/adk/tools/BaseToolTest.java

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22

33
import static com.google.common.truth.Truth.assertThat;
44

5+
import com.google.adk.agents.InvocationContext;
6+
import com.google.adk.agents.LlmAgent;
7+
import com.google.adk.models.Gemini;
58
import com.google.adk.models.LlmRequest;
9+
import com.google.adk.sessions.InMemorySessionService;
610
import com.google.common.collect.ImmutableList;
711
import com.google.genai.types.FunctionDeclaration;
812
import com.google.genai.types.GenerateContentConfig;
@@ -171,13 +175,27 @@ public void processLlmRequestWithUrlContextToolAddsToolToConfig() {
171175
Tool.builder().urlContext(UrlContext.builder().build()).build());
172176
}
173177

178+
private static InvocationContext.Builder testInvocationContext() {
179+
InvocationContext.Builder builder = InvocationContext.builder();
180+
builder.agent(testAgent().build());
181+
InMemorySessionService inMemorySessionService = new InMemorySessionService();
182+
builder.sessionService(inMemorySessionService);
183+
builder.session(inMemorySessionService.createSession("test-app", "test-user-id").blockingGet());
184+
return builder;
185+
}
186+
187+
private static LlmAgent.Builder testAgent() {
188+
return LlmAgent.builder().name("test-agent");
189+
}
190+
174191
@Test
175-
public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() {
192+
public void
193+
processLlmRequestWithBuiltInCodeExecutionToolAndNonGeminiModelAndNullContextAddsToolToConfig() {
176194
BuiltInCodeExecutionTool builtInCodeExecutionTool = new BuiltInCodeExecutionTool();
177195
LlmRequest llmRequest =
178196
LlmRequest.builder()
179197
.config(GenerateContentConfig.builder().build())
180-
.model("gemini-2")
198+
.model("text-bison")
181199
.build();
182200
LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();
183201
Completable unused =
@@ -189,6 +207,29 @@ public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() {
189207
.containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build());
190208
}
191209

210+
@Test
211+
public void processLlmRequestWithBuiltInCodeExecutionToolAndGemini2ModelAddsToolToConfig() {
212+
BuiltInCodeExecutionTool builtInCodeExecutionTool = new BuiltInCodeExecutionTool();
213+
LlmRequest llmRequest =
214+
LlmRequest.builder()
215+
.config(GenerateContentConfig.builder().build())
216+
.model("gemini-2")
217+
.build();
218+
LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder();
219+
ToolContext toolContext =
220+
ToolContext.builder(
221+
testInvocationContext()
222+
.agent(testAgent().model(new Gemini("gemini-2", "")).build())
223+
.build())
224+
.build();
225+
Completable unused = builtInCodeExecutionTool.processLlmRequest(llmRequestBuilder, toolContext);
226+
LlmRequest updatedLlmRequest = llmRequestBuilder.build();
227+
assertThat(updatedLlmRequest.config()).isPresent();
228+
assertThat(updatedLlmRequest.config().get().tools()).isPresent();
229+
assertThat(updatedLlmRequest.config().get().tools().get())
230+
.containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build());
231+
}
232+
192233
@Test
193234
public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() {
194235
GoogleMapsTool googleMapsTool = new GoogleMapsTool();

core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import static com.google.common.truth.Truth.assertThat;
44

5+
import com.google.adk.models.Gemini;
56
import org.junit.Test;
67
import org.junit.runner.RunWith;
78
import org.junit.runners.JUnit4;
@@ -69,4 +70,103 @@ public void isGemini2Model_withApigeeProviderV1BetaGemini2Model_returnsTrue() {
6970
public void isGemini2Model_withNullModel_returnsFalse() {
7071
assertThat(ModelNameUtils.isGemini2Model(null)).isFalse();
7172
}
73+
74+
@Test
75+
public void isGeminiModel_withGeminiModel_returnsTrue() {
76+
assertThat(ModelNameUtils.isGeminiModel("gemini-1.5-flash")).isTrue();
77+
}
78+
79+
@Test
80+
public void isGeminiModel_withNonGeminiModel_returnsFalse() {
81+
assertThat(ModelNameUtils.isGeminiModel("text-bison")).isFalse();
82+
}
83+
84+
@Test
85+
public void isGeminiModel_withPathBasedGeminiModel_returnsTrue() {
86+
assertThat(
87+
ModelNameUtils.isGeminiModel(
88+
"projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro"))
89+
.isTrue();
90+
}
91+
92+
@Test
93+
public void isGeminiModel_withPathBasedNonGeminiModel_returnsFalse() {
94+
assertThat(
95+
ModelNameUtils.isGeminiModel(
96+
"projects/test-project/locations/us-central1/publishers/google/models/text-bison"))
97+
.isFalse();
98+
}
99+
100+
@Test
101+
public void isGeminiModel_withApigeeGeminiModel_returnsTrue() {
102+
assertThat(ModelNameUtils.isGeminiModel("apigee/gemini-1.5-flash")).isTrue();
103+
}
104+
105+
@Test
106+
public void isGeminiModel_withApigeeV1GeminiModel_returnsTrue() {
107+
assertThat(ModelNameUtils.isGeminiModel("apigee/v1/gemini-1.5-flash")).isTrue();
108+
}
109+
110+
@Test
111+
public void isGeminiModel_withApigeeProviderGeminiModel_returnsTrue() {
112+
assertThat(ModelNameUtils.isGeminiModel("apigee/gemini/gemini-1.5-flash")).isTrue();
113+
}
114+
115+
@Test
116+
public void isGeminiModel_withApigeeProviderVertexGeminiModel_returnsTrue() {
117+
assertThat(ModelNameUtils.isGeminiModel("apigee/vertex_ai/gemini-1.5-flash")).isTrue();
118+
}
119+
120+
@Test
121+
public void isGeminiModel_withApigeeProviderV1GeminiModel_returnsTrue() {
122+
assertThat(ModelNameUtils.isGeminiModel("apigee/gemini/v1/gemini-1.5-flash")).isTrue();
123+
}
124+
125+
@Test
126+
public void isGeminiModel_withApigeeProviderV1BetaGeminiModel_returnsTrue() {
127+
assertThat(ModelNameUtils.isGeminiModel("apigee/vertex_ai/v1beta/gemini-1.5-flash")).isTrue();
128+
}
129+
130+
@Test
131+
public void isGeminiModel_withNullModel_returnsFalse() {
132+
assertThat(ModelNameUtils.isGeminiModel(null)).isFalse();
133+
}
134+
135+
@Test
136+
public void isGeminiModel_withEmptyModel_returnsFalse() {
137+
assertThat(ModelNameUtils.isGeminiModel("")).isFalse();
138+
}
139+
140+
@Test
141+
public void isInstanceOfGemini_withGeminiInstance_returnsTrue() {
142+
assertThat(ModelNameUtils.isInstanceOfGemini(new Gemini("", ""))).isTrue();
143+
}
144+
145+
@Test
146+
public void isInstanceOfGemini_withNonGeminiInstance_returnsFalse() {
147+
assertThat(ModelNameUtils.isInstanceOfGemini(new Object())).isFalse();
148+
}
149+
150+
@Test
151+
public void isInstanceOfGemini_withNullInstance_returnsFalse() {
152+
assertThat(ModelNameUtils.isInstanceOfGemini(null)).isFalse();
153+
}
154+
155+
private static class GeminiSubclass extends Gemini {
156+
GeminiSubclass() {
157+
super("test-model", "test-api-key");
158+
}
159+
}
160+
161+
private static class GeminiSubclassSubclass extends GeminiSubclass {}
162+
163+
@Test
164+
public void isInstanceOfGemini_withGeminiSubclassInstance_returnsTrue() {
165+
assertThat(ModelNameUtils.isInstanceOfGemini(new GeminiSubclass())).isTrue();
166+
}
167+
168+
@Test
169+
public void isInstanceOfGemini_withSubclassOfGeminiSubclassInstance_returnsTrue() {
170+
assertThat(ModelNameUtils.isInstanceOfGemini(new GeminiSubclassSubclass())).isTrue();
171+
}
72172
}

0 commit comments

Comments
 (0)