-
Notifications
You must be signed in to change notification settings - Fork 1.3k
fix(coderd/x/chatd): drop foreign provider-executed tools on model switch #26555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
johnstcn
wants to merge
6
commits into
main
Choose a base branch
from
cian/codagt-471-sanitize-provider-switch-tool-history
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+352
−0
Draft
Changes from all commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
ea1d6c9
fix(coderd/x/chatd): drop foreign provider-executed tool history on p…
johnstcn 1f581a2
fixup! fix(coderd/x/chatd): drop foreign provider-executed tool histo…
johnstcn 6bca249
fixup! fix(coderd/x/chatd): drop foreign provider-executed tool histo…
johnstcn 80a6a93
fixup! fix(coderd/x/chatd): drop foreign provider-executed tool histo…
johnstcn e339f11
fixup! fix(coderd/x/chatd): drop foreign provider-executed tool histo…
johnstcn c46d4ad
fixup! fix(coderd/x/chatd): drop foreign provider-executed tool histo…
johnstcn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| package chatd | ||
|
|
||
| import ( | ||
| "context" | ||
|
|
||
| "github.com/google/uuid" | ||
|
|
||
| "cdr.dev/slog/v3" | ||
| "github.com/coder/coder/v2/coderd/database" | ||
| "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" | ||
| "github.com/coder/coder/v2/codersdk" | ||
| ) | ||
|
|
||
| // providerSwitchStripStats counts the provider-executed tool history removed | ||
| // when sanitizing a prompt for a model-provider switch. | ||
| type providerSwitchStripStats struct { | ||
| RemovedToolCalls int | ||
| RemovedToolResults int | ||
| DroppedMessages int | ||
| } | ||
|
|
||
| // stripForeignProviderExecutedToolRows removes provider-executed tool blocks | ||
| // (both calls and results) from assistant history rows whose producing provider | ||
| // differs from targetProvider. Provider-executed tool blocks are only valid for | ||
| // the provider that produced them: a provider sharing another's wire format can | ||
| // still reject them (e.g. Bedrock rejects Anthropic web_search_tool_result with | ||
| // HTTP 400), so switching providers mid-chat must drop the foreign blocks. | ||
| // | ||
| // originProvider resolves a row's ModelConfigID to a normalized provider name; | ||
| // ok is false when the origin cannot be determined, in which case the row is | ||
| // treated as foreign (fail closed). Rows from the target provider, non-assistant | ||
| // rows, rows with no provider-executed parts, and rows that fail to parse or | ||
| // re-marshal are returned unchanged. Rows emptied by stripping are dropped. | ||
| // | ||
| // Provenance is the model config provider (derived from the AI provider type), | ||
| // not anything fantasy reports, so it stays correct when requests route through | ||
| // aibridged, which serializes both Anthropic and Bedrock as the Anthropic wire | ||
| // format. | ||
| func stripForeignProviderExecutedToolRows( | ||
| rows []database.ChatMessage, | ||
| targetProvider string, | ||
| originProvider func(uuid.NullUUID) (string, bool), | ||
| ) ([]database.ChatMessage, providerSwitchStripStats) { | ||
| var stats providerSwitchStripStats | ||
| if targetProvider == "" || len(rows) == 0 { | ||
| return rows, stats | ||
| } | ||
|
|
||
| out := make([]database.ChatMessage, 0, len(rows)) | ||
| for _, row := range rows { | ||
| // Provider-executed blocks that must be replayed live on assistant rows. | ||
| // Provider-executed results orphaned onto tool rows are dropped during | ||
| // prompt conversion, so they never reach the provider. | ||
| if row.Role != database.ChatMessageRoleAssistant { | ||
| out = append(out, row) | ||
| continue | ||
| } | ||
| if origin, ok := originProvider(row.ModelConfigID); ok && origin == targetProvider { | ||
| out = append(out, row) | ||
| continue | ||
| } | ||
|
|
||
| parts, err := chatprompt.ParseContent(row) | ||
| if err != nil { | ||
| // Leave unparsable rows untouched; the converter handles them. | ||
| out = append(out, row) | ||
| continue | ||
| } | ||
|
|
||
| kept := make([]codersdk.ChatMessagePart, 0, len(parts)) | ||
| var removedCalls, removedResults int | ||
| for _, part := range parts { | ||
| switch { | ||
| case part.Type == codersdk.ChatMessagePartTypeToolCall && part.ProviderExecuted: | ||
| removedCalls++ | ||
| case part.Type == codersdk.ChatMessagePartTypeToolResult && part.ProviderExecuted: | ||
| removedResults++ | ||
| default: | ||
| kept = append(kept, part) | ||
| } | ||
| } | ||
| if removedCalls == 0 && removedResults == 0 { | ||
| out = append(out, row) | ||
| continue | ||
| } | ||
| if len(kept) == 0 { | ||
| stats.RemovedToolCalls += removedCalls | ||
| stats.RemovedToolResults += removedResults | ||
| stats.DroppedMessages++ | ||
| continue | ||
| } | ||
|
|
||
| content, err := chatprompt.MarshalParts(kept) | ||
| if err != nil { | ||
| // Keep the original row rather than corrupting history. | ||
| out = append(out, row) | ||
| continue | ||
| } | ||
| row.Content = content | ||
| row.ContentVersion = chatprompt.ContentVersionV1 | ||
| stats.RemovedToolCalls += removedCalls | ||
| stats.RemovedToolResults += removedResults | ||
| out = append(out, row) | ||
| } | ||
| return out, stats | ||
| } | ||
|
|
||
| // sanitizeForeignProviderExecutedToolRows strips provider-executed tool history | ||
| // produced by a provider other than the one targeted by modelConfigID. It | ||
| // resolves each row's provider via the model config cache and logs a single | ||
| // summary when anything is removed. | ||
| func (server *Server) sanitizeForeignProviderExecutedToolRows( | ||
| ctx context.Context, | ||
| logger slog.Logger, | ||
| rows []database.ChatMessage, | ||
| modelConfigID uuid.UUID, | ||
| ) []database.ChatMessage { | ||
| _, target, err := server.resolveModelConfigAndNormalizedProvider(ctx, modelConfigID) | ||
| if err != nil || target == "" { | ||
|
johnstcn marked this conversation as resolved.
|
||
| // Without a known target provider we cannot classify history; leave it. | ||
| logger.Debug(ctx, "skipping provider-switch sanitization: target provider unresolved", | ||
| slog.F("model_config_id", modelConfigID), | ||
| slog.Error(err), | ||
| ) | ||
| return rows | ||
| } | ||
|
|
||
| cache := make(map[uuid.UUID]string) | ||
| originProvider := func(id uuid.NullUUID) (string, bool) { | ||
| if !id.Valid { | ||
| return "", false | ||
| } | ||
| if provider, seen := cache[id.UUID]; seen { | ||
| return provider, provider != "" | ||
| } | ||
| _, provider, rErr := server.resolveModelConfigAndNormalizedProvider(ctx, id.UUID) | ||
|
johnstcn marked this conversation as resolved.
|
||
| if rErr != nil { | ||
|
johnstcn marked this conversation as resolved.
|
||
| // Unresolvable origin (e.g. a since-disabled or deleted config) is | ||
| // treated as foreign so we fail closed rather than replay blocks the | ||
| // target may reject. | ||
| logger.Debug(ctx, "provider-switch sanitization: origin provider unresolved, treating as foreign", | ||
| slog.F("model_config_id", id.UUID), | ||
| slog.Error(rErr), | ||
| ) | ||
| provider = "" | ||
| } | ||
| cache[id.UUID] = provider | ||
|
johnstcn marked this conversation as resolved.
|
||
| return provider, provider != "" | ||
| } | ||
|
|
||
| sanitized, stats := stripForeignProviderExecutedToolRows(rows, target, originProvider) | ||
| if stats != (providerSwitchStripStats{}) { | ||
| logger.Warn(ctx, "stripped foreign provider-executed tool history", | ||
| slog.F("phase", "provider_switch"), | ||
| slog.F("target_provider", target), | ||
| slog.F("removed_tool_calls", stats.RemovedToolCalls), | ||
| slog.F("removed_tool_results", stats.RemovedToolResults), | ||
| slog.F("dropped_messages", stats.DroppedMessages), | ||
| ) | ||
| } | ||
| return sanitized | ||
| } | ||
183 changes: 183 additions & 0 deletions
183
coderd/x/chatd/provider_switch_sanitize_internal_test.go
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,183 @@ | ||
| package chatd | ||
|
|
||
| import ( | ||
| "encoding/json" | ||
| "testing" | ||
|
|
||
| "github.com/google/uuid" | ||
| "github.com/sqlc-dev/pqtype" | ||
| "github.com/stretchr/testify/require" | ||
|
|
||
| "github.com/coder/coder/v2/coderd/database" | ||
| "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" | ||
| "github.com/coder/coder/v2/codersdk" | ||
| ) | ||
|
|
||
| func TestStripForeignProviderExecutedToolRows(t *testing.T) { | ||
| t.Parallel() | ||
|
|
||
| const ( | ||
| anthropic = "anthropic" | ||
| bedrock = "bedrock" | ||
| openai = "openai" | ||
| ) | ||
|
|
||
| anthropicCfg := uuid.New() | ||
| openAICfg := uuid.New() | ||
| unknownCfg := uuid.New() | ||
|
|
||
| peCall := func(id string) codersdk.ChatMessagePart { | ||
| p := codersdk.ChatMessageToolCall(id, "web_search", json.RawMessage(`{"query":"x"}`)) | ||
| p.ProviderExecuted = true | ||
| return p | ||
| } | ||
| peResult := func(id string) codersdk.ChatMessagePart { | ||
| p := codersdk.ChatMessageToolResult(id, "web_search", json.RawMessage(`{"ok":true}`), false, false) | ||
| p.ProviderExecuted = true | ||
| return p | ||
| } | ||
| localCall := func(id string) codersdk.ChatMessagePart { | ||
| return codersdk.ChatMessageToolCall(id, "read_file", json.RawMessage(`{}`)) | ||
| } | ||
| text := codersdk.ChatMessageText | ||
|
|
||
| assistantRow := func(t *testing.T, cfg uuid.UUID, parts ...codersdk.ChatMessagePart) database.ChatMessage { | ||
| t.Helper() | ||
| content, err := chatprompt.MarshalParts(parts) | ||
| require.NoError(t, err) | ||
| return database.ChatMessage{ | ||
| Role: database.ChatMessageRoleAssistant, | ||
| ModelConfigID: uuid.NullUUID{UUID: cfg, Valid: cfg != uuid.Nil}, | ||
| Content: content, | ||
| ContentVersion: chatprompt.ContentVersionV1, | ||
| } | ||
| } | ||
| userRow := func(t *testing.T, s string) database.ChatMessage { | ||
| t.Helper() | ||
| content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{text(s)}) | ||
| require.NoError(t, err) | ||
| return database.ChatMessage{ | ||
| Role: database.ChatMessageRoleUser, | ||
| Content: content, | ||
| ContentVersion: chatprompt.ContentVersionV1, | ||
| } | ||
| } | ||
|
|
||
| // origin maps a model config ID to its normalized provider. unknownCfg is | ||
| // intentionally absent so the resolver reports an unknown origin. | ||
| origin := func(providerByConfig map[uuid.UUID]string) func(uuid.NullUUID) (string, bool) { | ||
| return func(id uuid.NullUUID) (string, bool) { | ||
| if !id.Valid { | ||
| return "", false | ||
| } | ||
| provider, ok := providerByConfig[id.UUID] | ||
| return provider, ok | ||
| } | ||
| } | ||
| resolver := origin(map[uuid.UUID]string{ | ||
| anthropicCfg: anthropic, | ||
| openAICfg: openai, | ||
|
johnstcn marked this conversation as resolved.
|
||
| }) | ||
|
|
||
| // partsOf parses a row's content back into SDK parts for comparison. | ||
| partsOf := func(t *testing.T, row database.ChatMessage) []codersdk.ChatMessagePart { | ||
| t.Helper() | ||
| parts, err := chatprompt.ParseContent(row) | ||
| require.NoError(t, err) | ||
| return parts | ||
| } | ||
|
|
||
| t.Run("same provider kept", func(t *testing.T) { | ||
| t.Parallel() | ||
| rows := []database.ChatMessage{ | ||
| userRow(t, "hi"), | ||
| assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws"), text("done")), | ||
| } | ||
| got, stats := stripForeignProviderExecutedToolRows(rows, anthropic, resolver) | ||
| require.Equal(t, rows, got) | ||
| require.Zero(t, stats) | ||
| }) | ||
|
|
||
| t.Run("anthropic to bedrock drops provider blocks", func(t *testing.T) { | ||
| t.Parallel() | ||
| rows := []database.ChatMessage{ | ||
| userRow(t, "hi"), | ||
| assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws"), text("done")), | ||
| } | ||
| got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) | ||
| require.Len(t, got, 2) | ||
| require.Equal(t, []codersdk.ChatMessagePart{text("done")}, partsOf(t, got[1])) | ||
| require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, RemovedToolResults: 1}, stats) | ||
| }) | ||
|
|
||
| t.Run("foreign-only row dropped", func(t *testing.T) { | ||
| t.Parallel() | ||
| rows := []database.ChatMessage{ | ||
| userRow(t, "hi"), | ||
| assistantRow(t, anthropicCfg, peCall("ws")), | ||
| userRow(t, "again"), | ||
| } | ||
| got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) | ||
| require.Len(t, got, 2) | ||
| require.Equal(t, database.ChatMessageRoleUser, got[0].Role) | ||
| require.Equal(t, database.ChatMessageRoleUser, got[1].Role) | ||
| require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, DroppedMessages: 1}, stats) | ||
| }) | ||
|
|
||
| t.Run("multi-provider keeps native strips foreign", func(t *testing.T) { | ||
| t.Parallel() | ||
| rows := []database.ChatMessage{ | ||
| assistantRow(t, openAICfg, peCall("os"), peResult("os"), text("openai")), | ||
| assistantRow(t, anthropicCfg, peCall("as"), peResult("as"), text("anthropic")), | ||
| } | ||
| got, stats := stripForeignProviderExecutedToolRows(rows, anthropic, resolver) | ||
| require.Len(t, got, 2) | ||
| require.Equal(t, []codersdk.ChatMessagePart{text("openai")}, partsOf(t, got[0])) | ||
| require.Equal(t, rows[1], got[1]) | ||
| require.Equal(t, providerSwitchStripStats{RemovedToolCalls: 1, RemovedToolResults: 1}, stats) | ||
| }) | ||
|
|
||
| t.Run("non-provider-executed parts untouched", func(t *testing.T) { | ||
| t.Parallel() | ||
| rows := []database.ChatMessage{ | ||
| assistantRow(t, anthropicCfg, text("hello"), localCall("local")), | ||
| } | ||
| got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) | ||
| require.Equal(t, rows, got) | ||
| require.Zero(t, stats) | ||
| }) | ||
|
|
||
| t.Run("empty target is a no-op", func(t *testing.T) { | ||
| t.Parallel() | ||
| rows := []database.ChatMessage{ | ||
| assistantRow(t, anthropicCfg, peCall("ws"), peResult("ws")), | ||
| } | ||
| got, stats := stripForeignProviderExecutedToolRows(rows, "", resolver) | ||
| require.Equal(t, rows, got) | ||
| require.Zero(t, stats) | ||
| }) | ||
|
|
||
| t.Run("unknown origin fails closed", func(t *testing.T) { | ||
| t.Parallel() | ||
| rows := []database.ChatMessage{ | ||
| assistantRow(t, unknownCfg, peResult("ws"), text("done")), | ||
| } | ||
| got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) | ||
| require.Len(t, got, 1) | ||
| require.Equal(t, []codersdk.ChatMessagePart{text("done")}, partsOf(t, got[0])) | ||
| require.Equal(t, providerSwitchStripStats{RemovedToolResults: 1}, stats) | ||
| }) | ||
|
|
||
| t.Run("unparsable foreign row kept unchanged", func(t *testing.T) { | ||
| t.Parallel() | ||
| rows := []database.ChatMessage{{ | ||
| Role: database.ChatMessageRoleAssistant, | ||
| ModelConfigID: uuid.NullUUID{UUID: anthropicCfg, Valid: true}, | ||
| Content: pqtype.NullRawMessage{RawMessage: []byte("{not json"), Valid: true}, | ||
| ContentVersion: chatprompt.ContentVersionV1, | ||
| }} | ||
| got, stats := stripForeignProviderExecutedToolRows(rows, bedrock, resolver) | ||
| require.Equal(t, rows, got) | ||
| require.Zero(t, stats) | ||
| }) | ||
| } | ||
|
johnstcn marked this conversation as resolved.
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.