diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b5b966fb1..77c0d292c 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.6.0" + ".": "0.7.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index f360e9399..b10bad3ce 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,33 @@ # Changelog +## [0.7.0](https://github.com/google/adk-java/compare/v0.6.0...v0.7.0) (2026-02-27) + + +### Features + +* Add ComputerUse tool ([d733a48](https://github.com/google/adk-java/commit/d733a480a7a787cb7c32fd3470ab978ca3eb574c)) +* add the AgentExecutor config ([e0f7137](https://github.com/google/adk-java/commit/e0f7137253c9bd929fe3ea899e32f4b61f994986)) +* drop gemini-1 support in GoogleSearchTool ([15255b4](https://github.com/google/adk-java/commit/15255b48285819c7d3aedb4470e91f37d1bcfaf4)) +* Extend url_context support to Gemini 3 in Java ADK ([2c9d4dd](https://github.com/google/adk-java/commit/2c9d4dd5eafe8efe3a2fb099b58e2d0f1d9cad98)) +* Extend url_context support to Gemini 3 in Java ADK ([5f5869f](https://github.com/google/adk-java/commit/5f5869f67200831dcbb7ac10ad0d7f44410bc096)) +* Handle final and error TaskStatusUpdateEvents ([746e857](https://github.com/google/adk-java/commit/746e857d97c6f356ffe5c20be0ccae85d5a8f989)) +* remove model restrictions in BuiltInCodeExecutionTool ([1a593a9](https://github.com/google/adk-java/commit/1a593a996607904eed24b64bc63eecd7708710af)) +* Update AgentExecutor so it builds new runner on execute and there is no need to pass the runner instance ([7218295](https://github.com/google/adk-java/commit/72182958586e59ccb3d7490cd207ec2837c5b577)) + + +### Bug Fixes + +* change Session events list to a threadsafe implementation by default ([0b5ac92](https://github.com/google/adk-java/commit/0b5ac9214926200c3d65d64d8c10489847c29291)) +* deep-merge stateDelta maps when merging EventActions ([ff07474](https://github.com/google/adk-java/commit/ff07474035baec910f0c3fa83b7b1646d8409ffd)) +* drop explicit gemini-1 model version check in GoogleMapsTool ([7953503](https://github.com/google/adk-java/commit/7953503e61c547e40a1e1abbece73a99910766c1)) +* LlmAgent model name resolution and improve Gemini-3 model detection logic ([313ce85](https://github.com/google/adk-java/commit/313ce8590982346bb8ac631b4bf88da76fb849a4)) +* make a mutable copy of function args for the beforeToolCallback invocations ([64d3a77](https://github.com/google/adk-java/commit/64d3a775d68610d20c084678ffdc559cd467e627)) + + +### Documentation + +* Update a parameter name in a comment ([5262d4a](https://github.com/google/adk-java/commit/5262d4ae3eca533e1a695e6e2e71c5845055ed5d)) + ## [0.6.0](https://github.com/google/adk-java/compare/v0.5.0...v0.6.0) (2026-02-19) diff --git a/README.md b/README.md index 691b62f5f..d0471c1bf 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,13 @@ If you're using Maven, add the following to your dependencies: com.google.adk google-adk - 0.6.0 + 0.7.0 com.google.adk google-adk-dev - 0.6.0 + 0.7.0 ``` diff --git a/a2a/pom.xml b/a2a/pom.xml index 485eed617..dc6d9b8f7 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 google-adk-a2a @@ -23,7 +23,6 @@ 3.1.5 1.0.0 2.0.17 - 2.38.0 1.4.4 4.13.2 @@ -42,7 +41,6 @@ com.google.errorprone error_prone_annotations - ${errorprone.version} com.fasterxml.jackson.core diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java index c6ef06400..bc0620f83 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/PartConverter.java @@ -47,6 +47,13 @@ public final class PartConverter { "code_execution_result"; public static final String A2A_DATA_PART_METADATA_TYPE_EXECUTABLE_CODE = "executable_code"; + public static Optional toTextPart(io.a2a.spec.Part part) { + if (part instanceof TextPart textPart) { + return Optional.of(textPart); + } + return Optional.empty(); + } + /** Convert an A2A JSON part into a Google GenAI part representation. */ public static Optional toGenaiPart(io.a2a.spec.Part a2aPart) { if (a2aPart == null) { diff --git a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java index 61ab84c90..0a272b72d 100644 --- a/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java +++ b/a2a/src/main/java/com/google/adk/a2a/converters/ResponseConverter.java @@ -197,13 +197,28 @@ private static Optional handleTaskUpdate( } if (updateEvent instanceof TaskStatusUpdateEvent statusEvent) { var status = statusEvent.getStatus(); - return Optional.ofNullable(status.message()) - .map( - value -> - messageToEvent( - value, - context, - PENDING_STATES.contains(event.getTask().getStatus().state()))); + var taskState = event.getTask().getStatus().state(); + + Optional messageEvent = + Optional.ofNullable(status.message()) + .map( + value -> { + if (taskState == TaskState.FAILED) { + return messageToFailedEvent(value, context); + } + return messageToEvent(value, context, PENDING_STATES.contains(taskState)); + }); + + if (statusEvent.isFinal()) { + return messageEvent + .map(Event::toBuilder) + .or(() -> Optional.of(remoteAgentEventBuilder(context))) + .map(builder -> builder.turnComplete(true)) + .map(builder -> builder.partial(false)) + .map(Event.Builder::build); + } else { + return messageEvent; + } } throw new IllegalArgumentException( "Unsupported TaskUpdateEvent type: " + updateEvent.getClass()); @@ -216,6 +231,16 @@ public static Event messageToEvent(Message message, InvocationContext invocation .build(); } + /** Converts an A2A message for a failed task to ADK event filling in the error message. */ + public static Event messageToFailedEvent(Message message, InvocationContext invocationContext) { + Event.Builder builder = remoteAgentEventBuilder(invocationContext); + Optional.ofNullable(Iterables.getFirst(message.getParts(), null)) + .flatMap(PartConverter::toTextPart) + .ifPresent(textPart -> builder.errorMessage(textPart.getText())); + + return builder.build(); + } + /** * Converts an A2A message back to ADK events. For streaming task in pending state it sets the * thought field to true, to mark them as thought updates. diff --git a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java similarity index 60% rename from a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java rename to a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java index 6df01694a..0c12727aa 100644 --- a/a2a/src/main/java/com/google/adk/a2a/AgentExecutor.java +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutor.java @@ -1,9 +1,15 @@ -package com.google.adk.a2a; +package com.google.adk.a2a.executor; + +import static java.util.Objects.requireNonNull; import com.google.adk.a2a.converters.EventConverter; import com.google.adk.a2a.converters.PartConverter; -import com.google.adk.agents.RunConfig; +import com.google.adk.agents.BaseAgent; +import com.google.adk.apps.App; +import com.google.adk.artifacts.BaseArtifactService; import com.google.adk.events.Event; +import com.google.adk.memory.BaseMemoryService; +import com.google.adk.plugins.Plugin; import com.google.adk.runner.Runner; import com.google.adk.sessions.BaseSessionService; import com.google.adk.sessions.Session; @@ -21,6 +27,7 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.disposables.CompositeDisposable; import io.reactivex.rxjava3.disposables.Disposable; +import java.util.List; import java.util.Map; import java.util.Optional; import java.util.UUID; @@ -38,32 +45,108 @@ public class AgentExecutor implements io.a2a.server.agentexecution.AgentExecutor private static final Logger logger = LoggerFactory.getLogger(AgentExecutor.class); private static final String USER_ID_PREFIX = "A2A_USER_"; - private static final RunConfig DEFAULT_RUN_CONFIG = - RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); - private final Runner runner; private final Map activeTasks = new ConcurrentHashMap<>(); + private final Runner.Builder runnerBuilder; + private final AgentExecutorConfig agentExecutorConfig; + + private AgentExecutor( + App app, + BaseAgent agent, + String appName, + BaseArtifactService artifactService, + BaseSessionService sessionService, + BaseMemoryService memoryService, + List plugins, + AgentExecutorConfig agentExecutorConfig) { + requireNonNull(agentExecutorConfig); + this.agentExecutorConfig = agentExecutorConfig; - private AgentExecutor(Runner runner) { - this.runner = runner; + this.runnerBuilder = + Runner.builder() + .agent(agent) + .appName(appName) + .artifactService(artifactService) + .sessionService(sessionService) + .memoryService(memoryService) + .plugins(plugins); + if (app != null) { + this.runnerBuilder.app(app); + } + // Check that the runner is configured correctly and can be built. + var unused = runnerBuilder.build(); } /** Builder for {@link AgentExecutor}. */ public static class Builder { - private Runner runner; + private App app; + private BaseAgent agent; + private String appName; + private BaseArtifactService artifactService; + private BaseSessionService sessionService; + private BaseMemoryService memoryService; + private List plugins = ImmutableList.of(); + private AgentExecutorConfig agentExecutorConfig; + + @CanIgnoreReturnValue + public Builder agentExecutorConfig(AgentExecutorConfig agentExecutorConfig) { + this.agentExecutorConfig = agentExecutorConfig; + return this; + } + + @CanIgnoreReturnValue + public Builder app(App app) { + this.app = app; + return this; + } + + @CanIgnoreReturnValue + public Builder agent(BaseAgent agent) { + this.agent = agent; + return this; + } + + @CanIgnoreReturnValue + public Builder appName(String appName) { + this.appName = appName; + return this; + } + + @CanIgnoreReturnValue + public Builder artifactService(BaseArtifactService artifactService) { + this.artifactService = artifactService; + return this; + } + + @CanIgnoreReturnValue + public Builder sessionService(BaseSessionService sessionService) { + this.sessionService = sessionService; + return this; + } + + @CanIgnoreReturnValue + public Builder memoryService(BaseMemoryService memoryService) { + this.memoryService = memoryService; + return this; + } @CanIgnoreReturnValue - public Builder runner(Runner runner) { - this.runner = runner; + public Builder plugins(List plugins) { + this.plugins = plugins; return this; } @CanIgnoreReturnValue public AgentExecutor build() { - if (runner == null) { - throw new IllegalStateException("Runner must be provided."); - } - return new AgentExecutor(runner); + return new AgentExecutor( + app, + agent, + appName, + artifactService, + sessionService, + memoryService, + plugins, + agentExecutorConfig); } } @@ -96,13 +179,15 @@ public void execute(RequestContext ctx, EventQueue eventQueue) { EventProcessor p = new EventProcessor(); Content content = PartConverter.messageToContent(message); + Runner runner = runnerBuilder.build(); taskDisposables.add( - prepareSession(ctx, runner.sessionService()) + prepareSession(ctx, runner.appName(), runner.sessionService()) .flatMapPublisher( session -> { updater.startWork(); - return runner.runAsync(getUserId(ctx), session.id(), content, DEFAULT_RUN_CONFIG); + return runner.runAsync( + getUserId(ctx), session.id(), content, agentExecutorConfig.runConfig()); }) .subscribe( event -> { @@ -130,13 +215,14 @@ private String getUserId(RequestContext ctx) { return USER_ID_PREFIX + ctx.getContextId(); } - private Maybe prepareSession(RequestContext ctx, BaseSessionService service) { + private Maybe prepareSession( + RequestContext ctx, String appName, BaseSessionService service) { return service - .getSession(runner.appName(), getUserId(ctx), ctx.getContextId(), Optional.empty()) + .getSession(appName, getUserId(ctx), ctx.getContextId(), Optional.empty()) .switchIfEmpty( Maybe.defer( () -> { - return service.createSession(runner.appName(), getUserId(ctx)).toMaybe(); + return service.createSession(appName, getUserId(ctx)).toMaybe(); })); } diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java new file mode 100644 index 000000000..9b1ed808b --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/executor/AgentExecutorConfig.java @@ -0,0 +1,53 @@ +package com.google.adk.a2a.executor; + +import com.google.adk.a2a.executor.Callbacks.AfterEventCallback; +import com.google.adk.a2a.executor.Callbacks.AfterExecuteCallback; +import com.google.adk.a2a.executor.Callbacks.BeforeExecuteCallback; +import com.google.adk.agents.RunConfig; +import com.google.auto.value.AutoValue; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import org.jspecify.annotations.Nullable; + +/** Configuration for the {@link AgentExecutor}. */ +@AutoValue +public abstract class AgentExecutorConfig { + + private static final RunConfig DEFAULT_RUN_CONFIG = + RunConfig.builder().setStreamingMode(RunConfig.StreamingMode.NONE).setMaxLlmCalls(20).build(); + + public abstract RunConfig runConfig(); + + public abstract @Nullable BeforeExecuteCallback beforeExecuteCallback(); + + public abstract @Nullable AfterExecuteCallback afterExecuteCallback(); + + public abstract @Nullable AfterEventCallback afterEventCallback(); + + public abstract Builder toBuilder(); + + public static Builder builder() { + return new AutoValue_AgentExecutorConfig.Builder().runConfig(DEFAULT_RUN_CONFIG); + } + + /** Builder for {@link AgentExecutorConfig}. */ + @AutoValue.Builder + public abstract static class Builder { + @CanIgnoreReturnValue + public abstract Builder runConfig(RunConfig runConfig); + + @CanIgnoreReturnValue + public abstract Builder beforeExecuteCallback(BeforeExecuteCallback beforeExecuteCallback); + + @CanIgnoreReturnValue + public abstract Builder afterExecuteCallback(AfterExecuteCallback afterExecuteCallback); + + @CanIgnoreReturnValue + public abstract Builder afterEventCallback(AfterEventCallback afterEventCallback); + + abstract AgentExecutorConfig autoBuild(); + + public AgentExecutorConfig build() { + return autoBuild(); + } + } +} diff --git a/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java b/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java new file mode 100644 index 000000000..666f1d8a0 --- /dev/null +++ b/a2a/src/main/java/com/google/adk/a2a/executor/Callbacks.java @@ -0,0 +1,68 @@ +package com.google.adk.a2a.executor; + +import com.google.adk.events.Event; +import io.a2a.server.agentexecution.RequestContext; +import io.a2a.spec.TaskArtifactUpdateEvent; +import io.a2a.spec.TaskStatusUpdateEvent; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; + +/** Functional interfaces for agent executor lifecycle callbacks. */ +public final class Callbacks { + + private Callbacks() {} + + interface BeforeExecuteCallbackBase {} + + /** Async callback interface for actions to be performed before an execution is started. */ + @FunctionalInterface + public interface BeforeExecuteCallback extends BeforeExecuteCallbackBase { + /** + * Callback which will be called before an execution is started. It can be used to instrument a + * context or prevent the execution by returning an error. + * + * @param ctx the request context + * @return a {@link Single} that completes with a boolean indicating whether the execution + * should be prevented + */ + Single call(RequestContext ctx); + } + + interface AfterExecuteCallbackBase {} + + /** + * Async callback interface for actions to be performed after an execution is completed or failed. + */ + @FunctionalInterface + public interface AfterExecuteCallback extends AfterExecuteCallbackBase { + /** + * Callback which will be called after an execution resolved into a completed or failed task. + * This gives an opportunity to enrich the event with additional metadata or log it. + * + * @param ctx the request context + * @param finalUpdateEvent the final update event + * @return a {@link Maybe} that completes when the callback is done + */ + Maybe call(RequestContext ctx, TaskStatusUpdateEvent finalUpdateEvent); + } + + interface AfterEventCallbackBase {} + + /** Async callback interface for actions to be performed after an event is processed. */ + @FunctionalInterface + public interface AfterEventCallback extends AfterEventCallbackBase { + /** + * Callback which will be called after an ADK event is successfully converted to an A2A event. + * This gives an opportunity to enrich the event with additional metadata or abort the execution + * by returning an error. The callback is not invoked for errors originating from ADK or event + * processing. + * + * @param ctx the request context + * @param processedEvent the processed task artifact update event + * @param event the ADK event + * @return a {@link Maybe} that completes when the callback is done + */ + Maybe call( + RequestContext ctx, TaskArtifactUpdateEvent processedEvent, Event event); + } +} diff --git a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java index 1a4873a85..d196d2f6d 100644 --- a/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java +++ b/a2a/src/test/java/com/google/adk/a2a/converters/ResponseConverterTest.java @@ -62,6 +62,10 @@ private Task.Builder testTask() { return new Task.Builder().id("task-1").contextId("context-1"); } + private static TaskStatusUpdateEvent.Builder testTaskStatusUpdateEvent() { + return new TaskStatusUpdateEvent.Builder().taskId("task-1").contextId("context-1"); + } + @Test public void eventsToMessage_withNullEvents_returnsEmptyAgentMessage() { Message message = ResponseConverter.eventsToMessage(null, "context-1", "task-1"); @@ -330,6 +334,72 @@ public void clientEventToEvent_withTaskArtifactUpdateEvent_withLastChunkFalse_re assertThat(optionalEvent).isEmpty(); } + @Test + public void clientEventToEvent_withFinalTaskStatusUpdateEvent_withMessage_returnsEvent() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Final status message"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.COMPLETED, statusMessage, null); + TaskStatusUpdateEvent updateEvent = + testTaskStatusUpdateEvent().isFinal(true).status(status).build(); + + TaskUpdateEvent event = new TaskUpdateEvent(testTask().status(status).build(), updateEvent); + + Optional optionalEvent = ResponseConverter.clientEventToEvent(event, invocationContext); + assertThat(optionalEvent).isPresent(); + Event resultEvent = optionalEvent.get(); + assertThat(resultEvent.content().get().parts().get().get(0).text()) + .hasValue("Final status message"); + assertThat(resultEvent.content().get().parts().get().get(0).thought()).hasValue(false); + assertThat(resultEvent.partial().orElse(false)).isFalse(); + assertThat(resultEvent.turnComplete()).hasValue(true); + } + + @Test + public void clientEventToEvent_withFinalTaskStatusUpdateEvent_withoutMessage_returnsEvent() { + TaskStatus status = new TaskStatus(TaskState.COMPLETED, null, null); + TaskStatusUpdateEvent updateEvent = + new TaskStatusUpdateEvent("task-id-1", status, "context-1", true, null); + TaskUpdateEvent event = new TaskUpdateEvent(testTask().status(status).build(), updateEvent); + + Optional optionalEvent = ResponseConverter.clientEventToEvent(event, invocationContext); + assertThat(optionalEvent).isPresent(); + Event resultEvent = optionalEvent.get(); + assertThat(resultEvent.turnComplete()).hasValue(true); + } + + @Test + public void clientEventToEvent_withNonFinalTaskStatusUpdateEvent_withoutMessage_returnsEmpty() { + TaskStatus status = new TaskStatus(TaskState.WORKING, null, null); + TaskStatusUpdateEvent updateEvent = + new TaskStatusUpdateEvent("task-id-1", status, "context-1", false, null); + TaskUpdateEvent event = new TaskUpdateEvent(testTask().status(status).build(), updateEvent); + + Optional optionalEvent = ResponseConverter.clientEventToEvent(event, invocationContext); + assertThat(optionalEvent).isEmpty(); + } + + @Test + public void clientEventToEvent_withFailedTaskStatusUpdateEvent_returnsErrorEvent() { + Message statusMessage = + new Message.Builder() + .role(Message.Role.AGENT) + .parts(ImmutableList.of(new TextPart("Task failed"))) + .build(); + TaskStatus status = new TaskStatus(TaskState.FAILED, statusMessage, null); + TaskStatusUpdateEvent updateEvent = + new TaskStatusUpdateEvent("task-id-1", status, "context-1", true, null); + TaskUpdateEvent event = new TaskUpdateEvent(testTask().status(status).build(), updateEvent); + + Optional optionalEvent = ResponseConverter.clientEventToEvent(event, invocationContext); + assertThat(optionalEvent).isPresent(); + Event resultEvent = optionalEvent.get(); + assertThat(resultEvent.errorMessage()).hasValue("Task failed"); + assertThat(resultEvent.turnComplete()).hasValue(true); + } + private static final class TestAgent extends BaseAgent { TestAgent() { super("test_agent", "test", ImmutableList.of(), null, null); diff --git a/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java new file mode 100644 index 000000000..350bd6f16 --- /dev/null +++ b/a2a/src/test/java/com/google/adk/a2a/executor/AgentExecutorTest.java @@ -0,0 +1,98 @@ +package com.google.adk.a2a.executor; + +import static org.junit.Assert.assertThrows; + +import com.google.adk.agents.BaseAgent; +import com.google.adk.agents.InvocationContext; +import com.google.adk.apps.App; +import com.google.adk.artifacts.InMemoryArtifactService; +import com.google.adk.events.Event; +import com.google.adk.sessions.InMemorySessionService; +import com.google.common.collect.ImmutableList; +import io.reactivex.rxjava3.core.Flowable; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class AgentExecutorTest { + + private TestAgent testAgent; + + @Before + public void setUp() { + testAgent = new TestAgent(); + } + + @Test + public void createAgentExecutor_noAgent_succeeds() { + var unused = + new AgentExecutor.Builder() + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .agentExecutorConfig(AgentExecutorConfig.builder().build()) + .build(); + } + + @Test + public void createAgentExecutor_withAgentAndApp_throwsException() { + assertThrows( + IllegalStateException.class, + () -> { + new AgentExecutor.Builder() + .agent(testAgent) + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .agentExecutorConfig(AgentExecutorConfig.builder().build()) + .artifactService(new InMemoryArtifactService()) + .build(); + }); + } + + @Test + public void createAgentExecutor_withEmptyAgentAndApp_throwsException() { + assertThrows( + IllegalStateException.class, + () -> { + new AgentExecutor.Builder() + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .agentExecutorConfig(AgentExecutorConfig.builder().build()) + .build(); + }); + } + + @Test + public void createAgentExecutor_noAgentExecutorConfig_throwsException() { + assertThrows( + NullPointerException.class, + () -> { + new AgentExecutor.Builder() + .app(App.builder().name("test_app").rootAgent(testAgent).build()) + .sessionService(new InMemorySessionService()) + .artifactService(new InMemoryArtifactService()) + .build(); + }); + } + + private static final class TestAgent extends BaseAgent { + private final Flowable eventsToEmit = Flowable.empty(); + + TestAgent() { + // BaseAgent constructor: name, description, examples, tools, model + super("test_agent", "test", ImmutableList.of(), null, null); + } + + @Override + protected Flowable runAsyncImpl(InvocationContext invocationContext) { + return eventsToEmit; + } + + @Override + protected Flowable runLiveImpl(InvocationContext invocationContext) { + return eventsToEmit; + } + } +} diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index 9bcc29724..461c15db9 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 ../../pom.xml diff --git a/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java b/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java index d3295a6ef..db236fc88 100644 --- a/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java +++ b/contrib/firestore-session-service/src/main/java/com/google/adk/sessions/FirestoreSessionService.java @@ -50,6 +50,7 @@ import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicBoolean; import java.util.regex.Matcher; +import javax.annotation.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -88,7 +89,20 @@ private CollectionReference getSessionsCollection(String userId) { /** Creates a new session in Firestore. */ @Override public Single createSession( - String appName, String userId, ConcurrentMap state, String sessionId) { + String appName, + String userId, + @Nullable ConcurrentMap state, + @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + /** Creates a new session in Firestore. */ + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { return Single.fromCallable( () -> { Objects.requireNonNull(appName, "appName cannot be null"); diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index 538d3009e..fd029e19d 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 ../../pom.xml diff --git a/contrib/samples/a2a_basic/pom.xml b/contrib/samples/a2a_basic/pom.xml index 2711a8bf5..698b56372 100644 --- a/contrib/samples/a2a_basic/pom.xml +++ b/contrib/samples/a2a_basic/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.6.0 + 0.7.0 .. diff --git a/contrib/samples/configagent/pom.xml b/contrib/samples/configagent/pom.xml index ff49e428c..a267c9df1 100644 --- a/contrib/samples/configagent/pom.xml +++ b/contrib/samples/configagent/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 0.6.0 + 0.7.0 .. diff --git a/contrib/samples/helloworld/pom.xml b/contrib/samples/helloworld/pom.xml index 188596f2b..a14d91fbe 100644 --- a/contrib/samples/helloworld/pom.xml +++ b/contrib/samples/helloworld/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-samples - 0.6.0 + 0.7.0 .. @@ -36,7 +36,7 @@ UTF-8 17 - 1.11.0 + 1.11.1 com.example.helloworld.HelloWorldRun ${project.version} diff --git a/contrib/samples/mcpfilesystem/pom.xml b/contrib/samples/mcpfilesystem/pom.xml index 0d9046917..5cda3dc6f 100644 --- a/contrib/samples/mcpfilesystem/pom.xml +++ b/contrib/samples/mcpfilesystem/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 ../../.. @@ -36,7 +36,7 @@ UTF-8 17 - 1.11.0 + 1.11.1 com.example.mcpfilesystem.McpFilesystemRun ${project.parent.version} diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index f009dc575..0f1aed85c 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 ../.. diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index dd465e6d1..5481a0458 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 24b3604ef..822722e82 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 google-adk @@ -201,6 +201,59 @@ maven-compiler-plugin + + maven-surefire-plugin + + + basic + + test + + + + apigee-llm + + test + + + ApigeeLlmTest + + + api-key + false + + + + + apigee-llm-vertex-ai + + test + + + ApigeeLlmTest#generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi + + api-key + + true + + + + + apigee-llm-proxy-url + + test + + + ApigeeLlmTest#build_withoutProxyUrl_readsFromEnvironment + + api-key + + proxy-url + + + + + diff --git a/core/src/main/java/com/google/adk/Version.java b/core/src/main/java/com/google/adk/Version.java index 26577f792..10219a31b 100644 --- a/core/src/main/java/com/google/adk/Version.java +++ b/core/src/main/java/com/google/adk/Version.java @@ -22,7 +22,7 @@ */ public final class Version { // Don't touch this, release-please should keep it up to date. - public static final String JAVA_ADK_VERSION = "0.6.0"; // x-release-please-released-version + public static final String JAVA_ADK_VERSION = "0.7.0"; // x-release-please-released-version private Version() {} } diff --git a/core/src/main/java/com/google/adk/agents/CallbackContext.java b/core/src/main/java/com/google/adk/agents/CallbackContext.java index 49298451b..a29783769 100644 --- a/core/src/main/java/com/google/adk/agents/CallbackContext.java +++ b/core/src/main/java/com/google/adk/agents/CallbackContext.java @@ -92,14 +92,20 @@ public Single> listArtifacts() { .map(ListArtifactsResponse::filenames); } + /** Loads the latest version of an artifact from the service. */ + public Maybe loadArtifact(String filename) { + return loadArtifact(filename, Optional.empty()); + } + + /** Loads a specific version of an artifact from the service. */ + public Maybe loadArtifact(String filename, int version) { + return loadArtifact(filename, Optional.of(version)); + } + /** - * Loads an artifact from the artifact service associated with the current session. - * - * @param filename Artifact file name. - * @param version Artifact version (optional). - * @return loaded part, or empty if not found. - * @throws IllegalStateException if the artifact service is not initialized. + * @deprecated Use {@link #loadArtifact(String)} or {@link #loadArtifact(String, int)} instead. */ + @Deprecated public Maybe loadArtifact(String filename, Optional version) { if (invocationContext.artifactService() == null) { throw new IllegalStateException("Artifact service is not initialized."); diff --git a/core/src/main/java/com/google/adk/agents/LlmAgent.java b/core/src/main/java/com/google/adk/agents/LlmAgent.java index 1893fb162..ee4e6ab4c 100644 --- a/core/src/main/java/com/google/adk/agents/LlmAgent.java +++ b/core/src/main/java/com/google/adk/agents/LlmAgent.java @@ -755,21 +755,36 @@ public Single> canonicalGlobalInstruction(ReadonlyCon throw new IllegalStateException("Unknown Instruction subtype: " + globalInstruction.getClass()); } + /** + * @deprecated Use {@link #canonicalTools(ReadonlyContext)} instead. + */ + @Deprecated + public Flowable canonicalTools(Optional context) { + return canonicalTools(context.orElse(null)); + } + /** * Constructs the list of tools for this agent based on the {@link #tools} field. * - *

This method is only for use by Agent Development Kit. + * @return The resolved list of tools as a {@link Single} wrapped list of {@link BaseTool}. + */ + public Flowable canonicalTools() { + return canonicalTools((ReadonlyContext) null); + } + + /** + * Constructs the list of tools for this agent based on the {@link #tools} field. * * @param context The context to retrieve the session state. * @return The resolved list of tools as a {@link Single} wrapped list of {@link BaseTool}. */ - public Flowable canonicalTools(Optional context) { + public Flowable canonicalTools(@Nullable ReadonlyContext context) { List> toolFlowables = new ArrayList<>(); for (Object toolOrToolset : toolsUnion) { if (toolOrToolset instanceof BaseTool baseTool) { toolFlowables.add(Flowable.just(baseTool)); } else if (toolOrToolset instanceof BaseToolset baseToolset) { - toolFlowables.add(baseToolset.getTools(context.orElse(null))); + toolFlowables.add(baseToolset.getTools(context)); } else { throw new IllegalArgumentException( "Object in tools list is not of a supported type: " @@ -779,16 +794,6 @@ public Flowable canonicalTools(Optional context) { return Flowable.concat(toolFlowables); } - /** Overload of canonicalTools that defaults to an empty context. */ - public Flowable canonicalTools() { - return canonicalTools(Optional.empty()); - } - - /** Convenience overload of canonicalTools that accepts a non-optional ReadonlyContext. */ - public Flowable canonicalTools(ReadonlyContext context) { - return canonicalTools(Optional.ofNullable(context)); - } - public Instruction instruction() { return instruction; } @@ -965,7 +970,10 @@ private Model resolveModelInternal() { Model currentModel = this.model.get(); if (currentModel.model().isPresent()) { - return currentModel; + String modelName = currentModel.model().get().model(); + BaseLlm resolvedLlm = currentModel.model().get(); + + return Model.builder().modelName(modelName).model(resolvedLlm).build(); } if (currentModel.modelName().isPresent()) { diff --git a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java index b6a3cee23..32ef9ff4d 100644 --- a/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/BaseArtifactService.java @@ -55,22 +55,26 @@ Single saveArtifact( default Single saveAndReloadArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { return saveArtifact(appName, userId, sessionId, filename, artifact) - .flatMap( - version -> - loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .toSingle()); + .flatMap(version -> loadArtifact(appName, userId, sessionId, filename, version).toSingle()); + } + + /** Loads the latest version of an artifact from the service. */ + default Maybe loadArtifact( + String appName, String userId, String sessionId, String filename) { + return loadArtifact(appName, userId, sessionId, filename, Optional.empty()); + } + + /** Loads a specific version of an artifact from the service. */ + default Maybe loadArtifact( + String appName, String userId, String sessionId, String filename, int version) { + return loadArtifact(appName, userId, sessionId, filename, Optional.of(version)); } /** - * Gets an artifact. - * - * @param appName the app name - * @param userId the user ID - * @param sessionId the session ID - * @param filename the filename - * @param version Optional version number. If null, loads the latest version. - * @return the artifact or empty if not found + * @deprecated Use {@link #loadArtifact(String, String, String, String)} or {@link + * #loadArtifact(String, String, String, String, int)} instead. */ + @Deprecated Maybe loadArtifact( String appName, String userId, String sessionId, String filename, Optional version); diff --git a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java index 5808f7083..8c8ec2af8 100644 --- a/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java +++ b/core/src/main/java/com/google/adk/artifacts/InMemoryArtifactService.java @@ -129,10 +129,7 @@ public Single> listVersions( public Single saveAndReloadArtifact( String appName, String userId, String sessionId, String filename, Part artifact) { return saveArtifact(appName, userId, sessionId, filename, artifact) - .flatMap( - version -> - loadArtifact(appName, userId, sessionId, filename, Optional.of(version)) - .toSingle()); + .flatMap(version -> loadArtifact(appName, userId, sessionId, filename, version).toSingle()); } private Map> getArtifactsMap(String appName, String userId, String sessionId) { diff --git a/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java b/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java index 1d1202ead..4e75dab75 100644 --- a/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java +++ b/core/src/main/java/com/google/adk/codeexecutors/ContainerCodeExecutor.java @@ -17,6 +17,8 @@ package com.google.adk.codeexecutors; +import static java.util.Objects.requireNonNullElse; + import com.github.dockerjava.api.DockerClient; import com.github.dockerjava.api.command.ExecCreateCmdResponse; import com.github.dockerjava.api.model.Container; @@ -32,7 +34,6 @@ import java.io.UncheckedIOException; import java.nio.charset.StandardCharsets; import java.nio.file.Paths; -import java.util.Optional; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -41,37 +42,68 @@ public class ContainerCodeExecutor extends BaseCodeExecutor { private static final Logger logger = LoggerFactory.getLogger(ContainerCodeExecutor.class); private static final String DEFAULT_IMAGE_TAG = "adk-code-executor:latest"; - private final Optional baseUrl; + private final String baseUrl; private final String image; - private final Optional dockerPath; + private final String dockerPath; private final DockerClient dockerClient; private Container container; /** - * Initializes the ContainerCodeExecutor. + * Creates a ContainerCodeExecutor from an image. + * + * @param baseUrl The base url of the user hosted Docker client. + * @param image The tag of the predefined image or custom image to run on the container. + */ + public static ContainerCodeExecutor fromImage(String baseUrl, String image) { + return new ContainerCodeExecutor(baseUrl, image, null); + } + + /** + * Creates a ContainerCodeExecutor from an image. + * + * @param image The tag of the predefined image or custom image to run on the container. + */ + public static ContainerCodeExecutor fromImage(String image) { + return new ContainerCodeExecutor(null, image, null); + } + + /** + * Creates a ContainerCodeExecutor from a Dockerfile path. + * + * @param baseUrl The base url of the user hosted Docker client. + * @param dockerPath The path to the directory containing the Dockerfile. + */ + public static ContainerCodeExecutor fromDockerPath(String baseUrl, String dockerPath) { + return new ContainerCodeExecutor(baseUrl, null, dockerPath); + } + + /** + * Creates a ContainerCodeExecutor from a Dockerfile path. + * + * @param dockerPath The path to the directory containing the Dockerfile. + */ + public static ContainerCodeExecutor fromDockerPath(String dockerPath) { + return new ContainerCodeExecutor(null, null, dockerPath); + } + + /** + * Initializes the ContainerCodeExecutor. Either dockerPath or image must be set. * - * @param baseUrl Optional. The base url of the user hosted Docker client. - * @param image The tag of the predefined image or custom image to run on the container. Either - * dockerPath or image must be set. - * @param dockerPath The path to the directory containing the Dockerfile. If set, build the image - * from the dockerfile path instead of using the predefined image. Either dockerPath or image - * must be set. + * @deprecated Use one of the static factory methods instead. */ - public ContainerCodeExecutor( - Optional baseUrl, Optional image, Optional dockerPath) { - if (image.isEmpty() && dockerPath.isEmpty()) { + @Deprecated + public ContainerCodeExecutor(String baseUrl, String image, String dockerPath) { + if (image == null && dockerPath == null) { throw new IllegalArgumentException( "Either image or dockerPath must be set for ContainerCodeExecutor."); } this.baseUrl = baseUrl; - this.image = image.orElse(DEFAULT_IMAGE_TAG); - this.dockerPath = dockerPath.map(p -> Paths.get(p).toAbsolutePath().toString()); + this.image = requireNonNullElse(image, DEFAULT_IMAGE_TAG); + this.dockerPath = dockerPath == null ? null : Paths.get(dockerPath).toAbsolutePath().toString(); - if (baseUrl.isPresent()) { + if (baseUrl != null) { var config = - DefaultDockerClientConfig.createDefaultConfigBuilder() - .withDockerHost(baseUrl.get()) - .build(); + DefaultDockerClientConfig.createDefaultConfigBuilder().withDockerHost(baseUrl).build(); this.dockerClient = DockerClientBuilder.getInstance(config).build(); } else { this.dockerClient = DockerClientBuilder.getInstance().build(); @@ -121,12 +153,12 @@ public CodeExecutionResult executeCode( } private void buildDockerImage() { - if (dockerPath.isEmpty()) { + if (dockerPath == null) { throw new IllegalStateException("Docker path is not set."); } - File dockerfile = new File(dockerPath.get()); + File dockerfile = new File(dockerPath); if (!dockerfile.exists()) { - throw new UncheckedIOException(new IOException("Invalid Docker path: " + dockerPath.get())); + throw new UncheckedIOException(new IOException("Invalid Docker path: " + dockerPath)); } logger.info("Building Docker image..."); @@ -158,7 +190,7 @@ private void initContainer() { if (dockerClient == null) { throw new IllegalStateException("Docker client is not initialized."); } - if (dockerPath.isPresent()) { + if (dockerPath != null) { buildDockerImage(); } else { // If a dockerPath is not provided, always pull the image to ensure it's up-to-date. diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 6d8c698dd..bf25acfc7 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -22,6 +22,7 @@ import com.google.adk.sessions.State; import com.google.errorprone.annotations.CanIgnoreReturnValue; import java.util.HashSet; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; @@ -383,7 +384,7 @@ public Builder compaction(EventCompaction value) { @CanIgnoreReturnValue public Builder merge(EventActions other) { other.skipSummarization().ifPresent(this::skipSummarization); - this.stateDelta.putAll(other.stateDelta()); + other.stateDelta().forEach((key, value) -> stateDelta.merge(key, value, Builder::deepMerge)); this.artifactDelta.putAll(other.artifactDelta()); this.deletedArtifactIds.addAll(other.deletedArtifactIds()); other.transferToAgent().ifPresent(this::transferToAgent); @@ -395,6 +396,34 @@ public Builder merge(EventActions other) { return this; } + private static Object deepMerge(Object target, Object source) { + if (!(target instanceof Map) || !(source instanceof Map)) { + // If one of them is not a map, the source value overwrites the target. + return source; + } + + Map targetMap = (Map) target; + Map sourceMap = (Map) source; + + if (!targetMap.isEmpty() && !sourceMap.isEmpty()) { + Object targetKey = targetMap.keySet().iterator().next(); + Object sourceKey = sourceMap.keySet().iterator().next(); + if (targetKey != null + && sourceKey != null + && !targetKey.getClass().equals(sourceKey.getClass())) { + throw new IllegalArgumentException( + String.format( + "Cannot merge maps with different key types: %s vs %s", + targetKey.getClass().getName(), sourceKey.getClass().getName())); + } + } + + // Create a new map to prevent UnsupportedOperationException from immutable maps + Map mergedMap = new ConcurrentHashMap<>(targetMap); + sourceMap.forEach((key, value) -> mergedMap.merge(key, value, Builder::deepMerge)); + return mergedMap; + } + public EventActions build() { return new EventActions(this); } diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java index f45461626..0f2e2d166 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Contents.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Contents.java @@ -559,7 +559,7 @@ private static List rearrangeEventsForAsyncFunctionResponsesInHistory( // Gemini 3 requires function calls to be grouped first and only then function responses: // FC1 FC2 FR1 FR2 - boolean shouldBufferResponseEvents = modelName.startsWith("gemini-3-"); + boolean shouldBufferResponseEvents = modelName.contains("gemini-3-"); for (int i = 0; i < events.size(); i++) { Event event = events.get(i); diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 82813defa..269764046 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -253,7 +253,8 @@ private static Function> getFunctionCallMapper( functionCall.id().map(toolConfirmations::get).orElse(null)) .build(); - Map functionArgs = functionCall.args().orElse(new HashMap<>()); + Map functionArgs = + functionCall.args().map(HashMap::new).orElse(new HashMap<>()); Maybe> maybeFunctionResult = maybeInvokeBeforeToolCall(invocationContext, tool, functionArgs, toolContext) @@ -482,12 +483,8 @@ private static Maybe> maybeInvokeBeforeToolCall( if (invocationContext.agent() instanceof LlmAgent) { LlmAgent agent = (LlmAgent) invocationContext.agent(); - HashMap mutableFunctionArgs = new HashMap<>(functionArgs); - Maybe> pluginResult = - invocationContext - .pluginManager() - .beforeToolCallback(tool, mutableFunctionArgs, toolContext); + invocationContext.pluginManager().beforeToolCallback(tool, functionArgs, toolContext); List callbacks = agent.canonicalBeforeToolCallbacks(); if (callbacks.isEmpty()) { @@ -500,8 +497,7 @@ private static Maybe> maybeInvokeBeforeToolCall( Flowable.fromIterable(callbacks) .concatMapMaybe( callback -> - callback.call( - invocationContext, tool, mutableFunctionArgs, toolContext)) + callback.call(invocationContext, tool, functionArgs, toolContext)) .firstElement()); return pluginResult.switchIfEmpty(callbackResult); diff --git a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java index 540153460..94e8cd7ba 100644 --- a/core/src/main/java/com/google/adk/sessions/BaseSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/BaseSessionService.java @@ -23,8 +23,10 @@ import io.reactivex.rxjava3.core.Maybe; import io.reactivex.rxjava3.core.Single; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import javax.annotation.Nullable; @@ -47,13 +49,35 @@ public interface BaseSessionService { * service should generate a unique ID. * @return The newly created {@link Session} instance. * @throws SessionException if creation fails. + * @deprecated Use {@link #createSession(String, String, Map, String)} instead. */ + @Deprecated Single createSession( String appName, String userId, @Nullable ConcurrentMap state, @Nullable String sessionId); + /** + * Creates a new session with the specified parameters. + * + * @param appName The name of the application associated with the session. + * @param userId The identifier for the user associated with the session. + * @param state An optional map representing the initial state of the session. Can be null or + * empty. + * @param sessionId An optional client-provided identifier for the session. If empty or null, the + * service should generate a unique ID. + * @return The newly created {@link Session} instance. + * @throws SessionException if creation fails. + */ + default Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { + return createSession(appName, userId, ensureConcurrentMap(state), sessionId); + } + /** * Creates a new session with the specified application name and user ID, using a default state * (null) and allowing the service to generate a unique session ID. @@ -165,9 +189,9 @@ default Single appendEvent(Session session, Event event) { EventActions actions = event.actions(); if (actions != null) { - ConcurrentMap stateDelta = actions.stateDelta(); + Map stateDelta = actions.stateDelta(); if (stateDelta != null && !stateDelta.isEmpty()) { - ConcurrentMap sessionState = session.state(); + Map sessionState = session.state(); if (sessionState != null) { stateDelta.forEach( (key, value) -> { @@ -190,4 +214,21 @@ default Single appendEvent(Session session, Event event) { return Single.just(event); } + + /** + * Ensures the given {@link Map} is a {@link ConcurrentMap}. If the input is null, returns null. + * If the input is already a {@link ConcurrentMap}, it is cast and returned. Otherwise, a new + * {@link ConcurrentHashMap} is created from the input map. + */ + @Nullable + private static ConcurrentMap ensureConcurrentMap( + @Nullable Map state) { + if (state == null) { + return null; + } + if (state instanceof ConcurrentMap concurrentMap) { + return concurrentMap; + } + return new ConcurrentHashMap<>(state); + } } diff --git a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java index 060fcaf60..b2a584b11 100644 --- a/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java +++ b/core/src/main/java/com/google/adk/sessions/InMemorySessionService.java @@ -71,6 +71,15 @@ public Single createSession( String userId, @Nullable ConcurrentMap state, @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { Objects.requireNonNull(appName, "appName cannot be null"); Objects.requireNonNull(userId, "userId cannot be null"); @@ -83,7 +92,6 @@ public Single createSession( // Ensure state map and events list are mutable for the new session ConcurrentMap initialState = (state == null) ? new ConcurrentHashMap<>() : new ConcurrentHashMap<>(state); - List initialEvents = new ArrayList<>(); // Assuming Session constructor or setters allow setting these mutable collections Session newSession = @@ -91,7 +99,6 @@ public Single createSession( .appName(appName) .userId(userId) .state(initialState) - .events(initialEvents) .lastUpdateTime(Instant.now()) .build(); diff --git a/core/src/main/java/com/google/adk/sessions/Session.java b/core/src/main/java/com/google/adk/sessions/Session.java index 3bf27b55e..877a95220 100644 --- a/core/src/main/java/com/google/adk/sessions/Session.java +++ b/core/src/main/java/com/google/adk/sessions/Session.java @@ -25,6 +25,7 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -54,7 +55,7 @@ public static final class Builder { private String appName; private String userId; private State state = new State(new ConcurrentHashMap<>()); - private List events = new ArrayList<>(); + private List events = Collections.synchronizedList(new ArrayList<>()); private Instant lastUpdateTime = Instant.EPOCH; public Builder(String id) { @@ -101,7 +102,7 @@ public Builder userId(String userId) { @CanIgnoreReturnValue @JsonProperty("events") public Builder events(List events) { - this.events = events; + this.events = Collections.synchronizedList(events); return this; } diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java index d35bbccae..718738b92 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiClient.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiClient.java @@ -14,10 +14,10 @@ import io.reactivex.rxjava3.core.Single; import java.io.IOException; import java.io.UncheckedIOException; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeoutException; import javax.annotation.Nullable; import okhttp3.ResponseBody; @@ -51,8 +51,8 @@ final class VertexAiClient { } Maybe createSession( - String reasoningEngineId, String userId, ConcurrentMap state) { - ConcurrentHashMap sessionJsonMap = new ConcurrentHashMap<>(); + String reasoningEngineId, String userId, Map state) { + Map sessionJsonMap = new HashMap<>(); sessionJsonMap.put("userId", userId); if (state != null) { sessionJsonMap.put("sessionState", state); diff --git a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java index 7878daf22..2fff7a752 100644 --- a/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java +++ b/core/src/main/java/com/google/adk/sessions/VertexAiSessionService.java @@ -76,6 +76,15 @@ public Single createSession( String userId, @Nullable ConcurrentMap state, @Nullable String sessionId) { + return createSession(appName, userId, (Map) state, sessionId); + } + + @Override + public Single createSession( + String appName, + String userId, + @Nullable Map state, + @Nullable String sessionId) { String reasoningEngineId = parseReasoningEngineId(appName); return client diff --git a/core/src/main/java/com/google/adk/tools/BaseToolset.java b/core/src/main/java/com/google/adk/tools/BaseToolset.java index 4d3482c57..c8ed6df4e 100644 --- a/core/src/main/java/com/google/adk/tools/BaseToolset.java +++ b/core/src/main/java/com/google/adk/tools/BaseToolset.java @@ -20,6 +20,7 @@ import io.reactivex.rxjava3.core.Flowable; import java.util.List; import java.util.Optional; +import javax.annotation.Nullable; /** Base interface for toolsets. */ public interface BaseToolset extends AutoCloseable { @@ -43,28 +44,35 @@ public interface BaseToolset extends AutoCloseable { void close() throws Exception; /** - * Helper method to be used by implementers that returns true if the given tool is in the provided - * list of tools of if testing against the given ToolPredicate returns true (otherwise false). + * Checks if a tool should be selected based on a filter. * * @param tool The tool to check. - * @param toolFilter An Optional containing either a ToolPredicate or a List of tool names. - * @param readonlyContext The current context. - * @return true if the tool is selected. + * @param toolFilter A ToolPredicate, a List of tool names, or null. + * @param readonlyContext The context for checking the tool, or null. */ default boolean isToolSelected( - BaseTool tool, Optional toolFilter, Optional readonlyContext) { - if (toolFilter.isEmpty()) { + BaseTool tool, @Nullable Object toolFilter, @Nullable ReadonlyContext readonlyContext) { + if (toolFilter == null) { return true; } - Object filter = toolFilter.get(); - if (filter instanceof ToolPredicate toolPredicate) { + + if (toolFilter instanceof ToolPredicate toolPredicate) { return toolPredicate.test(tool, readonlyContext); } - if (filter instanceof List) { - @SuppressWarnings("unchecked") - List toolNames = (List) filter; + + if (toolFilter instanceof List toolNames) { return toolNames.contains(tool.name()); } + return false; } + + /** + * @deprecated Use {@link #isToolSelected(BaseTool, Object, ReadonlyContext)} instead. + */ + @Deprecated + default boolean isToolSelected( + BaseTool tool, Optional toolFilter, Optional readonlyContext) { + return isToolSelected(tool, toolFilter.orElse(null), readonlyContext.orElse(null)); + } } diff --git a/core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java b/core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java index 060b3ffb8..ad97b96a6 100644 --- a/core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java +++ b/core/src/main/java/com/google/adk/tools/BuiltInCodeExecutionTool.java @@ -16,13 +16,19 @@ package com.google.adk.tools; +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.BaseLlm; import com.google.adk.models.LlmRequest; +import com.google.adk.utils.ModelNameUtils; import com.google.common.collect.ImmutableList; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.Tool; import com.google.genai.types.ToolCodeExecution; import io.reactivex.rxjava3.core.Completable; import java.util.List; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * A built-in code execution tool that is automatically invoked by Gemini 2 models. @@ -32,6 +38,7 @@ */ public final class BuiltInCodeExecutionTool extends BaseTool { public static final BuiltInCodeExecutionTool INSTANCE = new BuiltInCodeExecutionTool(); + private static final Logger LOG = LoggerFactory.getLogger(BuiltInCodeExecutionTool.class); public BuiltInCodeExecutionTool() { super("code_execution", "code_execution"); @@ -41,10 +48,28 @@ public BuiltInCodeExecutionTool() { public Completable processLlmRequest( LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { - String model = llmRequestBuilder.build().model().get(); - if (model.isEmpty() || !model.startsWith("gemini-2")) { - return Completable.error( - new IllegalArgumentException("Code execution tool is not supported for model " + model)); + Optional model = + Optional.ofNullable(toolContext) + .flatMap(tCtx -> Optional.ofNullable(tCtx.invocationContext())) + .flatMap( + iCtx -> { + if (iCtx.agent() instanceof LlmAgent llmAgent) { + return Optional.of(llmAgent); + } else { + return Optional.empty(); + } + }) + .flatMap(llmAgent -> llmAgent.resolvedModel().model()); + + String modelName = llmRequestBuilder.build().model().get(); + if (!ModelNameUtils.isGeminiModel(modelName) + || model.filter(ModelNameUtils::isInstanceOfGemini).isEmpty()) { + // model name is not a gemini model, or the model isn't an instance of Gemini class (eg. + // LangChain case). + LOG.warn( + "Code execution tool is not supported for model: {} ({}).", + modelName, + model.map(Object::getClass).map(Class::toString).orElse("")); } GenerateContentConfig.Builder configBuilder = llmRequestBuilder diff --git a/core/src/main/java/com/google/adk/tools/GoogleMapsTool.java b/core/src/main/java/com/google/adk/tools/GoogleMapsTool.java index 8689849c2..12ec27169 100644 --- a/core/src/main/java/com/google/adk/tools/GoogleMapsTool.java +++ b/core/src/main/java/com/google/adk/tools/GoogleMapsTool.java @@ -79,15 +79,8 @@ public Completable processLlmRequest( List existingTools = configBuilder.build().tools().orElse(ImmutableList.of()); ImmutableList.Builder updatedToolsBuilder = ImmutableList.builder(); updatedToolsBuilder.addAll(existingTools); - - String model = llmRequestBuilder.build().model().orElse(null); - if (model != null && !model.startsWith("gemini-1")) { - updatedToolsBuilder.add(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); - configBuilder.tools(updatedToolsBuilder.build()); - } else { - return Completable.error( - new IllegalArgumentException("Google Maps tool is not supported for model " + model)); - } + updatedToolsBuilder.add(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); + configBuilder.tools(updatedToolsBuilder.build()); llmRequestBuilder.config(configBuilder.build()); return Completable.complete(); diff --git a/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java b/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java index 6f89754cf..b4f298c21 100644 --- a/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java +++ b/core/src/main/java/com/google/adk/tools/GoogleSearchTool.java @@ -20,12 +20,9 @@ import com.google.common.collect.ImmutableList; import com.google.genai.types.GenerateContentConfig; import com.google.genai.types.GoogleSearch; -import com.google.genai.types.GoogleSearchRetrieval; import com.google.genai.types.Tool; import io.reactivex.rxjava3.core.Completable; import java.util.List; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; /** * A built-in tool that is automatically invoked by Gemini 2 and 3 models to retrieve search results @@ -43,7 +40,6 @@ * } */ public final class GoogleSearchTool extends BaseTool { - private static final Logger logger = LoggerFactory.getLogger(GoogleSearchTool.class); public static final GoogleSearchTool INSTANCE = new GoogleSearchTool(); public GoogleSearchTool() { @@ -66,17 +62,7 @@ public Completable processLlmRequest( updatedToolsBuilder.addAll(existingTools); String model = llmRequestBuilder.build().model().get(); - if (model != null && model.startsWith("gemini-1")) { - if (!updatedToolsBuilder.build().isEmpty()) { - logger.error("Tools already present: {}", configBuilder.build().tools().get()); - return Completable.error( - new IllegalArgumentException( - "Google search tool cannot be used with other tools in Gemini 1.x.")); - } - updatedToolsBuilder.add( - Tool.builder().googleSearchRetrieval(GoogleSearchRetrieval.builder().build()).build()); - configBuilder.tools(updatedToolsBuilder.build()); - } else if (model != null && (model.startsWith("gemini-2") || model.startsWith("gemini-3"))) { + if (model != null && (model.startsWith("gemini-2") || model.startsWith("gemini-3"))) { updatedToolsBuilder.add(Tool.builder().googleSearch(GoogleSearch.builder().build()).build()); configBuilder.tools(updatedToolsBuilder.build()); diff --git a/core/src/main/java/com/google/adk/tools/ToolPredicate.java b/core/src/main/java/com/google/adk/tools/ToolPredicate.java index 86d739e70..6adf53c18 100644 --- a/core/src/main/java/com/google/adk/tools/ToolPredicate.java +++ b/core/src/main/java/com/google/adk/tools/ToolPredicate.java @@ -18,6 +18,7 @@ import com.google.adk.agents.ReadonlyContext; import java.util.Optional; +import javax.annotation.Nullable; /** * Functional interface to decide whether a tool should be exposed to the LLM based on the current @@ -31,6 +32,19 @@ public interface ToolPredicate { * @param tool The tool to check. * @param readonlyContext The current context. * @return true if the tool should be selected, false otherwise. + * @deprecated Use {@link #test(BaseTool, ReadonlyContext)} instead. */ + @Deprecated boolean test(BaseTool tool, Optional readonlyContext); + + /** + * Decides if the given tool is selected. + * + * @param tool The tool to check. + * @param readonlyContext The current context. + * @return true if the tool should be selected, false otherwise. + */ + default boolean test(BaseTool tool, @Nullable ReadonlyContext readonlyContext) { + return test(tool, Optional.ofNullable(readonlyContext)); + } } diff --git a/core/src/main/java/com/google/adk/tools/UrlContextTool.java b/core/src/main/java/com/google/adk/tools/UrlContextTool.java index 5fe072d76..fe7f9c77e 100644 --- a/core/src/main/java/com/google/adk/tools/UrlContextTool.java +++ b/core/src/main/java/com/google/adk/tools/UrlContextTool.java @@ -25,8 +25,8 @@ import java.util.List; /** - * A built-in tool that is automatically invoked by Gemini 2 models to retrieve information from the - * given URLs. + * A built-in tool that is automatically invoked by Gemini 2 and 3 models to retrieve information + * from the given URLs. * *

This tool operates internally within the model and does not require or perform local code * execution. @@ -62,7 +62,7 @@ public Completable processLlmRequest( updatedToolsBuilder.addAll(existingTools); String model = llmRequestBuilder.build().model().get(); - if (model != null && model.startsWith("gemini-2")) { + if (model != null && (model.startsWith("gemini-2") || model.startsWith("gemini-3"))) { updatedToolsBuilder.add(Tool.builder().urlContext(UrlContext.builder().build()).build()); configBuilder.tools(updatedToolsBuilder.build()); } else { diff --git a/core/src/main/java/com/google/adk/tools/computeruse/BaseComputer.java b/core/src/main/java/com/google/adk/tools/computeruse/BaseComputer.java new file mode 100644 index 000000000..3ddb91963 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/BaseComputer.java @@ -0,0 +1,99 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import com.google.adk.tools.Annotations.Schema; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.time.Duration; +import java.util.List; + +/** + * Defines an interface for computer environments. + * + *

This interface defines the standard methods for controlling computer environments, including + * web browsers and other interactive systems. + */ +public interface BaseComputer { + + /** Returns the screen size of the environment. */ + Single screenSize(); + + /** Opens the web browser. */ + Single openWebBrowser(); + + /** Clicks at a specific x, y coordinate on the webpage. */ + Single clickAt(@Schema(name = "x") int x, @Schema(name = "y") int y); + + /** Hovers at a specific x, y coordinate on the webpage. */ + Single hoverAt(@Schema(name = "x") int x, @Schema(name = "y") int y); + + /** Types text at a specific x, y coordinate. */ + Single typeTextAt( + @Schema(name = "x") int x, + @Schema(name = "y") int y, + @Schema(name = "text") String text, + @Schema(name = "press_enter", optional = true) Boolean pressEnter, + @Schema(name = "clear_before_typing", optional = true) Boolean clearBeforeTyping); + + /** Scrolls the entire webpage in a direction. */ + Single scrollDocument(@Schema(name = "direction") String direction); + + /** Scrolls at a specific x, y coordinate by magnitude. */ + Single scrollAt( + @Schema(name = "x") int x, + @Schema(name = "y") int y, + @Schema(name = "direction") String direction, + @Schema(name = "magnitude") int magnitude); + + /** Waits for specified duration. */ + Single wait(@Schema(name = "duration") Duration duration); + + /** Navigates back. */ + Single goBack(); + + /** Navigates forward. */ + Single goForward(); + + /** Jumps to search. */ + Single search(); + + /** Navigates to URL. */ + Single navigate(@Schema(name = "url") String url); + + /** Presses key combination. */ + Single keyCombination(@Schema(name = "keys") List keys); + + /** Drag and drop. */ + Single dragAndDrop( + @Schema(name = "x") int x, + @Schema(name = "y") int y, + @Schema(name = "destination_x") int destinationX, + @Schema(name = "destination_y") int destinationY); + + /** Returns current state. */ + Single currentState(); + + /** Initialize the computer. */ + Completable initialize(); + + /** Cleanup resources. */ + Completable close(); + + /** Returns the environment. */ + Single environment(); +} diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerEnvironment.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerEnvironment.java new file mode 100644 index 000000000..2c897c794 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerEnvironment.java @@ -0,0 +1,23 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +/** Enum for computer environments. */ +public enum ComputerEnvironment { + ENVIRONMENT_UNSPECIFIED, + ENVIRONMENT_BROWSER +} diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java new file mode 100644 index 000000000..4f3be46c2 --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerState.java @@ -0,0 +1,108 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import java.util.Arrays; +import java.util.Objects; +import java.util.Optional; + +/** + * Represents the current state of the computer environment. + * + *

Attributes: screenshot: The screenshot in PNG format as bytes. url: The current URL of the + * webpage being displayed. + */ +public final class ComputerState { + private final byte[] screenshot; + private final Optional url; + + @JsonCreator + private ComputerState( + @JsonProperty("screenshot") byte[] screenshot, @JsonProperty("url") Optional url) { + this.screenshot = screenshot.clone(); + this.url = url; + } + + @JsonProperty("screenshot") + public byte[] screenshot() { + return screenshot.clone(); + } + + @JsonProperty("url") + public Optional url() { + return url; + } + + public static Builder builder() { + return new Builder(); + } + + /** Builder for {@link ComputerState}. */ + public static final class Builder { + private byte[] screenshot; + private Optional url = Optional.empty(); + + @CanIgnoreReturnValue + public Builder screenshot(byte[] screenshot) { + this.screenshot = screenshot.clone(); + return this; + } + + @CanIgnoreReturnValue + public Builder url(Optional url) { + this.url = url; + return this; + } + + @CanIgnoreReturnValue + public Builder url(String url) { + this.url = Optional.ofNullable(url); + return this; + } + + public ComputerState build() { + return new ComputerState(screenshot, url); + } + } + + public static ComputerState create(byte[] screenshot, String url) { + return builder().screenshot(screenshot).url(url).build(); + } + + public static ComputerState create(byte[] screenshot) { + return builder().screenshot(screenshot).build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof ComputerState that)) { + return false; + } + return Objects.deepEquals(screenshot, that.screenshot) && Objects.equals(url, that.url); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(screenshot), url); + } +} diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseTool.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseTool.java new file mode 100644 index 000000000..cedf7f35c --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseTool.java @@ -0,0 +1,125 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static java.lang.String.format; + +import com.google.adk.tools.FunctionTool; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Single; +import java.lang.reflect.Method; +import java.util.Base64; +import java.util.HashMap; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A tool that wraps computer control functions for use with LLMs. + * + *

This tool automatically normalizes coordinates from a virtual coordinate space (by default + * 1000x1000) to the actual screen size. + */ +public class ComputerUseTool extends FunctionTool { + + private static final Logger logger = LoggerFactory.getLogger(ComputerUseTool.class); + + private final int[] screenSize; + private final int[] coordinateSpace; + + public ComputerUseTool(Object instance, Method func, int[] screenSize, int[] virtualScreenSize) { + super(instance, func, /* isLongRunning= */ false); + this.screenSize = screenSize; + this.coordinateSpace = virtualScreenSize; + } + + private int normalize(Object object, String coordinateName, int index) { + if (!(object instanceof Number number)) { + throw new IllegalArgumentException(format("%s coordinate must be numeric", coordinateName)); + } + double coordinate = number.doubleValue(); + int normalized = (int) (coordinate / coordinateSpace[index] * screenSize[index]); + // Clamp to screen bounds + int clamped = Math.max(0, Math.min(normalized, screenSize[index] - 1)); + logger.atDebug().log( + format( + "%s: %.2f, normalized %s: %d, screen %s size: %d, coordinate-space %s size: %d, " + + "clamped %s: %d", + coordinateName, + coordinate, + coordinateName, + normalized, + coordinateName, + screenSize[index], + coordinateName, + coordinateSpace[index], + coordinateName, + clamped)); + return clamped; + } + + private int normalizeX(Object xObj) { + return normalize(xObj, "x", 0); + } + + private int normalizeY(Object yObj) { + return normalize(yObj, "y", 1); + } + + @Override + public Single> runAsync(Map args, ToolContext toolContext) { + Map normalizedArgs = new HashMap<>(args); + + if (args.containsKey("x")) { + normalizedArgs.put("x", normalizeX(args.get("x"))); + } + if (args.containsKey("y")) { + normalizedArgs.put("y", normalizeY(args.get("y"))); + } + if (args.containsKey("destination_x")) { + normalizedArgs.put("destination_x", normalizeX(args.get("destination_x"))); + } + if (args.containsKey("destination_y")) { + normalizedArgs.put("destination_y", normalizeY(args.get("destination_y"))); + } + + return super.runAsync(normalizedArgs, toolContext) + .map( + result -> { + // If the underlying tool method returned a structure containing a "screenshot" field + // (e.g., a ComputerState object), FunctionTool.runAsync will have converted it to a + // Map. This post-processing step transforms the byte array "screenshot" field into + // an "image" map with a mimetype and Base64 encoded data, as expected by some + // consuming systems. + if (result.containsKey("screenshot") && result.get("screenshot") instanceof byte[]) { + byte[] screenshot = (byte[]) result.get("screenshot"); + ImmutableMap imageMap = + ImmutableMap.of( + "mimetype", + "image/png", + "data", + Base64.getEncoder().encodeToString(screenshot)); + Map finalResult = new HashMap<>(result); + finalResult.remove("screenshot"); + finalResult.put("image", imageMap); + return finalResult; + } + return result; + }); + } +} diff --git a/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseToolset.java b/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseToolset.java new file mode 100644 index 000000000..6984f02fd --- /dev/null +++ b/core/src/main/java/com/google/adk/tools/computeruse/ComputerUseToolset.java @@ -0,0 +1,181 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.collect.ImmutableList.toImmutableList; + +import com.google.adk.agents.ReadonlyContext; +import com.google.adk.models.LlmRequest; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.BaseToolset; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.genai.types.ComputerUse; +import com.google.genai.types.Environment; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Tool; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Flowable; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.List; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A toolset that provides computer use capabilities. + * + *

It automatically discovers and wraps methods from a {@link BaseComputer} implementation. + */ +public class ComputerUseToolset implements BaseToolset { + + private static final Logger logger = LoggerFactory.getLogger(ComputerUseToolset.class); + + private static final ImmutableSet EXCLUDED_METHODS = + ImmutableSet.of( + "screenSize", + "environment", + "close", + "initialize", + "currentState", + "getClass", + "equals", + "hashCode", + "toString", + "wait", + "notify", + "notifyAll"); + + private final BaseComputer computer; + private final int[] virtualScreenSize; + private List tools; + private boolean initialized = false; + + public ComputerUseToolset(BaseComputer computer) { + this(computer, new int[] {1000, 1000}); + } + + public ComputerUseToolset(BaseComputer computer, int[] virtualScreenSize) { + this.computer = computer; + this.virtualScreenSize = virtualScreenSize; + } + + private synchronized Completable ensureInitialized() { + if (initialized) { + return Completable.complete(); + } + return computer + .initialize() + .doOnComplete( + () -> { + initialized = true; + }); + } + + @Override + public Flowable getTools(ReadonlyContext readonlyContext) { + return ensureInitialized() + .andThen(computer.screenSize()) + .flatMapPublisher( + actualScreenSize -> { + if (tools == null) { + tools = new ArrayList<>(); + for (Method method : BaseComputer.class.getMethods()) { + if (!EXCLUDED_METHODS.contains(method.getName())) { + tools.add( + new ComputerUseTool(computer, method, actualScreenSize, virtualScreenSize)); + } + } + } + return Flowable.fromIterable(tools); + }); + } + + @Override + public void close() throws Exception { + computer.close().blockingAwait(); + } + + /** Adds computer use configuration to the LLM request. */ + public Completable processLlmRequest( + LlmRequest.Builder llmRequestBuilder, ToolContext toolContext) { + return getTools(null) // Fetch tools to ensure they are added to the list + .toList() + .flatMapCompletable( + tools -> { + return Completable.concat( + tools.stream() + .map(t -> t.processLlmRequest(llmRequestBuilder, toolContext)) + .collect(toImmutableList())) + .andThen( + computer + .environment() + .flatMapCompletable( + env -> { + configureComputerUseIfNeeded(llmRequestBuilder, env); + return Completable.complete(); + })); + }); + } + + /** + * Returns the {@link Environment.Known} enum for the given {@link ComputerEnvironment}. If the + * computer environment is not found or not supported, defaults to {@link + * Environment.Known.ENVIRONMENT_BROWSER}. + * + * @param computerEnvironment The {@link ComputerEnvironment} to convert. + * @return The corresponding {@link Environment.Known} enum. + */ + private static Environment.Known getEnvironment(ComputerEnvironment computerEnvironment) { + try { + return Environment.Known.valueOf(computerEnvironment.name()); + } catch (IllegalArgumentException e) { + return Environment.Known.ENVIRONMENT_BROWSER; + } + } + + /** + * Configures the computer use tool in the LLM request if it is not already configured. + * + * @param computerEnvironment The environment to configure the computer use tool for. + * @param llmRequestBuilder The LLM request builder to add the computer use tool to. + */ + private static void configureComputerUseIfNeeded( + LlmRequest.Builder llmRequestBuilder, ComputerEnvironment computerEnvironment) { + // Get the current config from the LLM request + GenerateContentConfig config = + llmRequestBuilder.config().orElse(GenerateContentConfig.builder().build()); + + // Check if computer use is already configured + if (config.tools().orElse(ImmutableList.of()).stream() + .anyMatch(t -> t.computerUse().isPresent())) { + logger.debug("Computer use already configured"); + return; + } + + // Configure the computer + Environment.Known knownEnv = getEnvironment(computerEnvironment); + Tool computerUseTool = + Tool.builder().computerUse(ComputerUse.builder().environment(knownEnv).build()).build(); + // Add the computer use tool to the list of tools in the config + List currentTools = new ArrayList<>(config.tools().orElse(ImmutableList.of())); + currentTools.add(computerUseTool); + llmRequestBuilder.config(config.toBuilder().tools(ImmutableList.copyOf(currentTools)).build()); + logger.debug("Added computer use tool with environment: {}", knownEnv); + } +} diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java index 45a2fe333..73af9cc6a 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpAsyncToolset.java @@ -170,9 +170,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { .map( tools -> tools.stream() - .filter( - tool -> - isToolSelected(tool, toolFilter, Optional.ofNullable(readonlyContext))) + .filter(tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext)) .toList()) .onErrorResumeNext( err -> { diff --git a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java index 3bf8f39d0..207243ceb 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java +++ b/core/src/main/java/com/google/adk/tools/mcp/McpToolset.java @@ -216,9 +216,7 @@ public Flowable getTools(ReadonlyContext readonlyContext) { new McpTool( tool, this.mcpSession, this.mcpSessionManager, this.objectMapper)) .filter( - tool -> - isToolSelected( - tool, toolFilter, Optional.ofNullable(readonlyContext)))); + tool -> isToolSelected(tool, toolFilter.orElse(null), readonlyContext))); }) .retryWhen( errorObservable -> diff --git a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java index 9995f18b2..c46f6e3a8 100644 --- a/core/src/main/java/com/google/adk/utils/ModelNameUtils.java +++ b/core/src/main/java/com/google/adk/utils/ModelNameUtils.java @@ -16,16 +16,24 @@ package com.google.adk.utils; +import com.google.common.base.Strings; +import java.util.Objects; import java.util.regex.Matcher; import java.util.regex.Pattern; public final class ModelNameUtils { + private static final String GEMINI_PREFIX = "gemini-"; private static final Pattern GEMINI_2_PATTERN = Pattern.compile("^gemini-2\\..*"); + private static final String GEMINI_CLASS = "com.google.adk.models.Gemini"; private static final Pattern PATH_PATTERN = Pattern.compile("^projects/[^/]+/locations/[^/]+/publishers/[^/]+/models/(.+)$"); private static final Pattern APIGEE_PATTERN = Pattern.compile("^apigee/(?:[^/]+/)?(?:[^/]+/)?(.+)$"); + public static boolean isGeminiModel(String modelString) { + return extractModelName(Strings.nullToEmpty(modelString)).startsWith(GEMINI_PREFIX); + } + public static boolean isGemini2Model(String modelString) { if (modelString == null) { return false; @@ -34,6 +42,29 @@ public static boolean isGemini2Model(String modelString) { return GEMINI_2_PATTERN.matcher(modelName).matches(); } + /** + * Checks whether an object is an instance of {@link com.google.adk.models.Gemini}, by searching + * through its class hierarchy for a class whose name equals the hardcoded String name of Gemini + * class. + * + *

This method can be used where the "real" instanceof check is not possible because the Gemini + * type is not known at compile time. + * + * @param o The object to check. + * @return true if object's class is {@link com.google.adk.models.Gemini}, false otherwise. + */ + public static boolean isInstanceOfGemini(Object o) { + if (o == null) { + return false; + } + for (Class clazz = o.getClass(); clazz != null; clazz = clazz.getSuperclass()) { + if (Objects.equals(clazz.getName(), GEMINI_CLASS)) { + return true; + } + } + return false; + } + /** * Extract the actual model name from either simple or path-based format. * diff --git a/core/src/test/java/com/google/adk/agents/CallbacksTest.java b/core/src/test/java/com/google/adk/agents/CallbacksTest.java index 11087e6d6..8325d346e 100644 --- a/core/src/test/java/com/google/adk/agents/CallbacksTest.java +++ b/core/src/test/java/com/google/adk/agents/CallbacksTest.java @@ -1172,10 +1172,51 @@ public Maybe> beforeToolCallback( event, ImmutableMap.of("echo_tool", new TestUtils.FailingEchoTool())) .blockingGet(); - assertThat(getFunctionResponse(functionResponseEvent)).isEqualTo(responseFromAgentCb); } + @Test + public void handleFunctionCalls_withBeforeToolCallback_modifiesArgs() { + ImmutableMap originalArgs = ImmutableMap.of("arg1", "val1"); + ImmutableMap modifiedArgs = ImmutableMap.of("arg1", "val1", "arg2", "val2"); + + Callbacks.BeforeToolCallbackSync cb1 = + (invocationContext, tool, input, toolContext) -> { + input.put("arg2", "val2"); + return Optional.empty(); + }; + + TestUtils.EchoTool echoTool = new TestUtils.EchoTool(); + + InvocationContext invocationContext = + createInvocationContext( + createTestAgentBuilder(createTestLlm(LlmResponse.builder().build())) + .beforeToolCallbackSync(cb1) + .build()); + + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.fromText("..."), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("fc_id") + .name("echo_tool") + .args(originalArgs) + .build()) + .build())) + .build(); + + Event functionResponseEvent = + Functions.handleFunctionCalls( + invocationContext, event, ImmutableMap.of("echo_tool", echoTool)) + .blockingGet(); + + assertThat(getFunctionResponse(functionResponseEvent)).containsExactly("result", modifiedArgs); + } + @Test public void agentRunAsync_withToolCallbacks_inspectsArgsAndReturnsResponse() { TestUtils.EchoTool echoTool = new TestUtils.EchoTool(); diff --git a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java index ce8be8dfb..594e47fd8 100644 --- a/core/src/test/java/com/google/adk/agents/LlmAgentTest.java +++ b/core/src/test/java/com/google/adk/agents/LlmAgentTest.java @@ -376,6 +376,17 @@ public void resolveModel_withModelName_resolvesFromRegistry() { assertThat(resolvedModel.model()).hasValue(testLlm); } + @Test + public void resolveModel_withModel_usesProvidedModel() { + TestLlm testLlm = createTestLlm(LlmResponse.builder().build()); + LlmAgent testAgent = createTestAgent(testLlm); + + Model resolvedModel = testAgent.resolvedModel(); + + assertThat(resolvedModel.model()).hasValue(testLlm); + assertThat(resolvedModel.modelName()).hasValue(testLlm.model()); + } + @Test public void canonicalCallbacks_returnsEmptyListWhenNull() { TestLlm testLlm = createTestLlm(LlmResponse.builder().build()); diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index 94cd399df..28123bab8 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -17,12 +17,14 @@ package com.google.adk.events; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.adk.sessions.State; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; +import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import org.junit.Test; import org.junit.runner.RunWith; @@ -130,4 +132,37 @@ public void jsonSerialization_works() throws Exception { assertThat(deserialized).isEqualTo(eventActions); assertThat(deserialized.deletedArtifactIds()).containsExactly("d1", "d2"); } + + @Test + @SuppressWarnings("unchecked") // the nested map is known to be Map + public void merge_deeplyMergesStateDelta() { + EventActions eventActions1 = EventActions.builder().build(); + eventActions1.stateDelta().put("a", 1); + eventActions1.stateDelta().put("b", ImmutableMap.of("nested1", 10, "nested2", 20)); + eventActions1.stateDelta().put("c", 100); + EventActions eventActions2 = EventActions.builder().build(); + eventActions2.stateDelta().put("a", 2); + eventActions2.stateDelta().put("b", ImmutableMap.of("nested2", 22, "nested3", 30)); + eventActions2.stateDelta().put("d", 200); + + EventActions merged = eventActions1.toBuilder().merge(eventActions2).build(); + + assertThat(merged.stateDelta().keySet()).containsExactly("a", "b", "c", "d"); + assertThat(merged.stateDelta()).containsEntry("a", 2); + assertThat((Map) merged.stateDelta().get("b")) + .containsExactly("nested1", 10, "nested2", 22, "nested3", 30); + assertThat(merged.stateDelta()).containsEntry("c", 100); + assertThat(merged.stateDelta()).containsEntry("d", 200); + } + + @Test + public void merge_failsOnMismatchedKeyTypesNestedInStateDelta() { + EventActions eventActions1 = EventActions.builder().build(); + eventActions1.stateDelta().put("nested", ImmutableMap.of("a", 1)); + EventActions eventActions2 = EventActions.builder().build(); + eventActions2.stateDelta().put("nested", ImmutableMap.of(1, 2)); + + assertThrows( + IllegalArgumentException.class, () -> eventActions1.toBuilder().merge(eventActions2)); + } } diff --git a/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java b/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java index 65364e7b4..6ba2832c0 100644 --- a/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java +++ b/core/src/test/java/com/google/adk/models/ApigeeLlmTest.java @@ -31,6 +31,7 @@ import com.google.genai.types.Part; import io.reactivex.rxjava3.core.Flowable; import java.util.Map; +import java.util.Objects; import org.junit.Before; import org.junit.Rule; import org.junit.Test; @@ -57,11 +58,11 @@ public void checkApiKey() { @Test public void build_withValidModelStrings_succeeds() { String[] validModelStrings = { - "apigee/gemini-1.5-flash", - "apigee/v1/gemini-1.5-flash", - "apigee/vertex_ai/gemini-1.5-flash", - "apigee/gemini/v1/gemini-1.5-flash", - "apigee/vertex_ai/v1beta/gemini-1.5-flash" + "apigee/whatever-model", + "apigee/v1/whatever-model", + "apigee/vertex_ai/whatever-model", + "apigee/gemini/v1/whatever-model", + "apigee/vertex_ai/v1beta/whatever-model" }; for (String modelName : validModelStrings) { @@ -93,18 +94,18 @@ public void build_withInvalidModelStrings_throwsException() { public void generateContent_stripsApigeePrefixAndSendsToDelegate() { when(mockGeminiDelegate.generateContent(any(), anyBoolean())).thenReturn(Flowable.empty()); - ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/gemini-1.5-flash", mockGeminiDelegate); + ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/whatever-model", mockGeminiDelegate); LlmRequest request = LlmRequest.builder() - .model("apigee/gemini/v1/gemini-1.5-flash") + .model("apigee/gemini/v1/whatever-model") .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hi")).build())) .build(); llm.generateContent(request, true).test().assertNoErrors(); ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(LlmRequest.class); verify(mockGeminiDelegate).generateContent(requestCaptor.capture(), eq(true)); - assertThat(requestCaptor.getValue().model()).hasValue("gemini-1.5-flash"); + assertThat(requestCaptor.getValue().model()).hasValue("whatever-model"); } // Add a test to verify the vertexAI flag is set correctly. @@ -112,7 +113,7 @@ public void generateContent_stripsApigeePrefixAndSendsToDelegate() { public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() { ApigeeLlm llm = ApigeeLlm.builder() - .modelName("apigee/vertex_ai/gemini-1.5-flash") + .modelName("apigee/vertex_ai/whatever-model") .proxyUrl(PROXY_URL) .build(); assertThat(llm.getApiClient().vertexAI()).isTrue(); @@ -122,8 +123,10 @@ public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() { public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() { ApigeeLlm llm = - ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").proxyUrl(PROXY_URL).build(); - if (System.getenv("GOOGLE_GENAI_USE_VERTEXAI") != null) { + ApigeeLlm.builder().modelName("apigee/whatever-model").proxyUrl(PROXY_URL).build(); + String useVertexAi = System.getenv("GOOGLE_GENAI_USE_VERTEXAI"); + + if (Objects.equals(useVertexAi, "true") || Objects.equals(useVertexAi, "1")) { assertThat(llm.getApiClient().vertexAI()).isTrue(); } else { assertThat(llm.getApiClient().vertexAI()).isFalse(); @@ -133,7 +136,7 @@ public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() { @Test public void generateContent_setsVertexAiFlagCorrectly_withGemini() { ApigeeLlm llm = - ApigeeLlm.builder().modelName("apigee/gemini/gemini-1.5-flash").proxyUrl(PROXY_URL).build(); + ApigeeLlm.builder().modelName("apigee/gemini/whatever-model").proxyUrl(PROXY_URL).build(); assertThat(llm.getApiClient().vertexAI()).isFalse(); } @@ -142,11 +145,11 @@ public void generateContent_setsVertexAiFlagCorrectly_withGemini() { public void generateContent_setsApiVersionCorrectly() { ImmutableMap modelToApiVersion = ImmutableMap.of( - "apigee/gemini-1.5-flash", "", - "apigee/v1/gemini-1.5-flash", "v1", - "apigee/vertex_ai/gemini-1.5-flash", "", - "apigee/gemini/v1/gemini-1.5-flash", "v1", - "apigee/vertex_ai/v1beta/gemini-1.5-flash", "v1beta"); + "apigee/whatever-model", "", + "apigee/v1/whatever-model", "v1", + "apigee/vertex_ai/whatever-model", "", + "apigee/gemini/v1/whatever-model", "v1", + "apigee/vertex_ai/v1beta/whatever-model", "v1beta"); for (Map.Entry entry : modelToApiVersion.entrySet()) { String modelName = entry.getKey(); @@ -165,7 +168,7 @@ public void build_withCustomHeaders_setsHeadersInHttpOptions() { ImmutableMap customHeaders = ImmutableMap.of("X-Test-Header", "TestValue"); ApigeeLlm llm = ApigeeLlm.builder() - .modelName("apigee/gemini-1.5-flash") + .modelName("apigee/whatever-model") .proxyUrl(PROXY_URL) .customHeaders(customHeaders) .build(); @@ -192,14 +195,14 @@ public void build_withTrailingSlashInModel_parsesVersionAndModelId() { public void build_withoutProxyUrl_readsFromEnvironment() { String envProxyUrl = System.getenv("APIGEE_PROXY_URL"); if (envProxyUrl != null) { - ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build(); + ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/whatever-model").build(); assertThat(llm.getHttpOptions().baseUrl()).hasValue(envProxyUrl); } else { assertThrows( IllegalArgumentException.class, - () -> ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build()); + () -> ApigeeLlm.builder().modelName("apigee/whatever-model").build()); ApigeeLlm llm = - ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/gemini-1.5-flash").build(); + ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/whatever-model").build(); assertThat(llm.getHttpOptions().baseUrl()).hasValue(PROXY_URL); } } diff --git a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java index 6223dd2f0..41e156ffd 100644 --- a/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/InMemorySessionServiceTest.java @@ -20,6 +20,7 @@ import com.google.adk.events.Event; import com.google.adk.events.EventActions; import io.reactivex.rxjava3.core.Single; +import java.util.HashMap; import java.util.Optional; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -84,7 +85,7 @@ public void lifecycle_listSessions() { Session session = sessionService - .createSession("app-name", "user-id", new ConcurrentHashMap<>(), "session-1") + .createSession("app-name", "user-id", new HashMap<>(), "session-1") .blockingGet(); ConcurrentMap stateDelta = new ConcurrentHashMap<>(); @@ -130,9 +131,7 @@ public void lifecycle_deleteSession() { public void appendEvent_updatesSessionState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); ConcurrentMap stateDelta = new ConcurrentHashMap<>(); stateDelta.put("sessionKey", "sessionValue"); @@ -167,9 +166,7 @@ public void appendEvent_updatesSessionState() { public void appendEvent_removesState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); ConcurrentMap stateDeltaAdd = new ConcurrentHashMap<>(); stateDeltaAdd.put("sessionKey", "sessionValue"); @@ -221,9 +218,7 @@ public void appendEvent_removesState() { public void sequentialAgents_shareTempState() { InMemorySessionService sessionService = new InMemorySessionService(); Session session = - sessionService - .createSession("app", "user", new ConcurrentHashMap<>(), "session1") - .blockingGet(); + sessionService.createSession("app", "user", new HashMap<>(), "session1").blockingGet(); // Agent 1 writes to temp state ConcurrentMap stateDelta1 = new ConcurrentHashMap<>(); diff --git a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java index 36eab1d16..def4faf4c 100644 --- a/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java +++ b/core/src/test/java/com/google/adk/sessions/VertexAiSessionServiceTest.java @@ -167,8 +167,7 @@ public void setUp() throws Exception { @Test public void createSession_success() throws Exception { - ConcurrentMap sessionStateMap = - new ConcurrentHashMap<>(ImmutableMap.of("new_key", "new_value")); + Map sessionStateMap = new HashMap<>(ImmutableMap.of("new_key", "new_value")); Single sessionSingle = vertexAiSessionService.createSession("123", "test_user", sessionStateMap, null); Session createdSession = sessionSingle.blockingGet(); @@ -190,8 +189,7 @@ public void createSession_success() throws Exception { @Test public void createSession_getSession_success() throws Exception { - ConcurrentMap sessionStateMap = - new ConcurrentHashMap<>(ImmutableMap.of("new_key", "new_value")); + Map sessionStateMap = new HashMap<>(ImmutableMap.of("new_key", "new_value")); Single sessionSingle = vertexAiSessionService.createSession("789", "test_user", sessionStateMap, null); Session createdSession = sessionSingle.blockingGet(); @@ -252,8 +250,7 @@ public void getAndDeleteSession_success() throws Exception { @Test public void createSessionAndGetSession_success() throws Exception { - ConcurrentMap sessionStateMap = - new ConcurrentHashMap<>(ImmutableMap.of("key", "value")); + Map sessionStateMap = new HashMap<>(ImmutableMap.of("key", "value")); Single sessionSingle = vertexAiSessionService.createSession("123", "user", sessionStateMap, null); Session createdSession = sessionSingle.blockingGet(); @@ -341,8 +338,8 @@ public void listEmptySession_success() { @Test public void appendEvent_withStateRemoved_updatesSessionState() { String userId = "userB"; - ConcurrentMap initialState = - new ConcurrentHashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); + Map initialState = + new HashMap<>(ImmutableMap.of("key1", "value1", "key2", "value2")); Session session = vertexAiSessionService.createSession("987", userId, initialState, null).blockingGet(); diff --git a/core/src/test/java/com/google/adk/tools/BaseToolTest.java b/core/src/test/java/com/google/adk/tools/BaseToolTest.java index dde1d73ea..2a07e7a44 100644 --- a/core/src/test/java/com/google/adk/tools/BaseToolTest.java +++ b/core/src/test/java/com/google/adk/tools/BaseToolTest.java @@ -2,12 +2,16 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.Gemini; import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.InMemorySessionService; import com.google.common.collect.ImmutableList; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.GoogleMaps; import com.google.genai.types.GoogleSearch; -import com.google.genai.types.GoogleSearchRetrieval; import com.google.genai.types.Tool; import com.google.genai.types.ToolCodeExecution; import com.google.genai.types.UrlContext; @@ -142,25 +146,6 @@ public void processLlmRequestWithGoogleSearchToolAddsToolToConfig() { Tool.builder().googleSearch(GoogleSearch.builder().build()).build()); } - @Test - public void processLlmRequestWithGoogleSearchRetrievalToolAddsToolToConfig() { - GoogleSearchTool googleSearchTool = new GoogleSearchTool(); - LlmRequest llmRequest = - LlmRequest.builder() - .config(GenerateContentConfig.builder().build()) - .model("gemini-1") - .build(); - LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); - Completable unused = - googleSearchTool.processLlmRequest(llmRequestBuilder, /* toolContext= */ null); - LlmRequest updatedLlmRequest = llmRequestBuilder.build(); - assertThat(updatedLlmRequest.config()).isPresent(); - assertThat(updatedLlmRequest.config().get().tools()).isPresent(); - assertThat(updatedLlmRequest.config().get().tools().get()) - .containsExactly( - Tool.builder().googleSearchRetrieval(GoogleSearchRetrieval.builder().build()).build()); - } - @Test public void processLlmRequestWithUrlContextToolAddsToolToConfig() { FunctionDeclaration functionDeclaration = @@ -190,13 +175,27 @@ public void processLlmRequestWithUrlContextToolAddsToolToConfig() { Tool.builder().urlContext(UrlContext.builder().build()).build()); } + private static InvocationContext.Builder testInvocationContext() { + InvocationContext.Builder builder = InvocationContext.builder(); + builder.agent(testAgent().build()); + InMemorySessionService inMemorySessionService = new InMemorySessionService(); + builder.sessionService(inMemorySessionService); + builder.session(inMemorySessionService.createSession("test-app", "test-user-id").blockingGet()); + return builder; + } + + private static LlmAgent.Builder testAgent() { + return LlmAgent.builder().name("test-agent"); + } + @Test - public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() { + public void + processLlmRequestWithBuiltInCodeExecutionToolAndNonGeminiModelAndNullContextAddsToolToConfig() { BuiltInCodeExecutionTool builtInCodeExecutionTool = new BuiltInCodeExecutionTool(); LlmRequest llmRequest = LlmRequest.builder() .config(GenerateContentConfig.builder().build()) - .model("gemini-2") + .model("text-bison") .build(); LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); Completable unused = @@ -207,4 +206,45 @@ public void processLlmRequestWithBuiltInCodeExecutionToolAddsToolToConfig() { assertThat(updatedLlmRequest.config().get().tools().get()) .containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build()); } + + @Test + public void processLlmRequestWithBuiltInCodeExecutionToolAndGemini2ModelAddsToolToConfig() { + BuiltInCodeExecutionTool builtInCodeExecutionTool = new BuiltInCodeExecutionTool(); + LlmRequest llmRequest = + LlmRequest.builder() + .config(GenerateContentConfig.builder().build()) + .model("gemini-2") + .build(); + LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); + ToolContext toolContext = + ToolContext.builder( + testInvocationContext() + .agent(testAgent().model(new Gemini("gemini-2", "")).build()) + .build()) + .build(); + Completable unused = builtInCodeExecutionTool.processLlmRequest(llmRequestBuilder, toolContext); + LlmRequest updatedLlmRequest = llmRequestBuilder.build(); + assertThat(updatedLlmRequest.config()).isPresent(); + assertThat(updatedLlmRequest.config().get().tools()).isPresent(); + assertThat(updatedLlmRequest.config().get().tools().get()) + .containsExactly(Tool.builder().codeExecution(ToolCodeExecution.builder().build()).build()); + } + + @Test + public void processLlmRequestWithGoogleMapsToolAddsToolToConfig() { + GoogleMapsTool googleMapsTool = new GoogleMapsTool(); + LlmRequest llmRequest = + LlmRequest.builder() + .config(GenerateContentConfig.builder().build()) + .model("gemini-2") + .build(); + LlmRequest.Builder llmRequestBuilder = llmRequest.toBuilder(); + Completable unused = + googleMapsTool.processLlmRequest(llmRequestBuilder, /* toolContext= */ null); + LlmRequest updatedLlmRequest = llmRequestBuilder.build(); + assertThat(updatedLlmRequest.config()).isPresent(); + assertThat(updatedLlmRequest.config().get().tools()).isPresent(); + assertThat(updatedLlmRequest.config().get().tools().get()) + .containsExactly(Tool.builder().googleMaps(GoogleMaps.builder().build()).build()); + } } diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerEnvironmentTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerEnvironmentTest.java new file mode 100644 index 000000000..ed22819ec --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerEnvironmentTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ComputerEnvironment}. */ +@RunWith(JUnit4.class) +public final class ComputerEnvironmentTest { + + @Test + public void testEnumValues() { + assertThat(ComputerEnvironment.values()) + .asList() + .containsAtLeast( + ComputerEnvironment.ENVIRONMENT_UNSPECIFIED, ComputerEnvironment.ENVIRONMENT_BROWSER); + } +} diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerStateTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerStateTest.java new file mode 100644 index 000000000..736f9be0e --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerStateTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.truth.Truth.assertThat; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ComputerState}. */ +@RunWith(JUnit4.class) +public final class ComputerStateTest { + + @Test + public void testBuilder() { + byte[] screenshot = new byte[] {1, 2, 3}; + String url = "https://google.com"; + ComputerState state = ComputerState.builder().screenshot(screenshot).url(url).build(); + + assertThat(state.screenshot()).isEqualTo(screenshot); + assertThat(state.url()).hasValue(url); + } + + @Test + public void testBuilder_noUrl() { + byte[] screenshot = new byte[] {1, 2, 3}; + ComputerState state = ComputerState.builder().screenshot(screenshot).build(); + + assertThat(state.screenshot()).isEqualTo(screenshot); + assertThat(state.url()).isEmpty(); + } + + @Test + public void testEqualsAndHashCode() { + byte[] screenshot1 = new byte[] {1, 2, 3}; + byte[] screenshot2 = new byte[] {1, 2, 3}; + byte[] screenshot3 = new byte[] {4, 5, 6}; + + ComputerState state1 = ComputerState.builder().screenshot(screenshot1).url("url1").build(); + ComputerState state2 = ComputerState.builder().screenshot(screenshot2).url("url1").build(); + ComputerState state3 = ComputerState.builder().screenshot(screenshot3).url("url1").build(); + ComputerState state4 = ComputerState.builder().screenshot(screenshot1).url("url2").build(); + + assertThat(state1).isEqualTo(state2); + assertThat(state1.hashCode()).isEqualTo(state2.hashCode()); + + assertThat(state1).isNotEqualTo(state3); + assertThat(state1).isNotEqualTo(state4); + } + + @Test + public void testScreenshotImmutability() { + byte[] screenshot = new byte[] {1, 2, 3}; + ComputerState state = ComputerState.builder().screenshot(screenshot).build(); + + // Modify original array + screenshot[0] = 9; + assertThat(state.screenshot()[0]).isEqualTo(1); + + // Modify returned array + state.screenshot()[0] = 9; + assertThat(state.screenshot()[0]).isEqualTo(1); + } +} diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java new file mode 100644 index 000000000..20fb146cf --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolTest.java @@ -0,0 +1,258 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; +import com.google.adk.tools.Annotations.Schema; +import com.google.adk.tools.ToolContext; +import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Single; +import java.lang.reflect.Method; +import java.util.Base64; +import java.util.Map; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ComputerUseTool}. */ +@RunWith(JUnit4.class) +public final class ComputerUseToolTest { + + private LlmAgent agent; + private InMemorySessionService sessionService; + private ToolContext toolContext; + private ComputerMock computerMock; + + @Before + public void setUp() { + agent = LlmAgent.builder().name("test-agent").build(); + sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .agent(agent) + .session(session) + .sessionService(sessionService) + .invocationId("invocation-id") + .build(); + toolContext = ToolContext.builder(invocationContext).functionCallId("functionCallId").build(); + computerMock = new ComputerMock(); + } + + @Test + public void testNormalizeX() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + assertThat(tool.runAsync(ImmutableMap.of("x", 0, "y", 0), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastX).isEqualTo(0); + + assertThat(tool.runAsync(ImmutableMap.of("x", 500, "y", 300), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastX).isEqualTo(960); // 500/1000 * 1920 + + assertThat(tool.runAsync(ImmutableMap.of("x", 1000, "y", 300), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastX).isEqualTo(1919); // Clamped + } + + @Test + public void testNormalizeY() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + assertThat(tool.runAsync(ImmutableMap.of("x", 0, "y", 500), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastY).isEqualTo(540); // 500/1000 * 1080 + } + + @Test + public void testNormalizeWithCustomVirtualScreenSize() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {2000, 2000}); + + assertThat(tool.runAsync(ImmutableMap.of("x", 1000, "y", 1000), toolContext).blockingGet()) + .isNotNull(); + assertThat(computerMock.lastX).isEqualTo(960); // 1000/2000 * 1920 + assertThat(computerMock.lastY).isEqualTo(540); // 1000/2000 * 1080 + } + + @Test + public void testNormalizeDragAndDrop() throws NoSuchMethodException { + Method method = + ComputerMock.class.getMethod("dragAndDrop", int.class, int.class, int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + Map result = + tool.runAsync( + ImmutableMap.of("x", 100, "y", 200, "destination_x", 800, "destination_y", 600), + toolContext) + .blockingGet(); + assertThat(result).isNotNull(); + + assertThat(computerMock.lastX).isEqualTo(192); + assertThat(computerMock.lastY).isEqualTo(216); + assertThat(computerMock.lastDestX).isEqualTo(1536); + assertThat(computerMock.lastDestY).isEqualTo(648); + } + + @Test + public void testResultFormatting() throws NoSuchMethodException { + byte[] screenshot = new byte[] {1, 2, 3}; + computerMock.nextState = + ComputerState.builder() + .screenshot(screenshot) + .url(Optional.of("https://example.com")) + .build(); + + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + Map result = + tool.runAsync(ImmutableMap.of("x", 500, "y", 500), toolContext).blockingGet(); + assertThat(result).containsKey("image"); + Object imageData = result.get("image"); + assertThat(imageData).isInstanceOf(Map.class); + ((Map) imageData) + .forEach( + (key, value) -> { + assertThat(key).isInstanceOf(String.class); + assertThat(value).isInstanceOf(String.class); + }); + @SuppressWarnings("unchecked") // The types of the key and value are checked above. + Map imageMap = (Map) imageData; + assertThat(imageMap.get("mimetype")).isEqualTo("image/png"); + assertThat(imageMap.get("data")).isEqualTo(Base64.getEncoder().encodeToString(screenshot)); + assertThat(result.get("url")).isEqualTo("https://example.com"); + assertThat(result).containsKey("image"); + assertThat(result).doesNotContainKey("screenshot"); + } + + @Test + public void testResultFormatting_noScreenshot() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("noScreenshot"); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result).doesNotContainKey("image"); + assertThat(result.get("url")).isEqualTo("https://example.com"); + } + + @Test + public void testResultFormatting_nonByteArrayScreenshot() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("nonByteArrayScreenshot"); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + Map result = tool.runAsync(ImmutableMap.of(), toolContext).blockingGet(); + assertThat(result).doesNotContainKey("image"); + assertThat(result.get("screenshot")).isEqualTo("not-a-byte-array"); + } + + @Test + public void testNormalizeWithInvalidInputs() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + assertThrows( + IllegalArgumentException.class, + () -> tool.runAsync(ImmutableMap.of("x", "invalid", "y", 500), toolContext).blockingGet()); + } + + @Test + public void testRunAsyncWithNoCoordinates() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + // Arguments without x, y, etc. should be passed as is. + ImmutableMap args = ImmutableMap.of("other", "value"); + var unused = tool.runAsync(args, toolContext).blockingGet(); + assertThat(computerMock.lastX).isEqualTo(0); + assertThat(computerMock.lastY).isEqualTo(0); + } + + @Test + public void testCoordinateClamping() throws NoSuchMethodException { + Method method = ComputerMock.class.getMethod("clickAt", int.class, int.class); + ComputerUseTool tool = + new ComputerUseTool(computerMock, method, new int[] {1920, 1080}, new int[] {1000, 1000}); + + // Test clamping to 0 + var unused1 = tool.runAsync(ImmutableMap.of("x", -100, "y", -50), toolContext).blockingGet(); + assertThat(computerMock.lastX).isEqualTo(0); + assertThat(computerMock.lastY).isEqualTo(0); + + // Test clamping to max + var unused2 = tool.runAsync(ImmutableMap.of("x", 2000, "y", 1500), toolContext).blockingGet(); + assertThat(computerMock.lastX).isEqualTo(1919); + assertThat(computerMock.lastY).isEqualTo(1079); + } + + /** A mock class for Computer actions. */ + public static class ComputerMock { + public int lastX; + public int lastY; + public int lastDestX; + public int lastDestY; + public ComputerState nextState = + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build(); + + public Single clickAt(@Schema(name = "x") int x, @Schema(name = "y") int y) { + this.lastX = x; + this.lastY = y; + return Single.just(nextState); + } + + public Single dragAndDrop( + @Schema(name = "x") int x, + @Schema(name = "y") int y, + @Schema(name = "destination_x") int destinationX, + @Schema(name = "destination_y") int destinationY) { + this.lastX = x; + this.lastY = y; + this.lastDestX = destinationX; + this.lastDestY = destinationY; + return Single.just(nextState); + } + + public Single> noScreenshot() { + return Single.just(ImmutableMap.of("url", "https://example.com")); + } + + public Single> nonByteArrayScreenshot() { + return Single.just(ImmutableMap.of("screenshot", "not-a-byte-array")); + } + } +} diff --git a/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java new file mode 100644 index 000000000..1ed49419e --- /dev/null +++ b/core/src/test/java/com/google/adk/tools/computeruse/ComputerUseToolsetTest.java @@ -0,0 +1,264 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.tools.computeruse; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.LlmAgent; +import com.google.adk.models.LlmRequest; +import com.google.adk.sessions.InMemorySessionService; +import com.google.adk.sessions.Session; +import com.google.adk.tools.BaseTool; +import com.google.adk.tools.ToolContext; +import com.google.genai.types.Environment; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Tool; +import io.reactivex.rxjava3.core.Completable; +import io.reactivex.rxjava3.core.Single; +import java.time.Duration; +import java.util.List; +import java.util.Optional; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link ComputerUseToolset}. */ +@RunWith(JUnit4.class) +public final class ComputerUseToolsetTest { + + private LlmAgent agent; + private InMemorySessionService sessionService; + private ToolContext toolContext; + private MockComputer mockComputer; + private ComputerUseToolset toolset; + + @Before + public void setUp() { + agent = LlmAgent.builder().name("test-agent").build(); + sessionService = new InMemorySessionService(); + Session session = + sessionService.createSession("test-app", "test-user", null, "test-session").blockingGet(); + InvocationContext invocationContext = + InvocationContext.builder() + .agent(agent) + .session(session) + .sessionService(sessionService) + .invocationId("invocation-id") + .build(); + toolContext = ToolContext.builder(invocationContext).functionCallId("functionCallId").build(); + + mockComputer = new MockComputer(); + toolset = new ComputerUseToolset(mockComputer); + } + + @Test + public void testGetTools() { + List tools = toolset.getTools(null).toList().blockingGet(); + + assertThat(mockComputer.initializeCallCount).isEqualTo(1); + assertThat(tools).isNotEmpty(); + + // Verify method filtering + assertThat(tools.stream().anyMatch(t -> t.name().equals("clickAt"))).isTrue(); + assertThat(tools.stream().noneMatch(t -> t.name().equals("screenSize"))).isTrue(); + assertThat(tools.stream().noneMatch(t -> t.name().equals("environment"))).isTrue(); + } + + @Test + public void testEnsureInitializedOnlyCalledOnce() { + var unused1 = toolset.getTools(null).toList().blockingGet(); + var unused2 = toolset.getTools(null).toList().blockingGet(); + + assertThat(mockComputer.initializeCallCount).isEqualTo(1); + } + + @Test + public void testGetTools_cachesTools() { + List tools1 = toolset.getTools(null).toList().blockingGet(); + List tools2 = toolset.getTools(null).toList().blockingGet(); + + assertThat(tools1).hasSize(tools2.size()); + for (int i = 0; i < tools1.size(); i++) { + assertThat(tools1.get(i)).isSameInstanceAs(tools2.get(i)); + } + } + + @Test + public void testProcessLlmRequest() { + LlmRequest.Builder builder = + LlmRequest.builder().model("test-model").config(GenerateContentConfig.builder().build()); + + toolset.processLlmRequest(builder, toolContext).blockingAwait(); + + LlmRequest request = builder.build(); + assertThat(request.config()).isPresent(); + GenerateContentConfig config = request.config().get(); + + assertThat(config.tools()).isPresent(); + List tools = config.tools().get(); + + // Find the computer use tool + Optional computerUseTool = + tools.stream().filter(t -> t.computerUse().isPresent()).findFirst(); + assertThat(computerUseTool).isPresent(); + assertThat(computerUseTool.get().computerUse().get().environment().get().knownEnum()) + .isEqualTo(Environment.Known.ENVIRONMENT_BROWSER); + + // Verify computer actions were added as function declarations + Optional functionTool = + tools.stream().filter(t -> t.functionDeclarations().isPresent()).findFirst(); + assertThat(functionTool).isPresent(); + assertThat( + functionTool.get().functionDeclarations().get().stream() + .anyMatch(fd -> fd.name().orElse("").equals("clickAt"))) + .isTrue(); + } + + @Test + public void testProcessLlmRequest_withComputerError() { + mockComputer.nextError = new RuntimeException("Computer failure"); + LlmRequest.Builder builder = + LlmRequest.builder().model("test-model").config(GenerateContentConfig.builder().build()); + + assertThrows( + RuntimeException.class, + () -> toolset.processLlmRequest(builder, toolContext).blockingAwait()); + } + + private static class MockComputer implements BaseComputer { + int initializeCallCount = 0; + Throwable nextError = null; + + @Override + public Completable initialize() { + if (nextError != null) { + return Completable.error(nextError); + } + this.initializeCallCount++; + return Completable.complete(); + } + + @Override + public Single screenSize() { + if (nextError != null) { + return Single.error(nextError); + } + return Single.just(new int[] {1920, 1080}); + } + + @Override + public Single environment() { + if (nextError != null) { + return Single.error(nextError); + } + return Single.just(ComputerEnvironment.ENVIRONMENT_BROWSER); + } + + @Override + public Single openWebBrowser() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single clickAt(int x, int y) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single hoverAt(int x, int y) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single typeTextAt( + int x, int y, String text, Boolean pressEnter, Boolean clearBeforeTyping) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single scrollDocument(String direction) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single scrollAt(int x, int y, String direction, int magnitude) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single wait(Duration duration) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single goBack() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single goForward() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single search() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single navigate(String url) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.of(url)).build()); + } + + @Override + public Single keyCombination(List keys) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single dragAndDrop(int x, int y, int destinationX, int destinationY) { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Single currentState() { + return Single.just( + ComputerState.builder().screenshot(new byte[0]).url(Optional.empty()).build()); + } + + @Override + public Completable close() { + return Completable.complete(); + } + } +} diff --git a/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java b/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java index 37853c477..20dda7034 100644 --- a/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java +++ b/core/src/test/java/com/google/adk/utils/ModelNameUtilsTest.java @@ -2,6 +2,7 @@ import static com.google.common.truth.Truth.assertThat; +import com.google.adk.models.Gemini; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -69,4 +70,103 @@ public void isGemini2Model_withApigeeProviderV1BetaGemini2Model_returnsTrue() { public void isGemini2Model_withNullModel_returnsFalse() { assertThat(ModelNameUtils.isGemini2Model(null)).isFalse(); } + + @Test + public void isGeminiModel_withGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withNonGeminiModel_returnsFalse() { + assertThat(ModelNameUtils.isGeminiModel("text-bison")).isFalse(); + } + + @Test + public void isGeminiModel_withPathBasedGeminiModel_returnsTrue() { + assertThat( + ModelNameUtils.isGeminiModel( + "projects/test-project/locations/us-central1/publishers/google/models/gemini-1.5-pro")) + .isTrue(); + } + + @Test + public void isGeminiModel_withPathBasedNonGeminiModel_returnsFalse() { + assertThat( + ModelNameUtils.isGeminiModel( + "projects/test-project/locations/us-central1/publishers/google/models/text-bison")) + .isFalse(); + } + + @Test + public void isGeminiModel_withApigeeGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeV1GeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/v1/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeProviderGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/gemini/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeProviderVertexGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/vertex_ai/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeProviderV1GeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/gemini/v1/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withApigeeProviderV1BetaGeminiModel_returnsTrue() { + assertThat(ModelNameUtils.isGeminiModel("apigee/vertex_ai/v1beta/gemini-1.5-flash")).isTrue(); + } + + @Test + public void isGeminiModel_withNullModel_returnsFalse() { + assertThat(ModelNameUtils.isGeminiModel(null)).isFalse(); + } + + @Test + public void isGeminiModel_withEmptyModel_returnsFalse() { + assertThat(ModelNameUtils.isGeminiModel("")).isFalse(); + } + + @Test + public void isInstanceOfGemini_withGeminiInstance_returnsTrue() { + assertThat(ModelNameUtils.isInstanceOfGemini(new Gemini("", ""))).isTrue(); + } + + @Test + public void isInstanceOfGemini_withNonGeminiInstance_returnsFalse() { + assertThat(ModelNameUtils.isInstanceOfGemini(new Object())).isFalse(); + } + + @Test + public void isInstanceOfGemini_withNullInstance_returnsFalse() { + assertThat(ModelNameUtils.isInstanceOfGemini(null)).isFalse(); + } + + private static class GeminiSubclass extends Gemini { + GeminiSubclass() { + super("test-model", "test-api-key"); + } + } + + private static class GeminiSubclassSubclass extends GeminiSubclass {} + + @Test + public void isInstanceOfGemini_withGeminiSubclassInstance_returnsTrue() { + assertThat(ModelNameUtils.isInstanceOfGemini(new GeminiSubclass())).isTrue(); + } + + @Test + public void isInstanceOfGemini_withSubclassOfGeminiSubclassInstance_returnsTrue() { + assertThat(ModelNameUtils.isInstanceOfGemini(new GeminiSubclassSubclass())).isTrue(); + } } diff --git a/dev/pom.xml b/dev/pom.xml index 73b513040..cd8d0c80a 100644 --- a/dev/pom.xml +++ b/dev/pom.xml @@ -18,7 +18,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 google-adk-dev diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index 66f65c65f..22a3b353a 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -4,7 +4,7 @@ com.example custom-tools-example - 0.6.0 + 0.7.0 jar ADK Custom Tools Example diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index 365ebdd55..d1b62b667 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -4,7 +4,7 @@ com.example simple-adk-agent - 0.6.0 + 0.7.0 jar Simple ADK Agent Example diff --git a/maven_plugin/pom.xml b/maven_plugin/pom.xml index 58c1a6c8b..2800041e6 100644 --- a/maven_plugin/pom.xml +++ b/maven_plugin/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 ../pom.xml diff --git a/pom.xml b/pom.xml index e4f7f6022..3ef6fd8f1 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 pom Google Agent Development Kit Maven Parent POM @@ -42,13 +42,13 @@ ${java.version} UTF-8 - 1.11.0 + 1.11.1 3.4.1 - 1.49.0 + 1.59.0 0.14.0 - 2.38.0 - 1.32.0 - 4.32.0 + 2.47.0 + 1.41.0 + 4.33.5 5.11.4 5.20.0 1.6.0 @@ -58,17 +58,17 @@ 0.18.1 3.41.0 3.9.0 - 1.8.0 + 1.11.0 2.0.17 - 1.4.4 + 1.4.5 1.0.0 3.1.5 3.7.0 2.35.1 - 3.27.3 + 3.27.7 1.4.0 3.9.0 - 5.4.3 + 5.6 @@ -77,7 +77,7 @@ com.google.cloud libraries-bom - 26.53.0 + 26.76.0 pom import diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index fa84cb88e..c7edd9b37 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 ../../pom.xml diff --git a/tutorials/live-audio-single-agent/pom.xml b/tutorials/live-audio-single-agent/pom.xml index 944784aee..4efba4525 100644 --- a/tutorials/live-audio-single-agent/pom.xml +++ b/tutorials/live-audio-single-agent/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 0.6.0 + 0.7.0 ../../pom.xml