Skip to content

Commit 677b6d7

Browse files
Mateusz Krawieccopybara-github
authored andcommitted
fix: parallel agent execution
PiperOrigin-RevId: 889140710
1 parent 8a7f816 commit 677b6d7

2 files changed

Lines changed: 108 additions & 8 deletions

File tree

core/src/main/java/com/google/adk/agents/ParallelAgent.java

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
package com.google.adk.agents;
1717

1818
import static com.google.common.base.Strings.isNullOrEmpty;
19-
import static com.google.common.collect.ImmutableList.toImmutableList;
2019

2120
import com.google.adk.agents.ConfigAgentUtils.ConfigurationException;
2221
import com.google.adk.events.Event;
22+
import com.google.errorprone.annotations.CanIgnoreReturnValue;
2323
import io.reactivex.rxjava3.core.Flowable;
24+
import io.reactivex.rxjava3.core.Scheduler;
25+
import io.reactivex.rxjava3.schedulers.Schedulers;
26+
import java.util.ArrayList;
2427
import java.util.List;
2528
import org.slf4j.Logger;
2629
import org.slf4j.LoggerFactory;
@@ -35,6 +38,7 @@
3538
public class ParallelAgent extends BaseAgent {
3639

3740
private static final Logger logger = LoggerFactory.getLogger(ParallelAgent.class);
41+
private final Scheduler scheduler;
3842

3943
/**
4044
* Constructor for ParallelAgent.
@@ -44,24 +48,35 @@ public class ParallelAgent extends BaseAgent {
4448
* @param subAgents The list of sub-agents to run in parallel.
4549
* @param beforeAgentCallback Optional callback before the agent runs.
4650
* @param afterAgentCallback Optional callback after the agent runs.
51+
* @param scheduler The scheduler to use for parallel execution.
4752
*/
4853
private ParallelAgent(
4954
String name,
5055
String description,
5156
List<? extends BaseAgent> subAgents,
5257
List<Callbacks.BeforeAgentCallback> beforeAgentCallback,
53-
List<Callbacks.AfterAgentCallback> afterAgentCallback) {
58+
List<Callbacks.AfterAgentCallback> afterAgentCallback,
59+
Scheduler scheduler) {
5460

5561
super(name, description, subAgents, beforeAgentCallback, afterAgentCallback);
62+
this.scheduler = scheduler;
5663
}
5764

5865
/** Builder for {@link ParallelAgent}. */
5966
public static class Builder extends BaseAgent.Builder<Builder> {
6067

68+
private Scheduler scheduler = Schedulers.io();
69+
70+
@CanIgnoreReturnValue
71+
public Builder scheduler(Scheduler scheduler) {
72+
this.scheduler = scheduler;
73+
return this;
74+
}
75+
6176
@Override
6277
public ParallelAgent build() {
6378
return new ParallelAgent(
64-
name, description, subAgents, beforeAgentCallback, afterAgentCallback);
79+
name, description, subAgents, beforeAgentCallback, afterAgentCallback, scheduler);
6580
}
6681
}
6782

@@ -129,10 +144,11 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
129144
}
130145

131146
var updatedInvocationContext = setBranchForCurrentAgent(this, invocationContext);
132-
return Flowable.merge(
133-
currentSubAgents.stream()
134-
.map(subAgent -> subAgent.runAsync(updatedInvocationContext))
135-
.collect(toImmutableList()));
147+
List<Flowable<Event>> agentFlowables = new ArrayList<>();
148+
for (BaseAgent subAgent : currentSubAgents) {
149+
agentFlowables.add(subAgent.runAsync(updatedInvocationContext).subscribeOn(scheduler));
150+
}
151+
return Flowable.merge(agentFlowables);
136152
}
137153

138154
/**

core/src/test/java/com/google/adk/agents/ParallelAgentTest.java

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,10 @@
2525
import com.google.genai.types.Content;
2626
import com.google.genai.types.Part;
2727
import io.reactivex.rxjava3.core.Flowable;
28+
import io.reactivex.rxjava3.core.Scheduler;
2829
import io.reactivex.rxjava3.schedulers.Schedulers;
30+
import io.reactivex.rxjava3.schedulers.TestScheduler;
31+
import io.reactivex.rxjava3.subscribers.TestSubscriber;
2932
import java.util.List;
3033
import org.junit.Test;
3134
import org.junit.runner.RunWith;
@@ -36,10 +39,16 @@ public final class ParallelAgentTest {
3639

3740
static class TestingAgent extends BaseAgent {
3841
private final long delayMillis;
42+
private final Scheduler scheduler;
3943

4044
private TestingAgent(String name, String description, long delayMillis) {
45+
this(name, description, delayMillis, Schedulers.computation());
46+
}
47+
48+
private TestingAgent(String name, String description, long delayMillis, Scheduler scheduler) {
4149
super(name, description, ImmutableList.of(), null, null);
4250
this.delayMillis = delayMillis;
51+
this.scheduler = scheduler;
4352
}
4453

4554
@Override
@@ -55,7 +64,7 @@ protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
5564
.build());
5665

5766
if (delayMillis > 0) {
58-
return event.delay(delayMillis, MILLISECONDS, Schedulers.computation());
67+
return event.delay(delayMillis, MILLISECONDS, scheduler);
5968
}
6069
return event;
6170
}
@@ -110,4 +119,79 @@ public void runAsync_noSubAgents_returnsEmptyFlowable() {
110119

111120
assertThat(events).isEmpty();
112121
}
122+
123+
static class BlockingAgent extends BaseAgent {
124+
private final long sleepMillis;
125+
126+
private BlockingAgent(String name, long sleepMillis) {
127+
super(name, "Blocking Agent", ImmutableList.of(), null, null);
128+
this.sleepMillis = sleepMillis;
129+
}
130+
131+
@Override
132+
protected Flowable<Event> runAsyncImpl(InvocationContext invocationContext) {
133+
return Flowable.fromCallable(
134+
() -> {
135+
Thread.sleep(sleepMillis);
136+
return Event.builder()
137+
.author(name())
138+
.branch(invocationContext.branch().orElse(null))
139+
.invocationId(invocationContext.invocationId())
140+
.content(Content.fromParts(Part.fromText("Done")))
141+
.build();
142+
});
143+
}
144+
145+
@Override
146+
protected Flowable<Event> runLiveImpl(InvocationContext invocationContext) {
147+
throw new UnsupportedOperationException("Not implemented");
148+
}
149+
}
150+
151+
@Test
152+
public void runAsync_blockingSubAgents_shouldExecuteInParallel() {
153+
long sleepTime = 1000;
154+
BlockingAgent agent1 = new BlockingAgent("agent1", sleepTime);
155+
BlockingAgent agent2 = new BlockingAgent("agent2", sleepTime);
156+
157+
ParallelAgent parallelAgent =
158+
ParallelAgent.builder().name("parallel_agent").subAgents(agent1, agent2).build();
159+
160+
InvocationContext invocationContext = createInvocationContext(parallelAgent);
161+
162+
long startTime = System.currentTimeMillis();
163+
List<Event> events = parallelAgent.runAsync(invocationContext).toList().blockingGet();
164+
long duration = System.currentTimeMillis() - startTime;
165+
166+
assertThat(events).hasSize(2);
167+
// If parallel, duration should be less than 1.5 * sleepTime (1500ms).
168+
assertThat(duration).isAtLeast(sleepTime);
169+
assertThat(duration).isLessThan((long) (1.5 * sleepTime));
170+
}
171+
172+
@Test
173+
public void runAsync_withTestScheduler_usesVirtualTime() {
174+
TestScheduler testScheduler = new TestScheduler();
175+
long delayMillis = 1000;
176+
TestingAgent agent =
177+
new TestingAgent("delayed_agent", "Delayed Agent", delayMillis, testScheduler);
178+
179+
ParallelAgent parallelAgent =
180+
ParallelAgent.builder()
181+
.name("parallel_agent")
182+
.subAgents(agent)
183+
.scheduler(testScheduler)
184+
.build();
185+
186+
InvocationContext invocationContext = createInvocationContext(parallelAgent);
187+
188+
TestSubscriber<Event> testSubscriber = parallelAgent.runAsync(invocationContext).test();
189+
190+
testScheduler.advanceTimeBy(delayMillis - 100, MILLISECONDS);
191+
testSubscriber.assertNoValues();
192+
testSubscriber.assertNotComplete();
193+
testScheduler.advanceTimeBy(200, MILLISECONDS);
194+
testSubscriber.assertValueCount(1);
195+
testSubscriber.assertComplete();
196+
}
113197
}

0 commit comments

Comments
 (0)