Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions coderd/x/chatd/generation_preparer.go
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,13 @@ func (server *Server) prepareGeneration(
resolvedUserPrompt string
)

// Drop provider-executed tool history produced by a different provider
// before building the prompt. A provider that shares another's wire format
// (e.g. Bedrock and Anthropic) can still reject the other's
// provider-executed blocks, so a mid-chat provider switch must not replay
// them.
promptRows = server.sanitizeForeignProviderExecutedToolRows(ctx, logger, promptRows, modelConfig.ID)

persistedSkills := skillsFromParts(promptRows)
hasContextFiles := false
if chat.WorkspaceID.Valid {
Expand Down
162 changes: 162 additions & 0 deletions coderd/x/chatd/provider_switch_sanitize.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package chatd

import (
"context"

"github.com/google/uuid"

"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
)

// providerSwitchStripStats counts the provider-executed tool history removed
// when sanitizing a prompt for a model-provider switch.
type providerSwitchStripStats struct {
RemovedToolCalls int
RemovedToolResults int
DroppedMessages int
}

// stripForeignProviderExecutedToolRows removes provider-executed tool blocks
// (both calls and results) from assistant history rows whose producing provider
// differs from targetProvider. Provider-executed tool blocks are only valid for
// the provider that produced them: a provider sharing another's wire format can
// still reject them (e.g. Bedrock rejects Anthropic web_search_tool_result with
// HTTP 400), so switching providers mid-chat must drop the foreign blocks.
//
// originProvider resolves a row's ModelConfigID to a normalized provider name;
// ok is false when the origin cannot be determined, in which case the row is
// treated as foreign (fail closed). Rows from the target provider, non-assistant
// rows, rows with no provider-executed parts, and rows that fail to parse or
// re-marshal are returned unchanged. Rows emptied by stripping are dropped.
//
// Provenance is the model config provider (derived from the AI provider type),
// not anything fantasy reports, so it stays correct when requests route through
// aibridged, which serializes both Anthropic and Bedrock as the Anthropic wire
// format.
func stripForeignProviderExecutedToolRows(
rows []database.ChatMessage,
targetProvider string,
originProvider func(uuid.NullUUID) (string, bool),
) ([]database.ChatMessage, providerSwitchStripStats) {
var stats providerSwitchStripStats
if targetProvider == "" || len(rows) == 0 {
return rows, stats
}

out := make([]database.ChatMessage, 0, len(rows))
for _, row := range rows {
// Provider-executed blocks that must be replayed live on assistant rows.
// Provider-executed results orphaned onto tool rows are dropped during
// prompt conversion, so they never reach the provider.
if row.Role != database.ChatMessageRoleAssistant {
out = append(out, row)
continue
}
if origin, ok := originProvider(row.ModelConfigID); ok && origin == targetProvider {
out = append(out, row)
continue
}

parts, err := chatprompt.ParseContent(row)
if err != nil {
// Leave unparsable rows untouched; the converter handles them.
out = append(out, row)
continue
}

kept := make([]codersdk.ChatMessagePart, 0, len(parts))
var removedCalls, removedResults int
for _, part := range parts {
switch {
case part.Type == codersdk.ChatMessagePartTypeToolCall && part.ProviderExecuted:
removedCalls++
case part.Type == codersdk.ChatMessagePartTypeToolResult && part.ProviderExecuted:
removedResults++
default:
kept = append(kept, part)
}
}
if removedCalls == 0 && removedResults == 0 {
out = append(out, row)
continue
}
if len(kept) == 0 {
Comment thread
johnstcn marked this conversation as resolved.
stats.RemovedToolCalls += removedCalls
stats.RemovedToolResults += removedResults
stats.DroppedMessages++
continue
}

content, err := chatprompt.MarshalParts(kept)
if err != nil {
// Keep the original row rather than corrupting history.
out = append(out, row)
continue
}
row.Content = content
row.ContentVersion = chatprompt.ContentVersionV1
stats.RemovedToolCalls += removedCalls
stats.RemovedToolResults += removedResults
out = append(out, row)
}
return out, stats
}

// sanitizeForeignProviderExecutedToolRows strips provider-executed tool history
// produced by a provider other than the one targeted by modelConfigID. It
// resolves each row's provider via the model config cache and logs a single
// summary when anything is removed.
func (server *Server) sanitizeForeignProviderExecutedToolRows(
ctx context.Context,
logger slog.Logger,
rows []database.ChatMessage,
modelConfigID uuid.UUID,
) []database.ChatMessage {
_, target, err := server.resolveModelConfigAndNormalizedProvider(ctx, modelConfigID)
if err != nil || target == "" {
Comment thread
johnstcn marked this conversation as resolved.
// Without a known target provider we cannot classify history; leave it.
logger.Debug(ctx, "skipping provider-switch sanitization: target provider unresolved",
slog.F("model_config_id", modelConfigID),
slog.Error(err),
)
return rows
}

cache := make(map[uuid.UUID]string)
originProvider := func(id uuid.NullUUID) (string, bool) {
if !id.Valid {
return "", false
}
if provider, seen := cache[id.UUID]; seen {
return provider, provider != ""
}
_, provider, rErr := server.resolveModelConfigAndNormalizedProvider(ctx, id.UUID)
Comment thread
johnstcn marked this conversation as resolved.
if rErr != nil {
Comment thread
johnstcn marked this conversation as resolved.
// Unresolvable origin (e.g. a since-disabled or deleted config) is
// treated as foreign so we fail closed rather than replay blocks the
// target may reject.
logger.Debug(ctx, "provider-switch sanitization: origin provider unresolved, treating as foreign",
slog.F("model_config_id", id.UUID),
slog.Error(rErr),
)
provider = ""
}
cache[id.UUID] = provider
Comment thread
johnstcn marked this conversation as resolved.
return provider, provider != ""
}

sanitized, stats := stripForeignProviderExecutedToolRows(rows, target, originProvider)
if stats != (providerSwitchStripStats{}) {
logger.Warn(ctx, "stripped foreign provider-executed tool history",
slog.F("phase", "provider_switch"),
slog.F("target_provider", target),
slog.F("removed_tool_calls", stats.RemovedToolCalls),
slog.F("removed_tool_results", stats.RemovedToolResults),
slog.F("dropped_messages", stats.DroppedMessages),
)
}
return sanitized
}
183 changes: 183 additions & 0 deletions coderd/x/chatd/provider_switch_sanitize_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
package chatd

import (
"encoding/json"
"testing"

"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/codersdk"
)

func TestStripForeignProviderExecutedToolRows(t *testing.T) {
t.Parallel()

const (
anthropic = "anthropic"
bedrock = "bedrock"
openai = "openai"
)

anthropicCfg := uuid.New()
openAICfg := uuid.New()
unknownCfg := uuid.New()

peCall := func(id string) codersdk.ChatMessagePart {
p := codersdk.ChatMessageToolCall(id, "web_search", json.RawMessage(`{"query":"x"}`))
p.ProviderExecuted = true
return p
}
peResult := func(id string) codersdk.ChatMessagePart {
p := codersdk.ChatMessageToolResult(id, "web_search", json.RawMessage(`{"ok":true}`), false, false)
p.ProviderExecuted = true
return p
}
localCall := func(id string) codersdk.ChatMessagePart {
return codersdk.ChatMessageToolCall(id, "read_file", json.RawMessage(`{}`))
}
text := codersdk.ChatMessageText

assistantRow := func(t *testing.T, cfg uuid.UUID, parts ...codersdk.ChatMessagePart) database.ChatMessage {
t.Helper()
content, err := chatprompt.MarshalParts(parts)
require.NoError(t, err)
return database.ChatMessage{
Role: database.ChatMessageRoleAssistant,
ModelConfigID: uuid.NullUUID{UUID: cfg, Valid: cfg != uuid.Nil},
Content: content,
ContentVersion: chatprompt.ContentVersionV1,
}
}
userRow := func(t *testing.T, s string) database.ChatMessage {
t.Helper()
content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{text(s)})
require.NoError(t, err)
return database.ChatMessage{
Role: database.ChatMessageRoleUser,
Content: content,
ContentVersion: chatprompt.ContentVersionV1,
}
}

// origin maps a model config ID to its normalized provider. unknownCfg is
// intentionally absent so the resolver reports an unknown origin.
origin := func(providerByConfig map[uuid.UUID]string) func(uuid.NullUUID) (string, bool) {
return func(id uuid.NullUUID) (string, bool) {
if !id.Valid {
return "", false
}
provider, ok := providerByConfig[id.UUID]
return provider, ok
}
}
resolver := origin(map[uuid.UUID]string{
anthropicCfg: anthropic,
openAICfg: openai,
Comment thread
johnstcn marked this conversation as resolved.
})

// partsOf parses a row's content back into SDK parts for comparison.
partsOf := func(t *testing.T, row database.ChatMessage) []codersdk.ChatMessagePart {
t.Helper()
parts, err := chatprompt.ParseContent(row)
require.NoError(t, err)
return parts
}

t.Run("same provider kept", func(t *testing.T) {
t.Parallel()
rows := []database.ChatMessage{
userRow(t, "hi"),
assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws"), text("done")),
}
got, stats := stripForeignProviderExecutedToolRows(rows, anthropic, resolver)
require.Equal(t, rows, got)
require.Zero(t, stats)
})

t.Run("anthropic to bedrock drops provider blocks", func(t *testing.T) {
t.Parallel()
rows := []database.ChatMessage{
userRow(t, "hi"),
assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws"), text("done")),
}
got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver)
require.Len(t, got, 2)
require.Equal(t, []codersdk.ChatMessagePart{text("done")}, partsOf(t, got[1]))
require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, RemovedToolResults: 1}, stats)
})

t.Run("foreign-only row dropped", func(t *testing.T) {
t.Parallel()
rows := []database.ChatMessage{
userRow(t, "hi"),
assistantRow(t, anthropicCfg, peCall("ws")),
userRow(t, "again"),
}
got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver)
require.Len(t, got, 2)
require.Equal(t, database.ChatMessageRoleUser, got[0].Role)
require.Equal(t, database.ChatMessageRoleUser, got[1].Role)
require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, DroppedMessages: 1}, stats)
})

t.Run("multi-provider keeps native strips foreign", func(t *testing.T) {
t.Parallel()
rows := []database.ChatMessage{
assistantRow(t, openAICfg, peCall("os"), peResult("os"), text("openai")),
assistantRow(t, anthropicCfg, peCall("as"), peResult("as"), text("anthropic")),
}
got, stats := stripForeignProviderExecutedToolRows(rows, anthropic, resolver)
require.Len(t, got, 2)
require.Equal(t, []codersdk.ChatMessagePart{text("openai")}, partsOf(t, got[0]))
require.Equal(t, rows[1], got[1])
require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, RemovedToolResults: 1}, stats)
})

t.Run("non-provider-executed parts untouched", func(t *testing.T) {
t.Parallel()
rows := []database.ChatMessage{
assistantRow(t, anthropicCfg, text("hello"), localCall("local")),
}
got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver)
require.Equal(t, rows, got)
require.Zero(t, stats)
})

t.Run("empty target is a no-op", func(t *testing.T) {
t.Parallel()
rows := []database.ChatMessage{
assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws")),
}
got, stats := stripForeignProviderExecutedToolRows(rows, "", resolver)
require.Equal(t, rows, got)
require.Zero(t, stats)
})

t.Run("unknown origin fails closed", func(t *testing.T) {
t.Parallel()
rows := []database.ChatMessage{
assistantRow(t, unknownCfg, peResult("ws"), text("done")),
}
got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver)
require.Len(t, got, 1)
require.Equal(t, []codersdk.ChatMessagePart{text("done")}, partsOf(t, got[0]))
require.Equal(t, providerSwitchStripStats{RemovedToolResults: 1}, stats)
})

t.Run("unparsable foreign row kept unchanged", func(t *testing.T) {
t.Parallel()
rows := []database.ChatMessage{{
Role: database.ChatMessageRoleAssistant,
ModelConfigID: uuid.NullUUID{UUID: anthropicCfg, Valid: true},
Content: pqtype.NullRawMessage{RawMessage: []byte("{not json"), Valid: true},
ContentVersion: chatprompt.ContentVersionV1,
}}
got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver)
require.Equal(t, rows, got)
require.Zero(t, stats)
})
}
Comment thread
johnstcn marked this conversation as resolved.
Loading