Skip to content

Commit e884108

Browse files
committed
feat: Add inputAudioTranscription support to Java ADK
1 parent 9dfc4c3 commit e884108

13 files changed

Lines changed: 518 additions & 322 deletions

File tree

contrib/samples/helloworld/HelloWorldAgent.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
14-
package internal.samples.helloworld;
14+
package com.example.helloworld;
1515

1616
import com.google.adk.agents.LlmAgent;
1717
import com.google.adk.tools.FunctionTool;

contrib/samples/helloworld/HelloWorldRun.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package internal.samples.helloworld;
1+
package com.example.helloworld;
22

33
import com.google.adk.agents.RunConfig;
44
import com.google.adk.artifacts.InMemoryArtifactService;
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Hello World Agent Sample
2+
3+
This directory contains the minimal Java sample for the Google ADK. It defines a
4+
single agent (`com.example.helloworld.HelloWorldAgent`) and a small console
5+
runner (`HelloWorldRun`) that demonstrates tool invocation for dice rolling and
6+
prime checking.
7+
8+
For configuration-driven examples that complement this code sample, see the
9+
config-based collection in `../configagent/README.md`.
10+
11+
## Project Layout
12+
13+
```
14+
├── HelloWorldAgent.java // Agent definition and tool wiring
15+
├── HelloWorldRun.java // Console runner entry point
16+
├── pom.xml // Maven configuration and exec main class
17+
└── README.md // This file
18+
```
19+
20+
## Prerequisites
21+
22+
- Java 17+
23+
- Maven 3.9+
24+
25+
## Build and Run
26+
27+
Compile the project and launch the sample conversation:
28+
29+
```bash
30+
mvn clean compile exec:java
31+
```
32+
33+
The runner sends a starter prompt (`Hi. Roll a die of 60 sides.`) and prints the
34+
agent's response. To explore additional prompts, pass the `--run-extended`
35+
argument:
36+
37+
```bash
38+
mvn exec:java -Dexec.args="--run-extended"
39+
```
40+
41+
## Next Steps
42+
43+
* Review `HelloWorldAgent.java` to see how function tools are registered.
44+
* Compare with the configuration-based samples in `../configagent/README.md` for
45+
more complex agent setups (callbacks, multi-agent coordination, and custom
46+
registries).

contrib/samples/helloworld/pom.xml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
<name>Google ADK - Sample - Hello World</name>
2424
<description>
2525
A sample "Hello World" application demonstrating basic agent and tool usage with the Google ADK,
26-
runnable via internal.samples.helloworld.HelloWorldRun.
26+
runnable via com.example.helloworld.HelloWorldRun.
2727
</description>
2828
<packaging>jar</packaging>
2929

@@ -32,7 +32,7 @@
3232
<java.version>17</java.version>
3333
<auto-value.version>1.11.0</auto-value.version>
3434
<!-- Main class for exec-maven-plugin -->
35-
<exec.mainClass>internal.samples.helloworld.HelloWorldRun</exec.mainClass>
35+
<exec.mainClass>com.example.helloworld.HelloWorldRun</exec.mainClass>
3636
<google-adk.version>0.3.0</google-adk.version>
3737
</properties>
3838

core/src/main/java/com/google/adk/agents/RunConfig.java

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,12 @@ public enum StreamingMode {
4848

4949
public abstract @Nullable AudioTranscriptionConfig outputAudioTranscription();
5050

51+
public abstract @Nullable AudioTranscriptionConfig inputAudioTranscription();
52+
5153
public abstract int maxLlmCalls();
5254

55+
public abstract Builder toBuilder();
56+
5357
public static Builder builder() {
5458
return new AutoValue_RunConfig.Builder()
5559
.setSaveInputBlobsAsArtifacts(false)
@@ -65,7 +69,8 @@ public static Builder builder(RunConfig runConfig) {
6569
.setMaxLlmCalls(runConfig.maxLlmCalls())
6670
.setResponseModalities(runConfig.responseModalities())
6771
.setSpeechConfig(runConfig.speechConfig())
68-
.setOutputAudioTranscription(runConfig.outputAudioTranscription());
72+
.setOutputAudioTranscription(runConfig.outputAudioTranscription())
73+
.setInputAudioTranscription(runConfig.inputAudioTranscription());
6974
}
7075

7176
/** Builder for {@link RunConfig}. */
@@ -88,6 +93,10 @@ public abstract static class Builder {
8893
public abstract Builder setOutputAudioTranscription(
8994
AudioTranscriptionConfig outputAudioTranscription);
9095

96+
@CanIgnoreReturnValue
97+
public abstract Builder setInputAudioTranscription(
98+
AudioTranscriptionConfig inputAudioTranscription);
99+
91100
@CanIgnoreReturnValue
92101
public abstract Builder setMaxLlmCalls(int maxLlmCalls);
93102

core/src/main/java/com/google/adk/flows/llmflows/Basic.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ public Single<RequestProcessor.RequestProcessingResult> processRequest(
4848
.ifPresent(liveConnectConfigBuilder::speechConfig);
4949
Optional.ofNullable(context.runConfig().outputAudioTranscription())
5050
.ifPresent(liveConnectConfigBuilder::outputAudioTranscription);
51+
Optional.ofNullable(context.runConfig().inputAudioTranscription())
52+
.ifPresent(liveConnectConfigBuilder::inputAudioTranscription);
5153

5254
LlmRequest.Builder builder =
5355
request.toBuilder()

core/src/main/java/com/google/adk/runner/Runner.java

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,9 @@ public Flowable<Event> runAsync(Session session, Content newMessage, RunConfig r
310310
private InvocationContext newInvocationContextForLive(
311311
Session session, Optional<LiveRequestQueue> liveRequestQueue, RunConfig runConfig) {
312312
RunConfig.Builder runConfigBuilder = RunConfig.builder(runConfig);
313-
if (!CollectionUtils.isNullOrEmpty(runConfig.responseModalities())
314-
&& liveRequestQueue.isPresent()) {
313+
if (liveRequestQueue.isPresent() && !this.agent.subAgents().isEmpty()) {
314+
// Parity with Python: apply modality defaults and transcription settings
315+
// only for multi-agent live scenarios.
315316
// Default to AUDIO modality if not specified.
316317
if (CollectionUtils.isNullOrEmpty(runConfig.responseModalities())) {
317318
runConfigBuilder.setResponseModalities(
@@ -324,6 +325,10 @@ private InvocationContext newInvocationContextForLive(
324325
runConfigBuilder.setOutputAudioTranscription(AudioTranscriptionConfig.builder().build());
325326
}
326327
}
328+
// Need input transcription for agent transferring in live mode.
329+
if (runConfig.inputAudioTranscription() == null) {
330+
runConfigBuilder.setInputAudioTranscription(AudioTranscriptionConfig.builder().build());
331+
}
327332
}
328333
return newInvocationContext(
329334
session, /* newMessage= */ Optional.empty(), liveRequestQueue, runConfigBuilder.build());
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
package com.google.adk.sessions;
2+
3+
import static java.util.concurrent.TimeUnit.SECONDS;
4+
5+
import com.fasterxml.jackson.databind.JsonNode;
6+
import com.fasterxml.jackson.databind.ObjectMapper;
7+
import com.google.adk.JsonBaseModel;
8+
import com.google.auth.oauth2.GoogleCredentials;
9+
import com.google.common.base.Splitter;
10+
import com.google.common.collect.Iterables;
11+
import com.google.genai.types.HttpOptions;
12+
import java.io.IOException;
13+
import java.io.UncheckedIOException;
14+
import java.util.List;
15+
import java.util.Optional;
16+
import java.util.concurrent.ConcurrentHashMap;
17+
import java.util.concurrent.ConcurrentMap;
18+
import javax.annotation.Nullable;
19+
import okhttp3.ResponseBody;
20+
import org.slf4j.Logger;
21+
import org.slf4j.LoggerFactory;
22+
23+
/** Client for interacting with the Vertex AI Session API. */
24+
final class VertexAiClient {
25+
private static final int MAX_RETRY_ATTEMPTS = 5;
26+
private static final ObjectMapper objectMapper = JsonBaseModel.getMapper();
27+
private static final Logger logger = LoggerFactory.getLogger(VertexAiClient.class);
28+
29+
private final HttpApiClient apiClient;
30+
31+
VertexAiClient(String project, String location, HttpApiClient apiClient) {
32+
this.apiClient = apiClient;
33+
}
34+
35+
VertexAiClient() {
36+
this.apiClient =
37+
new HttpApiClient(Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty());
38+
}
39+
40+
VertexAiClient(
41+
String project,
42+
String location,
43+
Optional<GoogleCredentials> credentials,
44+
Optional<HttpOptions> httpOptions) {
45+
this.apiClient =
46+
new HttpApiClient(Optional.of(project), Optional.of(location), credentials, httpOptions);
47+
}
48+
49+
@Nullable
50+
JsonNode createSession(
51+
String reasoningEngineId, String userId, ConcurrentMap<String, Object> state) {
52+
ConcurrentHashMap<String, Object> sessionJsonMap = new ConcurrentHashMap<>();
53+
sessionJsonMap.put("userId", userId);
54+
if (state != null) {
55+
sessionJsonMap.put("sessionState", state);
56+
}
57+
58+
String sessId;
59+
String operationId;
60+
try {
61+
String sessionJson = objectMapper.writeValueAsString(sessionJsonMap);
62+
try (ApiResponse apiResponse =
63+
apiClient.request(
64+
"POST", "reasoningEngines/" + reasoningEngineId + "/sessions", sessionJson)) {
65+
logger.debug("Create Session response {}", apiResponse.getResponseBody());
66+
if (apiResponse == null || apiResponse.getResponseBody() == null) {
67+
return null;
68+
}
69+
70+
JsonNode jsonResponse = getJsonResponse(apiResponse);
71+
if (jsonResponse == null) {
72+
return null;
73+
}
74+
String sessionName = jsonResponse.get("name").asText();
75+
List<String> parts = Splitter.on('/').splitToList(sessionName);
76+
sessId = parts.get(parts.size() - 3);
77+
operationId = Iterables.getLast(parts);
78+
}
79+
} catch (IOException e) {
80+
throw new UncheckedIOException(e);
81+
}
82+
83+
for (int i = 0; i < MAX_RETRY_ATTEMPTS; i++) {
84+
try (ApiResponse lroResponse = apiClient.request("GET", "operations/" + operationId, "")) {
85+
JsonNode lroJsonResponse = getJsonResponse(lroResponse);
86+
if (lroJsonResponse != null && lroJsonResponse.get("done") != null) {
87+
break;
88+
}
89+
}
90+
try {
91+
SECONDS.sleep(1);
92+
} catch (InterruptedException e) {
93+
logger.warn("Error during sleep", e);
94+
Thread.currentThread().interrupt();
95+
}
96+
}
97+
return getSession(reasoningEngineId, sessId);
98+
}
99+
100+
JsonNode listSessions(String reasoningEngineId, String userId) {
101+
try (ApiResponse apiResponse =
102+
apiClient.request(
103+
"GET",
104+
"reasoningEngines/" + reasoningEngineId + "/sessions?filter=user_id=" + userId,
105+
"")) {
106+
return getJsonResponse(apiResponse);
107+
}
108+
}
109+
110+
JsonNode listEvents(String reasoningEngineId, String sessionId) {
111+
try (ApiResponse apiResponse =
112+
apiClient.request(
113+
"GET",
114+
"reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + "/events",
115+
"")) {
116+
logger.debug("List events response {}", apiResponse);
117+
return getJsonResponse(apiResponse);
118+
}
119+
}
120+
121+
JsonNode getSession(String reasoningEngineId, String sessionId) {
122+
try (ApiResponse apiResponse =
123+
apiClient.request(
124+
"GET", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {
125+
return getJsonResponse(apiResponse);
126+
}
127+
}
128+
129+
void deleteSession(String reasoningEngineId, String sessionId) {
130+
try (ApiResponse response =
131+
apiClient.request(
132+
"DELETE", "reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId, "")) {}
133+
}
134+
135+
void appendEvent(String reasoningEngineId, String sessionId, String eventJson) {
136+
try (ApiResponse response =
137+
apiClient.request(
138+
"POST",
139+
"reasoningEngines/" + reasoningEngineId + "/sessions/" + sessionId + ":appendEvent",
140+
eventJson)) {
141+
if (response.getResponseBody().string().contains("com.google.genai.errors.ClientException")) {
142+
logger.warn("Failed to append event: {}", eventJson);
143+
}
144+
} catch (IOException e) {
145+
throw new UncheckedIOException(e);
146+
}
147+
}
148+
149+
/**
150+
* Parses the JSON response body from the given API response.
151+
*
152+
* @throws UncheckedIOException if parsing fails.
153+
*/
154+
@Nullable
155+
private static JsonNode getJsonResponse(ApiResponse apiResponse) {
156+
if (apiResponse == null || apiResponse.getResponseBody() == null) {
157+
return null;
158+
}
159+
try {
160+
ResponseBody responseBody = apiResponse.getResponseBody();
161+
String responseString = responseBody.string();
162+
if (responseString.isEmpty()) {
163+
return null;
164+
}
165+
return objectMapper.readTree(responseString);
166+
} catch (IOException e) {
167+
throw new UncheckedIOException(e);
168+
}
169+
}
170+
}

0 commit comments

Comments
 (0)