Skip to content

Commit 4b76cbc

Browse files
authored
[AI] Add TemplateChat for multi-turn template interactions (#7986)
Introduces `TemplateChat` for managing multi-turn conversations with server-side prompt templates, automatically handling chat history. Based on the code from #7954 Internal b/496905495
1 parent 51b37cc commit 4b76cbc

7 files changed

Lines changed: 399 additions & 3 deletions

File tree

ai-logic/firebase-ai/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Unreleased
22

3+
- [feature] Added support for Chat interactions using server prompt templates (#7986)
34
- [fixed] Fixed an issue causing network timeouts to throw the incorrect exception type, instead of
45
`RequestTimeoutException` (#7966)
56
- [fixed] Fixed missing `toString()` implemenation for `InferenceSource` (#7970)

ai-logic/firebase-ai/api.txt

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,19 @@ package com.google.firebase.ai {
137137
public static final class OnDeviceConfig.Companion {
138138
}
139139

140+
@com.google.firebase.ai.type.PublicPreviewAPI public final class TemplateChat {
141+
method public java.util.List<com.google.firebase.ai.type.Content> getHistory();
142+
method public suspend Object? sendMessage(com.google.firebase.ai.type.Content prompt, kotlin.coroutines.Continuation<? super com.google.firebase.ai.type.GenerateContentResponse>);
143+
method public suspend Object? sendMessage(String prompt, kotlin.coroutines.Continuation<? super com.google.firebase.ai.type.GenerateContentResponse>);
144+
method public kotlinx.coroutines.flow.Flow<com.google.firebase.ai.type.GenerateContentResponse> sendMessageStream(com.google.firebase.ai.type.Content prompt);
145+
method public kotlinx.coroutines.flow.Flow<com.google.firebase.ai.type.GenerateContentResponse> sendMessageStream(String prompt);
146+
property public final java.util.List<com.google.firebase.ai.type.Content> history;
147+
}
148+
140149
@com.google.firebase.ai.type.PublicPreviewAPI public final class TemplateGenerativeModel {
141150
method public suspend Object? generateContent(String templateId, java.util.Map<java.lang.String,?> inputs, kotlin.coroutines.Continuation<? super com.google.firebase.ai.type.GenerateContentResponse>);
142151
method public kotlinx.coroutines.flow.Flow<com.google.firebase.ai.type.GenerateContentResponse> generateContentStream(String templateId, java.util.Map<java.lang.String,?> inputs);
152+
method @com.google.firebase.ai.type.PublicPreviewAPI public com.google.firebase.ai.TemplateChat startChat(String templateId, java.util.Map<java.lang.String,?> inputs, java.util.List<com.google.firebase.ai.type.Content> history = emptyList());
143153
}
144154

145155
@com.google.firebase.ai.type.PublicPreviewAPI public final class TemplateImagenModel {

ai-logic/firebase-ai/gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
version=17.10.2
15+
version=17.11.0
1616
latestReleasedVersion=17.10.1
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.ai
18+
19+
import com.google.firebase.ai.type.Content
20+
import com.google.firebase.ai.type.GenerateContentResponse
21+
import com.google.firebase.ai.type.InvalidStateException
22+
import com.google.firebase.ai.type.Part
23+
import com.google.firebase.ai.type.PublicPreviewAPI
24+
import com.google.firebase.ai.type.content
25+
import java.util.concurrent.Semaphore
26+
import kotlinx.coroutines.flow.Flow
27+
import kotlinx.coroutines.flow.onCompletion
28+
import kotlinx.coroutines.flow.onEach
29+
30+
/** Representation of a multi-turn interaction with a server template model. */
31+
@PublicPreviewAPI
32+
public class TemplateChat
33+
internal constructor(
34+
private val model: TemplateGenerativeModel,
35+
private val templateId: String,
36+
private val inputs: Map<String, Any>,
37+
public val history: MutableList<Content> = ArrayList()
38+
) {
39+
private var lock = Semaphore(1)
40+
41+
/**
42+
* Sends a message using the provided [prompt]; automatically providing the existing [history] as
43+
* context.
44+
*
45+
* @param prompt The input that, together with the history, will be given to the model as the
46+
* prompt.
47+
*/
48+
public suspend fun sendMessage(prompt: Content): GenerateContentResponse {
49+
prompt.assertComesFromUser()
50+
attemptLock()
51+
try {
52+
return model.generateContentWithHistory(templateId, inputs, history + prompt).also { resp ->
53+
history.add(prompt)
54+
history.add(resp.candidates.first().content)
55+
}
56+
} finally {
57+
lock.release()
58+
}
59+
}
60+
61+
/**
62+
* Sends a message using the provided text [prompt]; automatically providing the existing
63+
* [history] as context.
64+
*/
65+
public suspend fun sendMessage(prompt: String): GenerateContentResponse {
66+
val content = content { text(prompt) }
67+
return sendMessage(content)
68+
}
69+
70+
/**
71+
* Sends a message using the provided [prompt]; automatically providing the existing [history] as
72+
* context. Returns a flow.
73+
*/
74+
public fun sendMessageStream(prompt: Content): Flow<GenerateContentResponse> {
75+
prompt.assertComesFromUser()
76+
attemptLock()
77+
78+
val fullPrompt = history + prompt
79+
val flow = model.generateContentWithHistoryStream(templateId, inputs, fullPrompt)
80+
val tempHistory = mutableListOf<Content>()
81+
val responseParts = mutableListOf<Part>()
82+
83+
return flow
84+
.onEach { response ->
85+
response.candidates.first().content.parts.let { responseParts.addAll(it) }
86+
}
87+
.onCompletion {
88+
lock.release()
89+
if (it == null) {
90+
tempHistory.add(prompt)
91+
tempHistory.add(
92+
content("model") { responseParts.forEach { part -> this.parts.add(part) } }
93+
)
94+
history.addAll(tempHistory)
95+
}
96+
}
97+
}
98+
99+
/**
100+
* Sends a message using the provided text [prompt]; automatically providing the existing
101+
* [history] as context. Returns a flow.
102+
*/
103+
public fun sendMessageStream(prompt: String): Flow<GenerateContentResponse> {
104+
val content = content { text(prompt) }
105+
return sendMessageStream(content)
106+
}
107+
108+
private fun Content.assertComesFromUser() {
109+
if (role !in listOf("user", "function")) {
110+
throw InvalidStateException("Chat prompts should come from the 'user' or 'function' role.")
111+
}
112+
}
113+
114+
private fun attemptLock() {
115+
if (!lock.tryAcquire()) {
116+
throw InvalidStateException(
117+
"This chat instance currently has an ongoing request, please wait for it to complete " +
118+
"before sending more messages"
119+
)
120+
}
121+
}
122+
}

ai-logic/firebase-ai/src/main/kotlin/com/google/firebase/ai/TemplateGenerativeModel.kt

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,27 @@ internal constructor(
8686
public suspend fun generateContent(
8787
templateId: String,
8888
inputs: Map<String, Any>,
89+
): GenerateContentResponse = generateContentWithHistory(templateId, inputs, null)
90+
91+
/**
92+
* Generates content from a prompt template and inputs.
93+
*
94+
* @param templateId The ID of the prompt template to use.
95+
* @param inputs A map of variables to substitute into the template.
96+
* @param history Prior history in the conversation.
97+
* @return The content generated by the model.
98+
* @throws [FirebaseAIException] if the request failed.
99+
* @see [FirebaseAIException] for types of errors.
100+
*/
101+
@PublicPreviewAPI
102+
internal suspend fun generateContentWithHistory(
103+
templateId: String,
104+
inputs: Map<String, Any>,
105+
history: List<Content>?
89106
): GenerateContentResponse =
90107
try {
91108
controller
92-
.templateGenerateContent("$templateUri$templateId", constructRequest(inputs))
109+
.templateGenerateContent("$templateUri$templateId", constructRequest(inputs, history))
93110
.toPublic()
94111
.validate()
95112
} catch (e: Throwable) {
@@ -108,12 +125,44 @@ internal constructor(
108125
public fun generateContentStream(
109126
templateId: String,
110127
inputs: Map<String, Any>
128+
): Flow<GenerateContentResponse> = generateContentWithHistoryStream(templateId, inputs, null)
129+
130+
/**
131+
* Generates content as a stream from a prompt template, inputs, and history.
132+
*
133+
* @param templateId The ID of the prompt template to use.
134+
* @param inputs A map of variables to substitute into the template.
135+
* @param history Prior history in the conversation.
136+
* @return A [Flow] which will emit responses as they are returned by the model.
137+
* @throws [FirebaseAIException] if the request failed.
138+
* @see [FirebaseAIException] for types of errors.
139+
*/
140+
@PublicPreviewAPI
141+
internal fun generateContentWithHistoryStream(
142+
templateId: String,
143+
inputs: Map<String, Any>,
144+
history: List<Content>?
111145
): Flow<GenerateContentResponse> =
112146
controller
113-
.templateGenerateContentStream("$templateUri$templateId", constructRequest(inputs))
147+
.templateGenerateContentStream("$templateUri$templateId", constructRequest(inputs, history))
114148
.catch { throw FirebaseAIException.from(it) }
115149
.map { it.toPublic().validate() }
116150

151+
/**
152+
* Creates a [TemplateChat] instance using this model with the optionally provided history.
153+
*
154+
* @param templateId The ID of the prompt template to use.
155+
* @param inputs A map of variables to substitute into the template for the session.
156+
* @param history Prior history in the conversation.
157+
* @return The initialized [TemplateChat] instance.
158+
*/
159+
@PublicPreviewAPI
160+
public fun startChat(
161+
templateId: String,
162+
inputs: Map<String, Any>,
163+
history: List<Content> = emptyList()
164+
): TemplateChat = TemplateChat(this, templateId, inputs, history.toMutableList())
165+
117166
internal fun constructRequest(
118167
inputs: Map<String, Any>,
119168
history: List<Content>? = null
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
/*
2+
* Copyright 2026 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.firebase.ai
18+
19+
import com.google.firebase.ai.type.Candidate
20+
import com.google.firebase.ai.type.Content
21+
import com.google.firebase.ai.type.FinishReason
22+
import com.google.firebase.ai.type.GenerateContentResponse
23+
import com.google.firebase.ai.type.Part
24+
import com.google.firebase.ai.type.PublicPreviewAPI
25+
import com.google.firebase.ai.type.TextPart
26+
import com.google.firebase.ai.type.content
27+
import io.kotest.matchers.collections.shouldHaveSize
28+
import io.kotest.matchers.shouldBe
29+
import io.kotest.matchers.types.shouldBeInstanceOf
30+
import io.mockk.coEvery
31+
import io.mockk.every
32+
import io.mockk.mockk
33+
import kotlinx.coroutines.flow.flowOf
34+
import kotlinx.coroutines.flow.toList
35+
import kotlinx.coroutines.test.runTest
36+
import org.junit.Before
37+
import org.junit.Test
38+
import org.junit.runner.RunWith
39+
import org.robolectric.RobolectricTestRunner
40+
41+
@OptIn(PublicPreviewAPI::class)
42+
@RunWith(RobolectricTestRunner::class)
43+
class TemplateChatTests {
44+
private val model = mockk<TemplateGenerativeModel>()
45+
private val templateId = "test-template"
46+
private val inputs = mapOf("key" to "value")
47+
48+
private lateinit var chat: TemplateChat
49+
50+
@Before
51+
fun setup() {
52+
chat = TemplateChat(model, templateId, inputs)
53+
}
54+
55+
@Test
56+
fun `sendMessage(Content) adds prompt and response to history`() = runTest {
57+
val prompt = content("user") { text("hello") }
58+
val responseContent = content("model") { text("hi") }
59+
val response = createResponse(responseContent)
60+
61+
coEvery { model.generateContentWithHistory(templateId, inputs, any()) } returns response
62+
63+
chat.sendMessage(prompt)
64+
65+
chat.history shouldHaveSize 2
66+
chat.history[0] shouldBeEquivalentTo prompt
67+
chat.history[1] shouldBeEquivalentTo responseContent
68+
}
69+
70+
@Test
71+
fun `sendMessageStream(Content) adds prompt and aggregated responses to history`() = runTest {
72+
val prompt = content("user") { text("hello") }
73+
val response1 = createResponse(content("model") { text("hi ") })
74+
val response2 = createResponse(content("model") { text("there") })
75+
76+
every { model.generateContentWithHistoryStream(templateId, inputs, any()) } returns
77+
flowOf(response1, response2)
78+
79+
val flow = chat.sendMessageStream(prompt)
80+
flow.toList()
81+
82+
chat.history shouldHaveSize 2
83+
chat.history[0] shouldBeEquivalentTo prompt
84+
chat.history[1].parts shouldHaveSize 2
85+
chat.history[1].parts[0].shouldBeInstanceOf<TextPart>().text shouldBe "hi "
86+
chat.history[1].parts[1].shouldBeInstanceOf<TextPart>().text shouldBe "there"
87+
}
88+
89+
private fun createResponse(content: Content): GenerateContentResponse {
90+
return GenerateContentResponse.Internal(
91+
listOf(Candidate.Internal(content.toInternal(), finishReason = FinishReason.Internal.STOP))
92+
)
93+
.toPublic()
94+
}
95+
96+
private infix fun Content.shouldBeEquivalentTo(other: Content) {
97+
this.role shouldBe other.role
98+
this.parts shouldHaveSize other.parts.size
99+
this.parts.zip(other.parts).forEach { (a, b) -> a.shouldBeEquivalentTo(b) }
100+
}
101+
102+
private fun Part.shouldBeEquivalentTo(other: Part) {
103+
this::class shouldBe other::class
104+
if (this is TextPart && other is TextPart) {
105+
this.text shouldBe other.text
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)