diff --git a/coderd/x/chatd/generation_preparer.go b/coderd/x/chatd/generation_preparer.go index a0aec403ba279..b206486b97669 100644 --- a/coderd/x/chatd/generation_preparer.go +++ b/coderd/x/chatd/generation_preparer.go @@ -215,6 +215,13 @@ func (server *Server) prepareGeneration( resolvedUserPrompt string ) + // Drop provider-executed tool history produced by a different provider + // before building the prompt. A provider that shares another's wire format + // (e.g. Bedrock and Anthropic) can still reject the other's + // provider-executed blocks, so a mid-chat provider switch must not replay + // them. + promptRows = server.sanitizeForeignProviderExecutedToolRows(ctx, logger, promptRows, modelConfig.ID) + persistedSkills := skillsFromParts(promptRows) hasContextFiles := false if chat.WorkspaceID.Valid { diff --git a/coderd/x/chatd/provider_switch_sanitize.go b/coderd/x/chatd/provider_switch_sanitize.go new file mode 100644 index 0000000000000..3a178847ed664 --- /dev/null +++ b/coderd/x/chatd/provider_switch_sanitize.go @@ -0,0 +1,162 @@ +package chatd + +import ( + "context" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/codersdk" +) + +// providerSwitchStripStats counts the provider-executed tool history removed +// when sanitizing a prompt for a model-provider switch. +type providerSwitchStripStats struct { + RemovedToolCalls int + RemovedToolResults int + DroppedMessages int +} + +// stripForeignProviderExecutedToolRows removes provider-executed tool blocks +// (both calls and results) from assistant history rows whose producing provider +// differs from targetProvider. Provider-executed tool blocks are only valid for +// the provider that produced them: a provider sharing another's wire format can +// still reject them (e.g. Bedrock rejects Anthropic web_search_tool_result with +// HTTP 400), so switching providers mid-chat must drop the foreign blocks. +// +// originProvider resolves a row's ModelConfigID to a normalized provider name; +// ok is false when the origin cannot be determined, in which case the row is +// treated as foreign (fail closed). Rows from the target provider, non-assistant +// rows, rows with no provider-executed parts, and rows that fail to parse or +// re-marshal are returned unchanged. Rows emptied by stripping are dropped. +// +// Provenance is the model config provider (derived from the AI provider type), +// not anything fantasy reports, so it stays correct when requests route through +// aibridged, which serializes both Anthropic and Bedrock as the Anthropic wire +// format. +func stripForeignProviderExecutedToolRows( + rows []database.ChatMessage, + targetProvider string, + originProvider func(uuid.NullUUID) (string, bool), +) ([]database.ChatMessage, providerSwitchStripStats) { + var stats providerSwitchStripStats + if targetProvider == "" || len(rows) == 0 { + return rows, stats + } + + out := make([]database.ChatMessage, 0, len(rows)) + for _, row := range rows { + // Provider-executed blocks that must be replayed live on assistant rows. + // Provider-executed results orphaned onto tool rows are dropped during + // prompt conversion, so they never reach the provider. + if row.Role != database.ChatMessageRoleAssistant { + out = append(out, row) + continue + } + if origin, ok := originProvider(row.ModelConfigID); ok && origin == targetProvider { + out = append(out, row) + continue + } + + parts, err := chatprompt.ParseContent(row) + if err != nil { + // Leave unparsable rows untouched; the converter handles them. + out = append(out, row) + continue + } + + kept := make([]codersdk.ChatMessagePart, 0, len(parts)) + var removedCalls, removedResults int + for _, part := range parts { + switch { + case part.Type == codersdk.ChatMessagePartTypeToolCall && part.ProviderExecuted: + removedCalls++ + case part.Type == codersdk.ChatMessagePartTypeToolResult && part.ProviderExecuted: + removedResults++ + default: + kept = append(kept, part) + } + } + if removedCalls == 0 && removedResults == 0 { + out = append(out, row) + continue + } + if len(kept) == 0 { + stats.RemovedToolCalls += removedCalls + stats.RemovedToolResults += removedResults + stats.DroppedMessages++ + continue + } + + content, err := chatprompt.MarshalParts(kept) + if err != nil { + // Keep the original row rather than corrupting history. + out = append(out, row) + continue + } + row.Content = content + row.ContentVersion = chatprompt.ContentVersionV1 + stats.RemovedToolCalls += removedCalls + stats.RemovedToolResults += removedResults + out = append(out, row) + } + return out, stats +} + +// sanitizeForeignProviderExecutedToolRows strips provider-executed tool history +// produced by a provider other than the one targeted by modelConfigID. It +// resolves each row's provider via the model config cache and logs a single +// summary when anything is removed. +func (server *Server) sanitizeForeignProviderExecutedToolRows( + ctx context.Context, + logger slog.Logger, + rows []database.ChatMessage, + modelConfigID uuid.UUID, +) []database.ChatMessage { + _, target, err := server.resolveModelConfigAndNormalizedProvider(ctx, modelConfigID) + if err != nil || target == "" { + // Without a known target provider we cannot classify history; leave it. + logger.Debug(ctx, "skipping provider-switch sanitization: target provider unresolved", + slog.F("model_config_id", modelConfigID), + slog.Error(err), + ) + return rows + } + + cache := make(map[uuid.UUID]string) + originProvider := func(id uuid.NullUUID) (string, bool) { + if !id.Valid { + return "", false + } + if provider, seen := cache[id.UUID]; seen { + return provider, provider != "" + } + _, provider, rErr := server.resolveModelConfigAndNormalizedProvider(ctx, id.UUID) + if rErr != nil { + // Unresolvable origin (e.g. a since-disabled or deleted config) is + // treated as foreign so we fail closed rather than replay blocks the + // target may reject. + logger.Debug(ctx, "provider-switch sanitization: origin provider unresolved, treating as foreign", + slog.F("model_config_id", id.UUID), + slog.Error(rErr), + ) + provider = "" + } + cache[id.UUID] = provider + return provider, provider != "" + } + + sanitized, stats := stripForeignProviderExecutedToolRows(rows, target, originProvider) + if stats != (providerSwitchStripStats{}) { + logger.Warn(ctx, "stripped foreign provider-executed tool history", + slog.F("phase", "provider_switch"), + slog.F("target_provider", target), + slog.F("removed_tool_calls", stats.RemovedToolCalls), + slog.F("removed_tool_results", stats.RemovedToolResults), + slog.F("dropped_messages", stats.DroppedMessages), + ) + } + return sanitized +} diff --git a/coderd/x/chatd/provider_switch_sanitize_internal_test.go b/coderd/x/chatd/provider_switch_sanitize_internal_test.go new file mode 100644 index 0000000000000..7302a2c43414e --- /dev/null +++ b/coderd/x/chatd/provider_switch_sanitize_internal_test.go @@ -0,0 +1,183 @@ +package chatd + +import ( + "encoding/json" + "testing" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/codersdk" +) + +func TestStripForeignProviderExecutedToolRows(t *testing.T) { + t.Parallel() + + const ( + anthropic = "anthropic" + bedrock = "bedrock" + openai = "openai" + ) + + anthropicCfg := uuid.New() + openAICfg := uuid.New() + unknownCfg := uuid.New() + + peCall := func(id string) codersdk.ChatMessagePart { + p := codersdk.ChatMessageToolCall(id, "web_search", json.RawMessage(`{"query":"x"}`)) + p.ProviderExecuted = true + return p + } + peResult := func(id string) codersdk.ChatMessagePart { + p := codersdk.ChatMessageToolResult(id, "web_search", json.RawMessage(`{"ok":true}`), false, false) + p.ProviderExecuted = true + return p + } + localCall := func(id string) codersdk.ChatMessagePart { + return codersdk.ChatMessageToolCall(id, "read_file", json.RawMessage(`{}`)) + } + text := codersdk.ChatMessageText + + assistantRow := func(t *testing.T, cfg uuid.UUID, parts ...codersdk.ChatMessagePart) database.ChatMessage { + t.Helper() + content, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + return database.ChatMessage{ + Role: database.ChatMessageRoleAssistant, + ModelConfigID: uuid.NullUUID{UUID: cfg, Valid: cfg != uuid.Nil}, + Content: content, + ContentVersion: chatprompt.ContentVersionV1, + } + } + userRow := func(t *testing.T, s string) database.ChatMessage { + t.Helper() + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{text(s)}) + require.NoError(t, err) + return database.ChatMessage{ + Role: database.ChatMessageRoleUser, + Content: content, + ContentVersion: chatprompt.ContentVersionV1, + } + } + + // origin maps a model config ID to its normalized provider. unknownCfg is + // intentionally absent so the resolver reports an unknown origin. + origin := func(providerByConfig map[uuid.UUID]string) func(uuid.NullUUID) (string, bool) { + return func(id uuid.NullUUID) (string, bool) { + if !id.Valid { + return "", false + } + provider, ok := providerByConfig[id.UUID] + return provider, ok + } + } + resolver := origin(map[uuid.UUID]string{ + anthropicCfg: anthropic, + openAICfg: openai, + }) + + // partsOf parses a row's content back into SDK parts for comparison. + partsOf := func(t *testing.T, row database.ChatMessage) []codersdk.ChatMessagePart { + t.Helper() + parts, err := chatprompt.ParseContent(row) + require.NoError(t, err) + return parts + } + + t.Run("same provider kept", func(t *testing.T) { + t.Parallel() + rows := []database.ChatMessage{ + userRow(t, "hi"), + assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws"), text("done")), + } + got, stats := stripForeignProviderExecutedToolRows(rows, anthropic, resolver) + require.Equal(t, rows, got) + require.Zero(t, stats) + }) + + t.Run("anthropic to bedrock drops provider blocks", func(t *testing.T) { + t.Parallel() + rows := []database.ChatMessage{ + userRow(t, "hi"), + assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws"), text("done")), + } + got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) + require.Len(t, got, 2) + require.Equal(t, []codersdk.ChatMessagePart{text("done")}, partsOf(t, got[1])) + require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, RemovedToolResults: 1}, stats) + }) + + t.Run("foreign-only row dropped", func(t *testing.T) { + t.Parallel() + rows := []database.ChatMessage{ + userRow(t, "hi"), + assistantRow(t, anthropicCfg, peCall("ws")), + userRow(t, "again"), + } + got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) + require.Len(t, got, 2) + require.Equal(t, database.ChatMessageRoleUser, got[0].Role) + require.Equal(t, database.ChatMessageRoleUser, got[1].Role) + require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, DroppedMessages: 1}, stats) + }) + + t.Run("multi-provider keeps native strips foreign", func(t *testing.T) { + t.Parallel() + rows := []database.ChatMessage{ + assistantRow(t, openAICfg, peCall("os"), peResult("os"), text("openai")), + assistantRow(t, anthropicCfg, peCall("as"), peResult("as"), text("anthropic")), + } + got, stats := stripForeignProviderExecutedToolRows(rows, anthropic, resolver) + require.Len(t, got, 2) + require.Equal(t, []codersdk.ChatMessagePart{text("openai")}, partsOf(t, got[0])) + require.Equal(t, rows[1], got[1]) + require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, RemovedToolResults: 1}, stats) + }) + + t.Run("non-provider-executed parts untouched", func(t *testing.T) { + t.Parallel() + rows := []database.ChatMessage{ + assistantRow(t, anthropicCfg, text("hello"), localCall("local")), + } + got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) + require.Equal(t, rows, got) + require.Zero(t, stats) + }) + + t.Run("empty target is a no-op", func(t *testing.T) { + t.Parallel() + rows := []database.ChatMessage{ + assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws")), + } + got, stats := stripForeignProviderExecutedToolRows(rows, "", resolver) + require.Equal(t, rows, got) + require.Zero(t, stats) + }) + + t.Run("unknown origin fails closed", func(t *testing.T) { + t.Parallel() + rows := []database.ChatMessage{ + assistantRow(t, unknownCfg, peResult("ws"), text("done")), + } + got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) + require.Len(t, got, 1) + require.Equal(t, []codersdk.ChatMessagePart{text("done")}, partsOf(t, got[0])) + require.Equal(t, providerSwitchStripStats{RemovedToolResults: 1}, stats) + }) + + t.Run("unparsable foreign row kept unchanged", func(t *testing.T) { + t.Parallel() + rows := []database.ChatMessage{{ + Role: database.ChatMessageRoleAssistant, + ModelConfigID: uuid.NullUUID{UUID: anthropicCfg, Valid: true}, + Content: pqtype.NullRawMessage{RawMessage: []byte("{not json"), Valid: true}, + ContentVersion: chatprompt.ContentVersionV1, + }} + got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) + require.Equal(t, rows, got) + require.Zero(t, stats) + }) +}