diff --git a/.release-please-manifest.json b/.release-please-manifest.json index 09a252282..6a787c5c9 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "1.2.0" + ".": "1.3.0" } diff --git a/CHANGELOG.md b/CHANGELOG.md index c9ae0d827..2d258953d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,22 @@ # Changelog +## [1.3.0](https://github.com/google/adk-java/compare/v1.2.0...v1.3.0) (2026-05-13) + + +### Features + +* Add ChatCompletionsHTTPClient and support for non-streaming requests ([9529c1a](https://github.com/google/adk-java/commit/9529c1aeecb324e1c00c6bd105df2a0e9f67ed26)) +* Add conversion from LlmRequest to ChatCompletionsRequest ([d37f6ee](https://github.com/google/adk-java/commit/d37f6ee6d8ec036154593b734f1a3b080847cfea)) +* Add SkillSource interface and implementations for loading skills ([509c4aa](https://github.com/google/adk-java/commit/509c4aa75fdc752c2758a1761cbd8946075b310c)) +* Add support for refusal content using "[[REFUSAL]]:" prefix ([e9184c9](https://github.com/google/adk-java/commit/e9184c9846d97f65907667aa2a6bbac1f65fed64)) +* Refactor BigQueryAgentAnalyticsPlugin for async in preparation for GCS offloading ([d837ef0](https://github.com/google/adk-java/commit/d837ef0164cedd284af6caee84911569109ab7e3)) + + +### Bug Fixes + +* Account for nulls in EventActions and State ([582cf7c](https://github.com/google/adk-java/commit/582cf7c2b6534afaf5edfa501391191478d8d8ea)) +* upgrade Mockito and JaCoCo for Java 25 compatibility ([8574fc5](https://github.com/google/adk-java/commit/8574fc5bb6ac7edae99306b06c0a610f7da60048)) + ## [1.2.0](https://github.com/google/adk-java/compare/v1.1.0...v1.2.0) (2026-04-24) diff --git a/README.md b/README.md index 107a6967b..9613078dd 100644 --- a/README.md +++ b/README.md @@ -50,13 +50,13 @@ If you're using Maven, add the following to your dependencies: com.google.adk google-adk - 1.2.0 + 1.3.0 com.google.adk google-adk-dev - 1.2.0 + 1.3.0 ``` diff --git a/a2a/pom.xml b/a2a/pom.xml index 1d5cf5a90..70e24e023 100644 --- a/a2a/pom.xml +++ b/a2a/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 google-adk-a2a diff --git a/contrib/firestore-session-service/pom.xml b/contrib/firestore-session-service/pom.xml index 5864a6d4f..a801efc93 100644 --- a/contrib/firestore-session-service/pom.xml +++ b/contrib/firestore-session-service/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../../pom.xml diff --git a/contrib/langchain4j/pom.xml b/contrib/langchain4j/pom.xml index d5cf4dc63..9c998452c 100644 --- a/contrib/langchain4j/pom.xml +++ b/contrib/langchain4j/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../../pom.xml diff --git a/contrib/planners/pom.xml b/contrib/planners/pom.xml index 50cb91bc9..86cb0f43c 100644 --- a/contrib/planners/pom.xml +++ b/contrib/planners/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../../pom.xml diff --git a/contrib/samples/a2a_basic/pom.xml b/contrib/samples/a2a_basic/pom.xml index e12ca09a1..e511b9145 100644 --- a/contrib/samples/a2a_basic/pom.xml +++ b/contrib/samples/a2a_basic/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.2.0 + 1.3.0 .. diff --git a/contrib/samples/a2a_server/pom.xml b/contrib/samples/a2a_server/pom.xml index 6a7e87ef4..6cc4deb4c 100644 --- a/contrib/samples/a2a_server/pom.xml +++ b/contrib/samples/a2a_server/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.2.0 + 1.3.0 .. diff --git a/contrib/samples/configagent/pom.xml b/contrib/samples/configagent/pom.xml index db7bde0c5..463b82379 100644 --- a/contrib/samples/configagent/pom.xml +++ b/contrib/samples/configagent/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-samples - 1.2.0 + 1.3.0 .. diff --git a/contrib/samples/helloworld/pom.xml b/contrib/samples/helloworld/pom.xml index 1ff79260f..eabbd547f 100644 --- a/contrib/samples/helloworld/pom.xml +++ b/contrib/samples/helloworld/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-samples - 1.2.0 + 1.3.0 .. diff --git a/contrib/samples/mcpfilesystem/pom.xml b/contrib/samples/mcpfilesystem/pom.xml index ce6c2afc8..c3d1a8904 100644 --- a/contrib/samples/mcpfilesystem/pom.xml +++ b/contrib/samples/mcpfilesystem/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../../.. diff --git a/contrib/samples/pom.xml b/contrib/samples/pom.xml index 8978fa2c4..a29ef41f9 100644 --- a/contrib/samples/pom.xml +++ b/contrib/samples/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../.. diff --git a/contrib/spring-ai/pom.xml b/contrib/spring-ai/pom.xml index a64c22793..3abaff88b 100644 --- a/contrib/spring-ai/pom.xml +++ b/contrib/spring-ai/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 30b8760a8..a51b99b1f 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 google-adk diff --git a/core/src/main/java/com/google/adk/Version.java b/core/src/main/java/com/google/adk/Version.java index 2816d6763..d440f4d9e 100644 --- a/core/src/main/java/com/google/adk/Version.java +++ b/core/src/main/java/com/google/adk/Version.java @@ -22,7 +22,7 @@ */ public final class Version { // Don't touch this, release-please should keep it up to date. - public static final String JAVA_ADK_VERSION = "1.2.0"; // x-release-please-released-version + public static final String JAVA_ADK_VERSION = "1.3.0"; // x-release-please-released-version private Version() {} } diff --git a/core/src/main/java/com/google/adk/events/EventActions.java b/core/src/main/java/com/google/adk/events/EventActions.java index 9c977240c..c3c921be5 100644 --- a/core/src/main/java/com/google/adk/events/EventActions.java +++ b/core/src/main/java/com/google/adk/events/EventActions.java @@ -289,10 +289,16 @@ public Builder skipSummarization(@Nullable Boolean skipSummarization) { @CanIgnoreReturnValue @JsonProperty("stateDelta") public Builder stateDelta(@Nullable Map value) { - if (value == null) { - this.stateDelta = new ConcurrentHashMap<>(); - } else { - this.stateDelta = new ConcurrentHashMap<>(value); + this.stateDelta = new ConcurrentHashMap<>(); + if (value != null) { + // Convert null values to State.REMOVED to avoid NPEs. + value + .entrySet() + .forEach( + entry -> { + stateDelta.put( + entry.getKey(), Optional.ofNullable(entry.getValue()).orElse(State.REMOVED)); + }); } return this; } diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java index e26546313..1ed997824 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java @@ -25,6 +25,7 @@ import com.google.genai.types.Part; import java.util.Base64; import java.util.Map; +import java.util.Objects; import org.jspecify.annotations.Nullable; /** Shared models for Chat Completions Request and Response. */ @@ -45,6 +46,50 @@ private ChatCompletionsCommon() {} public static final String METADATA_KEY_SYSTEM_FINGERPRINT = "system_fingerprint"; public static final String METADATA_KEY_SERVICE_TIER = "service_tier"; + /** + * Prefix used to mark refusal content in a text Part, since there is no dedicated field for + * refusal content in the Gemini API. + */ + static final String REFUSAL_PREFIX = "[[REFUSAL]]: "; + + /** + * Result of splitting a text part into its non-refusal content and refusal content. Either + * component may be {@code null} when absent. + */ + record RefusalSplit(@Nullable String content, @Nullable String refusal) {} + + /** + * Splits a text Part value into a content portion and a refusal portion based on the {@link + * #REFUSAL_PREFIX} sentinel: + * + * + * + * @param text the raw text from a {@link Part#text()}. + * @return a {@link RefusalSplit} with the content and refusal portions. + */ + static RefusalSplit parseRefusalPrefix(String text) { + Objects.requireNonNull(text, "text cannot be null"); + if (text.startsWith(REFUSAL_PREFIX)) { + return new RefusalSplit(null, text.substring(REFUSAL_PREFIX.length())); + } + String separator = "\n" + REFUSAL_PREFIX; + int index = text.indexOf(separator); + if (index >= 0) { + String before = text.substring(0, index); + String after = text.substring(index + separator.length()); + return new RefusalSplit(before.isEmpty() ? null : before, after); + } + return new RefusalSplit(text, null); + } + /** * See * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_tool_call%20%3E%20(schema) diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java new file mode 100644 index 000000000..5b2b03a33 --- /dev/null +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsHttpClient.java @@ -0,0 +1,256 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.chat; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.HttpOptions; +import io.reactivex.rxjava3.core.BackpressureStrategy; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.FlowableEmitter; +import java.io.IOException; +import java.time.Duration; +import java.util.Map; +import java.util.Objects; +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.HttpUrl; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Request; +import okhttp3.RequestBody; +import okhttp3.Response; +import okhttp3.ResponseBody; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An HTTP client for interacting with OpenAI-compatible chat completions endpoints. + * + *

Supports both non-streaming responses (single {@link LlmResponse} emission) and streaming + * Server-Sent Events (SSE) responses (multiple incremental {@link LlmResponse} emissions). See the + * OpenAI Chat Completions API + * reference for the wire protocol. + */ +public class ChatCompletionsHttpClient { + private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsHttpClient.class); + private static final ObjectMapper objectMapper = JsonBaseModel.getMapper(); + + private static final MediaType JSON = MediaType.get("application/json; charset=utf-8"); + + /** + * Default OkHttp call timeout used when the caller does not supply an {@link HttpOptions} + * timeout. Five minutes is long enough for most non-streaming completions and short enough to + * prevent indefinite hangs in the common case where the caller does not configure timeouts. + * Callers who need infinite (e.g. long batch jobs or open streams) can opt in by passing an + * {@link HttpOptions} with {@code timeout() == 0}. + */ + private static final Duration DEFAULT_CALL_TIMEOUT = Duration.ofMinutes(5); + + /** + * Shared OkHttpClient instance whose connection pool and thread dispatcher are reused across all + * {@link ChatCompletionsHttpClient} instances. Each instance forks this client via {@link + * OkHttpClient#newBuilder()} to apply per-instance timeouts without leaking pools. + */ + private static final OkHttpClient SHARED_POOL_CLIENT = new OkHttpClient(); + + private final OkHttpClient client; + private final HttpUrl completionsUrl; + private final ImmutableMap headers; + + /** + * Constructs a new {@link ChatCompletionsHttpClient} that facilitates API interaction with the + * standard {@code /chat/completions} REST endpoint. + * + *

All configuration is sourced from the supplied {@link HttpOptions}: + * + *

+ * + *

Example: + * + *

{@code
+   * HttpOptions options =
+   *     HttpOptions.builder()
+   *         .baseUrl("https://example.com/v1/")
+   *         .headers(ImmutableMap.of("Authorization", "Bearer my-token"))
+   *         .timeout(30_000)
+   *         .build();
+   * ChatCompletionsHttpClient client = new ChatCompletionsHttpClient(options);
+   * }
+ * + * @param httpOptions HTTP configuration. Must not be {@code null}, and {@link + * HttpOptions#baseUrl()} must be present and parseable as an HTTP(S) URL. + * @throws IllegalArgumentException if {@code httpOptions.baseUrl()} is missing or is not a valid + * HTTP(S) URL. + */ + public ChatCompletionsHttpClient(HttpOptions httpOptions) { + Objects.requireNonNull(httpOptions, "httpOptions cannot be null"); + String baseUrl = + httpOptions + .baseUrl() + .orElseThrow(() -> new IllegalArgumentException("httpOptions.baseUrl() must be set")); + HttpUrl parsedBaseUrl = HttpUrl.parse(baseUrl); + if (parsedBaseUrl == null) { + throw new IllegalArgumentException( + "httpOptions.baseUrl() is not a valid HTTP(S) URL: " + baseUrl); + } + // Pre-build the completions URL once. HttpUrl.addPathSegment handles trailing slashes, + // percent-encoding, and existing path components on baseUrl deterministically. + this.completionsUrl = + parsedBaseUrl.newBuilder().addPathSegment("chat").addPathSegment("completions").build(); + // Defensive copy of caller-supplied headers; absent is treated as no extra headers. + this.headers = + httpOptions + .headers() + .>map(ImmutableMap::copyOf) + .orElse(ImmutableMap.of()); + + // Apply custom timeouts per instance. All internal timeouts are bounded by callTimeout. + OkHttpClient.Builder builder = SHARED_POOL_CLIENT.newBuilder(); + builder.connectTimeout(Duration.ZERO); + builder.readTimeout(Duration.ZERO); + builder.writeTimeout(Duration.ZERO); + builder.callTimeout(resolveCallTimeout(httpOptions)); + this.client = builder.build(); + } + + /** Resolves the call timeout from HttpOptions. */ + private static Duration resolveCallTimeout(HttpOptions httpOptions) { + if (httpOptions.timeout().isEmpty()) { + return DEFAULT_CALL_TIMEOUT; + } + long timeoutMs = httpOptions.timeout().get(); + // 0 is treated as no timeout (Duration.ZERO). + return timeoutMs == 0L ? Duration.ZERO : Duration.ofMillis(timeoutMs); + } + + /** + * Generates a conversational response from the chat completions endpoint based on the provided + * messages. This encapsulates building the HTTP payload, sending the request to the completions + * endpoint, and initiating the handling of complete calls. + * + * @param llmRequest The request containing the model, configuration, and sequence of messages. + * @param stream Whether to request a streaming response. + * @return A {@link Flowable} emitting the discrete (or combined) {@link LlmResponse} objects. + */ + public Flowable complete(LlmRequest llmRequest, boolean stream) { + return Flowable.defer( + () -> { + ChatCompletionsRequest dtoRequest = + ChatCompletionsRequest.fromLlmRequest(llmRequest, stream); + String jsonPayload = objectMapper.writeValueAsString(dtoRequest); + logger.trace( + "Chat Completion Request: model={}, stream={}, messagesCount={}", + dtoRequest.model, + dtoRequest.stream, + dtoRequest.messages != null ? dtoRequest.messages.size() : 0); + + Request.Builder requestBuilder = + new Request.Builder().url(completionsUrl).post(RequestBody.create(jsonPayload, JSON)); + + for (Map.Entry entry : headers.entrySet()) { + requestBuilder.addHeader(entry.getKey(), entry.getValue()); + } + // Defensively force Content-Type to JSON by replacing instead of appending. + requestBuilder.header("Content-Type", JSON.toString()); + + Request request = requestBuilder.build(); + if (stream) { + return createStreamingFlowable(request); + } else { + return createNonStreamingFlowable(request); + } + }); + } + + /** Placeholder for streaming responses. Errors with {@link UnsupportedOperationException}. */ + @SuppressWarnings("UnusedVariable") + private Flowable createStreamingFlowable(Request request) { + return Flowable.error( + new UnsupportedOperationException("Streaming is not yet implemented in this client.")); + } + + /** + * Wraps an OkHttp {@link Callback} in a reactive {@link Flowable} for single-turn, non-streaming + * responses. + */ + private Flowable createNonStreamingFlowable(Request request) { + return Flowable.create( + emitter -> { + Call call = client.newCall(request); + emitter.setCancellable(call::cancel); + call.enqueue(new NonStreamingCallback(emitter)); + }, + BackpressureStrategy.BUFFER); + } + + /** + * Handles OkHttp failure and success callbacks, pushing {@link LlmResponse} results to the given + * emitter. + */ + private static final class NonStreamingCallback implements Callback { + private final FlowableEmitter emitter; + + NonStreamingCallback(FlowableEmitter emitter) { + this.emitter = emitter; + } + + @Override + public void onFailure(Call call, IOException e) { + emitter.tryOnError(e); + } + + @Override + public void onResponse(Call call, Response response) { + try (ResponseBody body = response.body()) { + if (!response.isSuccessful()) { + String bodyStr = body != null ? body.string() : ""; + emitter.tryOnError( + new IOException("Unexpected code " + response + " - body: " + bodyStr)); + return; + } + if (body == null) { + emitter.tryOnError(new IOException("Empty response body")); + return; + } + + String jsonResponse = body.string(); + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(jsonResponse, ChatCompletionsResponse.ChatCompletion.class); + emitter.onNext(completion.toLlmResponse()); + emitter.onComplete(); + } catch (Exception e) { + emitter.tryOnError(e); + } + } + } +} diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java index 4b6747fb1..523c04a5a 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java @@ -21,18 +21,37 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.annotation.JsonValue; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.models.LlmRequest; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Content; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import java.util.ArrayList; +import java.util.Base64; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Data Transfer Objects for Chat Completion API requests. * + *

Can be used to translate from a {@link LlmRequest} into a {@link ChatCompletionsRequest} using + * {@link #fromLlmRequest(LlmRequest, boolean)}. + * *

See * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create */ @JsonIgnoreProperties(ignoreUnknown = true) @JsonInclude(JsonInclude.Include.NON_NULL) -final class ChatCompletionsRequest { +public final class ChatCompletionsRequest { /** * See @@ -249,6 +268,321 @@ final class ChatCompletionsRequest { @JsonProperty("extra_body") public Map extraBody; + private static final Logger logger = LoggerFactory.getLogger(ChatCompletionsRequest.class); + private static final ObjectMapper objectMapper = JsonBaseModel.getMapper(); + + /** + * Converts a standard {@link LlmRequest} into a {@link ChatCompletionsRequest} for + * /chat/completions compatible endpoints. + * + * @param llmRequest The internal source request containing contents, configuration, and tool + * definitions. + * @param responseStreaming True if the request asks for a streaming response. + * @return A populated ChatCompletionsRequest ready for JSON serialization. + */ + public static ChatCompletionsRequest fromLlmRequest( + LlmRequest llmRequest, boolean responseStreaming) { + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = llmRequest.model().orElse(""); + request.stream = responseStreaming; + if (responseStreaming) { + StreamOptions options = new StreamOptions(); + options.includeUsage = true; + request.streamOptions = options; + } + + boolean isOSeries = request.model.matches("^o\\d+(?:-.*)?$"); + + List messages = new ArrayList<>(); + + llmRequest + .config() + .flatMap(config -> processSystemInstruction(config, isOSeries)) + .ifPresent(messages::add); + + for (Content content : llmRequest.contents()) { + messages.addAll(processContent(content)); + } + + request.messages = ImmutableList.copyOf(messages); + + llmRequest + .config() + .ifPresent( + config -> { + handleConfigOptions(config, request); + handleTools(config, request); + }); + + return request; + } + + /** + * Processes the system instruction configuration and returns a mapped Message if present. + * + * @param config The content generation configuration that may contain a system instruction. + * @param isOSeries True if the target model belongs to the OpenAI o-series (e.g., o1, o3), which + * requires the "developer" role instead of the standard "system" role. + * @return An Optional containing the mapped instruction, or empty if none exists. + */ + private static Optional processSystemInstruction( + GenerateContentConfig config, boolean isOSeries) { + if (config.systemInstruction().isPresent()) { + Message systemMsg = new Message(); + systemMsg.role = isOSeries ? "developer" : "system"; + systemMsg.content = new MessageContent(config.systemInstruction().get().text()); + return Optional.of(systemMsg); + } + return Optional.empty(); + } + + /** + * Processes incoming content and returns a list of messages resulting from it. + * + * @param content The incoming content containing parts to map. + * @return A list of mapped messages. + */ + private static List processContent(Content content) { + Message msg = new Message(); + String role = content.role().orElse("user"); + msg.role = role.equals("model") ? "assistant" : role; + + List contentParts = new ArrayList<>(); + List toolCalls = new ArrayList<>(); + List toolResponses = new ArrayList<>(); + List refusals = new ArrayList<>(); + + content + .parts() + .ifPresent( + parts -> { + for (Part part : parts) { + if (part.text().isPresent()) { + // Text Parts may carry refusal content prefixed with REFUSAL_PREFIX. + ChatCompletionsCommon.RefusalSplit split = + ChatCompletionsCommon.parseRefusalPrefix(part.text().get()); + if (split.content() != null) { + ContentPart textPart = new ContentPart(); + textPart.type = "text"; + textPart.text = split.content(); + contentParts.add(textPart); + } + if (split.refusal() != null) { + refusals.add(split.refusal()); + } + } else if (part.inlineData().isPresent()) { + contentParts.add(processInlineDataPart(part)); + } else if (part.fileData().isPresent()) { + contentParts.add(processFileDataPart(part)); + } else if (part.functionCall().isPresent()) { + toolCalls.add(processFunctionCallPart(part)); + } else if (part.functionResponse().isPresent()) { + toolResponses.add(processFunctionResponsePart(part)); + } else if (part.executableCode().isPresent()) { + logger.warn("Executable code is not supported in Chat Completion conversion"); + } else if (part.codeExecutionResult().isPresent()) { + logger.warn( + "Code execution result is not supported in Chat Completion conversion"); + } + } + }); + + if (!toolResponses.isEmpty()) { + return toolResponses; + } else { + if (!toolCalls.isEmpty()) { + msg.toolCalls = ImmutableList.copyOf(toolCalls); + } + if (!refusals.isEmpty()) { + msg.refusal = String.join("\n", refusals); + } + if (!contentParts.isEmpty()) { + if (contentParts.size() == 1 && Objects.equals(contentParts.get(0).type, "text")) { + msg.content = new MessageContent(contentParts.get(0).text); + } else { + msg.content = new MessageContent(ImmutableList.copyOf(contentParts)); + } + } + List messages = new ArrayList<>(); + messages.add(msg); + return messages; + } + } + + /** + * Processes an inline data part and returns a mapped ContentPart. + * + * @param part The input part containing base64 inline data. + * @return The mapped inline data part. + */ + private static ContentPart processInlineDataPart(Part part) { + ContentPart imgPart = new ContentPart(); + imgPart.type = "image_url"; + ImageUrl imageUrl = new ImageUrl(); + imageUrl.url = + "data:" + + part.inlineData().get().mimeType().orElse("image/jpeg") + + ";base64," + + Base64.getEncoder().encodeToString(part.inlineData().get().data().get()); + imgPart.imageUrl = imageUrl; + return imgPart; + } + + /** + * Processes a file data part and returns a mapped ContentPart. + * + * @param part The input part referencing a stored file via URI. + * @return The mapped file data part. + */ + private static ContentPart processFileDataPart(Part part) { + ContentPart imgPart = new ContentPart(); + imgPart.type = "image_url"; + ImageUrl imageUrl = new ImageUrl(); + imageUrl.url = part.fileData().get().fileUri().orElse(""); + imgPart.imageUrl = imageUrl; + return imgPart; + } + + /** + * Processes a function call part and returns a mapped ToolCall. + * + * @param part The input part containing a requested function call or invocation. + * @return The mapped function call tool call. + */ + private static ChatCompletionsCommon.ToolCall processFunctionCallPart(Part part) { + com.google.genai.types.FunctionCall fc = part.functionCall().get(); + ChatCompletionsCommon.ToolCall toolCall = new ChatCompletionsCommon.ToolCall(); + toolCall.id = fc.id().orElse("call_" + fc.name().orElse("unknown")); + toolCall.type = "function"; + ChatCompletionsCommon.Function function = new ChatCompletionsCommon.Function(); + function.name = fc.name().orElse(""); + if (fc.args().isPresent()) { + try { + function.arguments = objectMapper.writeValueAsString(fc.args().get()); + } catch (Exception e) { + logger.warn("Failed to serialize function arguments", e); + } + } + toolCall.function = function; + return toolCall; + } + + /** + * Processes a function response part and returns a mapped Message. + * + * @param part The input part containing the execution results of a function. + * @return The mapped tool response message. + */ + private static Message processFunctionResponsePart(Part part) { + FunctionResponse fr = part.functionResponse().get(); + Message toolResp = new Message(); + toolResp.role = "tool"; + toolResp.toolCallId = fr.id().orElse(""); + if (fr.response().isPresent()) { + try { + toolResp.content = new MessageContent(objectMapper.writeValueAsString(fr.response().get())); + } catch (Exception e) { + logger.warn("Failed to serialize tool response", e); + } + } + return toolResp; + } + + /** + * Updates the request based on the provided configuration options. + * + * @param config The content generation configuration containing parameters such as temperature. + * @param request The chat completions request to populate with matching options. + */ + private static void handleConfigOptions( + GenerateContentConfig config, ChatCompletionsRequest request) { + config.temperature().ifPresent(v -> request.temperature = v.doubleValue()); + config.topP().ifPresent(v -> request.topP = v.doubleValue()); + config + .maxOutputTokens() + .ifPresent( + v -> { + request.maxCompletionTokens = Math.toIntExact(v); + }); + config.stopSequences().ifPresent(v -> request.stop = new StopCondition(v)); + config.candidateCount().ifPresent(v -> request.n = Math.toIntExact(v)); + config.presencePenalty().ifPresent(v -> request.presencePenalty = v.doubleValue()); + config.frequencyPenalty().ifPresent(v -> request.frequencyPenalty = v.doubleValue()); + config.seed().ifPresent(v -> request.seed = v.longValue()); + + if (config.responseJsonSchema().isPresent()) { + ResponseFormatJsonSchema format = new ResponseFormatJsonSchema(); + ResponseFormatJsonSchema.JsonSchema schema = new ResponseFormatJsonSchema.JsonSchema(); + schema.name = "response_schema"; + schema.schema = + objectMapper.convertValue( + config.responseJsonSchema().get(), new TypeReference>() {}); + schema.strict = true; + format.jsonSchema = schema; + request.responseFormat = format; + } else if (config.responseMimeType().isPresent() + && config.responseMimeType().get().equals("application/json")) { + request.responseFormat = new ResponseFormatJsonObject(); + } + + if (config.responseLogprobs().isPresent() && config.responseLogprobs().get()) { + request.logprobs = true; + config.logprobs().ifPresent(v -> request.topLogprobs = Math.toIntExact(v)); + } + } + + /** + * Updates the request tools list based on the provided tools configuration. + * + * @param config The content generation configuration defining available tools. + * @param request The chat completions request to populate with mapped tool definitions. + */ + private static void handleTools(GenerateContentConfig config, ChatCompletionsRequest request) { + if (config.tools().isPresent()) { + List tools = new ArrayList<>(); + for (com.google.genai.types.Tool t : config.tools().get()) { + if (t.functionDeclarations().isPresent()) { + for (FunctionDeclaration fd : t.functionDeclarations().get()) { + Tool tool = new Tool(); + tool.type = "function"; + FunctionDefinition def = new FunctionDefinition(); + def.name = fd.name().orElse(""); + def.description = fd.description().orElse(""); + fd.parameters() + .ifPresent( + params -> + def.parameters = + objectMapper.convertValue( + params, new TypeReference>() {})); + tool.function = def; + tools.add(tool); + } + } + } + if (!tools.isEmpty()) { + request.tools = ImmutableList.copyOf(tools); + if (config.toolConfig().isPresent() + && config.toolConfig().get().functionCallingConfig().isPresent()) { + config + .toolConfig() + .get() + .functionCallingConfig() + .get() + .mode() + .ifPresent( + mode -> { + switch (mode.knownEnum()) { + case ANY -> request.toolChoice = new ToolChoiceMode("required"); + case NONE -> request.toolChoice = new ToolChoiceMode("none"); + case AUTO -> request.toolChoice = new ToolChoiceMode("auto"); + default -> {} + } + }); + } + } + } + } + /** * A catch-all class for message parameters. See * https://developers.openai.com/api/reference/resources/chat/subresources/completions/methods/create#(resource)%20chat.completions%20%3E%20(method)%20create%20%3E%20(params)%200.non_streaming%20%3E%20(param)%20messages%20%3E%20(schema) diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java index 9645016a9..61e7e8358 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -50,7 +50,7 @@ public final class ChatCompletionsResponse { private ChatCompletionsResponse() {} - static @Nullable FinishReason mapFinishReason(String reason) { + static @Nullable FinishReason mapFinishReason(@Nullable String reason) { if (reason == null) { return null; } @@ -62,7 +62,7 @@ private ChatCompletionsResponse() {} }; } - static @Nullable GenerateContentResponseUsageMetadata mapUsage(Usage usage) { + static @Nullable GenerateContentResponseUsageMetadata mapUsage(@Nullable Usage usage) { if (usage == null) { return null; } @@ -180,7 +180,7 @@ private ImmutableList mapMessageToParts(Message message) { parts.add(Part.fromText(message.content)); } if (message.refusal != null) { - parts.add(Part.fromText(message.refusal)); + parts.add(Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + message.refusal)); } if (message.toolCalls != null) { parts.addAll(mapToolCallsToParts(message.toolCalls)); @@ -188,8 +188,15 @@ private ImmutableList mapMessageToParts(Message message) { return parts.build(); } + /** + * Maps a list of tool calls to a list of {@link Part} objects. + * + * @param toolCalls the list of tool calls to map (non-null). + * @return a list of parts containing converted tool calls. + */ private ImmutableList mapToolCallsToParts( List toolCalls) { + ImmutableList.Builder parts = ImmutableList.builder(); for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) { Part part = toolCall.toPart(); diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java index 5f8222e70..59e09c8a7 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPlugin.java @@ -30,7 +30,6 @@ import com.google.adk.models.LlmRequest; import com.google.adk.models.LlmResponse; import com.google.adk.plugins.BasePlugin; -import com.google.adk.plugins.agentanalytics.JsonFormatter.ParsedContent; import com.google.adk.plugins.agentanalytics.JsonFormatter.TruncationResult; import com.google.adk.plugins.agentanalytics.TraceManager.RecordData; import com.google.adk.plugins.agentanalytics.TraceManager.SpanIds; @@ -65,6 +64,7 @@ import java.util.HashMap; import java.util.Map; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.logging.Level; import java.util.logging.Logger; import org.jspecify.annotations.Nullable; @@ -184,37 +184,36 @@ private void processBigQueryException(BigQueryException e, String logMessage) { } } - private void logEvent( + private Completable logEvent( String eventType, InvocationContext invocationContext, - Object content, + @Nullable Object content, Optional eventData) { - logEvent(eventType, invocationContext, content, false, eventData); + return logEvent(eventType, invocationContext, content, false, eventData); } - private void logEvent( + private Completable logEvent( String eventType, InvocationContext invocationContext, - Object content, + @Nullable Object content, boolean isContentTruncated, Optional eventData) { if (!config.enabled()) { - return; + return Completable.complete(); } if (!config.eventAllowlist().isEmpty() && !config.eventAllowlist().contains(eventType)) { - return; + return Completable.complete(); } if (config.eventDenylist().contains(eventType)) { - return; + return Completable.complete(); } if (state.isProcessed(invocationContext.invocationId())) { - return; + return Completable.complete(); } if (config.contentFormatter() != null && content != null) { try { content = config.contentFormatter().apply(content, eventType); } catch (RuntimeException e) { - logger.log( Level.WARNING, "Failed to format content for invocation ID: " + invocationContext.invocationId(), @@ -222,8 +221,9 @@ private void logEvent( content = null; // Fail-closed to avoid leaking unmasked sensitive data } } - String invocationId = invocationContext.invocationId(); - BatchProcessor processor = state.getBatchProcessor(invocationId); + + // Resolve IDs before going async + ResolvedTraceIds traceIds = getResolvedTraceIds(invocationContext, eventData); // Ensure table exists before logging. ensureTableExistsOnce(); // Log common fields @@ -234,13 +234,9 @@ private void logEvent( row.put("session_id", invocationContext.session().id()); row.put("invocation_id", invocationContext.invocationId()); row.put("user_id", invocationContext.userId()); - // Parse and log content - if (content != null) { - ParsedContent parsedContent = JsonFormatter.parse(content, config.maxContentLength()); - row.put("content_parts", parsedContent.parts()); - row.put("content", parsedContent.content()); - row.put("is_truncated", isContentTruncated || parsedContent.isTruncated()); - } + row.put("trace_id", traceIds.traceId()); + row.put("span_id", traceIds.spanId()); + row.put("parent_span_id", traceIds.parentSpanId()); EventData data = eventData.orElse(EventData.builder().build()); row.put("status", data.status()); @@ -252,12 +248,48 @@ private void logEvent( } row.put("attributes", convertToJsonNode(getAttributes(data, invocationContext))); - addTraceDetails(row, invocationContext, eventData); - processor.append(row); + CompletableFuture parseFuture; + if (content != null) { + parseFuture = + state + .getParser() + .parse(content) + .thenAccept( + parsedContent -> { + row.put( + "content_parts", + config.logMultiModalContent() ? parsedContent.parts() : ImmutableList.of()); + row.put("content", parsedContent.content()); + row.put("is_truncated", isContentTruncated || parsedContent.isTruncated()); + }) + .exceptionally( + ex -> { + logger.log( + Level.WARNING, + "Failed to parse content for invocation ID: " + + invocationContext.invocationId(), + ex); + row.put("content", "Failed to parse content."); + row.put("content_parts", ImmutableList.of()); + row.put("is_truncated", true); + return null; + }); + } else { + parseFuture = CompletableFuture.completedFuture(null); + } + + CompletableFuture appendFuture = + parseFuture.thenRun( + () -> { + BatchProcessor processor = state.getBatchProcessor(invocationContext.invocationId()); + processor.append(row); + }); + state.addPendingTask(invocationContext.invocationId(), appendFuture); + return Completable.complete(); } - private void addTraceDetails( - Map row, InvocationContext invocationContext, Optional eventData) { + private ResolvedTraceIds getResolvedTraceIds( + InvocationContext invocationContext, Optional eventData) { TraceManager traceManager = state.getTraceManager(invocationContext.invocationId()); String traceId = eventData @@ -266,17 +298,17 @@ private void addTraceDetails( Optional ambientSpanIds = traceManager.getAmbientSpanAndParent(); SpanIds spanIds = ambientSpanIds.orElse(traceManager.getCurrentSpanAndParent()); - row.put("trace_id", traceId); - row.put( - "span_id", - eventData.flatMap(EventData::spanIdOverride).orElse(spanIds.spanId().orElse(null))); - row.put( - "parent_span_id", + return new ResolvedTraceIds( + traceId, + eventData.flatMap(EventData::spanIdOverride).orElse(spanIds.spanId().orElse(null)), eventData .flatMap(EventData::parentSpanIdOverride) .orElse(spanIds.parentSpanId().orElse(null))); } + private record ResolvedTraceIds( + String traceId, @Nullable String spanId, @Nullable String parentSpanId) {} + private @Nullable Map extractLatency(EventData eventData) { Map latencyMap = new HashMap<>(); eventData.latency().ifPresent(v -> latencyMap.put("total_ms", v.toMillis())); @@ -331,8 +363,7 @@ private Map getAttributes( @Override public Completable close() { - state.close(); - return Completable.complete(); + return state.close(); } @VisibleForTesting @@ -372,159 +403,139 @@ private Optional getCompletedEventData(InvocationContext invocationCo @Override public Maybe onUserMessageCallback( InvocationContext invocationContext, Content userMessage) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(invocationContext.invocationId())) { - return; - } - state - .getTraceManager(invocationContext.invocationId()) - .ensureInvocationSpan(invocationContext); - logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty()); - if (userMessage.parts().isPresent()) { - for (Part part : userMessage.parts().get()) { - if (part.functionCall().isPresent() - && HITL_EVENT_TYPES.containsKey(part.functionCall().get().name().orElse(""))) { - String hitlEvent = HITL_EVENT_TYPES.get(part.functionCall().get().name().get()); - TruncationResult truncatedResult = smartTruncate(part, config.maxContentLength()); - logEvent( - hitlEvent + "_COMPLETED", - invocationContext, - ImmutableMap.of( - "tool", - part.functionCall().get().name().get(), - "result", - truncatedResult.node()), - truncatedResult.isTruncated(), - Optional.empty()); - } - } - } - }); + if (state.isProcessed(invocationContext.invocationId())) { + return Maybe.empty(); + } + state.getTraceManager(invocationContext.invocationId()).ensureInvocationSpan(invocationContext); + Completable logCompletable = + logEvent("USER_MESSAGE_RECEIVED", invocationContext, userMessage, Optional.empty()); + + if (userMessage.parts().isPresent()) { + for (Part part : userMessage.parts().get()) { + if (part.functionCall().isPresent() + && HITL_EVENT_TYPES.containsKey(part.functionCall().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionCall().get().name().get()); + TruncationResult truncatedResult = smartTruncate(part, config.maxContentLength()); + logCompletable = + logCompletable.andThen( + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionCall().get().name().get(), + "result", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty())); + } + } + } + return logCompletable.andThen(Maybe.empty()); } @Override public Maybe onEventCallback(InvocationContext invocationContext, Event event) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(invocationContext.invocationId())) { - return; - } - EventData.Builder eventDataBuilder = - EventData.builder() - .setExtraAttributes( - ImmutableMap.builder() - .put("state_delta", event.actions().stateDelta()) - .put("author", event.author()) - .buildOrThrow()); - logEvent( - "STATE_DELTA", - invocationContext, - event.content().orElse(null), - Optional.of(eventDataBuilder.build())); - - if (event.content().isPresent() && event.content().get().parts().isPresent()) { - for (Part part : event.content().get().parts().get()) { - if (part.functionCall().isPresent() - && HITL_EVENT_TYPES.containsKey(part.functionCall().get().name().orElse(""))) { - String hitlEvent = HITL_EVENT_TYPES.get(part.functionCall().get().name().get()); - TruncationResult truncatedResult = - smartTruncate(part.functionCall().get().args(), config.maxContentLength()); - logEvent( - hitlEvent + "_COMPLETED", - invocationContext, - ImmutableMap.of( - "tool", - part.functionCall().get().name().get(), - "args", - truncatedResult.node()), - truncatedResult.isTruncated(), - Optional.empty()); - } - if (part.functionResponse().isPresent() - && HITL_EVENT_TYPES.containsKey( - part.functionResponse().get().name().orElse(""))) { - String hitlEvent = HITL_EVENT_TYPES.get(part.functionResponse().get().name().get()); - TruncationResult truncatedResult = - smartTruncate( - part.functionResponse().get().response(), config.maxContentLength()); - logEvent( - hitlEvent + "_COMPLETED", - invocationContext, - ImmutableMap.of( - "tool", - part.functionResponse().get().name().get(), - "response", - truncatedResult.node()), - truncatedResult.isTruncated(), - Optional.empty()); - } - } - } - }); + if (state.isProcessed(invocationContext.invocationId())) { + return Maybe.empty(); + } + EventData.Builder eventDataBuilder = + EventData.builder() + .setExtraAttributes( + ImmutableMap.builder() + .put("state_delta", event.actions().stateDelta()) + .put("author", event.author()) + .buildOrThrow()); + Completable logCompletable = + logEvent( + "STATE_DELTA", + invocationContext, + event.content().orElse(null), + Optional.of(eventDataBuilder.build())); + + if (event.content().isPresent() && event.content().get().parts().isPresent()) { + for (Part part : event.content().get().parts().get()) { + if (part.functionCall().isPresent() + && HITL_EVENT_TYPES.containsKey(part.functionCall().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionCall().get().name().get()); + TruncationResult truncatedResult = + smartTruncate(part.functionCall().get().args(), config.maxContentLength()); + logCompletable = + logCompletable.andThen( + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionCall().get().name().get(), + "args", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty())); + } + if (part.functionResponse().isPresent() + && HITL_EVENT_TYPES.containsKey(part.functionResponse().get().name().orElse(""))) { + String hitlEvent = HITL_EVENT_TYPES.get(part.functionResponse().get().name().get()); + TruncationResult truncatedResult = + smartTruncate(part.functionResponse().get().response(), config.maxContentLength()); + logCompletable = + logCompletable.andThen( + logEvent( + hitlEvent + "_COMPLETED", + invocationContext, + ImmutableMap.of( + "tool", + part.functionResponse().get().name().get(), + "response", + truncatedResult.node()), + truncatedResult.isTruncated(), + Optional.empty())); + } + } + } + return logCompletable.andThen(Maybe.empty()); } @Override public Maybe beforeRunCallback(InvocationContext invocationContext) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(invocationContext.invocationId())) { - return; - } - state - .getTraceManager(invocationContext.invocationId()) - .ensureInvocationSpan(invocationContext); - logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty()); - }); + if (state.isProcessed(invocationContext.invocationId())) { + return Maybe.empty(); + } + state.getTraceManager(invocationContext.invocationId()).ensureInvocationSpan(invocationContext); + return logEvent("INVOCATION_STARTING", invocationContext, null, Optional.empty()) + .andThen(Maybe.empty()); } @Override public Completable afterRunCallback(InvocationContext invocationContext) { - return Completable.fromAction( - () -> { - logEvent( - "INVOCATION_COMPLETED", - invocationContext, - null, - getCompletedEventData(invocationContext)); - // Mark invocation ID as processed to avoid memory leaks. - state.markProcessed(invocationContext.invocationId()); - BatchProcessor processor = state.removeProcessor(invocationContext.invocationId()); - if (processor != null) { - processor.flush(); - processor.close(); - } - TraceManager traceManager = state.removeTraceManager(invocationContext.invocationId()); - if (traceManager != null) { - traceManager.clearStack(); - } - }); + return logEvent( + "INVOCATION_COMPLETED", + invocationContext, + null, + getCompletedEventData(invocationContext)) + .andThen(state.ensureInvocationCompleted(invocationContext.invocationId())); } @Override public Maybe beforeAgentCallback(BaseAgent agent, CallbackContext callbackContext) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(callbackContext.invocationContext().invocationId())) { - return; - } - state - .getTraceManager(callbackContext.invocationContext().invocationId()) - .pushSpan("agent:" + agent.name()); - logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty()); - }); + if (state.isProcessed(callbackContext.invocationContext().invocationId())) { + return Maybe.empty(); + } + state + .getTraceManager(callbackContext.invocationContext().invocationId()) + .pushSpan("agent:" + agent.name()); + return logEvent("AGENT_STARTING", callbackContext.invocationContext(), null, Optional.empty()) + .andThen(Maybe.empty()); } @Override public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callbackContext) { - return Maybe.fromAction( - () -> { - logEvent( - "AGENT_COMPLETED", - callbackContext.invocationContext(), - null, - getCompletedEventData(callbackContext.invocationContext())); - }); + return logEvent( + "AGENT_COMPLETED", + callbackContext.invocationContext(), + null, + getCompletedEventData(callbackContext.invocationContext())) + .andThen(Maybe.empty()); } /** @@ -538,228 +549,204 @@ public Maybe afterAgentCallback(BaseAgent agent, CallbackContext callba @Override public Maybe beforeModelCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(callbackContext.invocationContext().invocationId())) { - return; - } - Map attributes = new HashMap<>(); - Map llmConfig = new HashMap<>(); - LlmRequest req = llmRequest.build(); - if (req.config().isPresent()) { - if (req.config().get().temperature().isPresent()) { - llmConfig.put("temperature", req.config().get().temperature().get()); - } - if (req.config().get().topP().isPresent()) { - llmConfig.put("top_p", req.config().get().topP().get()); - } - if (req.config().get().topK().isPresent()) { - llmConfig.put("top_k", req.config().get().topK().get()); - } - if (req.config().get().candidateCount().isPresent()) { - llmConfig.put("candidate_count", req.config().get().candidateCount().get()); - } - if (req.config().get().maxOutputTokens().isPresent()) { - llmConfig.put("max_output_tokens", req.config().get().maxOutputTokens().get()); - } - if (req.config().get().stopSequences().isPresent()) { - llmConfig.put("stop_sequences", req.config().get().stopSequences().get()); - } - if (req.config().get().presencePenalty().isPresent()) { - llmConfig.put("presence_penalty", req.config().get().presencePenalty().get()); - } - if (req.config().get().frequencyPenalty().isPresent()) { - llmConfig.put("frequency_penalty", req.config().get().frequencyPenalty().get()); - } - if (req.config().get().responseMimeType().isPresent()) { - llmConfig.put("response_mime_type", req.config().get().responseMimeType().get()); - } - if (req.config().get().responseSchema().isPresent()) { - llmConfig.put("response_schema", req.config().get().responseSchema().get()); - } - if (req.config().get().seed().isPresent()) { - llmConfig.put("seed", req.config().get().seed().get()); - } - if (req.config().get().responseLogprobs().isPresent()) { - llmConfig.put("response_logprobs", req.config().get().responseLogprobs().get()); - } - if (req.config().get().logprobs().isPresent()) { - llmConfig.put("logprobs", req.config().get().logprobs().get()); - } - // Put labels in attributes instead of LLM config. - if (req.config().get().labels().isPresent()) { - attributes.put("labels", req.config().get().labels().get()); - } - } - if (!llmConfig.isEmpty()) { - attributes.put("llm_config", llmConfig); - } - if (!req.tools().isEmpty()) { - attributes.put("tools", req.tools().keySet()); - } - EventData eventData = - EventData.builder() - .setModel(req.model().orElse("")) - .setExtraAttributes(attributes) - .build(); - state - .getTraceManager(callbackContext.invocationContext().invocationId()) - .pushSpan("llm_request"); - logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData)); - }); + if (state.isProcessed(callbackContext.invocationContext().invocationId())) { + return Maybe.empty(); + } + Map attributes = new HashMap<>(); + Map llmConfig = new HashMap<>(); + LlmRequest req = llmRequest.build(); + if (req.config().isPresent()) { + if (req.config().get().temperature().isPresent()) { + llmConfig.put("temperature", req.config().get().temperature().get()); + } + if (req.config().get().topP().isPresent()) { + llmConfig.put("top_p", req.config().get().topP().get()); + } + if (req.config().get().topK().isPresent()) { + llmConfig.put("top_k", req.config().get().topK().get()); + } + if (req.config().get().candidateCount().isPresent()) { + llmConfig.put("candidate_count", req.config().get().candidateCount().get()); + } + if (req.config().get().maxOutputTokens().isPresent()) { + llmConfig.put("max_output_tokens", req.config().get().maxOutputTokens().get()); + } + if (req.config().get().stopSequences().isPresent()) { + llmConfig.put("stop_sequences", req.config().get().stopSequences().get()); + } + if (req.config().get().presencePenalty().isPresent()) { + llmConfig.put("presence_penalty", req.config().get().presencePenalty().get()); + } + if (req.config().get().frequencyPenalty().isPresent()) { + llmConfig.put("frequency_penalty", req.config().get().frequencyPenalty().get()); + } + if (req.config().get().responseMimeType().isPresent()) { + llmConfig.put("response_mime_type", req.config().get().responseMimeType().get()); + } + if (req.config().get().responseSchema().isPresent()) { + llmConfig.put("response_schema", req.config().get().responseSchema().get()); + } + if (req.config().get().seed().isPresent()) { + llmConfig.put("seed", req.config().get().seed().get()); + } + if (req.config().get().responseLogprobs().isPresent()) { + llmConfig.put("response_logprobs", req.config().get().responseLogprobs().get()); + } + if (req.config().get().logprobs().isPresent()) { + llmConfig.put("logprobs", req.config().get().logprobs().get()); + } + // Put labels in attributes instead of LLM config. + if (req.config().get().labels().isPresent()) { + attributes.put("labels", req.config().get().labels().get()); + } + } + if (!llmConfig.isEmpty()) { + attributes.put("llm_config", llmConfig); + } + if (!req.tools().isEmpty()) { + attributes.put("tools", req.tools().keySet()); + } + EventData eventData = + EventData.builder().setModel(req.model().orElse("")).setExtraAttributes(attributes).build(); + state + .getTraceManager(callbackContext.invocationContext().invocationId()) + .pushSpan("llm_request"); + return logEvent("LLM_REQUEST", callbackContext.invocationContext(), req, Optional.of(eventData)) + .andThen(Maybe.empty()); } @Override public Maybe afterModelCallback( CallbackContext callbackContext, LlmResponse llmResponse) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(callbackContext.invocationContext().invocationId())) { - return; - } - TraceManager traceManager = - state.getTraceManager(callbackContext.invocationContext().invocationId()); - // TODO(b/495809488): Add formatting of the content - ParsedContent parsedContent = - JsonFormatter.parse(llmResponse.content().orElse(null), config.maxContentLength()); - - Map usageDict = new HashMap<>(); - llmResponse - .usageMetadata() - .ifPresent( - usage -> { - usage.promptTokenCount().ifPresent(c -> usageDict.put("prompt", c)); - usage.candidatesTokenCount().ifPresent(c -> usageDict.put("completion", c)); - usage.totalTokenCount().ifPresent(c -> usageDict.put("total", c)); - }); + if (state.isProcessed(callbackContext.invocationContext().invocationId())) { + return Maybe.empty(); + } + TraceManager traceManager = + state.getTraceManager(callbackContext.invocationContext().invocationId()); + + Map usageDict = new HashMap<>(); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + usage.promptTokenCount().ifPresent(c -> usageDict.put("prompt", c)); + usage.candidatesTokenCount().ifPresent(c -> usageDict.put("completion", c)); + usage.totalTokenCount().ifPresent(c -> usageDict.put("total", c)); + }); + + InvocationContext invocationContext = callbackContext.invocationContext(); + Optional spanId = traceManager.getCurrentSpanId(); + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + String parentSpanId = spanIds.parentSpanId().orElse(null); + + boolean isPopped = false; + Duration duration = Duration.ZERO; + Duration ttft = null; + Optional startTime = Optional.empty(); + Optional firstTokenTime = Optional.empty(); + + if (spanId.isPresent()) { + traceManager.recordFirstToken(spanId.get()); + startTime = traceManager.getStartTime(spanId.get()); + firstTokenTime = traceManager.getFirstTokenTime(spanId.get()); + if (startTime.isPresent() && firstTokenTime.isPresent()) { + ttft = Duration.between(startTime.get(), firstTokenTime.get()); + } + } + + if (llmResponse.partial().orElse(false)) { + // Streaming chunk - do NOT pop span yet + if (startTime.isPresent()) { + duration = Duration.between(startTime.get(), Instant.now()); + } + } else { + // Final response - pop span + Optional popped = traceManager.popSpan(); + if (popped.isPresent()) { + spanId = Optional.of(popped.get().spanId()); + duration = popped.get().duration(); + isPopped = true; + } + } + + boolean hasAmbient = traceManager.hasAmbientSpan(); + boolean useOverride = isPopped && !hasAmbient; + + EventData.Builder eventDataBuilder = EventData.builder(); + if (!duration.isZero()) { + eventDataBuilder.setLatency(duration); + } + if (ttft != null) { + eventDataBuilder.setTimeToFirstToken(ttft); + } + llmResponse.modelVersion().ifPresent(eventDataBuilder::setModelVersion); + + if (!usageDict.isEmpty()) { + eventDataBuilder.setUsageMetadata(usageDict); + } + + if (useOverride) { + if (spanId.isPresent()) { + eventDataBuilder.setSpanIdOverride(spanId.get()); + } + if (parentSpanId != null) { + eventDataBuilder.setParentSpanIdOverride(parentSpanId); + } + } - Map contentMap = new HashMap<>(); - if (parsedContent.content() != null && !parsedContent.content().isNull()) { - contentMap.put("response", parsedContent.content()); - } - if (!usageDict.isEmpty()) { - contentMap.put("usage", usageDict); - } - - InvocationContext invocationContext = callbackContext.invocationContext(); - Optional spanId = traceManager.getCurrentSpanId(); - SpanIds spanIds = traceManager.getCurrentSpanAndParent(); - String parentSpanId = spanIds.parentSpanId().orElse(null); - - boolean isPopped = false; - Duration duration = Duration.ZERO; - Duration ttft = null; - Optional startTime = Optional.empty(); - Optional firstTokenTime = Optional.empty(); - - if (spanId.isPresent()) { - traceManager.recordFirstToken(spanId.get()); - startTime = traceManager.getStartTime(spanId.get()); - firstTokenTime = traceManager.getFirstTokenTime(spanId.get()); - if (startTime.isPresent() && firstTokenTime.isPresent()) { - ttft = Duration.between(startTime.get(), firstTokenTime.get()); - } - } - - if (llmResponse.partial().orElse(false)) { - // Streaming chunk - do NOT pop span yet - if (startTime.isPresent()) { - duration = Duration.between(startTime.get(), Instant.now()); - } - } else { - // Final response - pop span - Optional popped = traceManager.popSpan(); - if (popped.isPresent()) { - spanId = Optional.of(popped.get().spanId()); - duration = popped.get().duration(); - isPopped = true; - } - } - - boolean hasAmbient = traceManager.hasAmbientSpan(); - boolean useOverride = isPopped && !hasAmbient; - - EventData.Builder eventDataBuilder = EventData.builder(); - if (!duration.isZero()) { - eventDataBuilder.setLatency(duration); - } - if (ttft != null) { - eventDataBuilder.setTimeToFirstToken(ttft); - } - llmResponse.modelVersion().ifPresent(eventDataBuilder::setModelVersion); - - if (!usageDict.isEmpty()) { - eventDataBuilder.setUsageMetadata(usageDict); - } - - if (useOverride) { - if (spanId.isPresent()) { - eventDataBuilder.setSpanIdOverride(spanId.get()); - } - if (parentSpanId != null) { - eventDataBuilder.setParentSpanIdOverride(parentSpanId); - } - } - - logEvent( - "LLM_RESPONSE", - invocationContext, - contentMap.isEmpty() ? null : contentMap, - parsedContent.isTruncated(), - Optional.of(eventDataBuilder.build())); - }); + return logEvent( + "LLM_RESPONSE", + invocationContext, + llmResponse, + false, + Optional.of(eventDataBuilder.build())) + .andThen(Maybe.empty()); } @Override public Maybe onModelErrorCallback( CallbackContext callbackContext, LlmRequest.Builder llmRequest, Throwable error) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(callbackContext.invocationContext().invocationId())) { - return; - } - TraceManager traceManager = - state.getTraceManager(callbackContext.invocationContext().invocationId()); - InvocationContext invocationContext = callbackContext.invocationContext(); - Optional popped = traceManager.popSpan(); - String spanId = popped.map(RecordData::spanId).orElse(null); - - SpanIds spanIds = traceManager.getCurrentSpanAndParent(); - String parentSpanId = spanIds.spanId().orElse(null); - - boolean hasAmbient = traceManager.hasAmbientSpan(); - EventData.Builder eventDataBuilder = - EventData.builder().setStatus("ERROR").setErrorMessage(error.getMessage()); - if (popped.isPresent()) { - eventDataBuilder.setLatency(popped.get().duration()); - } - if (!hasAmbient) { - if (spanId != null) { - eventDataBuilder.setSpanIdOverride(spanId); - } - if (parentSpanId != null) { - eventDataBuilder.setParentSpanIdOverride(parentSpanId); - } - } - logEvent("LLM_ERROR", invocationContext, null, Optional.of(eventDataBuilder.build())); - }); + if (state.isProcessed(callbackContext.invocationContext().invocationId())) { + return Maybe.empty(); + } + TraceManager traceManager = + state.getTraceManager(callbackContext.invocationContext().invocationId()); + InvocationContext invocationContext = callbackContext.invocationContext(); + Optional popped = traceManager.popSpan(); + String spanId = popped.map(RecordData::spanId).orElse(null); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + String parentSpanId = spanIds.spanId().orElse(null); + + boolean hasAmbient = traceManager.hasAmbientSpan(); + EventData.Builder eventDataBuilder = + EventData.builder().setStatus("ERROR").setErrorMessage(error.getMessage()); + if (popped.isPresent()) { + eventDataBuilder.setLatency(popped.get().duration()); + } + if (!hasAmbient) { + if (spanId != null) { + eventDataBuilder.setSpanIdOverride(spanId); + } + if (parentSpanId != null) { + eventDataBuilder.setParentSpanIdOverride(parentSpanId); + } + } + return logEvent("LLM_ERROR", invocationContext, null, Optional.of(eventDataBuilder.build())) + .andThen(Maybe.empty()); } @Override public Maybe> beforeToolCallback( BaseTool tool, Map toolArgs, ToolContext toolContext) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(toolContext.invocationContext().invocationId())) { - return; - } - TruncationResult res = smartTruncate(toolArgs, config.maxContentLength()); - ImmutableMap contentMap = - ImmutableMap.of( - "tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", res.node()); - state.getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool"); - logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty()); - }); + if (state.isProcessed(toolContext.invocationContext().invocationId())) { + return Maybe.empty(); + } + ImmutableMap contentMap = + ImmutableMap.of("tool_origin", getToolOrigin(tool), "tool", tool.name(), "args", toolArgs); + state.getTraceManager(toolContext.invocationContext().invocationId()).pushSpan("tool"); + return logEvent("TOOL_STARTING", toolContext.invocationContext(), contentMap, Optional.empty()) + .andThen(Maybe.empty()); } @Override @@ -768,86 +755,87 @@ public Maybe> afterToolCallback( Map toolArgs, ToolContext toolContext, Map result) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(toolContext.invocationContext().invocationId())) { - return; - } - state - .getTraceManager(toolContext.invocationContext().invocationId()) - .ensureInvocationSpan(toolContext.invocationContext()); - TraceManager traceManager = - state.getTraceManager(toolContext.invocationContext().invocationId()); - Optional popped = traceManager.popSpan(); - TruncationResult truncationResult = smartTruncate(result, config.maxContentLength()); - ImmutableMap contentMap = - ImmutableMap.of( - "tool", - tool.name(), - "result", - truncationResult.node(), - "tool_origin", - getToolOrigin(tool)); - - SpanIds spanIds = traceManager.getCurrentSpanAndParent(); - boolean hasAmbient = traceManager.hasAmbientSpan(); - - EventData.Builder eventDataBuilder = EventData.builder(); - if (popped.isPresent()) { - eventDataBuilder.setLatency(popped.get().duration()); - } - if (!hasAmbient) { - popped.ifPresent(p -> eventDataBuilder.setSpanIdOverride(p.spanId())); - spanIds.spanId().ifPresent(eventDataBuilder::setParentSpanIdOverride); - } - - logEvent( - "TOOL_COMPLETED", - toolContext.invocationContext(), - contentMap, - truncationResult.isTruncated(), - Optional.of(eventDataBuilder.build())); - }); + if (state.isProcessed(toolContext.invocationContext().invocationId())) { + return Maybe.empty(); + } + state + .getTraceManager(toolContext.invocationContext().invocationId()) + .ensureInvocationSpan(toolContext.invocationContext()); + TraceManager traceManager = + state.getTraceManager(toolContext.invocationContext().invocationId()); + Optional popped = traceManager.popSpan(); + TruncationResult truncationResult = smartTruncate(result, config.maxContentLength()); + ImmutableMap contentMap = + ImmutableMap.of( + "tool", + tool.name(), + "result", + truncationResult.node(), + "tool_origin", + getToolOrigin(tool)); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + boolean hasAmbient = traceManager.hasAmbientSpan(); + + EventData.Builder eventDataBuilder = EventData.builder(); + if (popped.isPresent()) { + eventDataBuilder.setLatency(popped.get().duration()); + } + if (!hasAmbient) { + popped.ifPresent(p -> eventDataBuilder.setSpanIdOverride(p.spanId())); + spanIds.spanId().ifPresent(eventDataBuilder::setParentSpanIdOverride); + } + + return logEvent( + "TOOL_COMPLETED", + toolContext.invocationContext(), + contentMap, + truncationResult.isTruncated(), + Optional.of(eventDataBuilder.build())) + .andThen(Maybe.empty()); } @Override public Maybe> onToolErrorCallback( BaseTool tool, Map toolArgs, ToolContext toolContext, Throwable error) { - return Maybe.fromAction( - () -> { - if (state.isProcessed(toolContext.invocationContext().invocationId())) { - return; - } - TraceManager traceManager = - state.getTraceManager(toolContext.invocationContext().invocationId()); - Optional popped = traceManager.popSpan(); - TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength()); - String toolOrigin = getToolOrigin(tool); - ImmutableMap contentMap = - ImmutableMap.of( - "tool", tool.name(), "args", truncationResult.node(), "tool_origin", toolOrigin); - - SpanIds spanIds = traceManager.getCurrentSpanAndParent(); - boolean hasAmbient = traceManager.hasAmbientSpan(); - - EventData.Builder eventDataBuilder = - EventData.builder().setStatus("ERROR").setErrorMessage(error.getMessage()); - - if (popped.isPresent()) { - eventDataBuilder.setLatency(popped.get().duration()); - } - if (!hasAmbient) { - popped.ifPresent(p -> eventDataBuilder.setSpanIdOverride(p.spanId())); - spanIds.spanId().ifPresent(eventDataBuilder::setParentSpanIdOverride); - } - - logEvent( - "TOOL_ERROR", - toolContext.invocationContext(), - contentMap, - truncationResult.isTruncated(), - Optional.of(eventDataBuilder.build())); - }); + if (state.isProcessed(toolContext.invocationContext().invocationId())) { + return Maybe.empty(); + } + state + .getTraceManager(toolContext.invocationContext().invocationId()) + .ensureInvocationSpan(toolContext.invocationContext()); + TraceManager traceManager = + state.getTraceManager(toolContext.invocationContext().invocationId()); + Optional popped = traceManager.popSpan(); + + TruncationResult truncationResult = smartTruncate(toolArgs, config.maxContentLength()); + ImmutableMap contentMap = + ImmutableMap.builder() + .put("tool", tool.name()) + .put("args", truncationResult.node()) + .put("tool_origin", getToolOrigin(tool)) + .buildOrThrow(); + + SpanIds spanIds = traceManager.getCurrentSpanAndParent(); + boolean hasAmbient = traceManager.hasAmbientSpan(); + + EventData.Builder eventDataBuilder = + EventData.builder().setStatus("ERROR").setErrorMessage(error.getMessage()); + if (popped.isPresent()) { + eventDataBuilder.setLatency(popped.get().duration()); + } + if (!hasAmbient) { + popped.ifPresent(p -> eventDataBuilder.setSpanIdOverride(p.spanId())); + spanIds.spanId().ifPresent(eventDataBuilder::setParentSpanIdOverride); + } + + return logEvent( + "TOOL_ERROR", + toolContext.invocationContext(), + contentMap, + truncationResult.isTruncated(), + Optional.of(eventDataBuilder.build())) + .andThen(Maybe.empty()); } private String getToolOrigin(BaseTool tool) { diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java index b35e7c51d..92a35b7d7 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/BigQueryLoggerConfig.java @@ -59,7 +59,6 @@ public abstract class BigQueryLoggerConfig { public abstract ImmutableList clusteringFields(); // Whether to log multi-modal content. - // TODO(b/491852782): Implement logging of multi-modal content. public abstract boolean logMultiModalContent(); // Retry configuration for BigQuery writes. @@ -96,7 +95,7 @@ public abstract class BigQueryLoggerConfig { // GCS bucket name to store multi-modal content. public abstract String gcsBucketName(); - // TODO(b/491852782): Implement connection id. + // Optional BigQuery connection ID for ObjectRef columns public abstract Optional connectionId(); // Toggle for session metadata (e.g. gchat thread-id). @@ -118,8 +117,7 @@ public abstract class BigQueryLoggerConfig { // Default "v" produces views like ``v_llm_request``. public abstract String viewPrefix(); - @Nullable - public abstract Credentials credentials(); + public abstract @Nullable Credentials credentials(); public abstract Builder toBuilder(); diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java index 26f436f29..34430b861 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/JsonFormatter.java @@ -16,29 +16,20 @@ package com.google.adk.plugins.agentanalytics; -import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; -import com.google.adk.models.LlmRequest; import com.google.auto.value.AutoValue; -import com.google.common.collect.ImmutableList; -import com.google.genai.types.Blob; -import com.google.genai.types.Content; -import com.google.genai.types.FileData; -import com.google.genai.types.FunctionCall; -import com.google.genai.types.Part; -import java.util.ArrayList; -import java.util.List; +import com.google.common.base.Utf8; import java.util.Map; -import java.util.Optional; import java.util.Set; import org.jspecify.annotations.Nullable; /** Utility for parsing, formatting and truncating content for BigQuery logging. */ final class JsonFormatter { - private static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); + static final ObjectMapper mapper = new ObjectMapper().findAndRegisterModules(); + static final String TRUNCATION_SUFFIX = "...[truncated]"; @AutoValue abstract static class TruncationResult { @@ -51,254 +42,6 @@ static TruncationResult create(JsonNode node, boolean isTruncated) { } } - @AutoValue - abstract static class ParsedContent { - abstract ImmutableList parts(); - - abstract JsonNode content(); - - abstract boolean isTruncated(); - - static ParsedContent create( - ImmutableList parts, JsonNode content, boolean isTruncated) { - return new AutoValue_JsonFormatter_ParsedContent(parts, content, isTruncated); - } - } - - @AutoValue - abstract static class ParsedContentObject { - abstract ArrayNode parts(); - - abstract String summary(); - - abstract boolean isTruncated(); - - static ParsedContentObject create(ArrayNode parts, String summary, boolean isTruncated) { - return new AutoValue_JsonFormatter_ParsedContentObject(parts, summary, isTruncated); - } - } - - @AutoValue - abstract static class ContentPart { - @JsonProperty("part_index") - abstract int partIndex(); - - @JsonProperty("mime_type") - abstract @Nullable String mimeType(); - - @JsonProperty("uri") - abstract @Nullable String uri(); - - @JsonProperty("text") - abstract @Nullable String text(); - - @JsonProperty("part_attributes") - abstract String partAttributes(); - - @JsonProperty("storage_mode") - abstract String storageMode(); - - @JsonProperty("object_ref") - abstract @Nullable String objectRef(); - - static Builder builder() { - return new AutoValue_JsonFormatter_ContentPart.Builder(); - } - - @AutoValue.Builder - abstract static class Builder { - abstract Builder setPartIndex(int value); - - abstract Builder setMimeType(@Nullable String value); - - abstract Builder setUri(@Nullable String value); - - abstract Builder setText(@Nullable String value); - - abstract Builder setPartAttributes(String value); - - abstract Builder setStorageMode(String value); - - abstract Builder setObjectRef(@Nullable String value); - - abstract ContentPart build(); - } - } - - /** - * Parses content into JSON payload and content parts, matching Python implementation. - * - * @param content the content to parse - * @param maxLength the maximum length for text fields - * @return a ParsedContent object - */ - static ParsedContent parse(Object content, int maxLength) { - JsonNode contentNode = mapper.nullNode(); - ArrayNode contentParts = mapper.createArrayNode(); - boolean isTruncated = false; - - if (content instanceof LlmRequest llmRequest) { - ObjectNode jsonPayload = mapper.createObjectNode(); - // Handle prompt - ArrayNode messages = mapper.createArrayNode(); - List contents = llmRequest.contents(); - for (Content c : contents) { - String role = c.role().orElse("unknown"); - ParsedContentObject parsedContentObject = parseContentObject(c, maxLength); - isTruncated = isTruncated || parsedContentObject.isTruncated(); - contentParts.addAll(parsedContentObject.parts()); - - ObjectNode message = mapper.createObjectNode(); - message.put("role", role); - message.put("content", parsedContentObject.summary()); - messages.add(message); - } - if (!messages.isEmpty()) { - jsonPayload.set("prompt", messages); - } - // Handle system instruction - if (llmRequest.config().isPresent() - && llmRequest.config().get().systemInstruction().isPresent()) { - Content systemInstruction = llmRequest.config().get().systemInstruction().get(); - ParsedContentObject parsedSystemInstruction = - parseContentObject(systemInstruction, maxLength); - isTruncated = isTruncated || parsedSystemInstruction.isTruncated(); - contentParts.addAll(parsedSystemInstruction.parts()); - jsonPayload.put("system_prompt", parsedSystemInstruction.summary()); - } - contentNode = jsonPayload; - } else if (content instanceof Content || content instanceof Part) { - ParsedContentObject parsedContentObject = parseContentObject(content, maxLength); - ObjectNode summaryNode = mapper.createObjectNode(); - summaryNode.put("text_summary", parsedContentObject.summary()); - return ParsedContent.create( - ImmutableList.copyOf(parsedContentObject.parts()), - summaryNode, - parsedContentObject.isTruncated()); - } else if (content instanceof String s) { - TruncationResult result = truncateWithStatus(s, maxLength); - contentNode = result.node(); - isTruncated = result.isTruncated(); - } else { - TruncationResult result = smartTruncate(content, maxLength); - contentNode = result.node(); - isTruncated = result.isTruncated(); - } - return ParsedContent.create(ImmutableList.copyOf(contentParts), contentNode, isTruncated); - } - - /** - * Parses a Content or Part object into summary text and content parts. - * - * @param content the Content or Part object to parse - * @param maxLength the maximum length of text fields before truncation - * @return a ParsedContentObject containing parts, summary, and truncation flag - */ - private static ParsedContentObject parseContentObject(Object content, int maxLength) { - ArrayNode contentParts = mapper.createArrayNode(); - boolean isTruncated = false; - List summaryText = new ArrayList<>(); - - List parts; - if (content instanceof Content c) { - parts = c.parts().orElse(ImmutableList.of()); - } else if (content instanceof Part p) { - parts = ImmutableList.of(p); - } else { - return ParsedContentObject.create(contentParts, "", false); - } - - for (int i = 0; i < parts.size(); i++) { - Part part = parts.get(i); - ContentPart.Builder partBuilder = - ContentPart.builder() - .setPartIndex(i) - .setMimeType("text/plain") - .setUri(null) - .setText(null) - .setPartAttributes("{}") - .setStorageMode("INLINE") - .setObjectRef(null); - - // CASE A: It is already a URI (e.g. from user input) - if (part.fileData().isPresent()) { - FileData fileData = part.fileData().get(); - partBuilder - .setStorageMode("EXTERNAL_URI") - .setUri(fileData.fileUri().orElse(null)) - .setMimeType(fileData.mimeType().orElse(null)); - } - // CASE B: It is Binary/Inline Data (Image/Blob) - else if (part.inlineData().isPresent()) { - // TODO: (b/485571635) Implement GCS offloading here. - partBuilder - .setText("[BINARY DATA]") - .setMimeType(part.inlineData().get().mimeType().orElse("")); - } - // CASE C: Text - else if (part.text().isPresent()) { - String text = part.text().get(); - // TODO: (b/485571635) Implement GCS offloading if text length exceeds maxLength. - if (text.length() > maxLength) { - text = truncate(text, maxLength); - isTruncated = true; - } - partBuilder.setText(text); - summaryText.add(text); - } else if (part.functionCall().isPresent()) { - FunctionCall fc = part.functionCall().get(); - ObjectNode partAttributes = mapper.createObjectNode(); - partAttributes.put("function_name", fc.name().orElse("unknown")); - partBuilder - .setMimeType("application/json") - .setText("Function: " + fc.name().orElse("unknown")) - .setPartAttributes(partAttributes.toString()); - } - contentParts.add(mapper.valueToTree(partBuilder.build())); - } - - String summaryResult = String.join(" | ", summaryText); - if (summaryResult.length() > maxLength) { - summaryResult = truncate(summaryResult, maxLength); - isTruncated = true; - } - - return ParsedContentObject.create(contentParts, summaryResult, isTruncated); - } - - /** Formats Content parts into an ArrayNode for BigQuery logging. */ - static ArrayNode formatContentParts(Optional content, int maxLength) { - ArrayNode partsArray = mapper.createArrayNode(); - if (content.isEmpty() || content.get().parts() == null) { - return partsArray; - } - - List parts = content.get().parts().orElse(ImmutableList.of()); - - for (int i = 0; i < parts.size(); i++) { - Part part = parts.get(i); - ObjectNode partObj = mapper.createObjectNode(); - partObj.put("part_index", i); - partObj.put("storage_mode", "INLINE"); - - if (part.text().isPresent()) { - partObj.put("mime_type", "text/plain"); - partObj.put("text", truncate(part.text().get(), maxLength)); - } else if (part.inlineData().isPresent()) { - Blob blob = part.inlineData().get(); - partObj.put("mime_type", blob.mimeType().orElse("")); - partObj.put("text", "[BINARY DATA]"); - } else if (part.fileData().isPresent()) { - FileData fileData = part.fileData().get(); - partObj.put("mime_type", fileData.mimeType().orElse("")); - partObj.put("uri", fileData.fileUri().orElse("")); - partObj.put("storage_mode", "EXTERNAL_URI"); - } - partsArray.add(partObj); - } - return partsArray; - } - /** Recursively truncates long strings inside an object and returns a TruncationResult. */ static TruncationResult smartTruncate(Object obj, int maxLength) { if (obj == null) { @@ -328,7 +71,7 @@ private static TruncationResult recursiveSmartTruncate(JsonNode node, int maxLen boolean isTruncated = false; if (node.isTextual()) { String text = node.asText(); - if (text.length() > maxLength) { + if (Utf8.encodedLength(text) > maxLength) { return TruncationResult.create(mapper.valueToTree(truncate(text, maxLength)), true); } return TruncationResult.create(node, false); @@ -353,21 +96,59 @@ private static TruncationResult recursiveSmartTruncate(JsonNode node, int maxLen return TruncationResult.create(node, false); } - private static TruncationResult truncateWithStatus(String s, int maxLength) { + static TruncationResult truncateWithStatus(String s, int maxLength) { if (s == null) { return TruncationResult.create(mapper.nullNode(), false); } - if (s.length() <= maxLength) { + if (Utf8.encodedLength(s) <= maxLength) { return TruncationResult.create(mapper.valueToTree(s), false); } return TruncationResult.create(mapper.valueToTree(truncate(s, maxLength)), true); } - private static String truncate(String s, int maxLength) { - if (s == null || s.length() <= maxLength) { + static @Nullable String truncate(String s, int budget) { + return truncateAndAddSuffix(s, budget, TRUNCATION_SUFFIX); + } + + static @Nullable String truncateAndAddSuffix(String s, int budget, String suffix) { + if (s == null) { + return null; + } + if (Utf8.encodedLength(s) <= budget) { return s; } - return s.substring(0, maxLength) + "...[truncated]"; + int suffixBytes = Utf8.encodedLength(suffix); + int effectiveBudget = Math.max(0, budget - suffixBytes); + // Fallback in case the budget is too small + if (effectiveBudget == 0) { + return suffix.substring(0, budget); + } + + int byteCount = 0; + int charIndex = 0; + for (int i = 0; i < s.length(); ) { + int codePoint = s.codePointAt(i); + int codePointLen = Character.charCount(codePoint); + int codePointBytes; + if (codePoint < 0x80) { + codePointBytes = 1; + } else if (codePoint < 0x800) { + codePointBytes = 2; + } else if (codePoint < 0x10000) { + codePointBytes = 3; + } else { + codePointBytes = 4; + } + + if (byteCount + codePointBytes > effectiveBudget) { + break; + } + byteCount += codePointBytes; + charIndex += codePointLen; + i += codePointLen; + } + + return s.substring(0, charIndex) + suffix; } /** Converts a JsonNode to a standard Java object (Map, List, etc.). */ diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java new file mode 100644 index 000000000..5db8be46c --- /dev/null +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/Parser.java @@ -0,0 +1,382 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static com.google.adk.plugins.agentanalytics.JsonFormatter.mapper; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.smartTruncate; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.truncate; +import static com.google.adk.plugins.agentanalytics.JsonFormatter.truncateWithStatus; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.adk.plugins.agentanalytics.JsonFormatter.TruncationResult; +import com.google.auto.value.AutoValue; +import com.google.common.base.Utf8; +import com.google.common.collect.ImmutableList; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.Part; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import org.jspecify.annotations.Nullable; + +/** Utility for parsing content for BigQuery logging. */ +final class Parser { + private static final String BINARY_DATA_MESSAGE = "[BINARY DATA]"; + private final int maxLength; + + Parser(int maxLength) { + this.maxLength = maxLength; + } + + @AutoValue + abstract static class ParsedContent { + abstract ImmutableList parts(); + + abstract JsonNode content(); + + abstract boolean isTruncated(); + + static ParsedContent create( + ImmutableList parts, JsonNode content, boolean isTruncated) { + return new AutoValue_Parser_ParsedContent(parts, content, isTruncated); + } + } + + @AutoValue + abstract static class ParsedContentObject { + abstract ArrayNode parts(); + + abstract String summary(); + + abstract boolean isTruncated(); + + static ParsedContentObject create(ArrayNode parts, String summary, boolean isTruncated) { + return new AutoValue_Parser_ParsedContentObject(parts, summary, isTruncated); + } + } + + @AutoValue + abstract static class ContentPart { + @JsonProperty("part_index") + abstract int partIndex(); + + @JsonProperty("mime_type") + abstract @Nullable String mimeType(); + + @JsonProperty("uri") + abstract @Nullable String uri(); + + @JsonProperty("text") + abstract @Nullable String text(); + + @JsonProperty("part_attributes") + abstract String partAttributes(); + + @JsonProperty("storage_mode") + abstract String storageMode(); + + @JsonProperty("object_ref") + abstract @Nullable JsonNode objectRef(); + + static Builder builder() { + return new AutoValue_Parser_ContentPart.Builder(); + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setPartIndex(int value); + + abstract Builder setMimeType(@Nullable String value); + + abstract Builder setUri(@Nullable String value); + + abstract Builder setText(@Nullable String value); + + abstract Builder setPartAttributes(String value); + + abstract Builder setStorageMode(String value); + + abstract Builder setObjectRef(@Nullable JsonNode value); + + abstract ContentPart build(); + } + } + + @AutoValue + abstract static class ObjectRef { + @JsonProperty("uri") + abstract @Nullable String uri(); + + @JsonProperty("version") + abstract @Nullable String version(); + + @JsonProperty("authorizer") + abstract @Nullable String authorizer(); + + @JsonProperty("details") + abstract @Nullable JsonNode details(); + + static ObjectRef create( + @Nullable String uri, + @Nullable String version, + @Nullable String authorizer, + @Nullable JsonNode details) { + return new AutoValue_Parser_ObjectRef(uri, version, authorizer, details); + } + } + + /** + * Parses content into JSON payload and content parts, matching Python implementation. + * + * @param content the content to parse + * @return a CompletableFuture of ParsedContent object + */ + CompletableFuture parse(Object content) { + if (content instanceof LlmRequest llmRequest) { + ObjectNode jsonPayload = mapper.createObjectNode(); + ArrayNode messages = mapper.createArrayNode(); + List> futures = new ArrayList<>(); + List contents = llmRequest.contents(); + + for (Content c : contents) { + futures.add(parseContentObject(c)); + } + + CompletableFuture systemFuture = null; + if (llmRequest.config().isPresent() + && llmRequest.config().get().systemInstruction().isPresent()) { + systemFuture = parseContentObject(llmRequest.config().get().systemInstruction().get()); + futures.add(systemFuture); + } + CompletableFuture finalSystemFuture = systemFuture; + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) + .thenApply( + v -> { + boolean isTruncated = false; + ArrayNode contentParts = mapper.createArrayNode(); + for (int i = 0; i < contents.size(); i++) { + ParsedContentObject res = futures.get(i).join(); + isTruncated = isTruncated || res.isTruncated(); + contentParts.addAll(res.parts()); + + ObjectNode message = mapper.createObjectNode(); + message.put("role", contents.get(i).role().orElse("unknown")); + message.put("content", res.summary()); + messages.add(message); + } + if (!messages.isEmpty()) { + jsonPayload.set("prompt", messages); + } + if (finalSystemFuture != null) { + ParsedContentObject res = finalSystemFuture.join(); + isTruncated = isTruncated || res.isTruncated(); + contentParts.addAll(res.parts()); + jsonPayload.put("system_prompt", res.summary()); + } + return ParsedContent.create( + ImmutableList.copyOf(contentParts), jsonPayload, isTruncated); + }); + } + if (content instanceof LlmResponse llmResponse) { + ObjectNode jsonPayload = mapper.createObjectNode(); + return parseContentObject(llmResponse.content().orElse(null)) + .thenApply( + parsed -> { + ObjectNode summaryNode = mapper.createObjectNode(); + summaryNode.put("text_summary", parsed.summary()); + jsonPayload.set("response", summaryNode); + llmResponse + .usageMetadata() + .ifPresent( + usage -> { + ObjectNode usageNode = jsonPayload.putObject("usage"); + usage.promptTokenCount().ifPresent(c -> usageNode.put("prompt", c)); + usage + .candidatesTokenCount() + .ifPresent(c -> usageNode.put("completion", c)); + usage.totalTokenCount().ifPresent(c -> usageNode.put("total", c)); + }); + + return ParsedContent.create( + ImmutableList.copyOf(parsed.parts()), jsonPayload, parsed.isTruncated()); + }); + } + if (content instanceof Content || content instanceof Part) { + return parseContentObject(content) + .thenApply( + parsed -> { + ObjectNode summaryNode = mapper.createObjectNode(); + summaryNode.put("text_summary", parsed.summary()); + return ParsedContent.create( + ImmutableList.copyOf(parsed.parts()), summaryNode, parsed.isTruncated()); + }); + } + // Fallback for types that don't support multi-part content + TruncationResult result; + if (content instanceof String s) { + result = truncateWithStatus(s, maxLength); + } else { + result = smartTruncate(content, maxLength); + } + return CompletableFuture.completedFuture( + ParsedContent.create(ImmutableList.of(), result.node(), result.isTruncated())); + } + + /** + * Parses a Content or Part object into summary text and content parts. + * + * @param content the Content or Part object to parse + * @return a CompletableFuture of ParsedContentObject containing parts, summary, and truncation + * flag + */ + private CompletableFuture parseContentObject(Object content) { + List parts; + if (content instanceof Content c) { + parts = c.parts().orElse(ImmutableList.of()); + } else if (content instanceof Part p) { + parts = ImmutableList.of(p); + } else { + return CompletableFuture.completedFuture( + ParsedContentObject.create(mapper.createArrayNode(), "", false)); + } + + List> partFutures = new ArrayList<>(); + for (int i = 0; i < parts.size(); i++) { + partFutures.add(processPart(parts.get(i), i)); + } + + return CompletableFuture.allOf(partFutures.toArray(new CompletableFuture[0])) + .thenApply( + v -> { + ArrayNode contentParts = mapper.createArrayNode(); + List summaries = new ArrayList<>(); + boolean isTruncated = false; + + for (CompletableFuture future : partFutures) { + TruncationResult res = future.join(); + contentParts.add(res.node()); + isTruncated = isTruncated || res.isTruncated(); + JsonNode textNode = res.node().get("text"); + if (textNode != null && !textNode.isNull()) { + summaries.add(textNode.asText()); + } + } + + String summary = String.join(" | ", summaries); + if (Utf8.encodedLength(summary) > maxLength) { + summary = truncate(summary, maxLength); + isTruncated = true; + } + + return ParsedContentObject.create(contentParts, summary, isTruncated); + }); + } + + private CompletableFuture processPart(Part part, int index) { + ContentPart.Builder partBuilder = + ContentPart.builder() + .setPartIndex(index) + .setMimeType("text/plain") + .setUri(null) + .setText(null) + .setPartAttributes("{}") + .setStorageMode("INLINE") + .setObjectRef(null); + + // CASE A: It is already a URI (e.g. from user input) + if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partBuilder + .setStorageMode("EXTERNAL_URI") + .setUri(fileData.fileUri().orElse(null)) + .setMimeType(fileData.mimeType().orElse(null)); + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), false)); + } + // CASE B: It is Binary/Inline Data (Image/Blob) + if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + String mimeType = blob.mimeType().orElse("application/octet-stream"); + partBuilder.setText(BINARY_DATA_MESSAGE).setMimeType(mimeType); + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), false)); + } + // CASE C: Text + if (part.text().isPresent()) { + String text = part.text().get(); + TruncationResult res = truncateWithStatus(text, maxLength); + partBuilder.setText(res.node().asText()); + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), res.isTruncated())); + } + if (part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + ObjectNode partAttributes = mapper.createObjectNode(); + partAttributes.put("function_name", fc.name().orElse("unknown")); + partBuilder + .setMimeType("application/json") + .setText("Function: " + fc.name().orElse("unknown")) + .setPartAttributes(partAttributes.toString()); + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), false)); + } + return CompletableFuture.completedFuture( + TruncationResult.create(mapper.valueToTree(partBuilder.build()), false)); + } + + /** Formats Content parts into an ArrayNode for BigQuery logging. */ + ArrayNode formatContentParts(Optional content) { + ArrayNode partsArray = mapper.createArrayNode(); + if (content.isEmpty()) { + return partsArray; + } + + List parts = content.get().parts().orElse(ImmutableList.of()); + + for (int i = 0; i < parts.size(); i++) { + Part part = parts.get(i); + ObjectNode partObj = mapper.createObjectNode(); + partObj.put("part_index", i); + partObj.put("storage_mode", "INLINE"); + + if (part.text().isPresent()) { + partObj.put("mime_type", "text/plain"); + partObj.put("text", truncate(part.text().get(), maxLength)); + } else if (part.inlineData().isPresent()) { + Blob blob = part.inlineData().get(); + partObj.put("mime_type", blob.mimeType().orElse("")); + partObj.put("text", BINARY_DATA_MESSAGE); + } else if (part.fileData().isPresent()) { + FileData fileData = part.fileData().get(); + partObj.put("mime_type", fileData.mimeType().orElse("")); + partObj.put("uri", fileData.fileUri().orElse("")); + partObj.put("storage_mode", "EXTERNAL_URI"); + } + partsArray.add(partObj); + } + return partsArray; + } +} diff --git a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java index 63c60c491..0654fab5d 100644 --- a/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java +++ b/core/src/main/java/com/google/adk/plugins/agentanalytics/PluginState.java @@ -1,6 +1,7 @@ package com.google.adk.plugins.agentanalytics; import static com.google.adk.plugins.agentanalytics.BigQueryUtils.getVersionHeaderValue; +import static com.google.common.collect.ImmutableList.toImmutableList; import static java.util.concurrent.TimeUnit.MILLISECONDS; import com.google.api.gax.core.FixedCredentialsProvider; @@ -13,15 +14,20 @@ import com.google.common.base.VerifyException; import com.google.common.cache.Cache; import com.google.common.cache.CacheBuilder; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.reactivex.rxjava3.core.Completable; import java.io.IOException; import java.util.Collection; +import java.util.Set; import java.util.UUID; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; -import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Level; import java.util.logging.Logger; import org.threeten.bp.Duration; @@ -39,12 +45,15 @@ class PluginState { private final ConcurrentHashMap traceManagers = new ConcurrentHashMap<>(); // Cache of invocation ID to Boolean indicating invocation ID has been processed. private final Cache processedInvocations; + private final Parser parser; + private final ConcurrentHashMap>> pendingTasks = + new ConcurrentHashMap<>(); PluginState(BigQueryLoggerConfig config) throws IOException { this.config = config; - ThreadFactory threadFactory = - r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement()); - this.executor = Executors.newScheduledThreadPool(1, threadFactory); + this.executor = + Executors.newScheduledThreadPool( + 2, r -> new Thread(r, "bq-analytics-plugin-" + threadCounter.getAndIncrement())); // One write client per plugin instance, shared by all invocations. this.writeClient = createWriteClient(config); this.processedInvocations = @@ -52,6 +61,7 @@ class PluginState { .maximumSize(10000) .expireAfterWrite(java.time.Duration.ofMinutes(10)) .build(); + this.parser = new Parser(config.maxContentLength()); } ScheduledExecutorService getExecutor() { @@ -132,6 +142,10 @@ BatchProcessor getBatchProcessor(String invocationId) { }); } + Parser getParser() { + return parser; + } + @VisibleForTesting Collection getTraceManagers() { return traceManagers.values(); @@ -160,27 +174,102 @@ void clearBatchProcessors() { batchProcessors.clear(); } - void close() { - for (BatchProcessor processor : getBatchProcessors()) { - processor.close(); - } - for (TraceManager traceManager : getTraceManagers()) { - traceManager.clearStack(); - } - clearBatchProcessors(); - clearTraceManagers(); + private Set> getPendingTasksForInvocation(String invocationId) { + return pendingTasks.computeIfAbsent(invocationId, k -> ConcurrentHashMap.newKeySet()); + } - if (writeClient != null) { - writeClient.close(); + void addPendingTask(String invocationId, CompletableFuture task) { + Set> tasks = getPendingTasksForInvocation(invocationId); + tasks.add(task); + var unused = task.whenComplete((res, err) -> tasks.remove(task)); + } + + Completable ensureInvocationCompleted(String invocationId) { + Set> tasks = pendingTasks.get(invocationId); + Completable tasksState = Completable.complete(); + if (tasks != null && !tasks.isEmpty()) { + tasksState = + Completable.fromCompletionStage( + CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0]))); } - try { - executor.shutdown(); - if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { - executor.shutdownNow(); - } - } catch (InterruptedException e) { - executor.shutdownNow(); - Thread.currentThread().interrupt(); + logger.info("Waiting for pending tasks to complete for invocation ID: " + invocationId); + return tasksState + .timeout(config.shutdownTimeout().toMillis(), MILLISECONDS) + .doOnError( + e -> { + if (e instanceof TimeoutException) { + logger.log( + Level.WARNING, + "Timeout while waiting for pending tasks to complete for invocation ID: " + + invocationId, + e); + } + }) + .onErrorComplete() + .doFinally( + () -> { + // Mark invocation ID as processed to avoid memory leaks. + markProcessed(invocationId); + BatchProcessor processor = removeProcessor(invocationId); + if (processor != null) { + processor.flush(); + processor.close(); + } + TraceManager traceManager = removeTraceManager(invocationId); + if (traceManager != null) { + traceManager.clearStack(); + } + logger.info("Removing pending tasks for invocation ID: " + invocationId); + pendingTasks.remove(invocationId); + }); + } + + Completable close() { + ImmutableList> tasks = + pendingTasks.values().stream().flatMap(Set::stream).collect(toImmutableList()); + Completable tasksState = Completable.complete(); + if (tasks != null && !tasks.isEmpty()) { + tasksState = + Completable.fromCompletionStage( + CompletableFuture.allOf(tasks.toArray(new CompletableFuture[0]))); } + return tasksState + .timeout(config.shutdownTimeout().toMillis(), MILLISECONDS) + .doOnError( + e -> { + if (e instanceof TimeoutException) { + logger.log( + Level.WARNING, "Timeout while waiting for pending tasks to complete.", e); + } + }) + .onErrorComplete() + .doFinally( + () -> { + for (BatchProcessor processor : getBatchProcessors()) { + processor.close(); + } + for (TraceManager traceManager : getTraceManagers()) { + traceManager.clearStack(); + } + clearBatchProcessors(); + clearTraceManagers(); + + if (writeClient != null) { + try { + writeClient.close(); + } catch (RuntimeException e) { + logger.log(Level.WARNING, "Failed to close BigQueryWriteClient", e); + } + } + try { + executor.shutdown(); + if (!executor.awaitTermination(config.shutdownTimeout().toMillis(), MILLISECONDS)) { + executor.shutdownNow(); + } + } catch (InterruptedException e) { + executor.shutdownNow(); + Thread.currentThread().interrupt(); + } + }); } } diff --git a/core/src/main/java/com/google/adk/sessions/State.java b/core/src/main/java/com/google/adk/sessions/State.java index 577559f85..9a7042a72 100644 --- a/core/src/main/java/com/google/adk/sessions/State.java +++ b/core/src/main/java/com/google/adk/sessions/State.java @@ -21,6 +21,7 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -46,16 +47,24 @@ public State(Map state) { public State(Map state, @Nullable Map delta) { Objects.requireNonNull(state, "state is null"); - this.state = - state instanceof ConcurrentMap - ? (ConcurrentMap) state - : new ConcurrentHashMap<>(state); - this.delta = - delta == null - ? new ConcurrentHashMap<>() - : delta instanceof ConcurrentMap - ? (ConcurrentMap) delta - : new ConcurrentHashMap<>(delta); + this.state = toConcurrentMap(state); + this.delta = delta == null ? new ConcurrentHashMap<>() : toConcurrentMap(delta); + } + + /** + * Converts a map to a concurrent map. Null values are converted to {@link #REMOVED} to avoid + * NPEs. + * + *

If the map is already a concurrent map, it is returned as is. Otherwise, a new concurrent + * map is created and returned. + */ + private static ConcurrentMap toConcurrentMap(Map map) { + if (map instanceof ConcurrentMap) { + return (ConcurrentMap) map; + } + ConcurrentMap concurrentMap = new ConcurrentHashMap<>(); + map.forEach((key, value) -> concurrentMap.put(key, Optional.ofNullable(value).orElse(REMOVED))); + return concurrentMap; } @Override diff --git a/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java b/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java new file mode 100644 index 000000000..aca399f92 --- /dev/null +++ b/core/src/main/java/com/google/adk/skills/AbstractSkillSource.java @@ -0,0 +1,181 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +import static java.nio.channels.Channels.newReader; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.ByteSource; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Single; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; + +/** + * Abstract base class for SkillSource implementations that load skills from path like object. + * + * @param the type of path object + */ +public abstract class AbstractSkillSource implements SkillSource { + + private static final String THREE_DASHES = "---"; + private static final ObjectMapper yamlMapper = new ObjectMapper(new YAMLFactory()); + + /** A container class that holds a skill's name and the path to its SKILL.md file. */ + protected final class SkillMdPath { + + private final String name; + private final PathT mdPath; + + /** + * Constructs a {@code SkillMdPath}. + * + * @param name the name of the skill + * @param mdPath the path to the SKILL.md file + */ + @SuppressWarnings("ProtectedMembersInFinalClass") + protected SkillMdPath(String name, PathT mdPath) { + this.name = name; + this.mdPath = mdPath; + } + } + + @Override + public Single> listFrontmatters() { + return listSkills() + .map(skillMdPath -> loadFrontmatter(skillMdPath.name, skillMdPath.mdPath)) + .collectInto( + ImmutableMap.builder(), + (builder, frontmatter) -> builder.put(frontmatter.name(), frontmatter)) + .map(ImmutableMap.Builder::buildOrThrow); + } + + @Override + public Single loadFrontmatter(String skillName) { + return findSkillMdPath(skillName).map(path -> loadFrontmatter(skillName, path)); + } + + private Frontmatter loadFrontmatter(String skillName, PathT skillMdPath) + throws SkillSourceException { + try (BufferedReader reader = openReader(skillMdPath)) { + String yaml = readFrontmatterYaml(reader); + Frontmatter frontmatter = yamlMapper.readValue(yaml, Frontmatter.class); + if (!frontmatter.name().equals(skillName)) { + throw new SkillSourceException( + "Skill name '%s' does not match directory name '%s'." + .formatted(frontmatter.name(), skillName)); + } + return frontmatter; + } catch (IOException e) { + throw new SkillSourceException("Cannot load frontmatter for skill '" + skillName + "'", e); + } + } + + @Override + public Single loadInstructions(String skillName) { + return findSkillMdPath(skillName) + .map( + skillMdPath -> { + try (BufferedReader reader = openReader(skillMdPath)) { + return readInstructions(reader); + } catch (IOException e) { + throw new SkillSourceException( + "Failed to load instruction for skill '" + skillName + "'", e); + } + }); + } + + @Override + public Single loadResource(String skillName, String resourcePath) { + return findResourcePath(skillName, resourcePath) + .map( + path -> + new ByteSource() { + @Override + public InputStream openStream() throws IOException { + return Channels.newInputStream(AbstractSkillSource.this.openChannel(path)); + } + }); + } + + /** + * Returns a {@link Flowable} of skills as a pair of skill name and the path to the SKILL.md file. + */ + protected abstract Flowable listSkills(); + + /** Returns the path to the SKILL.md file for the given skill. */ + protected abstract Single findSkillMdPath(String skillName); + + /** Returns the path to the resource for the given skill. */ + protected abstract Single findResourcePath(String skillName, String resourcePath); + + /** Opens a {@link InputStream} for reading the content of the given path. */ + protected abstract ReadableByteChannel openChannel(PathT path) throws IOException; + + private BufferedReader openReader(PathT path) throws IOException { + return new BufferedReader(newReader(openChannel(path), UTF_8)); + } + + private String readFrontmatterYaml(BufferedReader reader) + throws IOException, SkillSourceException { + String line = reader.readLine(); + if (line == null || !line.trim().equals(THREE_DASHES)) { + throw new SkillSourceException("Skill file must start with " + THREE_DASHES); + } + + StringBuilder sb = new StringBuilder(); + while ((line = reader.readLine()) != null) { + if (line.trim().equals(THREE_DASHES)) { + return sb.toString(); + } + sb.append(line).append("\n"); + } + throw new SkillSourceException( + "Skill file frontmatter not properly closed with " + THREE_DASHES); + } + + private String readInstructions(BufferedReader reader) throws IOException, SkillSourceException { + // Skip the frontmatter block + String line = reader.readLine(); + if (line == null || !line.trim().equals(THREE_DASHES)) { + throw new SkillSourceException("Skill file must start with " + THREE_DASHES); + } + boolean dashClosed = false; + while ((line = reader.readLine()) != null) { + if (line.trim().equals(THREE_DASHES)) { + dashClosed = true; + break; + } + } + if (!dashClosed) { + throw new SkillSourceException( + "Skill file frontmatter not properly closed with " + THREE_DASHES); + } + // Read the instructions till the end of the file + StringBuilder sb = new StringBuilder(); + while ((line = reader.readLine()) != null) { + sb.append(line).append("\n"); + } + return sb.toString().trim(); + } +} diff --git a/core/src/main/java/com/google/adk/skills/Frontmatter.java b/core/src/main/java/com/google/adk/skills/Frontmatter.java new file mode 100644 index 000000000..6f9b56e9e --- /dev/null +++ b/core/src/main/java/com/google/adk/skills/Frontmatter.java @@ -0,0 +1,146 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +import com.fasterxml.jackson.annotation.JsonAlias; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.google.adk.JsonBaseModel; +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableMap; +import com.google.common.escape.Escaper; +import com.google.common.html.HtmlEscapers; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import java.util.Map; +import java.util.Optional; +import java.util.regex.Pattern; + +/** + * Frontmatter represents the YAML metadata at the top of a SKILL.md file. For more details, see + * https://agentskills.io/specification#frontmatter. + */ +@AutoValue +@JsonDeserialize(builder = Frontmatter.Builder.class) +@JsonIgnoreProperties(ignoreUnknown = true) +public abstract class Frontmatter extends JsonBaseModel { + + private static final Pattern NAME_PATTERN = Pattern.compile("^[a-z0-9]+(-[a-z0-9]+)*$"); + + /** Skill name in kebab-case. */ + @JsonProperty("name") + public abstract String name(); + + /** What the skill does and when the model should use it. */ + @JsonProperty("description") + public abstract String description(); + + /** License for the skill. */ + @JsonProperty("license") + public abstract Optional license(); + + /** Compatibility information for the skill. */ + @JsonProperty("compatibility") + public abstract Optional compatibility(); + + /** A space-delimited list of tools that are pre-approved to run. */ + @JsonProperty("allowed-tools") + public abstract Optional allowedTools(); + + /** Key-value pairs for client-specific properties. */ + @JsonProperty("metadata") + public abstract ImmutableMap metadata(); + + public String toXml() { + Escaper escaper = HtmlEscapers.htmlEscaper(); + return String.format( + """ + + + %s + + + %s + + + """, + escaper.escape(name()), escaper.escape(description())); + } + + public static Builder builder() { + return new AutoValue_Frontmatter.Builder().metadata(ImmutableMap.of()); + } + + @AutoValue.Builder + public abstract static class Builder { + + @JsonCreator + private static Builder create() { + return builder(); + } + + @CanIgnoreReturnValue + @JsonProperty("name") + public abstract Builder name(String name); + + @CanIgnoreReturnValue + @JsonProperty("description") + public abstract Builder description(String description); + + @CanIgnoreReturnValue + @JsonProperty("license") + public abstract Builder license(String license); + + @CanIgnoreReturnValue + @JsonProperty("compatibility") + public abstract Builder compatibility(String compatibility); + + @CanIgnoreReturnValue + @JsonProperty("allowed-tools") + @JsonAlias({"allowed_tools"}) + public abstract Builder allowedTools(String allowedTools); + + @CanIgnoreReturnValue + @JsonProperty("metadata") + public abstract Builder metadata(Map metadata); + + abstract Frontmatter autoBuild(); + + public Frontmatter build() { + Frontmatter fm = autoBuild(); + if (fm.name().length() > 64) { + throw new IllegalArgumentException("name must be at most 64 characters"); + } + if (!NAME_PATTERN.matcher(fm.name()).matches()) { + throw new IllegalArgumentException( + "name must be lowercase kebab-case (a-z, 0-9, hyphens), with no leading, trailing, or" + + " consecutive hyphens"); + } + if (fm.description().isEmpty()) { + throw new IllegalArgumentException("description must not be empty"); + } + if (fm.description().length() > 1024) { + throw new IllegalArgumentException("description must be at most 1024 characters"); + } + if (fm.compatibility().isPresent() && fm.compatibility().get().length() > 500) { + throw new IllegalArgumentException("compatibility must be at most 500 characters"); + } + return fm; + } + } +} diff --git a/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java b/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java new file mode 100644 index 000000000..42916e36a --- /dev/null +++ b/core/src/main/java/com/google/adk/skills/InMemorySkillSource.java @@ -0,0 +1,177 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Maps; +import com.google.common.io.ByteSource; +import com.google.errorprone.annotations.CanIgnoreReturnValue; +import io.reactivex.rxjava3.core.Single; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +/** + * An in-memory implementation of {@link SkillSource}. + * + *

Everything is provided upfront using a builder pattern. + */ +public final class InMemorySkillSource implements SkillSource { + + private final ImmutableMap skills; + + private InMemorySkillSource(ImmutableMap skills) { + this.skills = skills; + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public Single> listFrontmatters() { + return Single.just(ImmutableMap.copyOf(Maps.transformValues(skills, SkillData::frontmatter))); + } + + @Override + public Single> listResources(String skillName, String resourceDirectory) { + SkillData data = skills.get(skillName); + if (data == null) { + return Single.error(new SkillSourceException("Skill not found: " + skillName)); + } + String prefix = + resourceDirectory.isEmpty() + ? "" + : (resourceDirectory.endsWith("/") ? resourceDirectory : resourceDirectory + "/"); + + if (!resourceDirectory.isEmpty() + && data.resources().keySet().stream().noneMatch(path -> path.startsWith(prefix))) { + return Single.error( + new SkillSourceException( + "Resource directory not found: " + resourceDirectory + " for skill: " + skillName)); + } + + return Single.just( + data.resources().keySet().stream() + .filter(path -> path.startsWith(prefix)) + .collect(toImmutableList())); + } + + @Override + public Single loadFrontmatter(String skillName) { + return getSkillData(skillName).map(SkillData::frontmatter); + } + + @Override + public Single loadInstructions(String skillName) { + return getSkillData(skillName).map(SkillData::instructions); + } + + @Override + public Single loadResource(String skillName, String resourcePath) { + return getSkillData(skillName) + .map(SkillData::resources) + .mapOptional(m -> Optional.ofNullable(m.get(resourcePath))) + .switchIfEmpty( + Single.error(new SkillSourceException("Resource not found: " + resourcePath))); + } + + private Single getSkillData(String skillName) { + SkillData data = skills.get(skillName); + if (data == null) { + return Single.error(new SkillSourceException("Skill not found: " + skillName)); + } + return Single.just(data); + } + + /** Builder for {@link InMemorySkillSource}. */ + public static class Builder { + private final Map skillBuilders = new HashMap<>(); + + /** Returns a {@link SkillBuilder} for the specified skill, creating it if it doesn't exist. */ + public SkillBuilder skill(String name) { + return skillBuilders.computeIfAbsent(name, k -> new SkillBuilder()); + } + + public InMemorySkillSource build() { + return new InMemorySkillSource( + ImmutableMap.copyOf(Maps.transformValues(skillBuilders, SkillBuilder::buildSkillData))); + } + + /** Builder for a specific skill. */ + public final class SkillBuilder { + private Frontmatter frontmatter; + private String instructions; + private final ImmutableMap.Builder resourcesBuilder = + ImmutableMap.builder(); + + private SkillBuilder() {} + + @CanIgnoreReturnValue + public SkillBuilder frontmatter(Frontmatter frontmatter) { + this.frontmatter = frontmatter; + return this; + } + + @CanIgnoreReturnValue + public SkillBuilder instructions(String instructions) { + this.instructions = instructions; + return this; + } + + @CanIgnoreReturnValue + public SkillBuilder addResource(String path, ByteSource content) { + this.resourcesBuilder.put(path, content); + return this; + } + + @CanIgnoreReturnValue + public SkillBuilder addResource(String path, byte[] content) { + return addResource(path, ByteSource.wrap(content)); + } + + @CanIgnoreReturnValue + public SkillBuilder addResource(String path, String content) { + return addResource(path, content.getBytes(UTF_8)); + } + + /** Switches context to configure another skill, creating it if it doesn't exist. */ + public SkillBuilder skill(String name) { + return Builder.this.skill(name); + } + + /** Builds the {@link InMemorySkillSource} containing all configured skills. */ + public InMemorySkillSource build() { + return Builder.this.build(); + } + + private SkillData buildSkillData() { + checkState(frontmatter != null, "Frontmatter is required"); + checkState(instructions != null, "Instructions are required"); + return new SkillData(frontmatter, instructions, resourcesBuilder.buildOrThrow()); + } + } + } + + private record SkillData( + Frontmatter frontmatter, String instructions, ImmutableMap resources) {} +} diff --git a/core/src/main/java/com/google/adk/skills/LocalSkillSource.java b/core/src/main/java/com/google/adk/skills/LocalSkillSource.java new file mode 100644 index 000000000..939c30b3c --- /dev/null +++ b/core/src/main/java/com/google/adk/skills/LocalSkillSource.java @@ -0,0 +1,118 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.nio.file.Files.isDirectory; + +import com.google.common.collect.ImmutableList; +import io.reactivex.rxjava3.core.Flowable; +import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Single; +import java.io.IOException; +import java.nio.channels.ReadableByteChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Optional; +import java.util.stream.Stream; + +/** Loads skills from the local file system. */ +public final class LocalSkillSource extends AbstractSkillSource { + + private final Path skillsBasePath; + + public LocalSkillSource(Path skillsBasePath) { + this.skillsBasePath = skillsBasePath; + } + + @Override + public Single> listResources(String skillName, String resourceDirectory) { + Path skillDir = skillsBasePath.resolve(skillName); + if (!isDirectory(skillDir)) { + return Single.error(new SkillSourceException("Skill not found: " + skillName)); + } + Path resourceDir = skillDir.resolve(resourceDirectory); + if (!isDirectory(resourceDir)) { + return Single.error( + new SkillSourceException( + "Resource directory '%s' not found for skill '%s'" + .formatted(resourceDirectory, skillName))); + } + + return Single.fromCallable( + () -> { + try (Stream paths = Files.walk(resourceDir)) { + return paths + .filter(Files::isRegularFile) + .map(skillDir::relativize) + .map(Path::toString) + .collect(toImmutableList()); + } + }) + .onErrorResumeNext( + t -> + Single.error( + new SkillSourceException( + "Failed to traverse resource directory: " + resourceDirectory, t))); + } + + @Override + @SuppressWarnings("StreamResourceLeak") + protected Flowable listSkills() { + return Flowable.using(() -> Files.list(skillsBasePath), Flowable::fromStream, Stream::close) + .onErrorResumeNext( + t -> + Flowable.error( + new SkillSourceException( + "Failed to list skills in directory: " + skillsBasePath, t))) + .filter(Files::isDirectory) + .mapOptional(this::findSkillMd) + .map(skillMd -> new SkillMdPath(skillMd.getParent().getFileName().toString(), skillMd)); + } + + @Override + protected Single findResourcePath(String skillName, String resourcePath) { + Path file = skillsBasePath.resolve(skillName).resolve(resourcePath); + if (!Files.exists(file)) { + return Single.error(new SkillSourceException("Resource not found: " + file)); + } + return Single.just(file); + } + + @Override + protected Single findSkillMdPath(String skillName) { + Path skillDir = skillsBasePath.resolve(skillName); + if (!isDirectory(skillDir)) { + return Single.error(new SkillSourceException("Skill directory not found: " + skillName)); + } + return Maybe.fromOptional(findSkillMd(skillDir)) + .switchIfEmpty( + Single.error(new SkillSourceException("SKILL.md not found in " + skillName))); + } + + @Override + protected ReadableByteChannel openChannel(Path path) throws IOException { + return Files.newByteChannel(path); + } + + private Optional findSkillMd(Path dir) { + return Optional.of(dir.resolve("SKILL.md")) + .filter(Files::exists) + .or(() -> Optional.of(dir.resolve("skill.md"))) + .filter(Files::exists); + } +} diff --git a/core/src/main/java/com/google/adk/skills/SkillSource.java b/core/src/main/java/com/google/adk/skills/SkillSource.java new file mode 100644 index 000000000..cabe60d86 --- /dev/null +++ b/core/src/main/java/com/google/adk/skills/SkillSource.java @@ -0,0 +1,92 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.ByteSource; +import io.reactivex.rxjava3.core.Single; + +/** + * Interface for getting access to available skills. + * + *

All operations are asynchronous and communicate failures reactively through the returned + * {@link Single} error channel (terminating with {@code onError}), rather than throwing exceptions + * synchronously. Implementation must use the {@link SkillSourceException} for propagating error + * message back to the LLM. + */ +public interface SkillSource { + + /** + * Lists all available {@link Frontmatter}s for discovered skills. + * + *

If the source is misconfigured, such as directory doesn't exist, or having malformed skill, + * the returned {@link Single} will terminate with a {@link SkillSourceException} with the reason + * in the message. + * + * @return a {@link Single} emitting a map where keys are skill names and values are their {@link + * Frontmatter} + */ + Single> listFrontmatters(); + + /** + * Lists all resource files for a specific skill within a given directory. + * + *

If the skill or the resource directory does not exist, the returned {@link Single} will + * terminate with a {@link SkillSourceException}. + * + * @param skillName the name of the skill + * @param resourceDirectory the relative directory within the skill to list (e.g., "assets", + * "scripts") + * @return a {@link Single} emitting a list of resource paths relative to the skill directory + */ + Single> listResources(String skillName, String resourceDirectory); + + /** + * Loads the {@link Frontmatter} for a specific skill. + * + *

If the skill is not found or its frontmatter is malformed, the returned {@link Single} will + * terminate with a {@link SkillSourceException} or parsing error. + * + * @param skillName the name of the skill + * @return a {@link Single} emitting the {@link Frontmatter} for the skill + */ + Single loadFrontmatter(String skillName); + + /** + * Loads the instructions (body of SKILL.md) for a specific skill. + * + *

If the skill is not found or its file structure is invalid (e.g., unclosed frontmatter + * blocks), the returned {@link Single} will terminate with a {@link SkillSourceException}. + * + * @param skillName the name of the skill + * @return a {@link Single} emitting the instructions as a String + */ + Single loadInstructions(String skillName); + + /** + * Loads a specific resource file content. + * + *

If the skill or the specific resource path cannot be found, the returned {@link Single} will + * terminate with a {@link SkillSourceException}. + * + * @param skillName the name of the skill + * @param resourcePath the path to the resource file relative to the skill directory + * @return a {@link Single} emitting the {@link ByteSource} for the resource content + */ + Single loadResource(String skillName, String resourcePath); +} diff --git a/core/src/main/java/com/google/adk/skills/SkillSourceException.java b/core/src/main/java/com/google/adk/skills/SkillSourceException.java new file mode 100644 index 000000000..be23291da --- /dev/null +++ b/core/src/main/java/com/google/adk/skills/SkillSourceException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +/** + * Exception for {@link SkillSource} implementations to signal recoverable errors that will have the + * message sending back to the LLM. + */ +public final class SkillSourceException extends Exception { + + public SkillSourceException(String message) { + super(message); + } + + public SkillSourceException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/core/src/main/java/com/google/adk/tools/mcp/ConversionUtils.java b/core/src/main/java/com/google/adk/tools/mcp/ConversionUtils.java index 8ad7d1b46..c115cfca0 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/ConversionUtils.java +++ b/core/src/main/java/com/google/adk/tools/mcp/ConversionUtils.java @@ -19,6 +19,7 @@ import com.google.adk.tools.BaseTool; import com.google.genai.types.FunctionDeclaration; import com.google.genai.types.Schema; +import io.modelcontextprotocol.json.McpJsonDefaults; import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import java.util.Optional; @@ -26,7 +27,7 @@ /** Utility class for converting between different representations of MCP tools. */ public final class ConversionUtils { - private static final McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + private static final McpJsonMapper jsonMapper = McpJsonDefaults.getMapper(); public McpSchema.Tool adkToMcpToolType(BaseTool tool) { Optional toolDeclaration = tool.declaration(); diff --git a/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java b/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java index b82ebae84..84c882c4e 100644 --- a/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java +++ b/core/src/main/java/com/google/adk/tools/mcp/DefaultMcpTransportBuilder.java @@ -5,6 +5,7 @@ import io.modelcontextprotocol.client.transport.HttpClientStreamableHttpTransport; import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.json.McpJsonDefaults; import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpClientTransport; import java.util.Collection; @@ -18,7 +19,7 @@ */ public class DefaultMcpTransportBuilder implements McpTransportBuilder { - private static final McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + private static final McpJsonMapper jsonMapper = McpJsonDefaults.getMapper(); @Override public McpClientTransport build(Object connectionParams) { diff --git a/core/src/test/java/com/google/adk/events/EventActionsTest.java b/core/src/test/java/com/google/adk/events/EventActionsTest.java index b1e645e1a..4b542c0e9 100644 --- a/core/src/test/java/com/google/adk/events/EventActionsTest.java +++ b/core/src/test/java/com/google/adk/events/EventActionsTest.java @@ -24,6 +24,7 @@ import com.google.common.collect.ImmutableSet; import com.google.genai.types.Content; import com.google.genai.types.Part; +import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -129,6 +130,33 @@ public void removeStateByKey_marksKeyAsRemoved() { assertThat(eventActions.stateDelta()).containsExactly("key1", State.REMOVED); } + @Test + public void builderStateDelta_withNullMap_initializesEmptyMap() { + EventActions eventActions = EventActions.builder().stateDelta(null).build(); + + assertThat(eventActions.stateDelta()).isEmpty(); + } + + @Test + public void builderStateDelta_withNullValue_marksKeyAsRemoved() { + Map inputDelta = new HashMap<>(); + inputDelta.put("key1", "value1"); + inputDelta.put("key2", null); + + EventActions eventActions = EventActions.builder().stateDelta(inputDelta).build(); + + assertThat(eventActions.stateDelta()).containsExactly("key1", "value1", "key2", State.REMOVED); + } + + @Test + public void jsonDeserialization_withNullValueInStateDelta_deserializesAsRemoved() + throws Exception { + String json = "{\"stateDelta\":{\"key1\":\"value1\",\"key2\":null}}"; + EventActions deserialized = EventActions.fromJsonString(json, EventActions.class); + + assertThat(deserialized.stateDelta()).containsExactly("key1", "value1", "key2", State.REMOVED); + } + @Test public void jsonSerialization_works() throws Exception { EventActions eventActions = diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsHttpClientTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsHttpClientTest.java new file mode 100644 index 000000000..175ca777e --- /dev/null +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsHttpClientTest.java @@ -0,0 +1,477 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.models.chat; + +import static com.google.common.truth.Truth.assertThat; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.Assert.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.models.LlmRequest; +import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; +import com.google.genai.types.FinishReason; +import com.google.genai.types.HttpOptions; +import com.google.genai.types.Part; +import io.reactivex.rxjava3.subscribers.TestSubscriber; +import java.io.IOException; +import java.lang.reflect.Field; +import java.time.Duration; +import okhttp3.Call; +import okhttp3.Callback; +import okhttp3.MediaType; +import okhttp3.OkHttpClient; +import okhttp3.Protocol; +import okhttp3.Request; +import okhttp3.Response; +import okhttp3.ResponseBody; +import okio.Buffer; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public final class ChatCompletionsHttpClientTest { + private static final ObjectMapper objectMapper = JsonBaseModel.getMapper(); + private static final MediaType JSON = MediaType.get("application/json"); + + /** + * Bounded wait for {@link TestSubscriber#await} so a buggy callback wiring cannot hang the test + * JVM. The mock callbacks fire synchronously in the same thread, so this value is intentionally + * short -- on a successful run the await returns in microseconds, and on a hung run we fail fast + * instead of stalling the test suite. + */ + private static final Duration AWAIT_TIMEOUT = Duration.ofMillis(500); + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock private OkHttpClient mockHttpClient; + @Mock private Call mockCall; + + private ChatCompletionsHttpClient client; + + @Before + public void setUp() throws Exception { + client = + new ChatCompletionsHttpClient( + HttpOptions.builder().baseUrl("https://example.com/").build()); + swapInMockHttpClient(client); + } + + /** + * Reflectively replaces the production {@link OkHttpClient} on a {@link + * ChatCompletionsHttpClient} with the test's mock so callbacks can be captured. Used by both + * setUp and tests that construct their own client (e.g. timeout tests, header tests). + */ + private void swapInMockHttpClient(ChatCompletionsHttpClient target) throws Exception { + when(mockHttpClient.newCall(any())).thenReturn(mockCall); + Field clientField = ChatCompletionsHttpClient.class.getDeclaredField("client"); + clientField.setAccessible(true); + clientField.set(target, mockHttpClient); + } + + private Response createMockResponse(String body, MediaType mediaType) { + return createMockResponse(body, mediaType, 200, "OK"); + } + + private Response createMockResponse(String body, MediaType mediaType, int code, String message) { + Response.Builder builder = + new Response.Builder() + .request(new Request.Builder().url("https://example.com/chat/completions").build()) + .protocol(Protocol.HTTP_1_1) + .code(code) + .message(message); + // OkHttp's Response.Builder rejects a null body via its Kotlin @NotNull contract; omit + // the body() call entirely to model an empty/null response body. + if (body != null) { + builder.body(ResponseBody.create(body, mediaType)); + } + return builder.build(); + } + + /** Returns a minimal {@link LlmRequest} suitable for tests that don't care about the payload. */ + private static LlmRequest minimalRequest() { + return LlmRequest.builder() + .model("gpt-4") + .contents(ImmutableList.of(Content.builder().parts(Part.fromText("hello")).build())) + .build(); + } + + @Test + public void complete_nonStreaming_sendsCorrectPayload() throws Exception { + String responseBody = + """ + { + "choices": [ + { + "message": { + "role": "assistant", + "content": "Hi" + }, + "finish_reason": "stop" + } + ] + } + """; + + Response mockResponse = createMockResponse(responseBody, JSON); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = client.complete(minimalRequest(), false).test(); + + callbackCaptor.getValue().onResponse(mockCall, mockResponse); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + LlmResponse response = testSubscriber.values().get(0); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(Request.class); + verify(mockHttpClient).newCall(requestCaptor.capture()); + Request capturedRequest = requestCaptor.getValue(); + assertThat(capturedRequest.url().encodedPath()).isEqualTo("/chat/completions"); + + Buffer buffer = new Buffer(); + capturedRequest.body().writeTo(buffer); + JsonNode requestBodyJson = objectMapper.readTree(buffer.readUtf8()); + assertThat(requestBodyJson.get("model").asText()).isEqualTo("gpt-4"); + assertThat(requestBodyJson.get("messages").get(0).get("role").asText()).isEqualTo("user"); + assertThat(requestBodyJson.get("messages").get(0).get("content").asText()).isEqualTo("hello"); + + LlmResponse expectedResponse = + LlmResponse.builder() + .content( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.fromText("Hi"))) + .build()) + .finishReason(new FinishReason(FinishReason.Known.STOP.toString())) + .customMetadata(ImmutableList.of()) + .build(); + + assertThat(response).isEqualTo(expectedResponse); + } + + @Test + public void complete_nonStreaming_propagateFailure() throws Exception { + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = client.complete(minimalRequest(), false).test(); + + callbackCaptor.getValue().onFailure(mockCall, new IOException("Network Error")); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + testSubscriber.assertError(IOException.class); + } + + // -- Header, error-propagation, and timeout coverage. ---------------------------------- + + /** + * Verifies that an HTTP error status (e.g. 500) propagates as a stream error and that the error + * message includes the response body so callers can debug. Covers the {@code + * !response.isSuccessful()} branch of the non-streaming path. The streaming counterpart lives in + * the streaming follow-up CL. + */ + @Test + public void complete_nonStreaming_propagatesHttpErrorStatus() throws Exception { + Response mockResponse = + createMockResponse("{\"error\":\"server exploded\"}", JSON, 500, "Internal Server Error"); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = client.complete(minimalRequest(), false).test(); + + callbackCaptor.getValue().onResponse(mockCall, mockResponse); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + testSubscriber.assertError( + e -> + e instanceof IOException + && e.getMessage().contains("Unexpected code") + && e.getMessage().contains("server exploded")); + } + + /** + * Verifies that an empty response body propagates as a stream error rather than silently emitting + * an empty value. The exact exception class depends on OkHttp's behavior: + * + *

    + *
  • If OkHttp produces a {@code null} body, our code surfaces an {@link IOException} with the + * message {@code "Empty response body"}. + *
  • If OkHttp produces an empty (non-null) body, Jackson surfaces a {@link + * com.fasterxml.jackson.databind.exc.MismatchedInputException} ("No content to map"). + *
+ * + * Both outcomes satisfy the contract: empty body must NOT silently produce a successful empty + * {@link LlmResponse}. + */ + @Test + public void complete_nonStreaming_propagatesEmptyBody() throws Exception { + Response mockResponse = createMockResponse(null, JSON); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = client.complete(minimalRequest(), false).test(); + + callbackCaptor.getValue().onResponse(mockCall, mockResponse); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + testSubscriber.assertNoValues(); + testSubscriber.assertError(Throwable.class); + } + + /** + * Verifies that caller-supplied headers reach the wire on the captured {@link Request}. This is + * the most common production failure mode (missing or wrong Authorization header), so it gets its + * own test rather than being implicit in other tests. + */ + @Test + public void complete_sendsCustomHeaders() throws Exception { + ChatCompletionsHttpClient clientWithHeaders = + new ChatCompletionsHttpClient( + HttpOptions.builder() + .baseUrl("https://example.com/") + .headers(ImmutableMap.of("Authorization", "Bearer test-token", "X-Custom", "value")) + .build()); + swapInMockHttpClient(clientWithHeaders); + + String responseBody = + """ + {"choices":[{"message":{"role":"assistant","content":"Hi"},"finish_reason":"stop"}]} + """; + Response mockResponse = createMockResponse(responseBody, JSON); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = + clientWithHeaders.complete(minimalRequest(), false).test(); + + callbackCaptor.getValue().onResponse(mockCall, mockResponse); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(Request.class); + verify(mockHttpClient).newCall(requestCaptor.capture()); + Request capturedRequest = requestCaptor.getValue(); + assertThat(capturedRequest.header("Authorization")).isEqualTo("Bearer test-token"); + assertThat(capturedRequest.header("X-Custom")).isEqualTo("value"); + // Content-Type is forced to application/json regardless of caller input. + assertThat(capturedRequest.header("Content-Type")).contains("application/json"); + } + + /** + * Verifies that even when a caller passes a conflicting {@code Content-Type} header, the client + * overrides it with {@code application/json} so the upstream API does not reject the request as a + * malformed payload. + */ + @Test + public void complete_overridesCallerContentType() throws Exception { + ChatCompletionsHttpClient clientWithBadHeader = + new ChatCompletionsHttpClient( + HttpOptions.builder() + .baseUrl("https://example.com/") + .headers(ImmutableMap.of("Content-Type", "text/plain")) + .build()); + swapInMockHttpClient(clientWithBadHeader); + + String responseBody = + """ + {"choices":[{"message":{"role":"assistant","content":"Hi"},"finish_reason":"stop"}]} + """; + Response mockResponse = createMockResponse(responseBody, JSON); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = + clientWithBadHeader.complete(minimalRequest(), false).test(); + + callbackCaptor.getValue().onResponse(mockCall, mockResponse); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(Request.class); + verify(mockHttpClient).newCall(requestCaptor.capture()); + Request capturedRequest = requestCaptor.getValue(); + // Should be exactly one Content-Type header, not two. + assertThat(capturedRequest.headers("Content-Type")).hasSize(1); + assertThat(capturedRequest.header("Content-Type")).contains("application/json"); + } + + /** + * Verifies that a {@code baseUrl} without a trailing slash still produces the correct {@code + * /chat/completions} path. {@link okhttp3.HttpUrl#newBuilder()} normalizes path segments + * regardless of the trailing-slash state of the base URL. + */ + @Test + public void complete_handlesBaseUrlWithoutTrailingSlash() throws Exception { + ChatCompletionsHttpClient clientNoSlash = + new ChatCompletionsHttpClient(HttpOptions.builder().baseUrl("https://example.com").build()); + swapInMockHttpClient(clientNoSlash); + + String responseBody = + """ + {"choices":[{"message":{"role":"assistant","content":"Hi"},"finish_reason":"stop"}]} + """; + Response mockResponse = createMockResponse(responseBody, JSON); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = + clientNoSlash.complete(minimalRequest(), false).test(); + + callbackCaptor.getValue().onResponse(mockCall, mockResponse); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(Request.class); + verify(mockHttpClient).newCall(requestCaptor.capture()); + assertThat(requestCaptor.getValue().url().encodedPath()).isEqualTo("/chat/completions"); + } + + /** + * Verifies that omitting {@code headers} on the supplied {@link HttpOptions} is treated as no + * extra headers, not as an NPE. + */ + @Test + public void constructor_missingHeaders_isTreatedAsEmpty() throws Exception { + ChatCompletionsHttpClient clientWithoutHeaders = + new ChatCompletionsHttpClient( + HttpOptions.builder().baseUrl("https://example.com/").build()); + swapInMockHttpClient(clientWithoutHeaders); + + String responseBody = + """ + {"choices":[{"message":{"role":"assistant","content":"Hi"},"finish_reason":"stop"}]} + """; + Response mockResponse = createMockResponse(responseBody, JSON); + + ArgumentCaptor callbackCaptor = ArgumentCaptor.forClass(Callback.class); + doNothing().when(mockCall).enqueue(callbackCaptor.capture()); + + TestSubscriber testSubscriber = + clientWithoutHeaders.complete(minimalRequest(), false).test(); + + callbackCaptor.getValue().onResponse(mockCall, mockResponse); + testSubscriber.await(AWAIT_TIMEOUT.toMillis(), MILLISECONDS); + + testSubscriber.assertNoErrors(); + testSubscriber.assertValueCount(1); + } + + /** Verifies that a {@code null} {@link HttpOptions} is rejected at construction time. */ + @Test + public void constructor_nullHttpOptions_throws() { + assertThrows(NullPointerException.class, () -> new ChatCompletionsHttpClient(null)); + } + + /** + * Verifies that an {@link HttpOptions} without a {@code baseUrl} is rejected at construction time + * as bad configuration. {@link IllegalArgumentException} (not NPE) is the conventional signal for + * missing required configuration. + */ + @Test + public void constructor_missingBaseUrl_throws() { + HttpOptions noBaseUrl = HttpOptions.builder().build(); + assertThrows(IllegalArgumentException.class, () -> new ChatCompletionsHttpClient(noBaseUrl)); + } + + /** + * Verifies that an {@link HttpOptions} with a malformed (non-HTTP(S)) {@code baseUrl} is rejected + * at construction time, rather than failing later at the first {@code complete()} call with a + * confusing NPE from {@link okhttp3.HttpUrl#parse}. + */ + @Test + public void constructor_malformedBaseUrl_throws() { + HttpOptions malformed = HttpOptions.builder().baseUrl("not a url").build(); + assertThrows(IllegalArgumentException.class, () -> new ChatCompletionsHttpClient(malformed)); + } + + // -- Tri-state timeout policy. ---------------------------------------------------------- + + /** + * Verifies that when {@code httpOptions} omits {@code timeout()}, the client applies the 5-minute + * default call timeout to prevent indefinite hangs in callers that did not explicitly configure a + * timeout. + */ + @Test + public void constructor_missingTimeout_appliesDefaultFiveMinuteTimeout() { + ChatCompletionsHttpClient defaultClient = + new ChatCompletionsHttpClient( + HttpOptions.builder().baseUrl("https://example.com/").build()); + + OkHttpClient internal = readInternalClient(defaultClient); + assertThat(internal.callTimeoutMillis()) + .isEqualTo((int) Duration.ofMinutes(5).toMillis()); // 300_000 + } + + /** + * Verifies that when the caller explicitly sets {@code httpOptions.timeout() == 0}, the client + * respects this as the explicit opt-in to infinite hang. This is the migration path for + * long-running streams or batch jobs that need no timeout. + */ + @Test + public void constructor_zeroTimeout_respectsInfiniteHang() { + HttpOptions zeroTimeout = + HttpOptions.builder().baseUrl("https://example.com/").timeout(0).build(); + ChatCompletionsHttpClient infiniteClient = new ChatCompletionsHttpClient(zeroTimeout); + + OkHttpClient internal = readInternalClient(infiniteClient); + assertThat(internal.callTimeoutMillis()).isEqualTo(0); // OkHttp: 0 = no timeout + } + + /** + * Verifies that when the caller sets a positive timeout, that value (in milliseconds) is used as + * the call timeout. + */ + @Test + public void constructor_explicitTimeout_appliesIt() { + HttpOptions tenSeconds = + HttpOptions.builder().baseUrl("https://example.com/").timeout(10_000).build(); + ChatCompletionsHttpClient timedClient = new ChatCompletionsHttpClient(tenSeconds); + + OkHttpClient internal = readInternalClient(timedClient); + assertThat(internal.callTimeoutMillis()).isEqualTo(10_000); + } + + /** Reflectively reads the internal {@link OkHttpClient} to inspect the resolved timeout. */ + private static OkHttpClient readInternalClient(ChatCompletionsHttpClient target) { + try { + Field clientField = ChatCompletionsHttpClient.class.getDeclaredField("client"); + clientField.setAccessible(true); + return (OkHttpClient) clientField.get(target); + } catch (ReflectiveOperationException e) { + throw new LinkageError("Failed to read internal client", e); + } + } +} diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java index 9dc63c5d6..1f41189a2 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java @@ -17,11 +17,28 @@ package com.google.adk.models.chat; import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.adk.JsonBaseModel; +import com.google.adk.models.LlmRequest; import com.google.common.collect.ImmutableList; -import java.util.HashMap; -import java.util.Map; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.FunctionCall; +import com.google.genai.types.FunctionCallingConfig; +import com.google.genai.types.FunctionCallingConfigMode.Known; +import com.google.genai.types.FunctionDeclaration; +import com.google.genai.types.FunctionResponse; +import com.google.genai.types.GenerateContentConfig; +import com.google.genai.types.Part; +import com.google.genai.types.Tool; +import com.google.genai.types.ToolConfig; +import java.util.AbstractMap; +import java.util.List; +import java.util.Set; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -34,17 +51,17 @@ public final class ChatCompletionsRequestTest { @Before public void setUp() { - objectMapper = new ObjectMapper(); + objectMapper = JsonBaseModel.getMapper(); } @Test public void testSerializeChatCompletionRequest_standard() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - ChatCompletionsRequest.Message message = new ChatCompletionsRequest.Message(); message.role = "user"; message.content = new ChatCompletionsRequest.MessageContent("Hello"); + + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; request.messages = ImmutableList.of(message); String json = objectMapper.writeValueAsString(request); @@ -56,24 +73,20 @@ public void testSerializeChatCompletionRequest_standard() throws Exception { @Test public void testSerializeChatCompletionRequest_withExtraBody() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - ChatCompletionsRequest.Message message = new ChatCompletionsRequest.Message(); message.role = "user"; message.content = new ChatCompletionsRequest.MessageContent("Explain to me how AI works"); - request.messages = ImmutableList.of(message); - - Map thinkingConfig = new HashMap<>(); - thinkingConfig.put("thinking_level", "low"); - thinkingConfig.put("include_thoughts", true); - - Map google = new HashMap<>(); - google.put("thinking_config", thinkingConfig); - Map extraBody = new HashMap<>(); - extraBody.put("google", google); + ImmutableMap extraBody = + ImmutableMap.of( + "google", + ImmutableMap.of( + "thinking_config", + ImmutableMap.of("thinking_level", "low", "include_thoughts", true))); + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + request.messages = ImmutableList.of(message); request.extraBody = extraBody; String json = objectMapper.writeValueAsString(request); @@ -85,9 +98,6 @@ public void testSerializeChatCompletionRequest_withExtraBody() throws Exception @Test public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - ChatCompletionsRequest.Message userMessage = new ChatCompletionsRequest.Message(); userMessage.role = "user"; userMessage.content = new ChatCompletionsRequest.MessageContent("Check flight status"); @@ -104,11 +114,8 @@ public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() th function.arguments = "{\"flight\":\"AA100\"}"; toolCall.function = function; - Map google = new HashMap<>(); - google.put("thought_signature", ""); - - Map extraContent = new HashMap<>(); - extraContent.put("google", google); + ImmutableMap extraContent = + ImmutableMap.of("google", ImmutableMap.of("thought_signature", "")); toolCall.extraContent = extraContent; @@ -120,6 +127,8 @@ public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() th toolMessage.toolCallId = "function-call-1"; toolMessage.content = new ChatCompletionsRequest.MessageContent("{\"status\":\"delayed\"}"); + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; request.messages = ImmutableList.of(userMessage, modelMessage, toolMessage); String json = objectMapper.writeValueAsString(request); @@ -134,45 +143,38 @@ public void testSerializeChatCompletionRequest_withToolCallsAndExtraContent() th @Test public void testSerializeChatCompletionRequest_comprehensive() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - - // Developer message with name ChatCompletionsRequest.Message devMsg = new ChatCompletionsRequest.Message(); devMsg.role = "developer"; devMsg.content = new ChatCompletionsRequest.MessageContent("System instruction"); devMsg.name = "system-bot"; - request.messages = ImmutableList.of(devMsg); - - // Response Format JSON Schema ChatCompletionsRequest.ResponseFormatJsonSchema format = new ChatCompletionsRequest.ResponseFormatJsonSchema(); format.jsonSchema = new ChatCompletionsRequest.ResponseFormatJsonSchema.JsonSchema(); format.jsonSchema.name = "MySchema"; format.jsonSchema.strict = true; - request.responseFormat = format; - // Tool Choice Named ChatCompletionsRequest.NamedToolChoice choice = new ChatCompletionsRequest.NamedToolChoice(); choice.function = new ChatCompletionsRequest.NamedToolChoice.FunctionName(); choice.function.name = "my_function"; + + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + request.messages = ImmutableList.of(devMsg); + request.responseFormat = format; request.toolChoice = choice; String json = objectMapper.writeValueAsString(request); - // Assert Developer Message assertThat(json).contains("\"role\":\"developer\""); assertThat(json).contains("\"name\":\"system-bot\""); assertThat(json).contains("\"content\":\"System instruction\""); - // Assert Response Format assertThat(json).contains("\"response_format\":{"); assertThat(json).contains("\"type\":\"json_schema\""); assertThat(json).contains("\"name\":\"MySchema\""); assertThat(json).contains("\"strict\":true"); - // Assert Tool Choice assertThat(json).contains("\"tool_choice\":{"); assertThat(json).contains("\"type\":\"function\""); assertThat(json).contains("\"name\":\"my_function\""); @@ -182,7 +184,7 @@ public void testSerializeChatCompletionRequest_comprehensive() throws Exception public void testSerializeChatCompletionRequest_withToolChoiceMode() throws Exception { ChatCompletionsRequest request = new ChatCompletionsRequest(); request.model = "gemini-3-flash-preview"; - + request.messages = ImmutableList.of(); request.toolChoice = new ChatCompletionsRequest.ToolChoiceMode("none"); String json = objectMapper.writeValueAsString(request); @@ -192,13 +194,15 @@ public void testSerializeChatCompletionRequest_withToolChoiceMode() throws Excep @Test public void testSerializeChatCompletionRequest_withStopAndVoice() throws Exception { - ChatCompletionsRequest request = new ChatCompletionsRequest(); - request.model = "gemini-3-flash-preview"; - - request.stop = new ChatCompletionsRequest.StopCondition("STOP"); + ChatCompletionsRequest.StopCondition stop = new ChatCompletionsRequest.StopCondition("STOP"); ChatCompletionsRequest.AudioParam audio = new ChatCompletionsRequest.AudioParam(); audio.voice = new ChatCompletionsRequest.VoiceConfig("alloy"); + + ChatCompletionsRequest request = new ChatCompletionsRequest(); + request.model = "gemini-3-flash-preview"; + request.messages = ImmutableList.of(); + request.stop = stop; request.audio = audio; String json = objectMapper.writeValueAsString(request); @@ -211,11 +215,465 @@ public void testSerializeChatCompletionRequest_withStopAndVoice() throws Excepti public void testSerializeChatCompletionRequest_withStopList() throws Exception { ChatCompletionsRequest request = new ChatCompletionsRequest(); request.model = "gemini-3-flash-preview"; - + request.messages = ImmutableList.of(); request.stop = new ChatCompletionsRequest.StopCondition(ImmutableList.of("STOP1", "STOP2")); String json = objectMapper.writeValueAsString(request); assertThat(json).contains("\"stop\":[\"STOP1\",\"STOP2\"]"); } + + @Test + public void testFromLlmRequest_basic() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("user") + .parts(ImmutableList.of(Part.fromText("Hello"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.model).isEqualTo("gemini-1.5-pro"); + assertThat(request.stream).isFalse(); + assertThat(request.messages).hasSize(1); + assertThat(request.messages.get(0).role).isEqualTo("user"); + assertThat(request.messages.get(0).content.getValue()).isEqualTo("Hello"); + } + + @Test + public void testFromLlmRequest_withRefusal() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.fromText("Regular text response"), + Part.fromText( + ChatCompletionsCommon.REFUSAL_PREFIX + "I cannot do that."))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.role).isEqualTo("assistant"); + assertThat(message.refusal).isEqualTo("I cannot do that."); + assertThat(message.content.getValue()).isEqualTo("Regular text response"); + } + + @Test + public void testFromLlmRequest_withRefusalEmbeddedAfterNewline() throws Exception { + // A single Part containing both content and refusal, separated by "\n[[REFUSAL]]: ". + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.fromText( + "Partial text answer\n" + + ChatCompletionsCommon.REFUSAL_PREFIX + + "System error or refusal"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.role).isEqualTo("assistant"); + assertThat(message.content.getValue()).isEqualTo("Partial text answer"); + assertThat(message.refusal).isEqualTo("System error or refusal"); + } + + @Test + public void testFromLlmRequest_withMultipleRefusalsJoinedWithNewline() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + "First"), + Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + "Second"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.role).isEqualTo("assistant"); + assertThat(message.refusal).isEqualTo("First\nSecond"); + assertThat(message.content).isNull(); + } + + @Test + public void testFromLlmRequest_withRefusalOnlyHasNullContent() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.fromText( + ChatCompletionsCommon.REFUSAL_PREFIX + "Only a refusal"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.role).isEqualTo("assistant"); + assertThat(message.refusal).isEqualTo("Only a refusal"); + assertThat(message.content).isNull(); + } + + @Test + public void testFromLlmRequest_withRefusalPrefixAfterEmptyContentLine() throws Exception { + // Edge case: text begins with "\n[[REFUSAL]]: ..." -- empty content before the prefix. + // Expectation: no content part, refusal populated. + String text = "\n" + ChatCompletionsCommon.REFUSAL_PREFIX + "Refusal only"; + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.fromText(text))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.refusal).isEqualTo("Refusal only"); + assertThat(message.content).isNull(); + } + + @Test + public void testFromLlmRequest_withRefusalPrefixMidLineIsNotSplit() throws Exception { + // The prefix is intentionally NOT recognized mid-line without a preceding newline. + String inlineText = "foo " + ChatCompletionsCommon.REFUSAL_PREFIX + "bar"; + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts(ImmutableList.of(Part.fromText(inlineText))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message message = request.messages.get(0); + assertThat(message.refusal).isNull(); + assertThat(message.content.getValue()).isEqualTo(inlineText); + } + + @Test + public void testFromLlmRequest_withSystemInstruction() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gpt-4") + .config( + GenerateContentConfig.builder() + .systemInstruction( + Content.builder() + .parts(ImmutableList.of(Part.fromText("Be helpful"))) + .build()) + .temperature(0.7f) + .topP(0.9f) + .maxOutputTokens(100) + .stopSequences(ImmutableList.of("END")) + .candidateCount(2) + .presencePenalty(0.5f) + .frequencyPenalty(0.3f) + .seed(12345) + .tools( + ImmutableList.of( + Tool.builder() + .functionDeclarations( + ImmutableList.of( + FunctionDeclaration.builder() + .name("get_weather") + .description("Get current weather") + .build())) + .build())) + .toolConfig( + ToolConfig.builder() + .functionCallingConfig( + FunctionCallingConfig.builder().mode(Known.ANY).build()) + .build()) + .build()) + .contents( + ImmutableList.of( + Content.builder() + .role("user") + .parts(ImmutableList.of(Part.fromText("Hello"))) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(2); + assertThat(request.messages.get(0).role).isEqualTo("system"); + assertThat(request.messages.get(0).content.getValue()).isEqualTo("Be helpful"); + assertThat(request.temperature).isWithin(0.001).of(0.7); + assertThat(request.topP).isWithin(0.001).of(0.9); + assertThat(request.maxCompletionTokens).isEqualTo(100); + assertThat((List) request.stop.getValue()).containsExactly("END"); + assertThat(request.n).isEqualTo(2); + assertThat(request.presencePenalty).isWithin(0.001).of(0.5); + assertThat(request.frequencyPenalty).isWithin(0.001).of(0.3); + assertThat(request.seed).isEqualTo(12345L); + assertThat(request.tools).hasSize(1); + assertThat(request.tools.get(0).function.name).isEqualTo("get_weather"); + assertThat(request.tools.get(0).function.description).isEqualTo("Get current weather"); + assertThat(((ChatCompletionsRequest.ToolChoiceMode) request.toolChoice).getMode()) + .isEqualTo("required"); + } + + @Test + public void testFromLlmRequest_withInlineData() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("user") + .parts( + ImmutableList.of( + Part.builder() + .inlineData( + Blob.builder() + .mimeType("image/jpeg") + .data("base64data".getBytes(UTF_8)) + .build()) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + + @SuppressWarnings( + "unchecked") // Safe in unit tests and this is the expected type from msg.content + List parts = + (List) msg.content.getValue(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).type).isEqualTo("image_url"); + assertThat(parts.get(0).imageUrl.url).contains("base64,"); + } + + @Test + public void testFromLlmRequest_withFileData() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("user") + .parts( + ImmutableList.of( + Part.builder() + .fileData( + FileData.builder() + .fileUri("gs://bucket/file.jpg") + .mimeType("image/jpeg") + .build()) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + + @SuppressWarnings( + "unchecked") // Safe in unit tests and this is the expected type from msg.content + List parts = + (List) msg.content.getValue(); + assertThat(parts).hasSize(1); + assertThat(parts.get(0).type).isEqualTo("image_url"); + assertThat(parts.get(0).imageUrl.url).isEqualTo("gs://bucket/file.jpg"); + } + + @Test + public void testFromLlmRequest_withFunctionCall() throws Exception { + ImmutableMap args = ImmutableMap.of("location", "Paris"); + + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("model") + .parts( + ImmutableList.of( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_123") + .name("get_weather") + .args(args) + .build()) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(1); + ChatCompletionsRequest.Message msg = request.messages.get(0); + assertThat(msg.role).isEqualTo("assistant"); + assertThat(msg.toolCalls).hasSize(1); + assertThat(msg.toolCalls.get(0).id).isEqualTo("call_123"); + assertThat(msg.toolCalls.get(0).type).isEqualTo("function"); + assertThat(msg.toolCalls.get(0).function.name).isEqualTo("get_weather"); + assertThat(msg.toolCalls.get(0).function.arguments).isEqualTo("{\"location\":\"Paris\"}"); + } + + @Test + public void testFromLlmRequest_withStreamOptions() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder().model("gemini-1.5-pro").contents(ImmutableList.of()).build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, true); + + assertThat(request.stream).isTrue(); + assertThat(request.streamOptions).isNotNull(); + assertThat(request.streamOptions.includeUsage).isTrue(); + } + + private static class BadMap extends AbstractMap { + @Override + public Set> entrySet() { + throw new RuntimeException("Serialization failed!"); + } + } + + @Test + public void testFromLlmRequest_withFunctionResponse() throws Exception { + ImmutableMap respData = ImmutableMap.of("result", "ok"); + + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .contents( + ImmutableList.of( + Content.builder() + .role("tool") + .parts( + ImmutableList.of( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id("call_999") + .response(respData) + .build()) + .build(), + Part.builder() + .functionResponse(FunctionResponse.builder().build()) + .build(), + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id("call_faulty") + .response(new BadMap()) + .build()) + .build())) + .build())) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.messages).hasSize(3); + assertThat(request.messages.get(0).role).isEqualTo("tool"); + assertThat(request.messages.get(0).toolCallId).isEqualTo("call_999"); + assertThat(request.messages.get(0).content.getValue()).isEqualTo("{\"result\":\"ok\"}"); + + assertThat(request.messages.get(1).role).isEqualTo("tool"); + assertThat(request.messages.get(1).toolCallId).isEmpty(); + assertThat(request.messages.get(1).content).isNull(); + + assertThat(request.messages.get(2).role).isEqualTo("tool"); + assertThat(request.messages.get(2).toolCallId).isEqualTo("call_faulty"); + assertThat(request.messages.get(2).content).isNull(); + } + + @Test + public void testFromLlmRequest_withConfigSchemaAndLogprobs() throws Exception { + ImmutableMap schemaDef = ImmutableMap.of("type", "object"); + + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .config( + GenerateContentConfig.builder() + .responseJsonSchema(schemaDef) + .responseLogprobs(true) + .logprobs(5) + .build()) + .contents(ImmutableList.of()) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.responseFormat) + .isInstanceOf(ChatCompletionsRequest.ResponseFormatJsonSchema.class); + ChatCompletionsRequest.ResponseFormatJsonSchema format = + (ChatCompletionsRequest.ResponseFormatJsonSchema) request.responseFormat; + assertThat(format.jsonSchema.name).isEqualTo("response_schema"); + assertThat(format.jsonSchema.strict).isTrue(); + assertThat(format.jsonSchema.schema).isEqualTo(schemaDef); + assertThat(request.logprobs).isTrue(); + assertThat(request.topLogprobs).isEqualTo(5); + } + + @Test + public void testFromLlmRequest_withConfigResponseMimeTypeJson() throws Exception { + LlmRequest llmRequest = + LlmRequest.builder() + .model("gemini-1.5-pro") + .config(GenerateContentConfig.builder().responseMimeType("application/json").build()) + .contents(ImmutableList.of()) + .build(); + + ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false); + + assertThat(request.responseFormat) + .isInstanceOf(ChatCompletionsRequest.ResponseFormatJsonObject.class); + } } diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java index ad1839019..367545207 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java @@ -504,6 +504,7 @@ public void testToLlmResponse_withRefusal() throws Exception { "index": 0, "message": { "role": "assistant", + "content": "Partial text answer", "refusal": "System error or refusal" }, "finish_reason": "stop" @@ -521,8 +522,11 @@ public void testToLlmResponse_withRefusal() throws Exception { // Content assertThat(response.content().get().role()).hasValue("model"); + assertThat(response.content().get().parts().get()).hasSize(2); assertThat(response.content().get().parts().get().get(0).text()) - .hasValue("System error or refusal"); + .hasValue("Partial text answer"); + assertThat(response.content().get().parts().get().get(1).text()) + .hasValue("[[REFUSAL]]: System error or refusal"); // Custom Metadata List metadata = response.customMetadata().get(); diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java index 5a149d3e2..836442cad 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/BigQueryAgentAnalyticsPluginTest.java @@ -17,6 +17,7 @@ package com.google.adk.plugins.agentanalytics; import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.concurrent.TimeUnit.MILLISECONDS; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; @@ -75,6 +76,7 @@ import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; @@ -190,6 +192,19 @@ public void onUserMessageCallback_appendsToWriter() throws Exception { verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } + @Test + public void onUserMessageCallback_ensuresInvocationSpan() throws Exception { + Content content = Content.builder().build(); + + // Verify initial state + assertTrue(state.getTraceManager("invocation_id").getCurrentSpanId().isEmpty()); + + plugin.onUserMessageCallback(mockInvocationContext, content).blockingSubscribe(); + + // Verify that ensureInvocationSpan was called and created a span + assertTrue(state.getTraceManager("invocation_id").getCurrentSpanId().isPresent()); + } + @Test public void beforeRunCallback_appendsToWriter() throws Exception { plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); @@ -198,6 +213,65 @@ public void beforeRunCallback_appendsToWriter() throws Exception { verify(mockWriter, atLeastOnce()).append(any(ArrowRecordBatch.class)); } + @Test + public void beforeRunCallback_ensuresInvocationSpan() throws Exception { + // Verify initial state + assertTrue(state.getTraceManager("invocation_id").getCurrentSpanId().isEmpty()); + + plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + + // Verify that ensureInvocationSpan was called and created a span + assertTrue(state.getTraceManager("invocation_id").getCurrentSpanId().isPresent()); + } + + @Test + public void beforeRunCallback_addPendingTask() throws Exception { + final boolean[] addPendingTaskCalled = {false}; + PluginState customState = + new PluginState(config) { + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + return mockWriter; + } + + @Override + void addPendingTask(String invocationId, CompletableFuture task) { + super.addPendingTask(invocationId, task); + addPendingTaskCalled[0] = true; + } + }; + BigQueryAgentAnalyticsPlugin customPlugin = + new BigQueryAgentAnalyticsPlugin(config, mockBigQuery, customState); + + customPlugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); + + assertTrue("addPendingTask should have been called", addPendingTaskCalled[0]); + } + + @Test + public void afterRunCallback_waitsForPendingTasks() throws Exception { + CompletableFuture pendingTask = new CompletableFuture<>(); + String invocationId = "invocation_id"; + + // Manually add a pending task to the state + state.addPendingTask(invocationId, pendingTask); + + // Complete the task after a short delay + var unused = + Executors.newSingleThreadScheduledExecutor() + .schedule(() -> pendingTask.complete(null), 100, MILLISECONDS); + + // afterRunCallback should wait for the pending task + plugin.afterRunCallback(mockInvocationContext).blockingSubscribe(); + + assertTrue("Pending task should be completed after afterRunCallback", pendingTask.isDone()); + } + @Test public void afterRunCallback_flushesAndAppends() throws Exception { plugin.beforeRunCallback(mockInvocationContext).blockingSubscribe(); @@ -225,7 +299,8 @@ public void getStreamName_returnsCorrectFormat() { @Test public void formatContentParts_populatesCorrectFields() { Content content = Content.fromParts(Part.fromText("hello")); - ArrayNode nodes = JsonFormatter.formatContentParts(Optional.of(content), 100); + ArrayNode nodes = state.getParser().formatContentParts(Optional.of(content)); + assertEquals(1, nodes.size()); ObjectNode node = (ObjectNode) nodes.get(0); assertEquals(0, node.get("part_index").asInt()); @@ -727,7 +802,6 @@ public void logEvent_handlesExceptionFromFormatter() throws Exception { (content, eventType) -> { throw new RuntimeException("Formatter error"); }; - BigQueryLoggerConfig formattedConfig = config.toBuilder().contentFormatter(formatter).build(); PluginState formattedState = new PluginState(formattedConfig) { diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java index 739f3a7c3..4883438b6 100644 --- a/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/JsonFormatterTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import com.fasterxml.jackson.databind.JsonNode; @@ -40,7 +41,7 @@ public class JsonFormatterTest { @Test - public void parse_llmRequest_populatesPrompt() { + public void parse_llmRequest_populatesPrompt() throws Exception { LlmRequest request = LlmRequest.builder() .contents( @@ -48,7 +49,7 @@ public void parse_llmRequest_populatesPrompt() { Content.fromParts(Part.fromText("hello")).toBuilder().role("user").build())) .build(); - JsonFormatter.ParsedContent result = JsonFormatter.parse(request, 100); + Parser.ParsedContent result = new Parser(100).parse(request).get(); assertTrue(result.content().has("prompt")); ArrayNode prompt = (ArrayNode) result.content().get("prompt"); @@ -59,7 +60,7 @@ public void parse_llmRequest_populatesPrompt() { } @Test - public void parse_llmRequest_populatesSystemPrompt() { + public void parse_llmRequest_populatesSystemPrompt() throws Exception { LlmRequest request = LlmRequest.builder() .config( @@ -68,7 +69,7 @@ public void parse_llmRequest_populatesSystemPrompt() { .build()) .build(); - JsonFormatter.ParsedContent result = JsonFormatter.parse(request, 100); + Parser.ParsedContent result = new Parser(100).parse(request).get(); assertTrue(result.content().has("system_prompt")); assertEquals("be helpful", result.content().get("system_prompt").asText()); @@ -76,38 +77,39 @@ public void parse_llmRequest_populatesSystemPrompt() { } @Test - public void parse_string_truncates() { + public void parse_string_truncates() throws Exception { String longString = "this is a very long string that should be truncated"; - JsonFormatter.ParsedContent result = JsonFormatter.parse(longString, 10); + Parser.ParsedContent result = new Parser(24).parse(longString).get(); assertTrue(result.isTruncated()); assertEquals("this is a ...[truncated]", result.content().asText()); } @Test - public void parse_map_truncatesNested() { - ImmutableMap map = ImmutableMap.of("key", "this is a long value"); - JsonFormatter.ParsedContent result = JsonFormatter.parse(map, 10); + public void parse_map_truncatesNested() throws Exception { + ImmutableMap map = + ImmutableMap.of("key", "this is a very long value that should definitely be truncated"); + Parser.ParsedContent result = new Parser(24).parse(map).get(); assertTrue(result.isTruncated()); assertEquals("this is a ...[truncated]", result.content().get("key").asText()); } @Test - public void parse_content_returnsSummary() { + public void parse_content_returnsSummary() throws Exception { Content content = Content.fromParts(Part.fromText("part 1"), Part.fromText("part 2")); - JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + Parser.ParsedContent result = new Parser(100).parse(content).get(); assertEquals("part 1 | part 2", result.content().get("text_summary").asText()); assertEquals(2, result.parts().size()); } @Test - public void parse_content_withFileData() { + public void parse_content_withFileData() throws Exception { FileData fileData = FileData.builder().fileUri("gs://bucket/file.txt").mimeType("text/plain").build(); Content content = Content.fromParts(Part.builder().fileData(fileData).build()); - JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + Parser.ParsedContent result = new Parser(100).parse(content).get(); assertEquals(1, result.parts().size()); JsonNode partData = result.parts().get(0); @@ -117,10 +119,10 @@ public void parse_content_withFileData() { } @Test - public void parse_content_withFunctionCall() { + public void parse_content_withFunctionCall() throws Exception { FunctionCall fc = FunctionCall.builder().name("myFunction").build(); Content content = Content.fromParts(Part.builder().functionCall(fc).build()); - JsonFormatter.ParsedContent result = JsonFormatter.parse(content, 100); + Parser.ParsedContent result = new Parser(100).parse(content).get(); assertEquals(1, result.parts().size()); JsonNode partData = result.parts().get(0); @@ -130,10 +132,10 @@ public void parse_content_withFunctionCall() { } @Test - public void parse_list_truncatesElements() { + public void parse_list_truncatesElements() throws Exception { List list = Arrays.asList("short", "this is a very long string that should be truncated"); - JsonFormatter.ParsedContent result = JsonFormatter.parse(list, 10); + Parser.ParsedContent result = new Parser(24).parse(list).get(); assertTrue(result.isTruncated()); JsonNode arrayNode = result.content(); @@ -142,4 +144,62 @@ public void parse_list_truncatesElements() { assertEquals("short", arrayNode.get(0).asText()); assertEquals("this is a ...[truncated]", arrayNode.get(1).asText()); } + + @Test + public void truncate_variousInputs() { + assertNull(JsonFormatter.truncate(null, 10)); + assertEquals("", JsonFormatter.truncate("", 10)); + assertEquals("short", JsonFormatter.truncate("short", 10)); + assertEquals("exactlyten", JsonFormatter.truncate("exactlyten", 10)); + + // Simple truncation + String truncated = JsonFormatter.truncate("this is a long string for budget 24", 24); + assertEquals("this is a ...[truncated]", truncated); + + // Multi-byte truncation (UTF-8) + // "こんにちはこんにちは" is 30 bytes + String nihongo = "こんにちはこんにちは"; + String truncatedNihongo = JsonFormatter.truncate(nihongo, 20); // Should keep 2 chars (6 bytes) + assertEquals("こん...[truncated]", truncatedNihongo); + } + + @Test + public void truncate_budgetSmallerThanSuffix_returnsPartialSuffix() { + String longString = "this is a long string that should be truncated"; + assertEquals("...[t", JsonFormatter.truncate(longString, 5)); + assertEquals("", JsonFormatter.truncate(longString, 0)); + assertEquals("...[truncated]", JsonFormatter.truncate(longString, 14)); + } + + @Test + public void truncateAndAddSuffix_coversCodePointSizes() { + String s = "aαこ😀extra"; + String suffix = "..."; + + assertEquals("a...", JsonFormatter.truncateAndAddSuffix(s, 4, suffix)); + assertEquals("aα...", JsonFormatter.truncateAndAddSuffix(s, 6, suffix)); + assertEquals("aαこ...", JsonFormatter.truncateAndAddSuffix(s, 9, suffix)); + assertEquals("aαこ😀...", JsonFormatter.truncateAndAddSuffix(s, 13, suffix)); + assertEquals("aαこ...", JsonFormatter.truncateAndAddSuffix(s, 12, suffix)); + } + + @Test + public void parse_multibyteString_truncatesBasedOnBytes() throws Exception { + // "こんにちはこんにちは" is 30 bytes, but 10 characters. + String nihongo = "こんにちはこんにちは"; + // With budget 20, effective budget is 6, so only 2 characters (6 bytes) should be kept. + Parser.ParsedContent result = new Parser(20).parse(nihongo).get(); + + assertTrue(result.isTruncated()); + assertEquals("こん...[truncated]", result.content().asText()); + } + + @Test + public void parse_multibyteContent_truncatesBasedOnBytes() throws Exception { + Content content = Content.fromParts(Part.fromText("こんにちはこんにちは")); + Parser.ParsedContent result = new Parser(20).parse(content).get(); + + assertTrue(result.isTruncated()); + assertEquals("こん...[truncated]", result.content().get("text_summary").asText()); + } } diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java new file mode 100644 index 000000000..9bae03331 --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/ParserTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.fasterxml.jackson.databind.node.ArrayNode; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.genai.types.Blob; +import com.google.genai.types.Content; +import com.google.genai.types.FileData; +import com.google.genai.types.Part; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class ParserTest { + private Parser parser; + + @Before + public void setUp() { + parser = new Parser(100); + } + + @Test + public void parse_part_coversLine280() throws Exception { + Part part = Part.fromText("test part"); + CompletableFuture future = parser.parse(part); + Parser.ParsedContent result = future.get(); + + assertEquals("{\"text_summary\":\"test part\"}", result.content().toString()); + assertEquals(1, result.parts().size()); + assertEquals("test part", result.parts().get(0).get("text").asText()); + } + + @Test + public void parse_part_withInlineData_coversProcessPart() throws Exception { + Blob blob = Blob.builder().mimeType("image/png").data(new byte[] {1, 2, 3}).build(); + Part part = Part.builder().inlineData(blob).build(); + CompletableFuture future = parser.parse(part); + Parser.ParsedContent result = future.get(); + + assertEquals(1, result.parts().size()); + ObjectNode node = (ObjectNode) result.parts().get(0); + assertEquals("image/png", node.get("mime_type").asText()); + assertEquals("[BINARY DATA]", node.get("text").asText()); + assertEquals("INLINE", node.get("storage_mode").asText()); + } + + @Test + public void formatContentParts_inlineData_coversLine446() { + Blob blob = Blob.builder().mimeType("image/png").data(new byte[] {1, 2, 3}).build(); + Part part = Part.builder().inlineData(blob).build(); + Content content = Content.fromParts(part); + + ArrayNode nodes = parser.formatContentParts(Optional.of(content)); + + assertEquals(1, nodes.size()); + ObjectNode node = (ObjectNode) nodes.get(0); + assertEquals("image/png", node.get("mime_type").asText()); + assertEquals("[BINARY DATA]", node.get("text").asText()); + } + + @Test + public void formatContentParts_fileData_coversLine450() { + FileData fileData = + FileData.builder().mimeType("application/pdf").fileUri("gs://bucket/file.pdf").build(); + Part part = Part.builder().fileData(fileData).build(); + Content content = Content.fromParts(part); + + ArrayNode nodes = parser.formatContentParts(Optional.of(content)); + + assertEquals(1, nodes.size()); + ObjectNode node = (ObjectNode) nodes.get(0); + assertEquals("application/pdf", node.get("mime_type").asText()); + assertEquals("gs://bucket/file.pdf", node.get("uri").asText()); + assertEquals("EXTERNAL_URI", node.get("storage_mode").asText()); + } + + @Test + public void parse_multipartContent_coversLine310() throws Exception { + // maxLength is 100. + String longText = "a".repeat(100); + Content content = Content.fromParts(Part.fromText("Part 1"), Part.fromText(longText)); + + // Call private method using helper if necessary, but parseContentObject is private. + // However, parse(Object content, ...) calls it. + CompletableFuture future = parser.parse(content); + Parser.ParsedContent result = future.get(); + + assertTrue(result.isTruncated()); + } +} diff --git a/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java b/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java new file mode 100644 index 000000000..444cc8a6d --- /dev/null +++ b/core/src/test/java/com/google/adk/plugins/agentanalytics/PluginStateTest.java @@ -0,0 +1,245 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.plugins.agentanalytics; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.api.core.ApiFutures; +import com.google.cloud.bigquery.storage.v1.AppendRowsResponse; +import com.google.cloud.bigquery.storage.v1.BigQueryWriteClient; +import com.google.cloud.bigquery.storage.v1.StreamWriter; +import java.io.IOException; +import java.time.Duration; +import java.time.Instant; +import java.util.concurrent.CompletableFuture; +import java.util.logging.Handler; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; + +@RunWith(JUnit4.class) +public final class PluginStateTest { + private BigQueryLoggerConfig config; + private TestPluginState pluginState; + private Handler mockHandler; + private Logger pluginLogger; + private Level originalLevel; + + private static class TestPluginState extends PluginState { + TestPluginState(BigQueryLoggerConfig config) throws IOException { + super(config); + } + + private BigQueryWriteClient mockWriteClient; + + @Override + protected BigQueryWriteClient createWriteClient(BigQueryLoggerConfig config) { + mockWriteClient = mock(BigQueryWriteClient.class); + return mockWriteClient; + } + + BigQueryWriteClient getMockWriteClient() { + return mockWriteClient; + } + + @Override + protected StreamWriter createWriter() { + StreamWriter writer = mock(StreamWriter.class); + when(writer.append(any(ArrowRecordBatch.class))) + .thenReturn(ApiFutures.immediateFuture(AppendRowsResponse.newBuilder().build())); + return writer; + } + } + + @Before + public void setUp() throws IOException { + config = + BigQueryLoggerConfig.builder() + .projectId("test-project") + .datasetId("test-dataset") + .tableName("test-table") + .gcsBucketName("") + .build(); + pluginState = new TestPluginState(config); + + pluginLogger = Logger.getLogger(PluginState.class.getName()); + mockHandler = mock(Handler.class); + originalLevel = pluginLogger.getLevel(); + pluginLogger.setLevel(Level.INFO); + pluginLogger.addHandler(mockHandler); + } + + @After + public void tearDown() { + pluginLogger.removeHandler(mockHandler); + pluginLogger.setLevel(originalLevel); + } + + @Test + public void addPendingTask_removedTaskOnCompletion() { + String invocationId = "testInvocation"; + CompletableFuture task = new CompletableFuture<>(); + pluginState.addPendingTask(invocationId, task); + + task.complete(null); + pluginState.ensureInvocationCompleted(invocationId).blockingAwait(); + + // No specific log to check now, but we verify it completes without error. + } + + @Test + public void ensureInvocationCompleted_noTasks_succeeds() { + String invocationId = "testInvocation"; + + pluginState.ensureInvocationCompleted(invocationId).test().assertComplete(); + } + + @Test + public void ensureInvocationCompleted_executionException_completesSuccessfully() + throws InterruptedException { + String invocationId = "testInvocation"; + CompletableFuture task = new CompletableFuture<>(); + pluginState.addPendingTask(invocationId, task); + + task.completeExceptionally(new RuntimeException("test exception")); + + pluginState.ensureInvocationCompleted(invocationId).test().assertComplete(); + } + + @Test + public void ensureInvocationCompleted_interrupted_logsNothing() throws InterruptedException { + String invocationId = "testInvocation"; + CompletableFuture task = new CompletableFuture<>(); + pluginState.addPendingTask(invocationId, task); + + Thread testThread = + new Thread( + () -> { + pluginLogger.addHandler(mockHandler); + pluginState.ensureInvocationCompleted(invocationId).blockingAwait(); + }); + testThread.start(); + Thread.sleep(50); + testThread.interrupt(); + testThread.join(1000); + + // RxJava handles interruption differently, we just verify it doesn't crash here. + } + + @Test + public void ensureInvocationCompleted_timeout_logsWarning() throws IOException { + config = config.toBuilder().shutdownTimeout(Duration.ofMillis(100)).build(); + pluginState = new TestPluginState(config); + + String invocationId = "testInvocation"; + CompletableFuture task = new CompletableFuture<>(); // Never completes + pluginState.addPendingTask(invocationId, task); + + pluginState.ensureInvocationCompleted(invocationId).test().awaitDone(1, SECONDS); + + // Wait for cleanup side effects which run after terminal signal. + long deadline = Instant.now().plusMillis(1000).toEpochMilli(); + while (!pluginState.isProcessed(invocationId) && Instant.now().toEpochMilli() < deadline) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + ArgumentCaptor captor = ArgumentCaptor.forClass(LogRecord.class); + verify(mockHandler, atLeastOnce()).publish(captor.capture()); + + boolean found = + captor.getAllValues().stream() + .anyMatch( + record -> + record.getLevel().equals(Level.WARNING) + && record + .getMessage() + .contains("Timeout while waiting for pending tasks to complete")); + assertTrue( + "Expected log message 'Timeout while waiting for pending tasks to complete' not found", + found); + } + + @Test + public void ensureInvocationCompleted_timeout_cleansUpState() throws IOException { + config = config.toBuilder().shutdownTimeout(Duration.ofMillis(100)).build(); + pluginState = new TestPluginState(config); + + String invocationId = "testInvocation"; + CompletableFuture task = new CompletableFuture<>(); // Never completes + pluginState.addPendingTask(invocationId, task); + + // Populate processor and trace manager. + var unusedProcessor = pluginState.getBatchProcessor(invocationId); + var unusedTraceManager = pluginState.getTraceManager(invocationId); + + pluginState.ensureInvocationCompleted(invocationId).test().awaitDone(1, SECONDS); + + // Wait for cleanup side effects which run after terminal signal. + long deadline = Instant.now().plusMillis(1000).toEpochMilli(); + while (!pluginState.isProcessed(invocationId) && Instant.now().toEpochMilli() < deadline) { + try { + Thread.sleep(10); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + // Verify cleanup + assertTrue( + "Invocation ID should be marked as processed", pluginState.isProcessed(invocationId)); + assertTrue(pluginState.getBatchProcessors().isEmpty()); + assertTrue(pluginState.getTraceManagers().isEmpty()); + } + + @Test + public void close_succeedsAndCleansUp() throws Exception { + String invocationId = "testInvocation"; + CompletableFuture task = new CompletableFuture<>(); + pluginState.addPendingTask(invocationId, task); + + // Populate processor and trace manager. + var unusedProcessor = pluginState.getBatchProcessor(invocationId); + var unusedTraceManager = pluginState.getTraceManager(invocationId); + + // Complete the task so close doesn't time out. + task.complete(null); + + pluginState.close().test().assertComplete(); + + // Verify cleanup + assertTrue(pluginState.getBatchProcessors().isEmpty()); + assertTrue(pluginState.getTraceManagers().isEmpty()); + assertTrue(pluginState.getExecutor().isShutdown()); + } +} diff --git a/core/src/test/java/com/google/adk/sessions/StateTest.java b/core/src/test/java/com/google/adk/sessions/StateTest.java index e1fcaeadc..295466d1c 100644 --- a/core/src/test/java/com/google/adk/sessions/StateTest.java +++ b/core/src/test/java/com/google/adk/sessions/StateTest.java @@ -6,7 +6,6 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; -import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -22,11 +21,6 @@ public void constructor_nullDelta_createsEmptyConcurrentHashMap() { assertThat(state.hasDelta()).isTrue(); } - @Test - public void constructor_nullState_throwsException() { - Assert.assertThrows(NullPointerException.class, () -> new State(null, new HashMap<>())); - } - @Test public void constructor_regularMapState() { Map stateMap = new HashMap<>(); @@ -47,4 +41,14 @@ public void constructor_singleArgument() { state.put("key", "value"); assertThat(state.hasDelta()).isTrue(); } + + @Test + public void constructor_stateMapWithNullValues_replacesWithRemoved() { + Map stateMap = new HashMap<>(); + stateMap.put("key1", "value1"); + stateMap.put("key2", null); + State state = new State(stateMap); + assertThat(state).containsEntry("key1", "value1"); + assertThat(state).containsEntry("key2", State.REMOVED); + } } diff --git a/core/src/test/java/com/google/adk/skills/FrontmatterTest.java b/core/src/test/java/com/google/adk/skills/FrontmatterTest.java new file mode 100644 index 000000000..0f910eb46 --- /dev/null +++ b/core/src/test/java/com/google/adk/skills/FrontmatterTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class FrontmatterTest { + + private static final ObjectMapper yamlMapper = new ObjectMapper(new YAMLFactory()); + + @Test + public void testValidFrontmatter() throws Exception { + String yaml = + """ + name: test-skill + description: This is a test + allowed-tools: "tool1 tool2" + compatibility: "1.0" + """; + Frontmatter fm = yamlMapper.readValue(yaml, Frontmatter.class); + + assertThat(fm.name()).isEqualTo("test-skill"); + assertThat(fm.description()).isEqualTo("This is a test"); + assertThat(fm.allowedTools()).hasValue("tool1 tool2"); + assertThat(fm.compatibility()).hasValue("1.0"); + } + + @Test + public void testFrontmatterWithMetadata() throws Exception { + String yaml = + """ + name: test-skill-metadata + description: Test with metadata + metadata: + key1: value1 + key2: 123 + """; + Frontmatter fm = yamlMapper.readValue(yaml, Frontmatter.class); + + assertThat(fm.name()).isEqualTo("test-skill-metadata"); + assertThat(fm.metadata()).containsEntry("key1", "value1"); + assertThat(fm.metadata()).containsEntry("key2", 123); + } + + @Test + public void testInvalidName() { + Frontmatter.Builder builder = Frontmatter.builder().name("Invalid_Name").description("test"); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, builder::build); + assertThat(ex).hasMessageThat().contains("lowercase kebab-case"); + } + + @Test + public void testLongName() { + String longName = "a".repeat(65); + Frontmatter.Builder builder = Frontmatter.builder().name(longName).description("test"); + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, builder::build); + assertThat(ex).hasMessageThat().contains("must be at most 64 characters"); + } +} diff --git a/core/src/test/java/com/google/adk/skills/InMemorySkillSourceTest.java b/core/src/test/java/com/google/adk/skills/InMemorySkillSourceTest.java new file mode 100644 index 000000000..6723dfe0c --- /dev/null +++ b/core/src/test/java/com/google/adk/skills/InMemorySkillSourceTest.java @@ -0,0 +1,187 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.ByteSource; +import java.io.IOException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class InMemorySkillSourceTest { + + @Test + public void testListFrontmatters() { + Frontmatter fm1 = Frontmatter.builder().name("skill-1").description("desc1").build(); + Frontmatter fm2 = Frontmatter.builder().name("skill-2").description("desc2").build(); + + SkillSource source = + InMemorySkillSource.builder() + .skill("skill-1") + .frontmatter(fm1) + .instructions("body1") + .skill("skill-2") + .frontmatter(fm2) + .instructions("body2") + .build(); + + ImmutableMap frontmatters = source.listFrontmatters().blockingGet(); + + assertThat(frontmatters).hasSize(2); + assertThat(frontmatters.get("skill-1")).isEqualTo(fm1); + assertThat(frontmatters.get("skill-2")).isEqualTo(fm2); + } + + @Test + public void testListResources() { + Frontmatter fm = Frontmatter.builder().name("my-skill").description("desc").build(); + + SkillSource source = + InMemorySkillSource.builder() + .skill("my-skill") + .frontmatter(fm) + .instructions("body") + .addResource("assets/file1.txt", "content1") + .addResource("assets/subdir/file2.txt", "content2") + .addResource("other/file3.txt", "content3") + .build(); + + ImmutableList resources = source.listResources("my-skill", "assets").blockingGet(); + + assertThat(resources).containsExactly("assets/file1.txt", "assets/subdir/file2.txt"); + } + + @Test + public void testListResources_skillNotFound() { + SkillSource source = InMemorySkillSource.builder().build(); + + var single = source.listResources("non-existent", "assets"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception.getCause()).isInstanceOf(SkillSourceException.class); + } + + @Test + public void testListResources_directoryNotFound() { + Frontmatter fm = Frontmatter.builder().name("my-skill").description("desc").build(); + SkillSource source = + InMemorySkillSource.builder() + .skill("my-skill") + .frontmatter(fm) + .instructions("body") + .addResource("assets/file1.txt", "content1") + .build(); + + var single = source.listResources("my-skill", "non-existent"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception.getCause()).isInstanceOf(SkillSourceException.class); + } + + @Test + public void testLoadFrontmatter() { + Frontmatter fm = Frontmatter.builder().name("my-skill").description("desc").build(); + + SkillSource source = + InMemorySkillSource.builder() + .skill("my-skill") + .frontmatter(fm) + .instructions("body") + .build(); + + assertThat(source.loadFrontmatter("my-skill").blockingGet()).isEqualTo(fm); + } + + @Test + public void testLoadInstructions() { + Frontmatter fm = Frontmatter.builder().name("my-skill").description("desc").build(); + + SkillSource source = + InMemorySkillSource.builder() + .skill("my-skill") + .frontmatter(fm) + .instructions("my instructions") + .build(); + + assertThat(source.loadInstructions("my-skill").blockingGet()).isEqualTo("my instructions"); + } + + @Test + public void testLoadResource() throws IOException { + Frontmatter fm = Frontmatter.builder().name("my-skill").description("desc").build(); + + SkillSource source = + InMemorySkillSource.builder() + .skill("my-skill") + .frontmatter(fm) + .instructions("body") + .addResource("assets/file1.txt", "hello content") + .build(); + + ByteSource resource = source.loadResource("my-skill", "assets/file1.txt").blockingGet(); + + assertThat(new String(resource.read(), UTF_8)).isEqualTo("hello content"); + } + + @Test + public void testLoadResource_notFound() { + Frontmatter fm = Frontmatter.builder().name("my-skill").description("desc").build(); + + SkillSource source = + InMemorySkillSource.builder() + .skill("my-skill") + .frontmatter(fm) + .instructions("body") + .build(); + + var single = source.loadResource("my-skill", "non-existent.txt"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception.getCause()).isInstanceOf(SkillSourceException.class); + } + + @Test + public void testLoadFrontmatter_skillNotFound() { + SkillSource source = InMemorySkillSource.builder().build(); + + var single = source.loadFrontmatter("non-existent"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception.getCause()).isInstanceOf(SkillSourceException.class); + } + + @Test + public void testBuilder_missingFrontmatter() { + InMemorySkillSource.Builder builder = InMemorySkillSource.builder(); + builder.skill("my-skill").addResource("path", "content"); + + assertThrows(IllegalStateException.class, builder::build); + } + + @Test + public void testBuilder_missingInstructions() { + InMemorySkillSource.Builder builder = InMemorySkillSource.builder(); + Frontmatter fm = Frontmatter.builder().name("my-skill").description("desc").build(); + + builder.skill("my-skill").frontmatter(fm); + + assertThrows(IllegalStateException.class, builder::build); + } +} diff --git a/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java b/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java new file mode 100644 index 000000000..256f1d66a --- /dev/null +++ b/core/src/test/java/com/google/adk/skills/LocalSkillSourceTest.java @@ -0,0 +1,253 @@ +/* + * Copyright 2026 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.skills; + +import static com.google.common.truth.Truth.assertThat; +import static java.nio.charset.StandardCharsets.UTF_8; +import static org.junit.Assert.assertThrows; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.ByteSource; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public final class LocalSkillSourceTest { + + @Rule public TemporaryFolder tempFolder = new TemporaryFolder(); + + @Test + public void testListFrontmatters() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skill1 = skillsBase.resolve("skill-1"); + Files.createDirectory(skill1); + Files.writeString( + skill1.resolve("SKILL.md"), + """ + --- + name: skill-1 + description: test1 + --- + body + """); + + Path skill2 = skillsBase.resolve("skill-2"); + Files.createDirectory(skill2); + Files.writeString( + skill2.resolve("SKILL.md"), + """ + --- + name: skill-2 + description: test2 + --- + body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + ImmutableMap skills = source.listFrontmatters().blockingGet(); + + assertThat(skills).hasSize(2); + assertThat(skills).containsKey("skill-1"); + assertThat(skills).containsKey("skill-2"); + assertThat(skills.get("skill-1").description()).isEqualTo("test1"); + } + + @Test + public void testListResources() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Path assetsDir = skillDir.resolve("assets"); + Files.createDirectory(assetsDir); + + Files.createFile(assetsDir.resolve("file1.txt")); + Path subDir = assetsDir.resolve("subdir"); + Files.createDirectory(subDir); + Files.createFile(subDir.resolve("file2.txt")); + + SkillSource source = new LocalSkillSource(skillsBase); + ImmutableList resources = source.listResources("my-skill", "assets").blockingGet(); + + assertThat(resources).containsExactly("assets/file1.txt", "assets/subdir/file2.txt"); + } + + @Test + public void testListResources_notDirectory() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + // No assets directory created + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.listResources("my-skill", "assets"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + } + + @Test + public void testListResources_skillNotFound() { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.listResources("non-existent", "assets"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + } + + @Test + public void testLoadFrontmatter() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + --- + name: my-skill + description: This is a test skill + --- + body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + Frontmatter fm = source.loadFrontmatter("my-skill").blockingGet(); + + assertThat(fm.name()).isEqualTo("my-skill"); + assertThat(fm.description()).isEqualTo("This is a test skill"); + } + + @Test + public void testLoadInstructions() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + --- + name: my-skill + description: Test + --- + Some Markdown Body + """); + + SkillSource source = new LocalSkillSource(skillsBase); + String instructions = source.loadInstructions("my-skill").blockingGet(); + + assertThat(instructions).isEqualTo("Some Markdown Body"); + } + + @Test + public void testLoadInstructions_unclosedFrontmatter() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Files.writeString( + skillDir.resolve("SKILL.md"), + """ + --- + name: my-skill + description: Test + Some Markdown Body without closing dashes + """); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadInstructions("my-skill"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + assertThat(exception) + .hasCauseThat() + .hasMessageThat() + .contains("Skill file frontmatter not properly closed with ---"); + } + + @Test + public void testLoadResource() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + Path assetsDir = skillDir.resolve("assets"); + Files.createDirectory(assetsDir); + Path file = assetsDir.resolve("file1.txt"); + Files.writeString(file, "hello content"); + + SkillSource source = new LocalSkillSource(skillsBase); + ByteSource resource = source.loadResource("my-skill", "assets/file1.txt").blockingGet(); + + assertThat(new String(resource.read(), UTF_8)).isEqualTo("hello content"); + } + + @Test + public void testLoadResource_notFound() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + Path skillDir = skillsBase.resolve("my-skill"); + Files.createDirectory(skillDir); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadResource("my-skill", "non-existent.txt"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + } + + @Test + public void testLoadFrontmatter_skillNotFound() { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + + SkillSource source = new LocalSkillSource(skillsBase); + var single = source.loadFrontmatter("non-existent"); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + } + + @Test + public void testListSkillMdPaths_skillSourceException() throws IOException { + Path skillsBase = tempFolder.getRoot().toPath().resolve("skills"); + Files.createDirectory(skillsBase); + + SkillSource source = new LocalSkillSource(skillsBase); + + // Delete the directory to trigger IOException on Files.list + Files.delete(skillsBase); + + var single = source.listFrontmatters(); + RuntimeException exception = assertThrows(RuntimeException.class, single::blockingGet); + assertThat(exception).hasCauseThat().isInstanceOf(SkillSourceException.class); + } +} diff --git a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java index 0db218347..f00091d2d 100644 --- a/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java +++ b/core/src/test/java/com/google/adk/tools/mcp/McpToolsetTest.java @@ -31,6 +31,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.json.McpJsonDefaults; import io.modelcontextprotocol.json.McpJsonMapper; import io.modelcontextprotocol.spec.McpSchema; import java.util.List; @@ -50,7 +51,7 @@ public class McpToolsetTest { @Mock private McpSyncClient mockMcpSyncClient; @Mock private ReadonlyContext mockReadonlyContext; - private static final McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + private static final McpJsonMapper jsonMapper = McpJsonDefaults.getMapper(); private static final ImmutableMap STDIO_SERVER_PARAMS = ImmutableMap.of( diff --git a/core/src/test/java/com/google/adk/tools/mcp/StdioServerParametersTest.java b/core/src/test/java/com/google/adk/tools/mcp/StdioServerParametersTest.java index 166665309..7c970117d 100644 --- a/core/src/test/java/com/google/adk/tools/mcp/StdioServerParametersTest.java +++ b/core/src/test/java/com/google/adk/tools/mcp/StdioServerParametersTest.java @@ -22,6 +22,7 @@ import com.google.common.collect.ImmutableMap; import io.modelcontextprotocol.client.transport.ServerParameters; import io.modelcontextprotocol.client.transport.StdioClientTransport; +import io.modelcontextprotocol.json.McpJsonDefaults; import io.modelcontextprotocol.json.McpJsonMapper; import java.lang.reflect.Field; import java.util.List; @@ -34,7 +35,7 @@ @RunWith(JUnit4.class) public final class StdioServerParametersTest { - private static final McpJsonMapper jsonMapper = McpJsonMapper.getDefault(); + private static final McpJsonMapper jsonMapper = McpJsonDefaults.getMapper(); @Test public void toServerParameters_withNullArgs_createsValidServerParameters() { diff --git a/dev/pom.xml b/dev/pom.xml index 32cfa6441..8b5910f35 100644 --- a/dev/pom.xml +++ b/dev/pom.xml @@ -18,7 +18,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 google-adk-dev diff --git a/dev/src/main/java/com/google/adk/deploy/AgentEngineDeployer.java b/dev/src/main/java/com/google/adk/deploy/AgentEngineDeployer.java new file mode 100644 index 000000000..fa34954f3 --- /dev/null +++ b/dev/src/main/java/com/google/adk/deploy/AgentEngineDeployer.java @@ -0,0 +1,183 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.deploy; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.time.Instant; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Command line application to deploy an ADK Java Agent to Vertex AI Agent Engine (Reasoning + * Engine). + */ +class AgentEngineDeployer { + private static final Logger logger = Logger.getLogger(AgentEngineDeployer.class.getName()); + + private final String region; + private final String projectId; + private final String agentName; + private final int serverPort; + private final String sourceDir; + private final Path tempDir; + + private AgentEngineDeployer( + String region, + String projectId, + String agentName, + int serverPort, + String sourceDir, + Path tempDir) { + this.region = region; + this.projectId = projectId; + this.agentName = agentName; + this.serverPort = serverPort; + this.sourceDir = sourceDir; + this.tempDir = tempDir; + } + + /** Creates a temporary Dockerfile and bundles the application for Reasoning Engine deployment. */ + static Path prepareBundle(int serverPort) throws IOException { + Path tempDir = Files.createTempDirectory("agentEngineDeploy"); + Path dockerfile = tempDir.resolve("Dockerfile"); + + String dockerfileContent = + String.format( + "FROM eclipse-temurin:21-jdk\n" + + "WORKDIR /app\n" + + "COPY . .\n" + + "RUN ./mvnw clean package -DskipTests\n" + + "EXPOSE %d\n" + + "CMD [\"java\", \"-jar\", \"target/app.jar\"]\n", + serverPort); + + Files.writeString(dockerfile, dockerfileContent); + logger.info("Prepared Dockerfile at " + dockerfile.toAbsolutePath()); + return tempDir; + } + + /** Orchestrates the deployment process. */ + public void deploy() throws IOException { + logger.info("Starting Agent Engine deployment..."); + logger.info( + String.format( + "Deploying Agent '%s' to project '%s' in region '%s'...", + agentName, projectId, region)); + + // TODO: Integrate with Vertex AI CreateReasoningEngine API client. + logger.info("Preparation complete. Skipping actual deployment to Vertex AI for now."); + } + + /** Builder for {@link AgentEngineDeployer}. */ + public static class Builder { + private String region; + private String projectId; + private String agentName; + private int serverPort; + private String sourceDir; + + public Builder region(String region) { + this.region = region; + return this; + } + + public Builder projectId(String projectId) { + this.projectId = projectId; + return this; + } + + public Builder agentName(String agentName) { + this.agentName = agentName; + return this; + } + + public Builder serverPort(int serverPort) { + this.serverPort = serverPort; + return this; + } + + public Builder sourceDir(String sourceDir) { + this.sourceDir = sourceDir; + return this; + } + + public AgentEngineDeployer build() throws IOException { + if (projectId == null || projectId.isEmpty()) { + throw new IllegalStateException("Project ID must be specified."); + } + if (agentName == null || agentName.isEmpty()) { + agentName = "ADK Java Agent: " + Instant.now().toString(); + } + if (sourceDir == null || sourceDir.isEmpty()) { + sourceDir = System.getProperty("user.dir"); + } + Path tempDir = AgentEngineDeployer.prepareBundle(serverPort); + return new AgentEngineDeployer(region, projectId, agentName, serverPort, sourceDir, tempDir); + } + } + + public static Builder builder() { + return new Builder(); + } + + public static void main(String[] args) { + Builder builder = AgentEngineDeployer.builder().region("us-central1").serverPort(8080); + + // Minimal argument parsing logic + for (int i = 0; i < args.length; i++) { + switch (args[i]) { + case "--project": + if (i + 1 < args.length) { + builder.projectId(args[++i]); + } + break; + case "--region": + if (i + 1 < args.length) { + builder.region(args[++i]); + } + break; + case "--name": + if (i + 1 < args.length) { + builder.agentName(args[++i]); + } + break; + case "--port": + if (i + 1 < args.length) { + builder.serverPort(Integer.parseInt(args[++i])); + } + break; + case "--source-dir": + if (i + 1 < args.length) { + builder.sourceDir(args[++i]); + } + break; + default: + logger.warning("Unknown argument: " + args[i]); + } + } + + try { + AgentEngineDeployer deployer = builder.build(); + deployer.deploy(); + } catch (Exception e) { + logger.log(Level.SEVERE, "Deployment failed", e); + System.exit(1); + } + } +} diff --git a/dev/src/test/java/com/google/adk/deploy/AgentEngineDeployerTest.java b/dev/src/test/java/com/google/adk/deploy/AgentEngineDeployerTest.java new file mode 100644 index 000000000..ccd306c16 --- /dev/null +++ b/dev/src/test/java/com/google/adk/deploy/AgentEngineDeployerTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.google.adk.deploy; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AgentEngineDeployerTest { + + @Test + public void build_withMissingProjectId_throwsException() { + AgentEngineDeployer.Builder builder = AgentEngineDeployer.builder(); + // ProjectId is not set. + assertThrows(IllegalStateException.class, builder::build); + } + + @Test + public void deploy_createsDockerfileWithCorrectPort() throws IOException { + int serverPort = 9090; + + Path tempDir = AgentEngineDeployer.prepareBundle(serverPort); + Path dockerfile = tempDir.resolve("Dockerfile"); + + assertThat(Files.exists(dockerfile)).isTrue(); + String content = Files.readString(dockerfile); + assertThat(content).contains("EXPOSE " + serverPort); + } +} diff --git a/maven_plugin/examples/custom_tools/pom.xml b/maven_plugin/examples/custom_tools/pom.xml index 910e74439..be3bb3656 100644 --- a/maven_plugin/examples/custom_tools/pom.xml +++ b/maven_plugin/examples/custom_tools/pom.xml @@ -4,7 +4,7 @@ com.example custom-tools-example - 1.2.0 + 1.3.0 jar ADK Custom Tools Example diff --git a/maven_plugin/examples/simple-agent/pom.xml b/maven_plugin/examples/simple-agent/pom.xml index 17256f364..f3d22c47d 100644 --- a/maven_plugin/examples/simple-agent/pom.xml +++ b/maven_plugin/examples/simple-agent/pom.xml @@ -4,7 +4,7 @@ com.example simple-adk-agent - 1.2.0 + 1.3.0 jar Simple ADK Agent Example diff --git a/maven_plugin/pom.xml b/maven_plugin/pom.xml index 3cf9bd7bd..c6448e22f 100644 --- a/maven_plugin/pom.xml +++ b/maven_plugin/pom.xml @@ -5,7 +5,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../pom.xml diff --git a/pom.xml b/pom.xml index 4a20a5106..2902b8b90 100644 --- a/pom.xml +++ b/pom.xml @@ -17,7 +17,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 pom Google Agent Development Kit Maven Parent POM @@ -50,12 +50,12 @@ cloud libraries. Once they update their otel dependencies we can consider updating ours here as well --> 1.51.0 - 0.17.2 + 1.1.2 2.47.0 1.44.0 4.33.5 5.11.4 - 5.20.0 + 5.23.0 1.6.0 2.20.2 5.3.2 @@ -73,7 +73,7 @@ 3.27.7 2.15.0 3.9.0 - 5.6 + 5.6.1 4.1.118.Final @{jacoco.agent.argLine} --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.text=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/jdk.internal.misc=ALL-UNNAMED -Dio.netty.tryReflectionSetAccessible=true @@ -454,7 +454,7 @@ org.jacoco jacoco-maven-plugin - 0.8.12 + 0.8.14 diff --git a/tutorials/city-time-weather/pom.xml b/tutorials/city-time-weather/pom.xml index 8bcbb5887..262024ba1 100644 --- a/tutorials/city-time-weather/pom.xml +++ b/tutorials/city-time-weather/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../../pom.xml diff --git a/tutorials/live-audio-single-agent/pom.xml b/tutorials/live-audio-single-agent/pom.xml index 463565e7c..bdd2b3ff4 100644 --- a/tutorials/live-audio-single-agent/pom.xml +++ b/tutorials/live-audio-single-agent/pom.xml @@ -20,7 +20,7 @@ com.google.adk google-adk-parent - 1.2.0 + 1.3.0 ../../pom.xml