Skip to content

Commit b50db1b

Browse files
Poggeccicopybara-github
authored andcommitted
fix!: update basellmflow postprocessing to allow emitting original response prior to generating new events
PiperOrigin-RevId: 814456467
1 parent 76e4b5a commit b50db1b

1 file changed

Lines changed: 42 additions & 58 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 42 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@
5252
import io.reactivex.rxjava3.observers.DisposableCompletableObserver;
5353
import io.reactivex.rxjava3.schedulers.Schedulers;
5454
import java.util.ArrayList;
55-
import java.util.Collections;
5655
import java.util.List;
5756
import java.util.Optional;
5857
import java.util.Set;
@@ -131,9 +130,10 @@ protected Single<RequestProcessingResult> preprocess(
131130

132131
/**
133132
* Post-processes the LLM response after receiving it from the LLM. Executes all registered {@link
134-
* ResponseProcessor} instances. Handles function calls if present in the response.
133+
* ResponseProcessor} instances. Emits events for the model response and any subsequent function
134+
* calls.
135135
*/
136-
protected Single<ResponseProcessingResult> postprocess(
136+
protected Flowable<Event> postprocess(
137137
InvocationContext context,
138138
Event baseEventForLlmResponse,
139139
LlmRequest llmRequest,
@@ -154,46 +154,36 @@ protected Single<ResponseProcessingResult> postprocess(
154154
.map(ResponseProcessingResult::updatedResponse);
155155
}
156156

157-
return currentLlmResponse.flatMap(
157+
return currentLlmResponse.flatMapPublisher(
158158
updatedResponse -> {
159+
Flowable<Event> processorEvents = Flowable.fromIterable(Iterables.concat(eventIterables));
160+
159161
if (updatedResponse.content().isEmpty()
160162
&& updatedResponse.errorCode().isEmpty()
161163
&& !updatedResponse.interrupted().orElse(false)
162164
&& !updatedResponse.turnComplete().orElse(false)) {
163-
return Single.just(
164-
ResponseProcessingResult.create(
165-
updatedResponse, Iterables.concat(eventIterables), Optional.empty()));
165+
return processorEvents;
166166
}
167167

168168
Event modelResponseEvent =
169169
buildModelResponseEvent(baseEventForLlmResponse, llmRequest, updatedResponse);
170-
eventIterables.add(Collections.singleton(modelResponseEvent));
171170

172-
Maybe<Event> maybeFunctionCallEvent;
171+
Flowable<Event> modelEventStream = Flowable.just(modelResponseEvent);
172+
173173
if (modelResponseEvent.functionCalls().isEmpty()) {
174-
maybeFunctionCallEvent = Maybe.empty();
175-
} else if (context.runConfig().streamingMode() == StreamingMode.BIDI) {
174+
return processorEvents.concatWith(modelEventStream);
175+
}
176+
177+
Maybe<Event> maybeFunctionCallEvent;
178+
if (context.runConfig().streamingMode() == StreamingMode.BIDI) {
176179
maybeFunctionCallEvent =
177180
Functions.handleFunctionCallsLive(context, modelResponseEvent, llmRequest.tools());
178181
} else {
179182
maybeFunctionCallEvent =
180183
Functions.handleFunctionCalls(context, modelResponseEvent, llmRequest.tools());
181184
}
182-
return maybeFunctionCallEvent
183-
.map(Optional::of)
184-
.defaultIfEmpty(Optional.empty())
185-
.map(
186-
functionCallEventOpt -> {
187-
Optional<String> transferToAgent = Optional.empty();
188-
if (functionCallEventOpt.isPresent()) {
189-
Event functionCallEvent = functionCallEventOpt.get();
190-
eventIterables.add(Collections.singleton(functionCallEvent));
191-
transferToAgent = functionCallEvent.actions().transferToAgent();
192-
}
193-
Iterable<Event> combinedEvents = Iterables.concat(eventIterables);
194-
return ResponseProcessingResult.create(
195-
updatedResponse, combinedEvents, transferToAgent);
196-
});
185+
186+
return processorEvents.concatWith(modelEventStream).concatWith(maybeFunctionCallEvent);
197187
});
198188
}
199189

@@ -374,33 +364,27 @@ private Flowable<Event> runOneStep(InvocationContext context) {
374364
Flowable<Event> restOfFlow =
375365
callLlm(context, llmRequestAfterPreprocess, mutableEventTemplate)
376366
.concatMap(
377-
llmResponse -> {
378-
Single<ResponseProcessingResult> postResultSingle =
379-
postprocess(
380-
context,
381-
mutableEventTemplate,
382-
llmRequestAfterPreprocess,
383-
llmResponse);
384-
385-
return postResultSingle
386-
.doOnSuccess(
387-
ignored -> {
388-
String oldId = mutableEventTemplate.id();
389-
mutableEventTemplate.setId(Event.generateEventId());
390-
logger.debug(
391-
"Updated mutableEventTemplate ID from {} to {} for next"
392-
+ " LlmResponse",
393-
oldId,
394-
mutableEventTemplate.id());
395-
})
396-
.toFlowable();
397-
})
367+
llmResponse ->
368+
postprocess(
369+
context,
370+
mutableEventTemplate,
371+
llmRequestAfterPreprocess,
372+
llmResponse)
373+
.doFinally(
374+
() -> {
375+
String oldId = mutableEventTemplate.id();
376+
mutableEventTemplate.setId(Event.generateEventId());
377+
logger.debug(
378+
"Updated mutableEventTemplate ID from {} to {} for"
379+
+ " next LlmResponse",
380+
oldId,
381+
mutableEventTemplate.id());
382+
}))
398383
.concatMap(
399-
postResult -> {
400-
Flowable<Event> postProcessedEvents =
401-
Flowable.fromIterable(postResult.events());
402-
if (postResult.transferToAgent().isPresent()) {
403-
String agentToTransfer = postResult.transferToAgent().get();
384+
event -> {
385+
Flowable<Event> postProcessedEvents = Flowable.just(event);
386+
if (event.actions().transferToAgent().isPresent()) {
387+
String agentToTransfer = event.actions().transferToAgent().get();
404388
logger.debug("Transferring to agent: {}", agentToTransfer);
405389
BaseAgent rootAgent = context.agent().rootAgent();
406390
BaseAgent nextAgent = rootAgent.findAgent(agentToTransfer);
@@ -569,7 +553,7 @@ public void onError(Throwable e) {
569553
Flowable<Event> receiveFlow =
570554
connection
571555
.receive()
572-
.flatMapSingle(
556+
.flatMap(
573557
llmResponse -> {
574558
Event baseEventForThisLlmResponse =
575559
liveEventBuilderTemplate.id(Event.generateEventId()).build();
@@ -580,15 +564,15 @@ public void onError(Throwable e) {
580564
llmResponse);
581565
})
582566
.flatMap(
583-
postResult -> {
584-
Flowable<Event> events = Flowable.fromIterable(postResult.events());
585-
if (postResult.transferToAgent().isPresent()) {
567+
event -> {
568+
Flowable<Event> events = Flowable.just(event);
569+
if (event.actions().transferToAgent().isPresent()) {
586570
BaseAgent rootAgent = invocationContext.agent().rootAgent();
587571
BaseAgent nextAgent =
588-
rootAgent.findAgent(postResult.transferToAgent().get());
572+
rootAgent.findAgent(event.actions().transferToAgent().get());
589573
if (nextAgent == null) {
590574
throw new IllegalStateException(
591-
"Agent not found: " + postResult.transferToAgent().get());
575+
"Agent not found: " + event.actions().transferToAgent().get());
592576
}
593577
Flowable<Event> nextAgentEvents =
594578
nextAgent.runLive(invocationContext);

0 commit comments

Comments
 (0)