|
2 | 2 |
|
3 | 3 | import static com.google.common.truth.Truth.assertThat; |
4 | 4 |
|
| 5 | +import com.fasterxml.jackson.core.type.TypeReference; |
| 6 | +import com.fasterxml.jackson.databind.ObjectMapper; |
5 | 7 | import com.google.adk.agents.InvocationContext; |
6 | 8 | import com.google.adk.agents.LlmAgent; |
7 | 9 | import com.google.adk.models.Gemini; |
8 | 10 | import com.google.adk.models.LlmRequest; |
9 | 11 | import com.google.adk.sessions.InMemorySessionService; |
10 | 12 | import com.google.common.collect.ImmutableList; |
| 13 | +import com.google.common.collect.ImmutableMap; |
11 | 14 | import com.google.genai.types.FunctionDeclaration; |
12 | 15 | import com.google.genai.types.GenerateContentConfig; |
13 | 16 | import com.google.genai.types.GoogleMaps; |
|
17 | 20 | import com.google.genai.types.UrlContext; |
18 | 21 | import io.reactivex.rxjava3.core.Completable; |
19 | 22 | import io.reactivex.rxjava3.core.Single; |
| 23 | +import io.reactivex.rxjava3.observers.TestObserver; |
20 | 24 | import java.util.Map; |
21 | 25 | import java.util.Optional; |
22 | 26 | import org.junit.Test; |
|
27 | 31 | @RunWith(JUnit4.class) |
28 | 32 | public final class BaseToolTest { |
29 | 33 |
|
| 34 | + private final BaseTool doublingBaseTool = |
| 35 | + new BaseTool("doubling-test-tool", "returns doubled args") { |
| 36 | + @Override |
| 37 | + public Single<Map<String, Object>> runAsync( |
| 38 | + Map<String, Object> args, ToolContext toolContext) { |
| 39 | + String sArg = (String) args.get("s"); |
| 40 | + Integer iArg = (Integer) args.get("i"); |
| 41 | + return Single.just( |
| 42 | + ImmutableMap.<String, Object>of( |
| 43 | + "s", sArg + sArg, |
| 44 | + "i", iArg + iArg)); |
| 45 | + } |
| 46 | + }; |
| 47 | + |
30 | 48 | @Test |
31 | 49 | public void processLlmRequestNoDeclarationReturnsSameRequest() { |
32 | 50 | BaseTool tool = |
@@ -247,4 +265,94 @@ public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() { |
247 | 265 | assertThat(updatedLlmRequest.config().get().tools().get()) |
248 | 266 | .containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); |
249 | 267 | } |
| 268 | + |
| 269 | + @Test |
| 270 | + public void runAsync_withTypeReference_convertsArguments() throws Exception { |
| 271 | + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); |
| 272 | + |
| 273 | + Single<TestToolArgs> out = |
| 274 | + doublingBaseTool.runAsync( |
| 275 | + testToolArgs, /* toolContext= */ null, new TypeReference<TestToolArgs>() {}); |
| 276 | + TestObserver<TestToolArgs> testObserver = out.test(); |
| 277 | + |
| 278 | + testObserver.assertComplete(); |
| 279 | + TestToolArgs expected = new TestToolArgs(84, "foofoo"); |
| 280 | + testObserver.assertValue(expected); |
| 281 | + } |
| 282 | + |
| 283 | + @Test |
| 284 | + public void runAsync_withClass_convertsArguments() throws Exception { |
| 285 | + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); |
| 286 | + |
| 287 | + Single<TestToolArgs> out = |
| 288 | + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, TestToolArgs.class); |
| 289 | + TestObserver<TestToolArgs> testObserver = out.test(); |
| 290 | + |
| 291 | + testObserver.assertComplete(); |
| 292 | + TestToolArgs expected = new TestToolArgs(42, "barbar"); |
| 293 | + testObserver.assertValue(expected); |
| 294 | + } |
| 295 | + |
| 296 | + @Test |
| 297 | + public void runAsync_withObjectOnly_convertsArguments() throws Exception { |
| 298 | + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); |
| 299 | + |
| 300 | + Single<Map<String, Object>> out = |
| 301 | + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null); |
| 302 | + TestObserver<Map<String, Object>> testObserver = out.test(); |
| 303 | + |
| 304 | + testObserver.assertComplete(); |
| 305 | + ImmutableMap<String, Object> expected = ImmutableMap.of("i", 22, "s", "bazbaz"); |
| 306 | + testObserver.assertValue(expected); |
| 307 | + } |
| 308 | + |
| 309 | + @Test |
| 310 | + public void runAsync_withObjectMapperAndObjectOnly_convertsArguments() throws Exception { |
| 311 | + TestToolArgs testToolArgs = new TestToolArgs(11, "baz"); |
| 312 | + ObjectMapper objectMapper = new ObjectMapper(); |
| 313 | + |
| 314 | + Single<Map<String, Object>> out = |
| 315 | + doublingBaseTool.runAsync(testToolArgs, /* toolContext= */ null, objectMapper); |
| 316 | + TestObserver<Map<String, Object>> testObserver = out.test(); |
| 317 | + |
| 318 | + testObserver.assertComplete(); |
| 319 | + ImmutableMap<String, Object> expected = ImmutableMap.of("i", 22, "s", "bazbaz"); |
| 320 | + testObserver.assertValue(expected); |
| 321 | + } |
| 322 | + |
| 323 | + @Test |
| 324 | + public void runAsync_withTypeReferenceAndObjectMapper_convertsArguments() throws Exception { |
| 325 | + TestToolArgs testToolArgs = new TestToolArgs(42, "foo"); |
| 326 | + ObjectMapper objectMapper = new ObjectMapper(); |
| 327 | + |
| 328 | + Single<TestToolArgs> out = |
| 329 | + doublingBaseTool.runAsync( |
| 330 | + testToolArgs, |
| 331 | + /* toolContext= */ null, |
| 332 | + objectMapper, |
| 333 | + new TypeReference<TestToolArgs>() {}); |
| 334 | + |
| 335 | + TestObserver<TestToolArgs> testObserver = out.test(); |
| 336 | + |
| 337 | + testObserver.assertComplete(); |
| 338 | + TestToolArgs expected = new TestToolArgs(84, "foofoo"); |
| 339 | + testObserver.assertValue(expected); |
| 340 | + } |
| 341 | + |
| 342 | + @Test |
| 343 | + public void runAsync_withClassAndObjectMapper_convertsArguments() throws Exception { |
| 344 | + TestToolArgs testToolArgs = new TestToolArgs(21, "bar"); |
| 345 | + ObjectMapper objectMapper = new ObjectMapper(); |
| 346 | + |
| 347 | + Single<TestToolArgs> out = |
| 348 | + doublingBaseTool.runAsync( |
| 349 | + testToolArgs, /* toolContext= */ null, objectMapper, TestToolArgs.class); |
| 350 | + TestObserver<TestToolArgs> testObserver = out.test(); |
| 351 | + |
| 352 | + testObserver.assertComplete(); |
| 353 | + TestToolArgs expected = new TestToolArgs(42, "barbar"); |
| 354 | + testObserver.assertValue(expected); |
| 355 | + } |
| 356 | + |
| 357 | + public record TestToolArgs(int i, String s) {} |
250 | 358 | } |
0 commit comments