Skip to content

Commit b8cb7e2

Browse files
tilgalascopybara-github
authored andcommitted
feat: add type-safe runAsync methods to BaseTool
PiperOrigin-RevId: 884493553
1 parent 567fdf0 commit b8cb7e2

2 files changed

Lines changed: 189 additions & 0 deletions

File tree

core/src/main/java/com/google/adk/tools/BaseTool.java

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import com.fasterxml.jackson.annotation.JsonAnySetter;
2323
import com.fasterxml.jackson.annotation.JsonIgnore;
2424
import com.fasterxml.jackson.core.type.TypeReference;
25+
import com.fasterxml.jackson.databind.ObjectMapper;
2526
import com.google.adk.JsonBaseModel;
2627
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
2728
import com.google.adk.models.LlmRequest;
@@ -38,6 +39,7 @@
3839
import java.util.HashMap;
3940
import java.util.Map;
4041
import java.util.Optional;
42+
import java.util.function.Function;
4143
import javax.annotation.Nonnull;
4244
import org.jspecify.annotations.Nullable;
4345
import org.slf4j.Logger;
@@ -93,6 +95,85 @@ public Single<Map<String, Object>> runAsync(Map<String, Object> args, ToolContex
9395
throw new UnsupportedOperationException("This method is not implemented.");
9496
}
9597

98+
/**
99+
* Calls a tool with generic arguments and returns a map of results. The args type {@code T} need
100+
* to be serializable with {@link JsonBaseModel#getMapper()}
101+
*/
102+
public final <T> Single<Map<String, Object>> runAsync(T args, ToolContext toolContext) {
103+
return runAsync(args, toolContext, JsonBaseModel.getMapper());
104+
}
105+
106+
/**
107+
* Calls a tool with generic arguments using a custom {@link ObjectMapper} and returns a map of
108+
* results. The args type {@code T} needs to be serializable with the provided {@link
109+
* ObjectMapper}.
110+
*/
111+
public final <T> Single<Map<String, Object>> runAsync(
112+
T args, ToolContext toolContext, ObjectMapper objectMapper) {
113+
return runAsync(args, toolContext, objectMapper, output -> output);
114+
}
115+
116+
/**
117+
* Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results
118+
* converted to a specified class. The input type {@code I} needs to be serializable and the
119+
* output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}.
120+
*/
121+
public final <I, O> Single<O> runAsync(
122+
I args, ToolContext toolContext, ObjectMapper objectMapper, Class<? extends O> oClass) {
123+
return runAsync(
124+
args, toolContext, objectMapper, output -> objectMapper.convertValue(output, oClass));
125+
}
126+
127+
/**
128+
* Calls a tool with generic arguments and a custom {@link ObjectMapper}, returning the results
129+
* converted to a specified type reference. The input type {@code I} needs to be serializable and
130+
* the output type {@code O} needs to be deserializable with the provided {@link ObjectMapper}.
131+
*/
132+
public final <I, O> Single<O> runAsync(
133+
I args,
134+
ToolContext toolContext,
135+
ObjectMapper objectMapper,
136+
TypeReference<? extends O> typeReference) {
137+
return runAsync(
138+
args,
139+
toolContext,
140+
objectMapper,
141+
output -> objectMapper.convertValue(output, typeReference));
142+
}
143+
144+
/**
145+
* Calls a tool with generic arguments, returning the results converted to a specified class. The
146+
* input type {@code I} needs to be serializable and the output type {@code O} needs to be
147+
* deserializable with {@link JsonBaseModel#getMapper()}
148+
*/
149+
public final <I, O> Single<O> runAsync(
150+
I args, ToolContext toolContext, Class<? extends O> oClass) {
151+
return runAsync(args, toolContext, JsonBaseModel.getMapper(), oClass);
152+
}
153+
154+
/**
155+
* Calls a tool with generic arguments, returning the results converted to a specified type
156+
* reference. The input type needs to be serializable and the output type needs to be
157+
* deserializable with {@link JsonBaseModel#getMapper()}
158+
*/
159+
public final <I, O> Single<O> runAsync(
160+
I args, ToolContext toolContext, TypeReference<? extends O> typeReference) {
161+
return runAsync(args, toolContext, JsonBaseModel.getMapper(), typeReference);
162+
}
163+
164+
private <I, O> Single<O> runAsync(
165+
I args,
166+
ToolContext toolContext,
167+
ObjectMapper objectMapper,
168+
Function<? super Map<String, Object>, ? extends O> deserializer) {
169+
return Single.defer(
170+
() ->
171+
Single.just(
172+
objectMapper.convertValue(args, new TypeReference<Map<String, Object>>() {})))
173+
.flatMap(argsMap -> runAsync(argsMap, toolContext))
174+
.map(deserializer::apply);
175+
}
176+
96177
/**
97178
* Processes the outgoing {@link LlmRequest.Builder}.
98179
*

core/src/test/java/com/google/adk/tools/BaseToolTest.java

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
import static com.google.common.truth.Truth.assertThat;
44

5+
import com.fasterxml.jackson.core.type.TypeReference;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
57
import com.google.adk.agents.InvocationContext;
68
import com.google.adk.agents.LlmAgent;
79
import com.google.adk.models.Gemini;
810
import com.google.adk.models.LlmRequest;
911
import com.google.adk.sessions.InMemorySessionService;
1012
import com.google.common.collect.ImmutableList;
13+
import com.google.common.collect.ImmutableMap;
1114
import com.google.genai.types.FunctionDeclaration;
1215
import com.google.genai.types.GenerateContentConfig;
1316
import com.google.genai.types.GoogleMaps;
@@ -17,6 +20,7 @@
1720
import com.google.genai.types.UrlContext;
1821
import io.reactivex.rxjava3.core.Completable;
1922
import io.reactivex.rxjava3.core.Single;
23+
import io.reactivex.rxjava3.observers.TestObserver;
2024
import java.util.Map;
2125
import java.util.Optional;
2226
import org.junit.Test;
@@ -27,6 +31,20 @@
2731
@RunWith(JUnit4.class)
2832
public final class BaseToolTest {
2933

34+
private final BaseTool doublingBaseTool =
35+
new BaseTool("doubling-test-tool", "returns doubled args") {
36+
@Override
37+
public Single<Map<String, Object>> runAsync(
38+
Map<String, Object> args, ToolContext toolContext) {
39+
String sArg = (String) args.get("s");
40+
Integer iArg = (Integer) args.get("i");
41+
return Single.just(
42+
ImmutableMap.<String, Object>of(
43+
"s", sArg + sArg,
44+
"i", iArg + iArg));
45+
}
46+
};
47+
3048
@Test
3149
public void processLlmRequestNoDeclarationReturnsSameRequest() {
3250
BaseTool tool =
@@ -247,4 +265,94 @@ public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() {
247265
assertThat(updatedLlmRequest.config().get().tools().get())
248266
.containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build());
249267
}
268+
269+
@Test
270+
public void runAsync_withTypeReference_convertsArguments() throws Exception {
271+
TestToolArgs testToolArgs = new TestToolArgs(42, "foo");
272+
273+
Single<TestToolArgs> out =
274+
doublingBaseTool.runAsync(
275+
testToolArgs, /* toolContext= */ null, new TypeReference<TestToolArgs>() {});
276+
TestObserver<TestToolArgs> testObserver = out.test();
277+
278+
testObserver.assertComplete();
279+
TestToolArgs expected = new TestToolArgs(84, "foofoo");
280+
testObserver.assertValue(expected);
281+
}
282+
283+
@Test
284+
public void runAsync_withClass_convertsArguments() throws Exception {
285+
TestToolArgs testToolArgs = new TestToolArgs(21, "bar");
286+
287+
Single<TestToolArgs> out =
288+
doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, TestToolArgs.class);
289+
TestObserver<TestToolArgs> testObserver = out.test();
290+
291+
testObserver.assertComplete();
292+
TestToolArgs expected = new TestToolArgs(42, "barbar");
293+
testObserver.assertValue(expected);
294+
}
295+
296+
@Test
297+
public void runAsync_withObjectOnly_convertsArguments() throws Exception {
298+
TestToolArgs testToolArgs = new TestToolArgs(11, "baz");
299+
300+
Single<Map<String, Object>> out =
301+
doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null);
302+
TestObserver<Map<String, Object>> testObserver = out.test();
303+
304+
testObserver.assertComplete();
305+
ImmutableMap<String, Object> expected = ImmutableMap.of("i", 22, "s", "bazbaz");
306+
testObserver.assertValue(expected);
307+
}
308+
309+
@Test
310+
public void runAsync_withObjectMapperAndObjectOnly_convertsArguments() throws Exception {
311+
TestToolArgs testToolArgs = new TestToolArgs(11, "baz");
312+
ObjectMapper objectMapper = new ObjectMapper();
313+
314+
Single<Map<String, Object>> out =
315+
doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, objectMapper);
316+
TestObserver<Map<String, Object>> testObserver = out.test();
317+
318+
testObserver.assertComplete();
319+
ImmutableMap<String, Object> expected = ImmutableMap.of("i", 22, "s", "bazbaz");
320+
testObserver.assertValue(expected);
321+
}
322+
323+
@Test
324+
public void runAsync_withTypeReferenceAndObjectMapper_convertsArguments() throws Exception {
325+
TestToolArgs testToolArgs = new TestToolArgs(42, "foo");
326+
ObjectMapper objectMapper = new ObjectMapper();
327+
328+
Single<TestToolArgs> out =
329+
doublingBaseTool.runAsync(
330+
testToolArgs,
331+
/* toolContext= */ null,
332+
objectMapper,
333+
new TypeReference<TestToolArgs>() {});
334+
335+
TestObserver<TestToolArgs> testObserver = out.test();
336+
337+
testObserver.assertComplete();
338+
TestToolArgs expected = new TestToolArgs(84, "foofoo");
339+
testObserver.assertValue(expected);
340+
}
341+
342+
@Test
343+
public void runAsync_withClassAndObjectMapper_convertsArguments() throws Exception {
344+
TestToolArgs testToolArgs = new TestToolArgs(21, "bar");
345+
ObjectMapper objectMapper = new ObjectMapper();
346+
347+
Single<TestToolArgs> out =
348+
doublingBaseTool.runAsync(
349+
testToolArgs, /* toolContext= */ null, objectMapper, TestToolArgs.class);
350+
TestObserver<TestToolArgs> testObserver = out.test();
351+
352+
testObserver.assertComplete();
353+
TestToolArgs expected = new TestToolArgs(42, "barbar");
354+
testObserver.assertValue(expected);
355+
}
356+
357+
public record TestToolArgs(int i, String s) {}
250358
}

0 commit comments

Comments
 (0)