diff --git a/cli/aibridged_internal_test.go b/cli/aibridged_internal_test.go index 6b3e1eb7ac731..bb47a4e58eccd 100644 --- a/cli/aibridged_internal_test.go +++ b/cli/aibridged_internal_test.go @@ -154,13 +154,13 @@ func TestBuildProviders(t *testing.T) { require.NoError(t, err) names := providerNames(providers) - assert.Equal(t, []string{aibridge.ProviderAnthropic}, names) + assert.ElementsMatch(t, []string{aibridge.ProviderAnthropic, "bedrock"}, names) }) t.Run("LegacyBedrockWithoutAnthropicKey", func(t *testing.T) { t.Parallel() - // Bedrock credentials alone should be enough to create an - // Anthropic provider — no CODER_AIBRIDGE_ANTHROPIC_KEY needed. + // Bedrock credentials alone should be enough to create a + // Bedrock provider without CODER_AIBRIDGE_ANTHROPIC_KEY. cfg := codersdk.AIBridgeConfig{} cfg.LegacyBedrock.Region = serpent.String("us-west-2") cfg.LegacyBedrock.AccessKey = serpent.String("AKID") @@ -172,7 +172,7 @@ func TestBuildProviders(t *testing.T) { p := providers[0] assert.Equal(t, aibridge.ProviderAnthropic, p.Type()) - assert.Equal(t, aibridge.ProviderAnthropic, p.Name()) + assert.Equal(t, "bedrock", p.Name()) }) t.Run("UnknownType", func(t *testing.T) { diff --git a/coderd/ai_provider_canonical.go b/coderd/ai_provider_canonical.go new file mode 100644 index 0000000000000..e0a43992c365d --- /dev/null +++ b/coderd/ai_provider_canonical.go @@ -0,0 +1,15 @@ +package coderd + +import ( + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/codersdk" +) + +func canonicalDatabaseAIProviderType(providerType database.AIProviderType, settings codersdk.AIProviderSettings) database.AIProviderType { + return database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(providerType), settings)) +} + +func canonicalAIProviderTypeForRow(provider database.AIProvider) (database.AIProviderType, error) { + return db2sdk.CanonicalAIProviderType(provider) +} diff --git a/coderd/ai_providers.go b/coderd/ai_providers.go index 0637822592c68..5e1565f4e9155 100644 --- a/coderd/ai_providers.go +++ b/coderd/ai_providers.go @@ -159,6 +159,7 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) { if !httpapi.Read(ctx, rw, r, &req) { return } + req.Type = codersdk.CanonicalAIProviderType(req.Type, req.Settings) if validations := req.Validate(); len(validations) > 0 { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ @@ -168,16 +169,6 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) { return } - // Bedrock providers authenticate via the settings blob, not via a - // bearer key, so registering an api_keys list against them would - // be silently unused. - if req.Settings.Bedrock != nil && len(req.APIKeys) > 0 { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Bedrock providers do not accept api_keys; configure access credentials via settings.", - }) - return - } - settings, err := encodeAIProviderSettings(req.Settings) if err != nil { api.Logger.Error(ctx, "encode AI provider settings", slog.Error(err)) @@ -320,15 +311,17 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { if req.Settings != nil { existing = mergeAIProviderSettings(existing, *req.Settings) } - // Bedrock settings are only meaningful for anthropic- or - // bedrock-typed providers; rejecting the mismatch keeps a - // misconfiguration from sitting silently in the encrypted - // blob. - if existing.Bedrock != nil && - old.Type != database.AiProviderTypeAnthropic && - old.Type != database.AiProviderTypeBedrock { + targetType := canonicalDatabaseAIProviderType(old.Type, existing) + targetBaseURL := ptr.NilToDefault(req.BaseURL, old.BaseUrl) + // Bedrock settings are only meaningful for Bedrock providers; + // rejecting the mismatch keeps a misconfiguration from sitting + // silently in the encrypted blob. + if existing.Bedrock != nil && targetType != database.AiProviderTypeBedrock { return errAIProviderBedrockTypeMismatch } + if targetType == database.AiProviderTypeBedrock && !codersdk.IsBedrockProviderConfigured(targetBaseURL, existing.Bedrock) { + return errAIProviderBedrockSettingsRequired + } settings, err := encodeAIProviderSettings(existing) if err != nil { return xerrors.Errorf("encode settings: %w", err) @@ -336,11 +329,11 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { // Reject keys against Bedrock providers (whether the existing // row is Bedrock or the patch would make it so). - if req.APIKeys != nil && existing.Bedrock != nil && len(*req.APIKeys) > 0 { + if req.APIKeys != nil && targetType == database.AiProviderTypeBedrock && len(*req.APIKeys) > 0 { return errBedrockRejectsAPIKeys } - if req.APIKeys != nil && old.Type == database.AiProviderTypeCopilot && len(*req.APIKeys) > 0 { + if req.APIKeys != nil && targetType == database.AiProviderTypeCopilot && len(*req.APIKeys) > 0 { return errCopilotRejectsAPIKeys } @@ -351,9 +344,10 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { } params := database.UpdateAIProviderParams{ ID: old.ID, + Type: targetType, DisplayName: displayName, Enabled: ptr.NilToDefault(req.Enabled, old.Enabled), - BaseUrl: ptr.NilToDefault(req.BaseURL, old.BaseUrl), + BaseUrl: targetBaseURL, Settings: settings, // SettingsKeyID is set by the dbcrypt wrapper. SettingsKeyID: sql.NullString{}, @@ -393,9 +387,15 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { }) return } + if errors.Is(err, errAIProviderBedrockSettingsRequired) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "type=bedrock requires bedrock settings or base_url.", + }) + return + } if errors.Is(err, errAIProviderBedrockTypeMismatch) { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Bedrock settings are only valid for type=anthropic or type=bedrock.", + Message: "Bedrock settings are only valid for type=bedrock.", }) return } @@ -501,9 +501,13 @@ var errCopilotRejectsAPIKeys = xerrors.New("copilot providers do not accept api_ // errAIProviderBedrockTypeMismatch is the sentinel returned from // inside the update transaction when the post-merge settings carry a -// Bedrock block but the provider is not anthropic- or bedrock-typed; -// the outer handler translates it into a 400. -var errAIProviderBedrockTypeMismatch = xerrors.New("bedrock settings are only valid for type=anthropic or type=bedrock") +// Bedrock block but the provider is not Bedrock-typed; the outer handler +// translates it into a 400. +var errAIProviderBedrockTypeMismatch = xerrors.New("bedrock settings are only valid for type=bedrock") + +// errAIProviderBedrockSettingsRequired is returned when a Bedrock provider +// would be stored without enough Bedrock connection settings. +var errAIProviderBedrockSettingsRequired = xerrors.New("type=bedrock requires bedrock settings or base_url") // errAIProviderInvalidName is returned from lookupAIProvider when the // idOrName parameter is neither a UUID nor a syntactically-valid name. diff --git a/coderd/ai_providers_migrate.go b/coderd/ai_providers_migrate.go index 055877ecce9d5..93a383e7b8ed5 100644 --- a/coderd/ai_providers_migrate.go +++ b/coderd/ai_providers_migrate.go @@ -103,6 +103,31 @@ func SeedAIProvidersFromEnv( } existing, found := byName[dp.Name] + if found && !existing.Deleted && dp.Type == database.AiProviderTypeAnthropic { + existingSettings, err := db2sdk.AIProviderSettings(existing.Settings) + if err != nil { + return xerrors.Errorf("decode existing settings for %q: %w", existing.Name, err) + } + if canonicalDatabaseAIProviderType(existing.Type, existingSettings) == database.AiProviderTypeBedrock { + logger.Warn(sysCtx, "skipping legacy Anthropic env seed because an existing Anthropic-named row contains Bedrock settings", + slog.F("name", dp.Name), + ) + continue + } + } + if !found && dp.Type == database.AiProviderTypeBedrock { + candidate, ok := byName[aibridge.ProviderAnthropic] + if ok { + candidateSettings, err := db2sdk.AIProviderSettings(candidate.Settings) + if err != nil { + return xerrors.Errorf("decode existing settings for %q: %w", candidate.Name, err) + } + if canonicalDatabaseAIProviderType(candidate.Type, candidateSettings) == database.AiProviderTypeBedrock { + existing = candidate + found = true + } + } + } switch { case found && existing.Deleted: // The provider was created here, then explicitly @@ -127,7 +152,7 @@ func SeedAIProvidersFromEnv( existingKeys = append(existingKeys, k.APIKey) } existingDP := desiredAIProvider{ - Type: existing.Type, + Type: canonicalDatabaseAIProviderType(existing.Type, existingSettings), BaseURL: existing.BaseUrl, Bedrock: existingSettings.Bedrock, Keys: existingKeys, @@ -136,6 +161,9 @@ func SeedAIProvidersFromEnv( if existingHash == dp.Hash { continue } + if existing.Name != dp.Name { + return xerrors.Errorf("AI provider %q matches existing legacy row %q and differs from the current environment configuration; update the provider through the API or remove the CODER_AIBRIDGE_* env vars to stop seeding it", dp.Name, existing.Name) + } return xerrors.Errorf("AI provider %q already exists in the database and differs from the current environment configuration; update the provider through the API or remove the CODER_AIBRIDGE_* env vars to stop seeding it", dp.Name) } @@ -310,11 +338,9 @@ func providersFromEnv(ctx context.Context, cfg codersdk.AIBridgeConfig, logger s addLegacy(aibridge.ProviderOpenAI, dp) } - // Legacy Anthropic + Bedrock. Anthropic is enabled if either an - // Anthropic key OR any Bedrock setting is explicitly configured. - // Detection goes through AIProviderBedrockSettings.IsConfigured() - // so the legacy and indexed paths agree on what counts as a - // Bedrock provider. + // Legacy Anthropic and Bedrock env vars seed independent providers. + // Detection goes through IsBedrockConfigured so the legacy and + // indexed paths agree on what counts as a Bedrock provider. bedrock := codersdk.NewAIProviderBedrockSettings( cfg.LegacyBedrock.Region.String(), cfg.LegacyBedrock.AccessKey.String(), @@ -323,29 +349,27 @@ func providersFromEnv(ctx context.Context, cfg codersdk.AIBridgeConfig, logger s cfg.LegacyBedrock.SmallFastModel.String(), ) hasAnthropicKey := cfg.LegacyAnthropic.Key.String() != "" - hasLegacyBedrock := codersdk.IsBedrockConfigured(cfg.LegacyBedrock.BaseURL.String(), bedrock) - if hasAnthropicKey || hasLegacyBedrock { + if hasAnthropicKey { dp := desiredAIProvider{ - Name: aibridge.ProviderAnthropic, - Type: database.AiProviderTypeAnthropic, - } - if hasLegacyBedrock { - if hasAnthropicKey { - logger.Warn(ctx, "ignoring legacy Anthropic API key because Bedrock credentials are configured; Bedrock authenticates via access keys or credential chain", - slog.F("provider", aibridge.ProviderAnthropic), - ) - } - // Bedrock-only deployments use CODER_AIBRIDGE_BEDROCK_BASE_URL - // for custom VPC, FIPS, or proxy endpoints. - dp.BaseURL = cfg.LegacyBedrock.BaseURL.String() - dp.Bedrock = &bedrock - } else { - dp.BaseURL = cfg.LegacyAnthropic.BaseURL.String() - dp.Keys = []string{cfg.LegacyAnthropic.Key.String()} + Name: aibridge.ProviderAnthropic, + Type: database.AiProviderTypeAnthropic, + BaseURL: cfg.LegacyAnthropic.BaseURL.String(), + Keys: []string{cfg.LegacyAnthropic.Key.String()}, } dp.Hash = computeProviderHash(dp.canonical()) addLegacy(aibridge.ProviderAnthropic, dp) } + hasLegacyBedrock := codersdk.IsBedrockConfigured(cfg.LegacyBedrock.BaseURL.String(), bedrock) + if hasLegacyBedrock { + dp := desiredAIProvider{ + Name: "bedrock", + Type: database.AiProviderTypeBedrock, + BaseURL: cfg.LegacyBedrock.BaseURL.String(), + Bedrock: &bedrock, + } + dp.Hash = computeProviderHash(dp.canonical()) + addLegacy(dp.Name, dp) + } // Indexed providers. for _, p := range cfg.Providers { @@ -398,6 +422,7 @@ func providersFromEnv(ctx context.Context, cfg codersdk.AIBridgeConfig, logger s ) isBedrock = codersdk.IsBedrockConfigured(p.BedrockBaseURL, bedrock) if isBedrock { + dp.Type = database.AiProviderTypeBedrock dp.Bedrock = &bedrock // Always overwrite the generic BaseURL so removing // BASE_URL later doesn't trigger drift. Empty is fine: diff --git a/coderd/ai_providers_migrate_test.go b/coderd/ai_providers_migrate_test.go index 89165002b0da6..ba91091400a3b 100644 --- a/coderd/ai_providers_migrate_test.go +++ b/coderd/ai_providers_migrate_test.go @@ -2,8 +2,10 @@ package coderd_test import ( "bytes" + "database/sql" "testing" + "github.com/google/uuid" "github.com/stretchr/testify/require" "cdr.dev/slog/v3" @@ -145,8 +147,8 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - // Bedrock fields without an Anthropic key produce a Bedrock- - // authenticated Anthropic provider with no bearer keys. + // Bedrock fields without an Anthropic key produce an independent + // Bedrock provider with no bearer keys. cfg := codersdk.AIBridgeConfig{ LegacyBedrock: codersdk.AIBridgeBedrockConfig{ Region: serpent.String("us-west-2"), @@ -158,9 +160,9 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { } require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) - row, err := db.GetAIProviderByName(ctx, "anthropic") + row, err := db.GetAIProviderByName(ctx, "bedrock") require.NoError(t, err) - require.Equal(t, database.AiProviderTypeAnthropic, row.Type) + require.Equal(t, database.AiProviderTypeBedrock, row.Type) require.Contains(t, row.Settings.String, "us-west-2") require.Contains(t, row.Settings.String, "anthropic.claude-3-5-sonnet") require.Contains(t, row.Settings.String, "anthropic.claude-3-5-haiku") @@ -202,6 +204,116 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { require.Equal(t, "sk-ant-only", keys[0].APIKey) }) + t.Run("LegacyAnthropicAndBedrockCreateIndependentProviders", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + + cfg := codersdk.AIBridgeConfig{ + LegacyAnthropic: codersdk.AIBridgeAnthropicConfig{ + BaseURL: serpent.String("https://api.anthropic.com/"), + Key: serpent.String("sk-ant"), + }, + LegacyBedrock: codersdk.AIBridgeBedrockConfig{ + Region: serpent.String("us-west-2"), + AccessKey: serpent.String("AKIA"), + AccessKeySecret: serpent.String("secret"), + Model: serpent.String("anthropic.claude-3-5-sonnet"), + }, + } + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) + + anthropic, err := db.GetAIProviderByName(ctx, "anthropic") + require.NoError(t, err) + require.Equal(t, database.AiProviderTypeAnthropic, anthropic.Type) + anthropicKeys, err := db.GetAIProviderKeysByProviderID(ctx, anthropic.ID) + require.NoError(t, err) + require.Len(t, anthropicKeys, 1) + require.Equal(t, "sk-ant", anthropicKeys[0].APIKey) + + bedrock, err := db.GetAIProviderByName(ctx, "bedrock") + require.NoError(t, err) + require.Equal(t, database.AiProviderTypeBedrock, bedrock.Type) + bedrockKeys, err := db.GetAIProviderKeysByProviderID(ctx, bedrock.ID) + require.NoError(t, err) + require.Empty(t, bedrockKeys) + }) + + t.Run("LegacyBedrockRowSkipsBothEnvSeeds", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + + legacySettings := `{"_type":"bedrock","_version":1,"region":"us-west-2","model":"anthropic.claude-3-5-sonnet","access_key":"AKIA","access_key_secret":"secret"}` + _, err := db.InsertAIProvider(ctx, database.InsertAIProviderParams{ + ID: uuid.New(), + Type: database.AiProviderTypeAnthropic, + Name: "anthropic", + DisplayName: sql.NullString{String: "anthropic", Valid: true}, + Enabled: true, + BaseUrl: "", + Settings: sql.NullString{String: legacySettings, Valid: true}, + SettingsKeyID: sql.NullString{}, + }) + require.NoError(t, err) + + cfg := codersdk.AIBridgeConfig{ + LegacyAnthropic: codersdk.AIBridgeAnthropicConfig{ + BaseURL: serpent.String("https://api.anthropic.com/"), + Key: serpent.String("sk-ant"), + }, + LegacyBedrock: codersdk.AIBridgeBedrockConfig{ + Region: serpent.String("us-west-2"), + AccessKey: serpent.String("AKIA"), + AccessKeySecret: serpent.String("secret"), + Model: serpent.String("anthropic.claude-3-5-sonnet"), + }, + } + + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) + + row, err := db.GetAIProviderByName(ctx, "anthropic") + require.NoError(t, err) + require.Equal(t, database.AiProviderTypeAnthropic, row.Type) + require.Contains(t, row.Settings.String, "us-west-2") + keys, err := db.GetAIProviderKeysByProviderID(ctx, row.ID) + require.NoError(t, err) + require.Empty(t, keys) + }) + + t.Run("LegacyBedrockRowNamedAnthropicReportsBedrockDrift", func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + + legacySettings := `{"_type":"bedrock","_version":1,"region":"us-west-2","model":"anthropic.claude-3-5-sonnet","access_key":"AKIA","access_key_secret":"secret"}` + _, err := db.InsertAIProvider(ctx, database.InsertAIProviderParams{ + ID: uuid.New(), + Type: database.AiProviderTypeAnthropic, + Name: "anthropic", + DisplayName: sql.NullString{String: "anthropic", Valid: true}, + Enabled: true, + BaseUrl: "", + Settings: sql.NullString{String: legacySettings, Valid: true}, + SettingsKeyID: sql.NullString{}, + }) + require.NoError(t, err) + + cfg := codersdk.AIBridgeConfig{ + LegacyBedrock: codersdk.AIBridgeBedrockConfig{ + Region: serpent.String("eu-west-1"), + AccessKey: serpent.String("AKIA"), + AccessKeySecret: serpent.String("secret"), + Model: serpent.String("anthropic.claude-3-5-sonnet"), + }, + } + + err = coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) + require.Error(t, err) + require.Contains(t, err.Error(), `AI provider "bedrock" matches existing legacy row "anthropic"`) + require.Contains(t, err.Error(), "differs from the current environment configuration") + }) + t.Run("BedrockWithoutCredentialsUsesAWSEnvAuth", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) @@ -218,8 +330,9 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { } require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) - row, err := db.GetAIProviderByName(ctx, "anthropic") + row, err := db.GetAIProviderByName(ctx, "bedrock") require.NoError(t, err) + require.Equal(t, database.AiProviderTypeBedrock, row.Type) require.True(t, row.Settings.Valid, "Bedrock metadata must produce a settings blob") require.Contains(t, row.Settings.String, "us-east-1") require.Contains(t, row.Settings.String, "anthropic.claude-3-5-sonnet") @@ -242,13 +355,14 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { }, } require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) - row, err := db.GetAIProviderByName(ctx, "anthropic") + row, err := db.GetAIProviderByName(ctx, "bedrock") require.NoError(t, err) + require.Equal(t, database.AiProviderTypeBedrock, row.Type) require.Contains(t, row.Settings.String, "us-east-1") require.Contains(t, row.Settings.String, "AKIAONLY") require.Contains(t, row.Settings.String, "secretonly") - // Bedrock-only Anthropic has zero ai_provider_keys: it - // authenticates via the settings blob. + // Bedrock has zero ai_provider_keys: it authenticates via the + // settings blob. keys, err := db.GetAIProviderKeysByProviderID(ctx, row.ID) require.NoError(t, err) require.Empty(t, keys) @@ -371,6 +485,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { row, err := db.GetAIProviderByName(ctx, "bedrock-anthropic") require.NoError(t, err) + require.Equal(t, database.AiProviderTypeBedrock, row.Type) require.Contains(t, row.Settings.String, "AKIA-indexed") require.Contains(t, row.Settings.String, "indexed-secret") // Crucially, no ai_provider_keys rows for Bedrock providers. diff --git a/coderd/ai_providers_test.go b/coderd/ai_providers_test.go index b9bfd283f1c9e..5532d6f05de17 100644 --- a/coderd/ai_providers_test.go +++ b/coderd/ai_providers_test.go @@ -1,6 +1,7 @@ package coderd_test import ( + "database/sql" "encoding/json" "io" "net/http" @@ -87,11 +88,11 @@ func TestAIProvidersCRUD(t *testing.T) { // Create. req := codersdk.CreateAIProviderRequest{ - Type: codersdk.AIProviderTypeAnthropic, - Name: "primary-anthropic", - DisplayName: "Primary Anthropic", + Type: codersdk.AIProviderTypeBedrock, + Name: "primary-bedrock", + DisplayName: "Primary Bedrock", Enabled: true, - BaseURL: "https://api.anthropic.com/", + BaseURL: "https://bedrock-runtime.us-east-1.amazonaws.com/", Settings: codersdk.AIProviderSettings{ Bedrock: &codersdk.AIProviderBedrockSettings{ Region: "us-east-1", @@ -128,7 +129,7 @@ func TestAIProvidersCRUD(t *testing.T) { // Update. newDisplay := "Updated Display" - newURL := "https://api.anthropic.com/v1" + newURL := "https://bedrock-runtime.us-west-2.amazonaws.com/" disabled := false updated, err := client.UpdateAIProvider(ctx, created.Name, codersdk.UpdateAIProviderRequest{ DisplayName: &newDisplay, @@ -174,6 +175,26 @@ func TestAIProvidersCRUD(t *testing.T) { require.Equal(t, req.Name, recreated.Name) }) + t.Run("CreateLegacyAnthropicBedrockNormalizesType", func(t *testing.T) { + t.Parallel() + client := coderdtest.New(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx := testutil.Context(t, testutil.WaitLong) + + //nolint:gocritic // Owner role is the audience for this endpoint. + created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeAnthropic, + Name: "compat-bedrock", + Enabled: true, + BaseURL: "https://bedrock-runtime.us-east-1.amazonaws.com/", + Settings: codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }, + }) + require.NoError(t, err) + require.Equal(t, codersdk.AIProviderTypeBedrock, created.Type) + }) + t.Run("DefaultDisplayName", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) @@ -406,21 +427,21 @@ func TestAIProvidersCRUD(t *testing.T) { require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) }) - t.Run("UpdateSettingsEmptyObjectRejected", func(t *testing.T) { + t.Run("UpdateSettingsEmptyObjectClearsSettings", func(t *testing.T) { t.Parallel() - // "settings": {} cannot decode because the _type discriminator - // is missing. The handler must reject with 400; nothing about - // the provider should change. client := coderdtest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) ctx := testutil.Context(t, testutil.WaitLong) //nolint:gocritic // Owner role is the audience for this endpoint. created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ - Type: codersdk.AIProviderTypeOpenAI, + Type: codersdk.AIProviderTypeBedrock, Name: "patch-settings-empty", Enabled: true, - BaseURL: "https://api.openai.com/v1", + BaseURL: "https://bedrock-runtime.us-east-1.amazonaws.com/", + Settings: codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }, }) require.NoError(t, err) @@ -430,11 +451,11 @@ func TestAIProvidersCRUD(t *testing.T) { ) require.NoError(t, err) defer res.Body.Close() - require.Equal(t, http.StatusBadRequest, res.StatusCode) - var body codersdk.Response - require.NoError(t, json.NewDecoder(res.Body).Decode(&body)) - require.Contains(t, body.Message, "valid JSON") - require.Contains(t, body.Detail, "_type discriminator") + require.Equal(t, http.StatusOK, res.StatusCode) + + updated, err := client.AIProvider(ctx, created.Name) + require.NoError(t, err) + require.Nil(t, updated.Settings.Bedrock) }) t.Run("NotFound", func(t *testing.T) { @@ -535,7 +556,7 @@ func TestAIProvidersCRUD(t *testing.T) { require.NotEmpty(t, sdkErr.Message) }) - t.Run("BedrockSettingsRequireAnthropic", func(t *testing.T) { + t.Run("BedrockSettingsRequireBedrock", func(t *testing.T) { t.Parallel() client := coderdtest.New(t, nil) _ = coderdtest.CreateFirstUser(t, client) @@ -565,7 +586,7 @@ func TestAIProvidersCRUD(t *testing.T) { require.Contains(t, sdkErr.Message, "Invalid AI provider request") require.NotEmpty(t, sdkErr.Validations) require.Equal(t, "settings", sdkErr.Validations[0].Field) - require.Contains(t, sdkErr.Validations[0].Detail, "bedrock settings are only valid for type=anthropic") + require.Contains(t, sdkErr.Validations[0].Detail, "bedrock settings are only valid for type=bedrock") // Update: existing OpenAI provider patched with Bedrock settings // must also be rejected. @@ -584,7 +605,38 @@ func TestAIProvidersCRUD(t *testing.T) { require.Error(t, err) require.ErrorAs(t, err, &sdkErr) require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) - require.Contains(t, sdkErr.Message, "Bedrock settings are only valid for type=anthropic") + require.Contains(t, sdkErr.Message, "Bedrock settings are only valid for type=bedrock") + }) + + t.Run("BedrockUpdateCannotClearSettings", func(t *testing.T) { + t.Parallel() + client, db := coderdtest.NewWithDatabase(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx := testutil.Context(t, testutil.WaitLong) + + _, err := db.InsertAIProvider(dbauthz.AsSystemRestricted(ctx), database.InsertAIProviderParams{ + ID: uuid.New(), + Type: database.AiProviderTypeBedrock, + Name: "bedrock-clear-settings", + DisplayName: sql.NullString{String: "bedrock-clear-settings", Valid: true}, + Enabled: true, + Settings: sql.NullString{ + String: `{"_type":"bedrock","_version":1,"region":"us-east-1"}`, + Valid: true, + }, + }) + require.NoError(t, err) + + enabled := false + _, err = client.UpdateAIProvider(ctx, "bedrock-clear-settings", codersdk.UpdateAIProviderRequest{ + Enabled: &enabled, + Settings: &codersdk.AIProviderSettings{}, + }) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + require.Contains(t, sdkErr.Message, "type=bedrock requires bedrock settings or base_url") }) t.Run("BedrockSecretsHidden", func(t *testing.T) { @@ -598,7 +650,7 @@ func TestAIProvidersCRUD(t *testing.T) { // back, so callers cannot recover them after creation. //nolint:gocritic // Owner role is the audience for this endpoint. _, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ - Type: codersdk.AIProviderTypeAnthropic, + Type: codersdk.AIProviderTypeBedrock, Name: "bedrock-secret-leak", Enabled: true, BaseURL: "https://bedrock-runtime.us-east-1.amazonaws.com/", @@ -885,7 +937,10 @@ func TestAIProvidersKeyManagement(t *testing.T) { var sdkErr *codersdk.Error require.ErrorAs(t, err, &sdkErr) require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) - require.Contains(t, sdkErr.Message, "Bedrock providers do not accept api_keys") + require.Contains(t, sdkErr.Message, "Invalid AI provider request") + require.NotEmpty(t, sdkErr.Validations) + require.Equal(t, "api_keys", sdkErr.Validations[0].Field) + require.Contains(t, sdkErr.Validations[0].Detail, "type=bedrock does not accept api_keys") }) t.Run("BedrockRejectsUpdateWithKeys", func(t *testing.T) { diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index bc93df7cd3178..0ac0de75be6d7 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -43,6 +43,15 @@ func APIAllowListTarget(entry rbac.AllowListElement) codersdk.APIAllowListTarget } } +// CanonicalAIProviderType returns the runtime provider type for a database row. +func CanonicalAIProviderType(row database.AIProvider) (database.AIProviderType, error) { + settings, err := AIProviderSettings(row.Settings) + if err != nil { + return "", xerrors.Errorf("decode settings: %w", err) + } + return database.AIProviderType(codersdk.CanonicalAIProviderType(codersdk.AIProviderType(row.Type), settings)), nil +} + // AIProvider converts a database row plus its API keys into the // codersdk shape. The caller is responsible for ensuring the row and // keys have been decrypted (i.e. fetched through the dbcrypt-wrapped @@ -50,29 +59,40 @@ func APIAllowListTarget(entry rbac.AllowListElement) codersdk.APIAllowListTarget // write-only fields on Settings are stripped, so the result is safe // to echo back in API responses. func AIProvider(row database.AIProvider, keys []database.AIProviderKey) (codersdk.AIProvider, error) { - display := row.Name - if row.DisplayName.Valid && row.DisplayName.String != "" { - display = row.DisplayName.String + s, err := AIProviderSettings(row.Settings) + if err != nil { + return codersdk.AIProvider{}, xerrors.Errorf("decode settings: %w", err) } + providerType := codersdk.CanonicalAIProviderType(codersdk.AIProviderType(row.Type), s) out := codersdk.AIProvider{ ID: row.ID, - Type: codersdk.AIProviderType(row.Type), + Type: providerType, Name: row.Name, - DisplayName: display, + DisplayName: AIProviderDisplayName(row, providerType), Enabled: row.Enabled, BaseURL: row.BaseUrl, APIKeys: maskAIProviderKeys(keys), + Settings: redactAIProviderSettings(s), CreatedAt: row.CreatedAt, UpdatedAt: row.UpdatedAt, } - s, err := AIProviderSettings(row.Settings) - if err != nil { - return codersdk.AIProvider{}, xerrors.Errorf("decode settings: %w", err) - } - out.Settings = redactAIProviderSettings(s) return out, nil } +// AIProviderDisplayName returns the presentation name for an AI provider row. +func AIProviderDisplayName(row database.AIProvider, providerType codersdk.AIProviderType) string { + display := row.Name + if row.DisplayName.Valid && row.DisplayName.String != "" { + display = row.DisplayName.String + } + if providerType == codersdk.AIProviderTypeBedrock && + (strings.EqualFold(display, string(codersdk.AIProviderTypeAnthropic)) || + strings.EqualFold(display, string(codersdk.AIProviderTypeBedrock))) { + return codersdk.AIProviderDisplayNameBedrock + } + return display +} + // AIProviderSettings parses the on-disk JSON form back into a codersdk // settings value. SQL NULL and the empty string decode to the zero // value. diff --git a/coderd/database/db2sdk/db2sdk_test.go b/coderd/database/db2sdk/db2sdk_test.go index 7dce695afc773..fa10566d38dea 100644 --- a/coderd/database/db2sdk/db2sdk_test.go +++ b/coderd/database/db2sdk/db2sdk_test.go @@ -23,6 +23,97 @@ import ( "github.com/coder/coder/v2/provisionersdk/proto" ) +func TestAIProviderCanonicalTypeAndDisplayName(t *testing.T) { + t.Parallel() + + bedrockSettings := marshalAIProviderSettings(t, codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }) + + cases := []struct { + name string + row database.AIProvider + wantType codersdk.AIProviderType + wantDisplay string + }{ + { + name: "anthropic without bedrock stays anthropic", + row: database.AIProvider{ + Type: database.AiProviderTypeAnthropic, + Name: "anthropic", + }, + wantType: codersdk.AIProviderTypeAnthropic, + wantDisplay: "anthropic", + }, + { + name: "legacy anthropic with bedrock promotes to bedrock", + row: database.AIProvider{ + Type: database.AiProviderTypeAnthropic, + Name: "anthropic", + Settings: bedrockSettings, + }, + wantType: codersdk.AIProviderTypeBedrock, + wantDisplay: codersdk.AIProviderDisplayNameBedrock, + }, + { + name: "bedrock type passes through", + row: database.AIProvider{ + Type: database.AiProviderTypeBedrock, + Name: "bedrock", + DisplayName: sql.NullString{String: "bedrock", Valid: true}, + Settings: bedrockSettings, + }, + wantType: codersdk.AIProviderTypeBedrock, + wantDisplay: codersdk.AIProviderDisplayNameBedrock, + }, + { + name: "other types ignore bedrock settings", + row: database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "openai", + Settings: bedrockSettings, + }, + wantType: codersdk.AIProviderTypeOpenAI, + wantDisplay: "openai", + }, + { + name: "custom bedrock display name is preserved", + row: database.AIProvider{ + Type: database.AiProviderTypeBedrock, + Name: "bedrock", + DisplayName: sql.NullString{String: "Claude in AWS", Valid: true}, + Settings: bedrockSettings, + }, + wantType: codersdk.AIProviderTypeBedrock, + wantDisplay: "Claude in AWS", + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + gotType, err := db2sdk.CanonicalAIProviderType(tt.row) + require.NoError(t, err) + require.Equal(t, database.AIProviderType(tt.wantType), gotType) + require.Equal(t, tt.wantDisplay, db2sdk.AIProviderDisplayName(tt.row, tt.wantType)) + + got, err := db2sdk.AIProvider(tt.row, nil) + require.NoError(t, err) + require.Equal(t, tt.wantType, got.Type) + require.Equal(t, tt.wantDisplay, got.DisplayName) + }) + } +} + +func marshalAIProviderSettings(t *testing.T, settings codersdk.AIProviderSettings) sql.NullString { + t.Helper() + + raw, err := json.Marshal(settings) + require.NoError(t, err) + return sql.NullString{String: string(raw), Valid: true} +} + func TestProvisionerJobStatus(t *testing.T) { t.Parallel() diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 1d7f7a62c3a74..f7be57cd569a3 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -6543,6 +6543,7 @@ func (s *MethodTestSuite) TestAIBridge() { provider := testutil.Fake(s.T(), faker, database.AIProvider{}) arg := database.UpdateAIProviderParams{ ID: provider.ID, + Type: provider.Type, Enabled: true, BaseUrl: "https://api.example.com/", } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 6b7b39c02178f..6b02662cf9bdf 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -738,19 +738,21 @@ const updateAIProvider = `-- name: UpdateAIProvider :one UPDATE ai_providers SET - display_name = $1::text, - enabled = $2::boolean, - base_url = $3::text, - settings = $4::text, - settings_key_id = $5::text, + type = $1::ai_provider_type, + display_name = $2::text, + enabled = $3::boolean, + base_url = $4::text, + settings = $5::text, + settings_key_id = $6::text, updated_at = NOW() WHERE - id = $6::uuid AND deleted = FALSE + id = $7::uuid AND deleted = FALSE RETURNING id, type, name, display_name, enabled, deleted, base_url, settings, settings_key_id, created_at, updated_at ` type UpdateAIProviderParams struct { + Type AIProviderType `db:"type" json:"type"` DisplayName sql.NullString `db:"display_name" json:"display_name"` Enabled bool `db:"enabled" json:"enabled"` BaseUrl string `db:"base_url" json:"base_url"` @@ -761,6 +763,7 @@ type UpdateAIProviderParams struct { func (q *sqlQuerier) UpdateAIProvider(ctx context.Context, arg UpdateAIProviderParams) (AIProvider, error) { row := q.db.QueryRowContext(ctx, updateAIProvider, + arg.Type, arg.DisplayName, arg.Enabled, arg.BaseUrl, diff --git a/coderd/database/queries/ai_providers.sql b/coderd/database/queries/ai_providers.sql index 1c9a977e4c479..f7b4d5ec97768 100644 --- a/coderd/database/queries/ai_providers.sql +++ b/coderd/database/queries/ai_providers.sql @@ -66,6 +66,7 @@ RETURNING UPDATE ai_providers SET + type = @type::ai_provider_type, display_name = sqlc.narg('display_name')::text, enabled = @enabled::boolean, base_url = @base_url::text, diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 7b15178142ede..4b87484fee6bc 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -827,11 +827,15 @@ func (api *API) getUserChatProviderAvailability( enabledProviderIDs: make(map[uuid.UUID]struct{}, len(enabledProviders)), providerStatusByID: make(map[uuid.UUID]chatprovider.ProviderAvailability, len(enabledProviders)), } + providerNameByID := make(map[uuid.UUID]string, len(configuredProviders)) for _, configuredProvider := range configuredProviders { normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider) if normalizedProvider != "" { availability.enabledProviderNames[normalizedProvider] = struct{}{} } + if configuredProvider.ProviderID != uuid.Nil && normalizedProvider != "" { + providerNameByID[configuredProvider.ProviderID] = normalizedProvider + } if configuredProvider.ProviderID != uuid.Nil { availability.enabledProviderIDs[configuredProvider.ProviderID] = struct{}{} } @@ -884,9 +888,19 @@ func (api *API) getUserChatProviderAvailability( mergeProviderStatus(providerStatusByType, normalizedProvider, status) } + modelProviderName := func(model database.ChatModelConfig) string { + if model.AIProviderID.Valid { + if provider, ok := providerNameByID[model.AIProviderID.UUID]; ok { + return provider + } + } + return model.Provider + } + modelStatusByType := make(map[string]chatprovider.ProviderAvailability, len(enabledModels)) for _, model := range enabledModels { - normalizedProvider := chatprovider.NormalizeProvider(model.Provider) + modelProvider := modelProviderName(model) + normalizedProvider := chatprovider.NormalizeProvider(modelProvider) if normalizedProvider == "" { continue } @@ -907,7 +921,8 @@ func (api *API) getUserChatProviderAvailability( } for _, model := range enabledModels { - normalizedProvider := chatprovider.NormalizeProvider(model.Provider) + modelProvider := modelProviderName(model) + normalizedProvider := chatprovider.NormalizeProvider(modelProvider) if model.AIProviderID.Valid { status, ok := availability.providerStatusByID[model.AIProviderID.UUID] if !ok { @@ -918,7 +933,7 @@ func (api *API) getUserChatProviderAvailability( } } availability.configuredModels = append(availability.configuredModels, chatprovider.ConfiguredModel{ - Provider: model.Provider, + Provider: modelProvider, Model: model.Model, DisplayName: model.DisplayName, }) @@ -6516,19 +6531,20 @@ func parseUserAIProviderID(r *http.Request) (uuid.UUID, error) { return uuid.Parse(chi.URLParam(r, "aiProvider")) } -func convertAIProviderSummary(provider database.AIProvider) codersdk.AIProviderSummary { - displayName := provider.Name - if provider.DisplayName.Valid && provider.DisplayName.String != "" { - displayName = provider.DisplayName.String +func convertAIProviderSummary(provider database.AIProvider) (codersdk.AIProviderSummary, error) { + providerType, err := canonicalAIProviderTypeForRow(provider) + if err != nil { + return codersdk.AIProviderSummary{}, err } + sdkProviderType := codersdk.AIProviderType(providerType) return codersdk.AIProviderSummary{ ID: provider.ID, - Type: codersdk.AIProviderType(provider.Type), + Type: sdkProviderType, Name: provider.Name, - DisplayName: displayName, + DisplayName: db2sdk.AIProviderDisplayName(provider, sdkProviderType), Enabled: provider.Enabled, Deleted: provider.Deleted, - } + }, nil } func (api *API) listUserAIProviderKeyConfigs(rw http.ResponseWriter, r *http.Request) { @@ -6581,10 +6597,16 @@ func (api *API) listUserAIProviderKeyConfigs(rw http.ResponseWriter, r *http.Req byokEnabled := api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() configs := make([]codersdk.UserAIProviderKeyConfig, 0, len(visibleProviders)) for _, provider := range visibleProviders { + providerSummary, err := convertAIProviderSummary(provider) + if err != nil { + api.Logger.Error(ctx, "failed to convert AI provider summary", slog.Error(err), slog.F("provider_id", provider.ID)) + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to convert AI provider."}) + return + } _, hasUserKey := keysByProviderID[provider.ID] _, hasProviderKey := providerKeysByProviderID[provider.ID] configs = append(configs, codersdk.UserAIProviderKeyConfig{ - Provider: convertAIProviderSummary(provider), + Provider: providerSummary, HasUserAPIKey: hasUserKey, HasProviderAPIKey: hasProviderKey, BYOKEnabled: byokEnabled, @@ -6661,8 +6683,14 @@ func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update user AI provider key."}) return } + providerSummary, err := convertAIProviderSummary(provider) + if err != nil { + api.Logger.Error(ctx, "failed to convert AI provider summary", slog.Error(err), slog.F("provider_id", provider.ID)) + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to convert AI provider."}) + return + } httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserAIProviderKeyConfig{ - Provider: convertAIProviderSummary(provider), + Provider: providerSummary, HasUserAPIKey: true, HasProviderAPIKey: len(providerKeys) > 0, BYOKEnabled: true, @@ -6703,12 +6731,21 @@ func (api *API) configuredProvidersFromAIProviders(ctx context.Context, provider } configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers)) for _, provider := range providers { - configuredProviders = append(configuredProviders, api.configuredProviderFromAIProviderKeys(provider, keysByProviderID[provider.ID])) + configuredProvider, err := api.configuredProviderFromAIProviderKeys(ctx, provider, keysByProviderID[provider.ID]) + if err != nil { + return nil, err + } + configuredProviders = append(configuredProviders, configuredProvider) } return configuredProviders, nil } -func (api *API) configuredProviderFromAIProviderKeys(provider database.AIProvider, keys []database.AIProviderKey) chatprovider.ConfiguredProvider { +func (api *API) configuredProviderFromAIProviderKeys(ctx context.Context, provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) { + providerType, err := canonicalAIProviderTypeForRow(provider) + if err != nil { + api.Logger.Error(ctx, "failed to decode AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + return chatprovider.ConfiguredProvider{}, err + } apiKey := "" for _, key := range keys { if key.APIKey != "" { @@ -6718,13 +6755,13 @@ func (api *API) configuredProviderFromAIProviderKeys(provider database.AIProvide } return chatprovider.ConfiguredProvider{ ProviderID: provider.ID, - Provider: string(provider.Type), + Provider: string(providerType), APIKey: apiKey, BaseURL: provider.BaseUrl, CentralAPIKeyEnabled: true, AllowUserAPIKey: api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(), AllowCentralAPIKeyFallback: true, - } + }, nil } func writeLegacyChatProviderGone(rw http.ResponseWriter, r *http.Request) { @@ -6786,9 +6823,35 @@ func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) { return } + providerNameByID := map[uuid.UUID]string{} + //nolint:gocritic // All authenticated users need canonical provider metadata to render model configs. + providers, err := api.Database.GetAIProviders(dbauthz.AsChatd(ctx), database.GetAIProvidersParams{ + IncludeDeleted: true, + IncludeDisabled: true, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to list AI providers.", + Detail: err.Error(), + }) + return + } + for _, provider := range providers { + providerType, err := canonicalAIProviderTypeForRow(provider) + if err != nil { + api.Logger.Error(ctx, "failed to decode AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to decode AI provider settings.", + Detail: err.Error(), + }) + return + } + providerNameByID[provider.ID] = string(providerType) + } + resp := make([]codersdk.ChatModelConfig, 0, len(configs)) for _, config := range configs { - resp = append(resp, convertChatModelConfig(config)) + resp = append(resp, convertChatModelConfig(config, providerNameByID)) } httpapi.Write(ctx, rw, http.StatusOK, resp) @@ -6848,7 +6911,15 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."}) return } - provider := string(aiProvider.Type) + providerType, err := canonicalAIProviderTypeForRow(aiProvider) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to decode AI provider settings.", + Detail: err.Error(), + }) + return + } + provider := string(providerType) aiProviderID := uuid.NullUUID{UUID: aiProvider.ID, Valid: true} model := strings.TrimSpace(req.Model) @@ -6930,7 +7001,11 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { if !lockedAIProvider.Enabled { return errChatProviderNotConfigured } - insertParams.Provider = string(lockedAIProvider.Type) + lockedProviderType, err := canonicalAIProviderTypeForRow(lockedAIProvider) + if err != nil { + return xerrors.Errorf("canonicalize provider type for %q: %w", lockedAIProvider.Name, err) + } + insertParams.Provider = string(lockedProviderType) if err := validateChatModelConfigProviderModel(lockedAIProvider, insertParams.Model); err != nil { return err } @@ -7001,7 +7076,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventModelConfig, inserted.ID) - httpapi.Write(ctx, rw, http.StatusCreated, convertChatModelConfig(inserted)) + httpapi.Write(ctx, rw, http.StatusCreated, convertChatModelConfig(inserted, nil)) } func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { @@ -7035,20 +7110,43 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { return } + provider := existing.Provider + aiProviderID := existing.AIProviderID + if req.AIProviderID == nil && aiProviderID.Valid { + //nolint:gocritic // The route already authorized chat model config updates. + aiProvider, err := api.Database.GetAIProviderByID(dbauthz.AsChatd(ctx), aiProviderID.UUID) + if err != nil { + if !xerrors.Is(err, sql.ErrNoRows) { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get AI provider.", + Detail: err.Error(), + }) + return + } + } else { + providerType, err := canonicalAIProviderTypeForRow(aiProvider) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to decode AI provider settings.", + Detail: err.Error(), + }) + return + } + provider = string(providerType) + } + } + if strings.TrimSpace(req.Provider) != "" && req.AIProviderID == nil { requestedProvider := chatprovider.NormalizeProvider(req.Provider) if requestedProvider == "" { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid provider."}) return } - if requestedProvider != existing.Provider { + if requestedProvider != provider { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider ID is required when updating provider."}) return } } - - provider := existing.Provider - aiProviderID := existing.AIProviderID if req.AIProviderID != nil { //nolint:gocritic // The route already authorized chat model config updates. aiProvider, err := api.Database.GetAIProviderByID(dbauthz.AsChatd(ctx), *req.AIProviderID) @@ -7067,7 +7165,15 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."}) return } - provider = string(aiProvider.Type) + providerType, err := canonicalAIProviderTypeForRow(aiProvider) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to decode AI provider settings.", + Detail: err.Error(), + }) + return + } + provider = string(providerType) aiProviderID = uuid.NullUUID{UUID: aiProvider.ID, Valid: true} } @@ -7156,7 +7262,11 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { if !aiProvider.Enabled { return errChatProviderNotConfigured } - updateParams.Provider = string(aiProvider.Type) + providerType, err := canonicalAIProviderTypeForRow(aiProvider) + if err != nil { + return xerrors.Errorf("canonicalize provider type for %q: %w", aiProvider.Name, err) + } + updateParams.Provider = string(providerType) if err := validateChatModelConfigProviderModel(aiProvider, updateParams.Model); err != nil { return err } @@ -7233,7 +7343,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventModelConfig, updated.ID) - httpapi.Write(ctx, rw, http.StatusOK, convertChatModelConfig(updated)) + httpapi.Write(ctx, rw, http.StatusOK, convertChatModelConfig(updated, nil)) } func (api *API) deleteChatModelConfig(rw http.ResponseWriter, r *http.Request) { @@ -7405,14 +7515,20 @@ func parseChatModelConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, return modelConfigID, true } -func convertChatModelConfig(config database.ChatModelConfig) codersdk.ChatModelConfig { +func convertChatModelConfig(config database.ChatModelConfig, providerNameByID map[uuid.UUID]string) codersdk.ChatModelConfig { var aiProviderID *uuid.UUID + provider := config.Provider if config.AIProviderID.Valid { aiProviderID = &config.AIProviderID.UUID + if providerNameByID != nil { + if providerName, ok := providerNameByID[config.AIProviderID.UUID]; ok { + provider = providerName + } + } } return codersdk.ChatModelConfig{ ID: config.ID, - Provider: config.Provider, + Provider: provider, AIProviderID: aiProviderID, Model: config.Model, DisplayName: config.DisplayName, diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index c55d58c269eea..1cdaf5a881048 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -1623,6 +1623,122 @@ func TestListChatModels(t *testing.T) { require.True(t, openAIProvider.Available) }) + t.Run("BedrockOnlyProviderRemainsVisibleWithoutAnthropic", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + bedrockProvider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeBedrock, + Name: "test-bedrock-" + uuid.NewString(), + Enabled: true, + BaseURL: "https://bedrock-runtime.us-east-1.amazonaws.com/", + Settings: codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }, + }) + require.NoError(t, err) + + contextLimit := int64(4096) + bedrockModel, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "anthropic", + AIProviderID: &bedrockProvider.ID, + Model: "anthropic.claude-sonnet-4-20250514-v1:0", + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + require.Equal(t, "bedrock", bedrockModel.Provider) + + findProvider := func(models codersdk.ChatModelsResponse, provider string) *codersdk.ChatModelProvider { + for i := range models.Providers { + if models.Providers[i].Provider == provider { + return &models.Providers[i] + } + } + return nil + } + + models, err := client.ListChatModels(ctx) + require.NoError(t, err) + bedrock := findProvider(models, "bedrock") + require.NotNil(t, bedrock) + require.True(t, bedrock.Available) + require.Nil(t, findProvider(models, "anthropic")) + + anthropicProvider := createAIProviderForTest(t, client, "anthropic", "test-api-key") + anthropicModel, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "anthropic", + AIProviderID: &anthropicProvider.ID, + Model: "claude-sonnet-4-20250514", + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + require.Equal(t, "anthropic", anthropicModel.Provider) + + models, err = client.ListChatModels(ctx) + require.NoError(t, err) + require.NotNil(t, findProvider(models, "bedrock")) + require.NotNil(t, findProvider(models, "anthropic")) + + enabled := false + _, err = client.UpdateAIProvider(ctx, anthropicProvider.ID.String(), codersdk.UpdateAIProviderRequest{ + Enabled: &enabled, + }) + require.NoError(t, err) + + models, err = client.ListChatModels(ctx) + require.NoError(t, err) + bedrock = findProvider(models, "bedrock") + require.NotNil(t, bedrock) + require.True(t, bedrock.Available) + require.Nil(t, findProvider(models, "anthropic")) + }) + + t.Run("LegacyBedrockModelConfigUsesCanonicalProvider", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + + bedrockProvider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeAnthropic, + Name: "legacy-bedrock-list-models-" + uuid.NewString(), + Settings: sql.NullString{ + String: `{"_type":"bedrock","_version":1,"region":"us-east-1"}`, + Valid: true, + }, + }) + storedConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "anthropic", + AIProviderID: uuid.NullUUID{UUID: bedrockProvider.ID, Valid: true}, + Model: "anthropic.claude-3-5-sonnet", + DisplayName: "Claude via Bedrock", + CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, + ContextLimit: 4096, + CompressionThreshold: 80, + }) + + models, err := client.ListChatModels(ctx) + require.NoError(t, err) + + var provider *codersdk.ChatModelProvider + for i := range models.Providers { + if models.Providers[i].Provider == "bedrock" { + provider = &models.Providers[i] + } + require.NotEqual(t, "anthropic", models.Providers[i].Provider) + } + require.NotNil(t, provider) + require.True(t, provider.Available) + require.True(t, slices.ContainsFunc(provider.Models, func(model codersdk.ChatModel) bool { + return model.Provider == "bedrock" && model.Model == storedConfig.Model + })) + }) + t.Run("UserOnlyProviderRequiresUserKey", func(t *testing.T) { t.Parallel() @@ -2133,6 +2249,34 @@ func TestUserAIProviderKeys(t *testing.T) { require.False(t, cfg.HasUserAPIKey) }) + t.Run("LegacyBedrockProviderSummaryUsesBedrockDisplay", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeAnthropic, + Name: "legacy-bedrock-summary-" + uuid.NewString(), + DisplayName: sql.NullString{String: "anthropic", Valid: true}, + Enabled: true, + Settings: sql.NullString{ + String: `{"_type":"bedrock","_version":1,"region":"us-east-1"}`, + Valid: true, + }, + }) + + configs, err := memberClient.ListUserAIProviderKeyConfigs(ctx, "me") + require.NoError(t, err) + cfg := findUserAIProviderKeyConfig(t, configs, provider.ID) + require.NotNil(t, cfg) + require.Equal(t, codersdk.AIProviderTypeBedrock, cfg.Provider.Type) + require.Equal(t, codersdk.AIProviderDisplayNameBedrock, cfg.Provider.DisplayName) + }) + t.Run("ListsDisabledProviderWithSavedUserKey", func(t *testing.T) { t.Parallel() @@ -3458,6 +3602,43 @@ func TestListChatModelConfigs(t *testing.T) { require.True(t, configs[0].Enabled) }) + t.Run("CanonicalizesLegacyBedrockProvider", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeAnthropic, + Name: "legacy-bedrock-provider", + Settings: sql.NullString{ + String: `{"_type":"bedrock","_version":1,"region":"us-east-1"}`, + Valid: true, + }, + }) + storedConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "anthropic", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + Model: "anthropic.claude-3-5-sonnet", + DisplayName: "Claude via Bedrock", + CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, + ContextLimit: 4096, + CompressionThreshold: 80, + }) + + configs, err := client.ListChatModelConfigs(ctx) + require.NoError(t, err) + for _, config := range configs { + if config.ID == storedConfig.ID { + require.Equal(t, "bedrock", config.Provider) + return + } + } + require.Fail(t, "expected legacy Bedrock model config") + }) + t.Run("DeserializesLegacyPricingJSON", func(t *testing.T) { t.Parallel() @@ -3924,6 +4105,45 @@ func TestUpdateChatModelConfig(t *testing.T) { require.Contains(t, sdkErr.Detail, "Change the AI provider type to openrouter or openai-compat.") }) + t.Run("CanonicalizesLegacyBedrockProviderWithoutAIProviderID", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeAnthropic, + Name: "legacy-bedrock-update-provider", + Settings: sql.NullString{ + String: `{"_type":"bedrock","_version":1,"region":"us-east-1"}`, + Valid: true, + }, + }) + storedConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "anthropic", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + Model: "anthropic.claude-3-5-sonnet", + DisplayName: "Claude via Bedrock", + CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, + ContextLimit: 4096, + CompressionThreshold: 80, + }) + + updated, err := client.UpdateChatModelConfig(ctx, storedConfig.ID, codersdk.UpdateChatModelConfigRequest{ + Provider: "bedrock", + DisplayName: "Claude via Bedrock Updated", + }) + require.NoError(t, err) + require.Equal(t, "bedrock", updated.Provider) + require.Equal(t, "Claude via Bedrock Updated", updated.DisplayName) + + persisted, err := db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), storedConfig.ID) + require.NoError(t, err) + require.Equal(t, "bedrock", persisted.Provider) + }) + t.Run("DisablePreservesRecordAndHidesItFromNonAdmins", func(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/ai_provider_canonical.go b/coderd/x/chatd/ai_provider_canonical.go new file mode 100644 index 0000000000000..fdf8b991649b1 --- /dev/null +++ b/coderd/x/chatd/ai_provider_canonical.go @@ -0,0 +1,55 @@ +package chatd + +import ( + "context" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" +) + +func canonicalAIProviderType(provider database.AIProvider) (database.AIProviderType, error) { + return db2sdk.CanonicalAIProviderType(provider) +} + +func canonicalAIProviderTypeString(provider database.AIProvider) (string, error) { + providerType, err := canonicalAIProviderType(provider) + if err != nil { + return "", err + } + return string(providerType), nil +} + +func bestEffortCanonicalAIProviderType(ctx context.Context, logger slog.Logger, provider database.AIProvider) database.AIProviderType { + providerType, err := canonicalAIProviderType(provider) + if err != nil { + logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + return provider.Type + } + return providerType +} + +func bestEffortCanonicalAIProviderTypeString(ctx context.Context, logger slog.Logger, provider database.AIProvider) string { + return string(bestEffortCanonicalAIProviderType(ctx, logger, provider)) +} + +func aiProviderTypeCanSatisfyRequest(candidateProviderType string, requestedProviderType string) bool { + if candidateProviderType == requestedProviderType { + return true + } + return requestedProviderType == string(database.AiProviderTypeAnthropic) && + candidateProviderType == string(database.AiProviderTypeBedrock) +} + +func aiProviderMatchesCanonicalType(provider database.AIProvider, normalizedProviderType string) (bool, error) { + providerType, err := canonicalAIProviderTypeString(provider) + if err != nil { + return false, err + } + return aiProviderTypeCanSatisfyRequest(chatprovider.NormalizeProvider(providerType), normalizedProviderType), nil +} + +func aiProviderMatchesRawType(provider database.AIProvider, normalizedProviderType string) bool { + return aiProviderTypeCanSatisfyRequest(chatprovider.NormalizeProvider(string(provider.Type)), normalizedProviderType) +} diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 5a5ba7fb60a95..fd78e0338893e 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8642,10 +8642,10 @@ func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvi if err != nil { return chatprovider.ConfiguredProvider{}, xerrors.Errorf("get AI provider keys: %w", err) } - return p.aiProviderConfigFromKeys(provider, keys) + return p.aiProviderConfigFromKeys(ctx, provider, keys) } -func (p *Server) aiProviderConfigFromKeys(provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) { +func (p *Server) aiProviderConfigFromKeys(ctx context.Context, provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) { if !provider.Enabled { return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) } @@ -8660,7 +8660,7 @@ func (p *Server) aiProviderConfigFromKeys(provider database.AIProvider, keys []d } return chatprovider.ConfiguredProvider{ ProviderID: provider.ID, - Provider: string(provider.Type), + Provider: bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), APIKey: apiKey, BaseURL: provider.BaseUrl, CentralAPIKeyEnabled: true, @@ -8687,7 +8687,7 @@ func (p *Server) aiProviderConfigs(ctx context.Context, providers []database.AIP } configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers)) for _, provider := range providers { - configuredProvider, err := p.aiProviderConfigFromKeys(provider, keysByProviderID[provider.ID]) + configuredProvider, err := p.aiProviderConfigFromKeys(ctx, provider, keysByProviderID[provider.ID]) if err != nil { return nil, err } @@ -8763,17 +8763,40 @@ func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( return chatprovider.ProviderAPIKeys{}, nil, xerrors.Errorf("get enabled AI providers: %w", err) } normalizedProviderType := chatprovider.NormalizeProvider(providerType) - for _, provider := range providers { - if chatprovider.NormalizeProvider(string(provider.Type)) != normalizedProviderType { - continue - } + keysForProvider := func(provider database.AIProvider, providerKeysType string) (chatprovider.ProviderAPIKeys, *database.AIProvider, error) { keys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider) if err != nil { return chatprovider.ProviderAPIKeys{}, nil, err } - if userCanUseProviderKeys(keys, normalizedProviderType) { + if userCanUseProviderKeys(keys, providerKeysType) { return keys, &provider, nil } + return chatprovider.ProviderAPIKeys{}, nil, nil + } + for _, provider := range providers { + canonicalProviderType, err := canonicalAIProviderTypeString(provider) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + continue + } + providerKeysType := chatprovider.NormalizeProvider(canonicalProviderType) + if !aiProviderTypeCanSatisfyRequest(providerKeysType, normalizedProviderType) { + continue + } + keys, matchedProvider, err := keysForProvider(provider, providerKeysType) + if err != nil || matchedProvider != nil { + return keys, matchedProvider, err + } + } + for _, provider := range providers { + if !aiProviderMatchesRawType(provider, normalizedProviderType) { + continue + } + providerKeysType := chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider)) + keys, matchedProvider, err := keysForProvider(provider, providerKeysType) + if err != nil || matchedProvider != nil { + return keys, matchedProvider, err + } } keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) if err != nil { diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 965b6b474e9f7..0286cf3394a9c 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -202,6 +202,111 @@ func TestResolveUserProviderAPIKeysAndProviderForProviderTypeProviderMatch(t *te require.Equal(t, database.AiProviderTypeOpenai, aiProvider.Type) } +func TestResolveDirectModelRouteForProviderTypeMatchesCanonicalBedrockForAnthropicRequest(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + providerID := uuid.New() + + rawSettings, err := json.Marshal(codersdk.AIProviderSettings{ + Bedrock: &codersdk.AIProviderBedrockSettings{Region: "us-east-1"}, + }) + require.NoError(t, err) + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeAnthropic, + Enabled: true, + Settings: sql.NullString{String: string(rawSettings), Valid: true}, + } + + db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{provider}, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return(nil, nil) + + server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + route, err := server.resolveDirectModelRouteForProviderType( + ctx, + ownerID, + chattool.ComputerUseProviderAnthropic, + ) + require.NoError(t, err) + + providerHint, err := route.providerHint() + require.NoError(t, err) + require.Equal(t, "bedrock", providerHint) + require.True(t, route.directProviderKeys().HasProvider("bedrock")) +} + +func TestResolveDirectModelRouteForProviderTypeUsesBedrockProviderForAnthropicComputerUse(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + providerID := uuid.New() + + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeBedrock, + Enabled: true, + } + + db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{provider}, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return(nil, nil) + + server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + route, err := server.resolveDirectModelRouteForProviderType( + ctx, + ownerID, + chattool.ComputerUseProviderAnthropic, + ) + require.NoError(t, err) + + providerHint, err := route.providerHint() + require.NoError(t, err) + require.Equal(t, "bedrock", providerHint) + require.True(t, route.directProviderKeys().HasProvider("bedrock")) +} + +func TestResolveDirectModelRouteForProviderTypeFallsBackToRawProviderType(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + providerID := uuid.New() + + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Enabled: true, + Settings: sql.NullString{String: "{", Valid: true}, + } + + db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{provider}, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{ + ProviderID: providerID, + APIKey: "test-key", + }}, nil) + + server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + route, err := server.resolveDirectModelRouteForProviderType( + ctx, + ownerID, + chattool.ComputerUseProviderOpenAI, + ) + require.NoError(t, err) + + providerHint, err := route.providerHint() + require.NoError(t, err) + require.Equal(t, "openai", providerHint) + require.Equal(t, "test-key", route.directProviderKeys().APIKey("openai")) +} + func TestResolveModelRouteForProviderTypeAIGatewayRequiresProvider(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index 353769dd02376..7b0a66eb43371 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -6310,6 +6310,7 @@ func setOpenAIProviderBaseURL( } _, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ ID: provider.ID, + Type: provider.Type, DisplayName: provider.DisplayName, Enabled: provider.Enabled, BaseUrl: baseURL, diff --git a/coderd/x/chatd/chaterror/message.go b/coderd/x/chatd/chaterror/message.go index 3ebe6366e7f7e..d67ffb3f66bec 100644 --- a/coderd/x/chatd/chaterror/message.go +++ b/coderd/x/chatd/chaterror/message.go @@ -130,7 +130,7 @@ func providerDisplayName(provider string) string { case "azure": return "Azure OpenAI" case "bedrock": - return "AWS Bedrock" + return codersdk.AIProviderDisplayNameBedrock case "google": return "Google" case "openai": diff --git a/coderd/x/chatd/chatprovider/chatprovider.go b/coderd/x/chatd/chatprovider/chatprovider.go index ac817e094034e..4a98eb51b0798 100644 --- a/coderd/x/chatd/chatprovider/chatprovider.go +++ b/coderd/x/chatd/chatprovider/chatprovider.go @@ -44,7 +44,7 @@ var envPresetProviderNames = []string{ var providerDisplayNameByName = map[string]string{ fantasyanthropic.Name: "Anthropic", fantasyazure.Name: "Azure OpenAI", - fantasybedrock.Name: "AWS Bedrock", + fantasybedrock.Name: codersdk.AIProviderDisplayNameBedrock, fantasygoogle.Name: "Google", fantasyopenai.Name: "OpenAI", fantasyopenaicompat.Name: "OpenAI Compatible", diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index a732da1a952dc..077f072083340 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" @@ -168,7 +169,11 @@ func (p *Server) newAIGatewayModel( baseRT = &chatdebug.RecordingTransport{Base: baseRT} } - config := fantasyConfigForAIBridge(route.Provider.Type) + providerType, err := canonicalAIProviderType(route.Provider) + if err != nil { + return nil, xerrors.Errorf("canonicalize provider type for %q: %w", route.Provider.Name, err) + } + config := fantasyConfigForAIBridge(providerType) return newLanguageModel( config.ProviderHint, req.ModelName, @@ -258,11 +263,15 @@ func (p *Server) resolveAIGatewayRoute( provider database.AIProvider, modelProviderHint string, ) (resolvedModelRoute, error) { + providerType, err := canonicalAIProviderType(provider) + if err != nil { + return resolvedModelRoute{}, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) + } auth, err := p.aiGatewayProviderAuthForUser( ctx, ownerID, provider, - aiGatewayRequestFormatForProviderType(provider.Type), + aiGatewayRequestFormatForProviderType(providerType), ) if err != nil { return resolvedModelRoute{}, xerrors.Errorf("resolve AI Gateway provider auth: %w", err) @@ -279,7 +288,11 @@ func (p *Server) resolveAIGatewayModelRouteForConfig( if err != nil { return resolvedModelRoute{}, err } - return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(provider.Type)) + providerType, err := canonicalAIProviderType(provider) + if err != nil { + return resolvedModelRoute{}, xerrors.Errorf("canonicalize provider type for %q: %w", provider.Name, err) + } + return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(providerType)) } func (p *Server) resolveAIGatewayModelRouteForProviderType( @@ -295,7 +308,7 @@ func (p *Server) resolveAIGatewayModelRouteForProviderType( ctx, ownerID, provider, - chatprovider.NormalizeProvider(providerType), + bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), ) } @@ -326,7 +339,18 @@ func (p *Server) aiProviderForProviderType( if !provider.Enabled { continue } - if chatprovider.NormalizeProvider(string(provider.Type)) != normalizedProviderType { + matches, err := aiProviderMatchesCanonicalType(provider, normalizedProviderType) + if err != nil { + p.logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + continue + } + if !matches { + continue + } + return provider, nil + } + for _, provider := range providers { + if !provider.Enabled || !aiProviderMatchesRawType(provider, normalizedProviderType) { continue } return provider, nil diff --git a/coderd/x/chatd/model_routing_direct.go b/coderd/x/chatd/model_routing_direct.go index 8173aa75c92ba..52792d3ef5bc8 100644 --- a/coderd/x/chatd/model_routing_direct.go +++ b/coderd/x/chatd/model_routing_direct.go @@ -71,11 +71,15 @@ func (p *Server) resolveDirectModelRouteForProviderType( providerType string, ) (resolvedModelRoute, error) { normalizedProviderType := chatprovider.NormalizeProvider(providerType) - keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType) + keys, provider, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType) if err != nil { return resolvedModelRoute{}, err } - return newDirectModelRoute(normalizedProviderType, keys), nil + providerHint := normalizedProviderType + if provider != nil { + providerHint = chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, *provider)) + } + return newDirectModelRoute(providerHint, keys), nil } func (p *Server) directProviderHintAndProviderForConfig( @@ -89,5 +93,5 @@ func (p *Server) directProviderHintAndProviderForConfig( if err != nil { return "", nil, err } - return string(provider.Type), &provider, nil + return bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider), &provider, nil } diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 450397416b788..5c9c54ab9ccc4 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -155,12 +155,24 @@ func validateModelConfigAndResolveProvider( } func enabledProviderContainsName( + ctx context.Context, + logger slog.Logger, providers []database.AIProvider, providerName string, ) bool { normalizedProviderName := chatprovider.NormalizeProvider(providerName) for _, provider := range providers { - if chatprovider.NormalizeProvider(string(provider.Type)) == normalizedProviderName { + matches, err := aiProviderMatchesCanonicalType(provider, normalizedProviderName) + if err != nil { + logger.Warn(ctx, "parse AI provider settings", slog.F("provider_id", provider.ID), slog.Error(err)) + continue + } + if matches { + return true + } + } + for _, provider := range providers { + if aiProviderMatchesRawType(provider, normalizedProviderName) { return true } } @@ -506,7 +518,7 @@ func (p *Server) resolveModelConfigAndNormalizedProvider( if !provider.Enabled { return database.ChatModelConfig{}, "", sql.ErrNoRows } - providerName := chatprovider.NormalizeProvider(string(provider.Type)) + providerName := chatprovider.NormalizeProvider(bestEffortCanonicalAIProviderTypeString(ctx, p.logger, provider)) if providerName == "" { return database.ChatModelConfig{}, "", errInvalidModelOverrideMetadata } @@ -523,7 +535,7 @@ func (p *Server) resolveModelConfigAndNormalizedProvider( if err != nil { return database.ChatModelConfig{}, "", err } - if !enabledProviderContainsName(enabledProviders, providerName) { + if !enabledProviderContainsName(ctx, p.logger, enabledProviders, providerName) { return database.ChatModelConfig{}, "", sql.ErrNoRows } return modelConfig, providerName, nil diff --git a/codersdk/aiproviders.go b/codersdk/aiproviders.go index 7b513340bca62..fafe625307718 100644 --- a/codersdk/aiproviders.go +++ b/codersdk/aiproviders.go @@ -47,6 +47,9 @@ const ( AIProviderTypeCopilot AIProviderType = "copilot" ) +// AIProviderDisplayNameBedrock is the default display name for AWS Bedrock providers. +const AIProviderDisplayNameBedrock = "AWS Bedrock" + // AIProviderSettings is the discriminated container for type-specific // provider settings stored in ai_providers.settings. Providers that // need no type-specific configuration (current OpenAI and standard @@ -60,8 +63,9 @@ const ( // concrete settings struct directly. type AIProviderSettings struct { // Bedrock, when set, indicates this provider authenticates against - // AWS Bedrock instead of api.anthropic.com. Only meaningful for - // AIProviderTypeAnthropic. + // AWS Bedrock. Only meaningful for AIProviderTypeBedrock. Legacy rows + // may carry this field with AIProviderTypeAnthropic; callers should use + // CanonicalAIProviderType to normalize before checking the type. Bedrock *AIProviderBedrockSettings `json:"-"` } @@ -86,7 +90,7 @@ func (s AIProviderSettings) MarshalJSON() ([]byte, error) { func (s *AIProviderSettings) UnmarshalJSON(data []byte) error { *s = AIProviderSettings{} trimmed := bytes.TrimSpace(data) - if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) { + if len(trimmed) == 0 || bytes.Equal(trimmed, []byte("null")) || bytes.Equal(trimmed, []byte("{}")) { return nil } var header aiProviderSettingsHeader @@ -158,6 +162,26 @@ func marshalSettings(s settingsTyped) ([]byte, error) { return json.Marshal(m) } +func IsBedrockProviderConfigured(baseURL string, settings *AIProviderBedrockSettings) bool { + var bedrock AIProviderBedrockSettings + if settings != nil { + bedrock = *settings + } + return IsBedrockConfigured(baseURL, bedrock) +} + +// CanonicalAIProviderType returns the runtime provider type for a row +// or request payload. Bedrock has a dedicated provider type, but older +// clients may still send Bedrock settings with type=anthropic. Treat +// that shape as Bedrock at API boundaries while new writes use the +// dedicated type. +func CanonicalAIProviderType(providerType AIProviderType, settings AIProviderSettings) AIProviderType { + if providerType == AIProviderTypeAnthropic && settings.Bedrock != nil { + return AIProviderTypeBedrock + } + return providerType +} + // AIProvider represents an AI provider configuration row as returned // by the API. Each APIKey entry carries the row's ID so callers can // reference it in an UpdateAIProviderRequest; the plaintext value is @@ -226,18 +250,16 @@ func (req CreateAIProviderRequest) Validate() []ValidationError { validations = append(validations, validateAIProviderName(req.Name)...) validations = append(validations, validateRequiredAIProviderBaseURL(req.BaseURL)...) validations = append(validations, validateAIProviderAPIKeys(req.APIKeys)...) - if req.Settings.Bedrock != nil && - req.Type != AIProviderTypeAnthropic && - req.Type != AIProviderTypeBedrock { + if req.Settings.Bedrock != nil && req.Type != AIProviderTypeBedrock { validations = append(validations, ValidationError{ Field: "settings", - Detail: "bedrock settings are only valid for type=anthropic or type=bedrock", + Detail: "bedrock settings are only valid for type=bedrock", }) } - if req.Type == AIProviderTypeBedrock && (req.Settings.Bedrock == nil || !req.Settings.Bedrock.IsConfigured()) { + if req.Type == AIProviderTypeBedrock && !IsBedrockProviderConfigured(req.BaseURL, req.Settings.Bedrock) { validations = append(validations, ValidationError{ Field: "settings", - Detail: "type=bedrock requires bedrock settings", + Detail: "type=bedrock requires bedrock settings or base_url", }) } if req.Type == AIProviderTypeBedrock && len(req.APIKeys) > 0 { @@ -270,6 +292,31 @@ type UpdateAIProviderRequest struct { Settings *AIProviderSettings `json:"settings,omitempty"` } +func (req UpdateAIProviderRequest) MarshalJSON() ([]byte, error) { + var settings any + if req.Settings != nil { + if req.Settings.IsZero() { + settings = struct{}{} + } else { + settings = req.Settings + } + } + type wireUpdateAIProviderRequest struct { + DisplayName *string `json:"display_name,omitempty"` + Enabled *bool `json:"enabled,omitempty"` + BaseURL *string `json:"base_url,omitempty"` + APIKeys *[]AIProviderKeyMutation `json:"api_keys,omitempty"` + Settings any `json:"settings,omitempty"` + } + return json.Marshal(wireUpdateAIProviderRequest{ + DisplayName: req.DisplayName, + Enabled: req.Enabled, + BaseURL: req.BaseURL, + APIKeys: req.APIKeys, + Settings: settings, + }) +} + // AIProviderKeyMutation describes the intended state of a single key // in an UpdateAIProviderRequest. Exactly one of ID or APIKey must be // set: diff --git a/codersdk/aiproviders_test.go b/codersdk/aiproviders_test.go index 97baad6535dda..8ad4a7b571b08 100644 --- a/codersdk/aiproviders_test.go +++ b/codersdk/aiproviders_test.go @@ -80,6 +80,13 @@ func TestAIProviderSettings_Unmarshal(t *testing.T) { require.True(t, s.IsZero()) }) + t.Run("EmptyObjectZeroes", func(t *testing.T) { + t.Parallel() + var s codersdk.AIProviderSettings + require.NoError(t, json.Unmarshal([]byte(`{}`), &s)) + require.True(t, s.IsZero()) + }) + t.Run("BedrockSupportedVersion", func(t *testing.T) { t.Parallel() var s codersdk.AIProviderSettings diff --git a/codersdk/deployment.go b/codersdk/deployment.go index 68d6ae0f7cbc0..fb5f81b674a51 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -4710,7 +4710,7 @@ type AIProviderConfig struct { // BaseURL is the base URL of the upstream provider API. BaseURL string `json:"base_url"` - // Bedrock fields (only applicable when Type == "anthropic"). + // Bedrock fields apply when Type is "anthropic" or "bedrock". BedrockBaseURL string `json:"-"` BedrockRegion string `json:"bedrock_region,omitempty"` // BedrockAccessKeys and BedrockAccessKeySecrets hold one or diff --git a/enterprise/coderd/x/chatd/chatd_test.go b/enterprise/coderd/x/chatd/chatd_test.go index 9587c7e6b3e20..ec43e1092845c 100644 --- a/enterprise/coderd/x/chatd/chatd_test.go +++ b/enterprise/coderd/x/chatd/chatd_test.go @@ -200,6 +200,7 @@ func setOpenAIProviderBaseURL( } _, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ ID: provider.ID, + Type: provider.Type, DisplayName: provider.DisplayName, Enabled: provider.Enabled, BaseUrl: baseURL, diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index acdb0fcbbb006..8f2c4b916a210 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1195,6 +1195,7 @@ func TestAIProviders(t *testing.T) { const newSettings = `{"_type":"bedrock","_version":1,"region":"us-east-1","model":"anthropic.claude-sonnet-4-5-20250929-v1:0","access_key":"AKIA-test","access_key_secret":"test-secret"}` updated, err := crypt.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ ID: provider.ID, + Type: provider.Type, DisplayName: provider.DisplayName, Enabled: provider.Enabled, BaseUrl: provider.BaseUrl, @@ -1211,6 +1212,7 @@ func TestAIProviders(t *testing.T) { provider := insertProvider(t, crypt, ciphers) updated, err := crypt.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ ID: provider.ID, + Type: provider.Type, DisplayName: provider.DisplayName, Enabled: provider.Enabled, BaseUrl: provider.BaseUrl, diff --git a/site/src/api/queries/chats.test.ts b/site/src/api/queries/chats.test.ts index 8b17d17e06393..7ffd73a5efb93 100644 --- a/site/src/api/queries/chats.test.ts +++ b/site/src/api/queries/chats.test.ts @@ -6,6 +6,10 @@ import { ERROR_STATUSES, SUCCESS_STATUSES, } from "#/pages/AgentsPage/components/RightPanel/DebugPanel/debugPanelUtils"; +import { + MockAIProviderBedrock, + MockAIProviderBedrockAsAnthropic, +} from "#/testHelpers/entities"; import { buildOptimisticEditedMessage } from "./chatMessageEdits"; import { addChildToParentInCache, @@ -21,6 +25,7 @@ import { chatDiffContentsKey, chatKey, chatMessagesKey, + chatProviderConfigs, chatSearch, chatsKey, createChat, @@ -52,28 +57,33 @@ import { updateInfiniteChatsCache, } from "./chats"; -vi.mock("#/api/api", () => ({ - API: { - experimental: { - updateChat: vi.fn(), - createChat: vi.fn(), - deleteChatQueuedMessage: vi.fn(), - getChats: vi.fn(), - getChatCostSummary: vi.fn(), - getChatCostUsers: vi.fn(), - createChatMessage: vi.fn(), - editChatMessage: vi.fn(), - interruptChat: vi.fn(), - promoteChatQueuedMessage: vi.fn(), - proposeChatTitle: vi.fn(), - regenerateChatTitle: vi.fn(), - getChatAdvisorConfig: vi.fn(), - updateChatAdvisorConfig: vi.fn(), - getChatACL: vi.fn(), - updateChatACL: vi.fn(), +vi.mock("#/api/api", async (importOriginal) => { + const actual = await importOriginal(); + return { + ...actual, + API: { + experimental: { + updateChat: vi.fn(), + createChat: vi.fn(), + deleteChatQueuedMessage: vi.fn(), + getChats: vi.fn(), + getChatCostSummary: vi.fn(), + getChatCostUsers: vi.fn(), + createChatMessage: vi.fn(), + editChatMessage: vi.fn(), + interruptChat: vi.fn(), + promoteChatQueuedMessage: vi.fn(), + proposeChatTitle: vi.fn(), + regenerateChatTitle: vi.fn(), + getChatAdvisorConfig: vi.fn(), + updateChatAdvisorConfig: vi.fn(), + getChatACL: vi.fn(), + updateChatACL: vi.fn(), + listAIProviders: vi.fn(), + }, }, - }, -})); + }; +}); type InfiniteChatsTestOptions = Parameters[0]; @@ -141,6 +151,22 @@ const createTestQueryClient = (): QueryClient => }, }); +describe("chatProviderConfigs", () => { + it("maps canonical and legacy Bedrock AI providers to Bedrock chat providers", async () => { + vi.mocked(API.experimental.listAIProviders).mockResolvedValue([ + MockAIProviderBedrock, + MockAIProviderBedrockAsAnthropic, + ]); + + const configs = await chatProviderConfigs().queryFn(); + + expect(configs.map((config) => config.provider)).toEqual([ + "bedrock", + "bedrock", + ]); + }); +}); + describe("advisor config query factories", () => { it("builds the advisor config query and delegates to the API", async () => { const advisorConfig: TypesGen.AdvisorConfig = { diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index 0da5ec219761f..209a888bf1c09 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -12,7 +12,10 @@ import { import type * as TypesGen from "#/api/typesGenerated"; import { type AIProviderType, AIProviderTypes } from "#/api/typesGenerated"; import type { UsePaginatedQueryOptions } from "#/hooks/usePaginatedQuery"; -import { formatProviderLabel } from "#/utils/aiProviders"; +import { + canonicalAIProviderType, + formatProviderLabel, +} from "#/utils/aiProviders"; import { projectEditedConversationIntoCache, reconcileEditedMessageInCache, @@ -1662,20 +1665,23 @@ const chatProviderConfigsKey = ["chat-provider-configs"] as const; const toChatProviderConfig = ( provider: TypesGen.AIProvider, -): TypesGen.ChatProviderConfig => ({ - id: provider.id, - provider: provider.type, - display_name: provider.display_name || provider.type, - enabled: provider.enabled, - has_api_key: provider.api_keys.length > 0, - central_api_key_enabled: true, - allow_user_api_key: true, - allow_central_api_key_fallback: true, - base_url: provider.base_url, - source: "database", - created_at: provider.created_at, - updated_at: provider.updated_at, -}); +): TypesGen.ChatProviderConfig => { + const providerType = canonicalAIProviderType(provider); + return { + id: provider.id, + provider: providerType, + display_name: provider.display_name || providerType, + enabled: provider.enabled, + has_api_key: provider.api_keys.length > 0, + central_api_key_enabled: true, + allow_user_api_key: true, + allow_central_api_key_fallback: true, + base_url: provider.base_url, + source: "database", + created_at: provider.created_at, + updated_at: provider.updated_at, + }; +}; export const chatProviderConfigs = () => ({ queryKey: chatProviderConfigsKey, @@ -1776,6 +1782,9 @@ const normalizeAIProviderType = (provider: string): AIProviderType => { export const createChatProviderConfig = (queryClient: QueryClient) => ({ mutationFn: (req: TypesGen.CreateChatProviderConfigRequest) => { const providerType = normalizeAIProviderType(req.provider); + if (providerType === "bedrock") { + throw new Error("Configure AWS Bedrock in AI settings."); + } const apiKey = req.api_key; return API.experimental.createAIProvider({ type: providerType, diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 2f3f46e6794d1..24c00843be190 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -410,6 +410,12 @@ export interface AIProviderConfig { readonly bedrock_small_fast_model?: string; } +// From codersdk/aiproviders.go +/** + * AIProviderDisplayNameBedrock is the default display name for AWS Bedrock providers. + */ +export const AIProviderDisplayNameBedrock = "AWS Bedrock"; + // From codersdk/aiproviders.go /** * AIProviderKey is a single API key registered on a provider. The diff --git a/site/src/pages/AISettingsPage/ProvidersPage/components/providerFormApiMap.test.ts b/site/src/pages/AISettingsPage/ProvidersPage/components/providerFormApiMap.test.ts index b02e1413dcb4c..4879dbbecd21c 100644 --- a/site/src/pages/AISettingsPage/ProvidersPage/components/providerFormApiMap.test.ts +++ b/site/src/pages/AISettingsPage/ProvidersPage/components/providerFormApiMap.test.ts @@ -3,6 +3,7 @@ import type { AIProvider } from "#/api/typesGenerated"; import { MockAIProviderAnthropic, MockAIProviderBedrock, + MockAIProviderBedrockAsAnthropic, MockAIProviderCopilot, MockAIProviderOpenAI, } from "#/testHelpers/entities"; @@ -148,6 +149,10 @@ describe("isBedrockProvider", () => { expect(isBedrockProvider(provider)).toBe(true); }); + it("recognises a Bedrock provider stored with the Anthropic type", () => { + expect(isBedrockProvider(MockAIProviderBedrockAsAnthropic)).toBe(true); + }); + it("rejects an OpenAI provider", () => { expect(isBedrockProvider(MockAIProviderOpenAI)).toBe(false); }); @@ -187,6 +192,12 @@ describe("getProviderDisplayType", () => { expect(getProviderDisplayType(MockAIProviderBedrock)).toBe("bedrock"); }); + it("returns bedrock for a Bedrock provider stored with the Anthropic type", () => { + expect(getProviderDisplayType(MockAIProviderBedrockAsAnthropic)).toBe( + "bedrock", + ); + }); + it("returns anthropic for a non-Bedrock Anthropic provider", () => { expect(getProviderDisplayType(MockAIProviderAnthropic)).toBe("anthropic"); }); @@ -329,9 +340,9 @@ describe("providerFormValuesToCreate", () => { }); describe("Bedrock", () => { - it('maps Bedrock to a wire `type:"anthropic"`', () => { + it('maps Bedrock to a wire `type:"bedrock"`', () => { const req = providerFormValuesToCreate(baseBedrockFormValues); - expect(req.type).toBe("anthropic"); + expect(req.type).toBe("bedrock"); }); it("derives the region from a canonical AWS URL", () => { @@ -623,13 +634,11 @@ describe("aiProviderToFormValues", () => { }); it("handles a Bedrock provider whose settings are null", () => { - // `isBedrockProvider` will return false, so the provider falls - // through to the anthropic branch. The helper must not throw. const provider: AIProvider = { ...MockAIProviderBedrock, settings: null as unknown as AIProvider["settings"], }; const values = aiProviderToFormValues(provider); - expect(values.type).toBe("anthropic"); + expect(values.type).toBe("bedrock"); }); }); diff --git a/site/src/pages/AISettingsPage/ProvidersPage/components/providerFormApiMap.ts b/site/src/pages/AISettingsPage/ProvidersPage/components/providerFormApiMap.ts index 67eec7e4d913a..f56c9b8596617 100644 --- a/site/src/pages/AISettingsPage/ProvidersPage/components/providerFormApiMap.ts +++ b/site/src/pages/AISettingsPage/ProvidersPage/components/providerFormApiMap.ts @@ -1,12 +1,17 @@ -import type { - AIProvider, - AIProviderBedrockSettings, - AIProviderKeyMutation, - AIProviderSettings, - AIProviderType, - CreateAIProviderRequest, - UpdateAIProviderRequest, +import { + type AIProvider, + type AIProviderBedrockSettings, + AIProviderBedrockSettingsVersion, + type AIProviderKeyMutation, + type AIProviderSettings, + type AIProviderType, + type CreateAIProviderRequest, + type UpdateAIProviderRequest, } from "#/api/typesGenerated"; +import { + AI_PROVIDER_SETTINGS_TYPE_BEDROCK, + isBedrockAIProvider, +} from "#/utils/aiProviders"; import { type ProviderFormValues, parseBedrockRegionFromBaseUrl, @@ -30,12 +35,9 @@ const sanitizeCredential = ( // The generated `AIProviderSettings` interface is empty (the Go side uses // a custom marshaler), so we redeclare the structural wire shape here. -const BEDROCK_SETTINGS_TYPE = "bedrock"; -const BEDROCK_SETTINGS_VERSION = 1; - type BedrockSettingsWire = AIProviderBedrockSettings & { - _type: typeof BEDROCK_SETTINGS_TYPE; - _version: typeof BEDROCK_SETTINGS_VERSION; + _type: typeof AI_PROVIDER_SETTINGS_TYPE_BEDROCK; + _version: typeof AIProviderBedrockSettingsVersion; }; type SettingsWire = AIProviderSettings & @@ -44,23 +46,15 @@ type SettingsWire = AIProviderSettings & _version?: number; }; -// Bedrock providers are identified by the settings discriminator. The -// generated type marks settings as non-null, but Go serializes zero settings -// as JSON `null`. -export const isBedrockProvider = (provider: AIProvider): boolean => { - if (provider.type !== "anthropic" && provider.type !== "bedrock") { - return false; - } - const s = provider.settings as SettingsWire | null; - return s !== null && s._type === BEDROCK_SETTINGS_TYPE; -}; +export const isBedrockProvider = isBedrockAIProvider; export const hasBedrockStoredCredentials = (provider: AIProvider): boolean => { if (!isBedrockProvider(provider)) { return false; } - // Bedrock secrets are write-only. The server only persists Bedrock - // settings if credentials were supplied, so presence implies "on file". + // Bedrock secrets are write-only, so API responses cannot distinguish + // static keys from ambient AWS credentials. Keep existing values masked + // unless the administrator enters a replacement pair. return true; }; @@ -110,8 +104,8 @@ const buildBedrockSettings = ( accessKey: string, accessKeySecret: string, ): BedrockSettingsWire => ({ - _type: BEDROCK_SETTINGS_TYPE, - _version: BEDROCK_SETTINGS_VERSION, + _type: AI_PROVIDER_SETTINGS_TYPE_BEDROCK, + _version: AIProviderBedrockSettingsVersion, ...(region ? { region } : {}), model, small_fast_model: smallFastModel, @@ -143,7 +137,7 @@ export const providerFormValuesToCreate = ( sanitizeCredential(values.accessKeySecret), ); return { - type: "anthropic", + type: "bedrock", ...base, settings: settings as AIProviderSettings, }; diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx index 7d6fce9b0ead1..fff0607620ef9 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ChatModelAdminPanel.stories.tsx @@ -719,13 +719,17 @@ export const ProviderFormBedrockAmbientCredentials: Story = { await expect(apiKeyInput).not.toBeRequired(); await expect(apiKeyInput).toHaveAttribute( "placeholder", - "Enter bearer token", + "Managed in AI settings", ); + await expect(apiKeyInput).toBeDisabled(); await expect( body.findByText( - "Bearer token for Bedrock authentication. Leave empty to use ambient AWS credentials.", + "AWS credentials for Bedrock are managed in AI settings.", ), ).resolves.toBeInTheDocument(); + await expect( + body.findByText("Configure AWS Bedrock in AI settings"), + ).resolves.toBeInTheDocument(); await expect( body.findByText( /Bedrock runtime endpoint\.\s+Use the AWS region for the models this provider should call\./i, @@ -736,24 +740,8 @@ export const ProviderFormBedrockAmbientCredentials: Story = { baseURLInput, "https://bedrock-runtime.us-east-1.amazonaws.com", ); - await waitFor(() => { - expect(createButton).toBeEnabled(); - }); - - await userEvent.click(createButton); - await waitFor(() => { - expect(args.onCreateProvider).toHaveBeenCalledTimes(1); - }); - const createProviderMock = args.onCreateProvider as ReturnType; - const createRequest = createProviderMock.mock.calls[0][0] as Record< - string, - unknown - >; - expect(createRequest).toMatchObject({ - provider: "bedrock", - base_url: "https://bedrock-runtime.us-east-1.amazonaws.com", - }); - expect(createRequest).not.toHaveProperty("api_key"); + await expect(createButton).toBeDisabled(); + await expect(args.onCreateProvider).not.toHaveBeenCalled(); }, }; @@ -785,21 +773,9 @@ export const ProviderFormBedrockBearerToken: Story = { await expect(apiKeyInput).not.toBeRequired(); await expect(apiKeyInput).toHaveValue("••••••••••••••••"); - - await userEvent.click(apiKeyInput); - await userEvent.type(apiKeyInput, "bedrock-bearer-token"); - await waitFor(() => { - expect(saveButton).toBeEnabled(); - }); - await userEvent.click(saveButton); - - await waitFor(() => { - expect(args.onUpdateProvider).toHaveBeenCalledTimes(1); - }); - expect(args.onUpdateProvider).toHaveBeenCalledWith( - "provider-bedrock-bearer", - expect.objectContaining({ api_key: "bedrock-bearer-token" }), - ); + await expect(apiKeyInput).toBeDisabled(); + await expect(saveButton).toBeDisabled(); + await expect(args.onUpdateProvider).not.toHaveBeenCalled(); }, }; @@ -827,26 +803,15 @@ export const ProviderFormBedrockClearBearerToken: Story = { ); const apiKeyInput = await body.findByLabelText(/^API Key$/i); - const clearStoredTokenButton = body.getByRole("button", { - name: /Clear stored token/i, - }); const saveButton = body.getByRole("button", { name: "Save changes" }); await expect(apiKeyInput).toHaveValue("••••••••••••••••"); - await userEvent.click(clearStoredTokenButton); - await waitFor(() => { - expect(apiKeyInput).toHaveValue(""); - expect(saveButton).toBeEnabled(); - }); - await userEvent.click(saveButton); - - await waitFor(() => { - expect(args.onUpdateProvider).toHaveBeenCalledTimes(1); - }); - expect(args.onUpdateProvider).toHaveBeenCalledWith( - "provider-bedrock-clear", - expect.objectContaining({ api_key: "" }), - ); + await expect(apiKeyInput).toBeDisabled(); + expect( + body.queryByRole("button", { name: /Clear stored token/i }), + ).not.toBeInTheDocument(); + await expect(saveButton).toBeDisabled(); + await expect(args.onUpdateProvider).not.toHaveBeenCalled(); }, }; @@ -2092,6 +2057,41 @@ export const ModelFormBedrock: Story = { }, }; +export const ModelListUsesLinkedProviderPresentation: Story = { + args: { + section: "models" as ChatModelAdminSection, + providerConfigsData: [ + createProviderConfig({ + id: "provider-bedrock-linked", + provider: "bedrock", + display_name: "AWS Bedrock", + source: "database", + has_api_key: false, + central_api_key_enabled: true, + }), + ], + modelConfigsData: [ + createModelConfig({ + id: "model-bedrock-linked", + provider: "anthropic", + ai_provider_id: "provider-bedrock-linked", + model: "anthropic.claude-3-5-sonnet", + display_name: "Claude on Bedrock", + }), + ], + }, + play: async ({ canvasElement }) => { + const body = within(canvasElement.ownerDocument.body); + await expect( + await body.findByText("Claude on Bedrock"), + ).toBeInTheDocument(); + await expect( + await body.findByAltText("AWS Bedrock logo"), + ).toBeInTheDocument(); + expect(body.queryByAltText("Anthropic logo")).not.toBeInTheDocument(); + }, +}; + export const ModelPricingWarningInList: Story = { args: { section: "models" as ChatModelAdminSection, diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.tsx index f4738bc9f334c..4350c97b404fb 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ModelsSection.tsx @@ -332,10 +332,15 @@ export const ModelsSection: FC = ({ : `Set as default model: ${modelName}`; const starUnavailable = isUpdating || modelConfig.is_default || !modelConfig.enabled; + const providerKey = modelConfigProviderKey( + modelConfig, + providerStates, + ); const providerState = providerStates.find( - (ps) => - ps.key === modelConfigProviderKey(modelConfig, providerStates), + (ps) => ps.key === providerKey, ); + const displayProvider = + providerState?.provider ?? modelConfig.provider; const duplicateUnavailable = Boolean( providerState && !canManageProviderModels(providerState), ); @@ -355,7 +360,7 @@ export const ModelsSection: FC = ({ className="flex min-w-0 flex-1 cursor-pointer items-center gap-3.5 border-0 bg-transparent p-0 text-left" >
diff --git a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ProviderForm.tsx b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ProviderForm.tsx index fe817705f7001..70c0f5d63b9ce 100644 --- a/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ProviderForm.tsx +++ b/site/src/pages/AgentsPage/components/ChatModelAdminPanel/ProviderForm.tsx @@ -94,24 +94,29 @@ export const ProviderForm: FC = ({ const hasAPIKeyWhitespace = hasTypedAPIKey && effectiveApiKey.trim() !== effectiveApiKey; // Clearing a saved provider-scoped key switches the provider to - // BYOK-only behavior, or ambient AWS credentials for Bedrock. + // BYOK-only behavior. const isClearingAPIKey = providerState.hasManagedAPIKey && apiKeyModified && effectiveApiKey === ""; - const hasPendingAPIKeyChange = hasTypedAPIKey || isClearingAPIKey; - const shouldCreateAPIKey = hasTypedAPIKey; + const hasPendingAPIKeyChange = + !isBedrockProvider && (hasTypedAPIKey || isClearingAPIKey); + const shouldCreateAPIKey = !isBedrockProvider && hasTypedAPIKey; const apiKeyDescription = isBedrockProvider - ? "Bearer token for Bedrock authentication. Leave empty to use ambient AWS credentials." + ? "AWS credentials for Bedrock are managed in AI settings." : "Secret key used to authenticate requests to this provider."; const baseURLDescription = isBedrockProvider ? "Bedrock runtime endpoint. Use the AWS region for the models this provider should call." : "Endpoint used to call this provider."; - const apiKeyPlaceholder = isBedrockProvider ? "Enter bearer token" : "sk-..."; + const apiKeyPlaceholder = isBedrockProvider + ? "Managed in AI settings" + : "sk-..."; const deleteProviderDescription = "Are you sure you want to delete this provider? The provider will be " + "disabled and hidden from new model configuration. Existing model " + "configs that reference it remain saved but cannot run until updated."; const hasNewProviderConfiguration = !providerConfig; + const requiresBedrockAISettings = isBedrockProvider && !providerConfig; + const isDirty = displayName.trim() !== initialValues.displayName || hasPendingAPIKeyChange || @@ -123,6 +128,7 @@ export const ProviderForm: FC = ({ !providerConfigsUnavailable && !isProviderMutationPending && !isAPIKeyEnvManaged && + !requiresBedrockAISettings && isDirty && hasBaseURL && !hasAPIKeyWhitespace && @@ -145,6 +151,7 @@ export const ProviderForm: FC = ({ providerConfigsUnavailable || isProviderMutationPending || isAPIKeyEnvManaged || + requiresBedrockAISettings || !hasBaseURL || hasAPIKeyWhitespace ) { @@ -255,6 +262,24 @@ export const ProviderForm: FC = ({ data-form-type="other" >
+ {requiresBedrockAISettings && ( + + Configure AWS Bedrock in AI settings + + Bedrock providers require AWS region and credential settings. + Create the provider in AI settings, then return here to add + models. + + + + )} = ({ setApiKeyTouched(true); setApiKeyModified(true); }} - disabled={isDisabled} + disabled={isDisabled || isBedrockProvider} /> {hasAPIKeyWhitespace && (

API key must not contain leading or trailing whitespace.

)} - {isBedrockProvider && - providerState.hasManagedAPIKey && - !isDisabled && - (!apiKeyModified || apiKey !== "") && ( -
- -
- )}
diff --git a/site/src/pages/AgentsPage/utils/modelOptions.test.ts b/site/src/pages/AgentsPage/utils/modelOptions.test.ts index 9a6ee41375b20..627b60f0997e9 100644 --- a/site/src/pages/AgentsPage/utils/modelOptions.test.ts +++ b/site/src/pages/AgentsPage/utils/modelOptions.test.ts @@ -378,6 +378,54 @@ describe("getModelOptionsFromConfigs", () => { expect(getModelOptionsFromConfigs(configs, catalog)).toEqual([]); }); + it("includes Bedrock configs without Anthropic availability", () => { + const configs = [ + createConfig({ + id: "config-bedrock", + provider: "bedrock", + model: "anthropic.claude-sonnet-4-20250514-v1:0", + display_name: "Claude Sonnet via Bedrock", + context_limit: 200_000, + }), + ]; + const catalog = createCatalog([ + { + provider: "bedrock", + available: true, + models: [], + }, + ]); + + expect(getModelOptionsFromConfigs(configs, catalog)).toEqual([ + { + id: "config-bedrock", + provider: "bedrock", + model: "anthropic.claude-sonnet-4-20250514-v1:0", + displayName: "Claude Sonnet via Bedrock", + contextLimit: 200_000, + }, + ]); + }); + + it("does not use Anthropic availability for Bedrock configs", () => { + const configs = [ + createConfig({ + id: "config-bedrock", + provider: "bedrock", + model: "anthropic.claude-sonnet-4-20250514-v1:0", + }), + ]; + const catalog = createCatalog([ + { + provider: "anthropic", + available: true, + models: [], + }, + ]); + + expect(getModelOptionsFromConfigs(configs, catalog)).toEqual([]); + }); + it("excludes disabled configs", () => { const configs = [ createConfig({ diff --git a/site/src/testHelpers/entities.ts b/site/src/testHelpers/entities.ts index f85f3d4753899..3c35330f64d5d 100644 --- a/site/src/testHelpers/entities.ts +++ b/site/src/testHelpers/entities.ts @@ -5550,14 +5550,9 @@ export const MockAIProviderAnthropic: TypesGen.AIProvider = { updated_at: "2026-05-14T10:00:00Z", }; -/** - * Bedrock providers come over the wire with `type: "anthropic"` and a - * `settings._type: "bedrock"` discriminator. `isBedrockProvider` and the - * backend (see `coderd/ai_providers.go`) enforce this convention. - */ export const MockAIProviderBedrock: TypesGen.AIProvider = { id: "9c2e3b41-2e9f-4c97-9a4f-2e1a3d8f9f21", - type: "anthropic", + type: "bedrock", name: "bedrock", display_name: "Bedrock", base_url: "https://bedrock-runtime.us-east-2.amazonaws.com", @@ -5574,6 +5569,13 @@ export const MockAIProviderBedrock: TypesGen.AIProvider = { updated_at: "2026-05-14T10:00:00Z", }; +export const MockAIProviderBedrockAsAnthropic: TypesGen.AIProvider = { + ...MockAIProviderBedrock, + id: "da8b7fcb-ec30-4e15-a31d-73dfd87a8501", + type: "anthropic", + name: "anthropic-bedrock", +}; + export const MockAIProviderCopilot: TypesGen.AIProvider = { id: "b3f0d2c8-6a4e-4d11-8c2f-1e9a7c5b4d31", type: "copilot", diff --git a/site/src/utils/aiProviders.ts b/site/src/utils/aiProviders.ts index e91b37cf66a2f..1aadf3f0d47bd 100644 --- a/site/src/utils/aiProviders.ts +++ b/site/src/utils/aiProviders.ts @@ -1,3 +1,31 @@ +import type { AIProvider, AIProviderType } from "#/api/typesGenerated"; + +export const AI_PROVIDER_SETTINGS_TYPE_BEDROCK = "bedrock"; + +type SettingsWire = { + readonly _type?: string; +}; + +export const isBedrockAIProvider = ( + provider: Pick, +): boolean => { + if (provider.type === "bedrock") { + return true; + } + if (provider.type !== "anthropic") { + return false; + } + const settings = provider.settings as SettingsWire | null | undefined; + return ( + settings != null && settings._type === AI_PROVIDER_SETTINGS_TYPE_BEDROCK + ); +}; + +export const canonicalAIProviderType = ( + provider: Pick, +): AIProviderType => + isBedrockAIProvider(provider) ? "bedrock" : provider.type; + export const formatProviderLabel = (provider: string): string => { const normalized = provider.trim().toLowerCase(); switch (normalized) {