|
| 1 | +package com.google.adk.sessions; |
| 2 | + |
| 3 | +import static java.util.concurrent.TimeUnit.SECONDS; |
| 4 | + |
| 5 | +import com.fasterxml.jackson.databind.JsonNode; |
| 6 | +import com.fasterxml.jackson.databind.ObjectMapper; |
| 7 | +import com.google.adk.JsonBaseModel; |
| 8 | +import com.google.auth.oauth2.GoogleCredentials; |
| 9 | +import com.google.common.base.Splitter; |
| 10 | +import com.google.common.collect.Iterables; |
| 11 | +import com.google.genai.types.HttpOptions; |
| 12 | +import java.io.IOException; |
| 13 | +import java.io.UncheckedIOException; |
| 14 | +import java.util.List; |
| 15 | +import java.util.Optional; |
| 16 | +import java.util.concurrent.ConcurrentHashMap; |
| 17 | +import java.util.concurrent.ConcurrentMap; |
| 18 | +import javax.annotation.Nullable; |
| 19 | +import okhttp3.ResponseBody; |
| 20 | +import org.slf4j.Logger; |
| 21 | +import org.slf4j.LoggerFactory; |
| 22 | + |
| 23 | +/** Client for interacting with the Vertex AI Session API. */ |
| 24 | +final class VertexAiClient { |
| 25 | + private static final int MAX_RETRY_ATTEMPTS = 5; |
| 26 | + private static final ObjectMapper objectMapper = JsonBaseModel.getMapper(); |
| 27 | + private static final Logger logger = LoggerFactory.getLogger(VertexAiClient.class); |
| 28 | + |
| 29 | + private final HttpApiClient apiClient; |
| 30 | + |
| 31 | + VertexAiClient(String project, String location, HttpApiClient apiClient) { |
| 32 | + this.apiClient = apiClient; |
| 33 | + } |
| 34 | + |
| 35 | + VertexAiClient() { |
| 36 | + this.apiClient = |
| 37 | + new HttpApiClient(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); |
| 38 | + } |
| 39 | + |
| 40 | + VertexAiClient( |
| 41 | + String project, |
| 42 | + String location, |
| 43 | + Optional<GoogleCredentials> credentials, |
| 44 | + Optional<HttpOptions> httpOptions) { |
| 45 | + this.apiClient = |
| 46 | + new HttpApiClient(Optional.of(project), Optional.of(location), credentials, httpOptions); |
| 47 | + } |
| 48 | + |
| 49 | + @Nullable |
| 50 | + JsonNode createSession( |
| 51 | + String reasoningEngineId, String userId, ConcurrentMap<String, Object> state) { |
| 52 | + ConcurrentHashMap<String, Object> sessionJsonMap = new ConcurrentHashMap<>(); |
| 53 | + sessionJsonMap.put("userId", userId); |
| 54 | + if (state != null) { |
| 55 | + sessionJsonMap.put("sessionState", state); |
| 56 | + } |
| 57 | + |
| 58 | + String sessId; |
| 59 | + String operationId; |
| 60 | + try { |
| 61 | + String sessionJson = objectMapper.writeValueAsString(sessionJsonMap); |
| 62 | + try (ApiResponse apiResponse = |
| 63 | + apiClient.request( |
| 64 | + "POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson)) { |
| 65 | + logger.debug("Create Session response {}", apiResponse.getResponseBody()); |
| 66 | + if (apiResponse == null || apiResponse.getResponseBody() == null) { |
| 67 | + return null; |
| 68 | + } |
| 69 | + |
| 70 | + JsonNode jsonResponse = getJsonResponse(apiResponse); |
| 71 | + if (jsonResponse == null) { |
| 72 | + return null; |
| 73 | + } |
| 74 | + String sessionName = jsonResponse.get("name").asText(); |
| 75 | + List<String> parts = Splitter.on('/').splitToList(sessionName); |
| 76 | + sessId = parts.get(parts.size() - 3); |
| 77 | + operationId = Iterables.getLast(parts); |
| 78 | + } |
| 79 | + } catch (IOException e) { |
| 80 | + throw new UncheckedIOException(e); |
| 81 | + } |
| 82 | + |
| 83 | + for (int i = 0; i < MAX_RETRY_ATTEMPTS; i++) { |
| 84 | + try (ApiResponse lroResponse = apiClient.request("GET", "operations/" + operationId, "")) { |
| 85 | + JsonNode lroJsonResponse = getJsonResponse(lroResponse); |
| 86 | + if (lroJsonResponse != null && lroJsonResponse.get("done") != null) { |
| 87 | + break; |
| 88 | + } |
| 89 | + } |
| 90 | + try { |
| 91 | + SECONDS.sleep(1); |
| 92 | + } catch (InterruptedException e) { |
| 93 | + logger.warn("Error during sleep", e); |
| 94 | + Thread.currentThread().interrupt(); |
| 95 | + } |
| 96 | + } |
| 97 | + return getSession(reasoningEngineId, sessId); |
| 98 | + } |
| 99 | + |
| 100 | + JsonNode listSessions(String reasoningEngineId, String userId) { |
| 101 | + try (ApiResponse apiResponse = |
| 102 | + apiClient.request( |
| 103 | + "GET", |
| 104 | + "reasoningEngines/" + reasoningEngineId + "/sessions?filter=user_id=" + userId, |
| 105 | + "")) { |
| 106 | + return getJsonResponse(apiResponse); |
| 107 | + } |
| 108 | + } |
| 109 | + |
| 110 | + JsonNode listEvents(String reasoningEngineId, String sessionId) { |
| 111 | + try (ApiResponse apiResponse = |
| 112 | + apiClient.request( |
| 113 | + "GET", |
| 114 | + "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events", |
| 115 | + "")) { |
| 116 | + logger.debug("List events response {}", apiResponse); |
| 117 | + return getJsonResponse(apiResponse); |
| 118 | + } |
| 119 | + } |
| 120 | + |
| 121 | + JsonNode getSession(String reasoningEngineId, String sessionId) { |
| 122 | + try (ApiResponse apiResponse = |
| 123 | + apiClient.request( |
| 124 | + "GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) { |
| 125 | + return getJsonResponse(apiResponse); |
| 126 | + } |
| 127 | + } |
| 128 | + |
| 129 | + void deleteSession(String reasoningEngineId, String sessionId) { |
| 130 | + try (ApiResponse response = |
| 131 | + apiClient.request( |
| 132 | + "DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {} |
| 133 | + } |
| 134 | + |
| 135 | + void appendEvent(String reasoningEngineId, String sessionId, String eventJson) { |
| 136 | + try (ApiResponse response = |
| 137 | + apiClient.request( |
| 138 | + "POST", |
| 139 | + "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + ":appendEvent", |
| 140 | + eventJson)) { |
| 141 | + if (response.getResponseBody().string().contains("com.google.genai.errors.ClientException")) { |
| 142 | + logger.warn("Failed to append event: {}", eventJson); |
| 143 | + } |
| 144 | + } catch (IOException e) { |
| 145 | + throw new UncheckedIOException(e); |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + /** |
| 150 | + * Parses the JSON response body from the given API response. |
| 151 | + * |
| 152 | + * @throws UncheckedIOException if parsing fails. |
| 153 | + */ |
| 154 | + @Nullable |
| 155 | + private static JsonNode getJsonResponse(ApiResponse apiResponse) { |
| 156 | + if (apiResponse == null || apiResponse.getResponseBody() == null) { |
| 157 | + return null; |
| 158 | + } |
| 159 | + try { |
| 160 | + ResponseBody responseBody = apiResponse.getResponseBody(); |
| 161 | + String responseString = responseBody.string(); |
| 162 | + if (responseString.isEmpty()) { |
| 163 | + return null; |
| 164 | + } |
| 165 | + return objectMapper.readTree(responseString); |
| 166 | + } catch (IOException e) { |
| 167 | + throw new UncheckedIOException(e); |
| 168 | + } |
| 169 | + } |
| 170 | +} |
0 commit comments