From a2b199edc27dc1a3b3c2e418f45deb00ae46f618 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 17:35:57 +0000 Subject: [PATCH 1/8] feat: remove legacy chat provider tables --- coderd/coderdtest/chat.go | 15 +- coderd/database/check_constraint.go | 72 +- coderd/database/dbauthz/dbauthz.go | 107 -- coderd/database/dbauthz/dbauthz_test.go | 97 +- coderd/database/dbgen/dbgen.go | 58 +- coderd/database/dbmetrics/querymetrics.go | 104 -- coderd/database/dbmock/dbmock.go | 193 ---- coderd/database/dump.sql | 61 +- coderd/database/foreign_key_constraint.go | 5 - .../database/legacy_chat_provider_compat.go | 45 + ...00500_ai_providers_legacy_cleanup.down.sql | 36 + .../000500_ai_providers_legacy_cleanup.up.sql | 6 + coderd/database/models.go | 27 - coderd/database/querier.go | 13 - coderd/database/querier_test.go | 155 +-- coderd/database/queries.sql.go | 507 +--------- coderd/database/queries/chatmodelconfigs.sql | 24 +- coderd/database/queries/chatproviders.sql | 102 -- .../database/queries/userchatproviderkeys.sql | 20 - coderd/database/unique_constraint.go | 4 - coderd/exp_chats.go | 948 ++---------------- coderd/exp_chats_test.go | 491 +++------ ...kspaceagents_chat_context_internal_test.go | 23 +- coderd/x/chatd/advisor_internal_test.go | 4 +- coderd/x/chatd/chatd.go | 47 +- coderd/x/chatd/chatd_internal_test.go | 43 +- coderd/x/chatd/chatd_test.go | 145 +-- coderd/x/chatd/configcache.go | 14 +- coderd/x/chatd/configcache_test.go | 55 +- coderd/x/chatd/subagent.go | 14 +- coderd/x/chatd/subagent_internal_test.go | 28 +- coderd/x/chatd/title_override_test.go | 8 +- coderd/x/chatd/turn_summary_internal_test.go | 26 +- coderd/x/gitsync/worker_test.go | 2 +- enterprise/coderd/x/chatd/chatd_test.go | 50 +- enterprise/dbcrypt/cliutil.go | 107 +- enterprise/dbcrypt/dbcrypt.go | 137 --- enterprise/dbcrypt/dbcrypt_internal_test.go | 47 +- 38 files changed, 720 insertions(+), 3120 deletions(-) create mode 100644 coderd/database/legacy_chat_provider_compat.go create mode 100644 coderd/database/migrations/000500_ai_providers_legacy_cleanup.down.sql create mode 100644 coderd/database/migrations/000500_ai_providers_legacy_cleanup.up.sql delete mode 100644 coderd/database/queries/chatproviders.sql delete mode 100644 coderd/database/queries/userchatproviderkeys.sql diff --git a/coderd/coderdtest/chat.go b/coderd/coderdtest/chat.go index f7d994a00d92f..6713374b92958 100644 --- a/coderd/coderdtest/chat.go +++ b/coderd/coderdtest/chat.go @@ -59,10 +59,16 @@ func CreateOpenAICompatChatModelConfig( t.Helper() ctx := testutil.Context(t, testutil.WaitLong) - _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: TestChatProviderOpenAICompat, - APIKey: TestChatProviderAPIKey, - BaseURL: baseURL, + enabled := true + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderType(TestChatProviderOpenAICompat), + Name: "test-" + uuid.NewString(), + BaseURL: baseURL, + Enabled: &enabled, + }) + require.NoError(t, err) + _, err = client.CreateAIProviderKey(ctx, provider.ID, codersdk.CreateAIProviderKeyRequest{ + APIKey: TestChatProviderAPIKey, }) require.NoError(t, err) @@ -70,6 +76,7 @@ func CreateOpenAICompatChatModelConfig( isDefault := true modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: TestChatProviderOpenAICompat, + AIProviderID: &provider.ID, Model: TestChatModelOpenAICompat, ContextLimit: &contextLimit, IsDefault: &isDefault, diff --git a/coderd/database/check_constraint.go b/coderd/database/check_constraint.go index fab86c4bc7ee5..5e14a6f6e3dcf 100644 --- a/coderd/database/check_constraint.go +++ b/coderd/database/check_constraint.go @@ -6,41 +6,39 @@ type CheckConstraint string // CheckConstraint enums. const ( - CheckAiModelPricesCacheReadPriceCheck CheckConstraint = "ai_model_prices_cache_read_price_check" // ai_model_prices - CheckAiModelPricesCacheWritePriceCheck CheckConstraint = "ai_model_prices_cache_write_price_check" // ai_model_prices - CheckAiModelPricesInputPriceCheck CheckConstraint = "ai_model_prices_input_price_check" // ai_model_prices - CheckAiModelPricesOutputPriceCheck CheckConstraint = "ai_model_prices_output_price_check" // ai_model_prices - CheckAiProvidersNameCheck CheckConstraint = "ai_providers_name_check" // ai_providers - CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys - CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs - CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs - CheckChatProvidersProviderCheck CheckConstraint = "chat_providers_provider_check" // chat_providers - CheckValidCredentialPolicy CheckConstraint = "valid_credential_policy" // chat_providers - CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config - CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config - CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config - CheckChatsPinOrderArchivedCheck CheckConstraint = "chats_pin_order_archived_check" // chats - CheckChatsPinOrderParentCheck CheckConstraint = "chats_pin_order_parent_check" // chats - CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users - CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users - CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users - CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users - CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users - CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles - CheckGroupAiBudgetsSpendLimitMicrosCheck CheckConstraint = "group_ai_budgets_spend_limit_micros_check" // group_ai_budgets - CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups - CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs - CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs - CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs - CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs - CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents - CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents - CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds - CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces - CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces - CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks - CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters - CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events - CheckUserAiProviderKeysAPIKeyCheck CheckConstraint = "user_ai_provider_keys_api_key_check" // user_ai_provider_keys - CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys + CheckAiModelPricesCacheReadPriceCheck CheckConstraint = "ai_model_prices_cache_read_price_check" // ai_model_prices + CheckAiModelPricesCacheWritePriceCheck CheckConstraint = "ai_model_prices_cache_write_price_check" // ai_model_prices + CheckAiModelPricesInputPriceCheck CheckConstraint = "ai_model_prices_input_price_check" // ai_model_prices + CheckAiModelPricesOutputPriceCheck CheckConstraint = "ai_model_prices_output_price_check" // ai_model_prices + CheckAiProvidersNameCheck CheckConstraint = "ai_providers_name_check" // ai_providers + CheckAPIKeysAllowListNotEmpty CheckConstraint = "api_keys_allow_list_not_empty" // api_keys + CheckChatModelConfigsAiProviderRequiredWhenActive CheckConstraint = "chat_model_configs_ai_provider_required_when_active" // chat_model_configs + CheckChatModelConfigsCompressionThresholdCheck CheckConstraint = "chat_model_configs_compression_threshold_check" // chat_model_configs + CheckChatModelConfigsContextLimitCheck CheckConstraint = "chat_model_configs_context_limit_check" // chat_model_configs + CheckChatUsageLimitConfigDefaultLimitMicrosCheck CheckConstraint = "chat_usage_limit_config_default_limit_micros_check" // chat_usage_limit_config + CheckChatUsageLimitConfigPeriodCheck CheckConstraint = "chat_usage_limit_config_period_check" // chat_usage_limit_config + CheckChatUsageLimitConfigSingletonCheck CheckConstraint = "chat_usage_limit_config_singleton_check" // chat_usage_limit_config + CheckChatsPinOrderArchivedCheck CheckConstraint = "chats_pin_order_archived_check" // chats + CheckChatsPinOrderParentCheck CheckConstraint = "chats_pin_order_parent_check" // chats + CheckOneTimePasscodeSet CheckConstraint = "one_time_passcode_set" // users + CheckUsersChatSpendLimitMicrosCheck CheckConstraint = "users_chat_spend_limit_micros_check" // users + CheckUsersEmailNotEmpty CheckConstraint = "users_email_not_empty" // users + CheckUsersServiceAccountLoginType CheckConstraint = "users_service_account_login_type" // users + CheckUsersUsernameMinLength CheckConstraint = "users_username_min_length" // users + CheckOrganizationIDNotZero CheckConstraint = "organization_id_not_zero" // custom_roles + CheckGroupAiBudgetsSpendLimitMicrosCheck CheckConstraint = "group_ai_budgets_spend_limit_micros_check" // group_ai_budgets + CheckGroupsChatSpendLimitMicrosCheck CheckConstraint = "groups_chat_spend_limit_micros_check" // groups + CheckMcpServerConfigsAuthTypeCheck CheckConstraint = "mcp_server_configs_auth_type_check" // mcp_server_configs + CheckMcpServerConfigsAvailabilityCheck CheckConstraint = "mcp_server_configs_availability_check" // mcp_server_configs + CheckMcpServerConfigsTransportCheck CheckConstraint = "mcp_server_configs_transport_check" // mcp_server_configs + CheckMaxProvisionerLogsLength CheckConstraint = "max_provisioner_logs_length" // provisioner_jobs + CheckMaxLogsLength CheckConstraint = "max_logs_length" // workspace_agents + CheckSubsystemsNotNone CheckConstraint = "subsystems_not_none" // workspace_agents + CheckWorkspaceBuildsDeadlineBelowMaxDeadline CheckConstraint = "workspace_builds_deadline_below_max_deadline" // workspace_builds + CheckGroupAclIsObject CheckConstraint = "group_acl_is_object" // workspaces + CheckUserAclIsObject CheckConstraint = "user_acl_is_object" // workspaces + CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks + CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters + CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events + CheckUserAiProviderKeysAPIKeyCheck CheckConstraint = "user_ai_provider_keys_api_key_check" // user_ai_provider_keys ) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 0cc8e9561a372..1bf318b2143b9 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1951,13 +1951,6 @@ func (q *querier) DeleteChatModelConfigsByProvider(ctx context.Context, provider return q.db.DeleteChatModelConfigsByProvider(ctx, provider) } -func (q *querier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return err - } - return q.db.DeleteChatProviderByID(ctx, id) -} - func (q *querier) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error { chat, err := q.db.GetChatByID(ctx, arg.ChatID) if err != nil { @@ -2283,17 +2276,6 @@ func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg dat return q.db.DeleteUserChatCompactionThreshold(ctx, arg) } -func (q *querier) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { - u, err := q.db.GetUserByID(ctx, arg.UserID) - if err != nil { - return err - } - if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { - return err - } - return q.db.DeleteUserChatProviderKey(ctx, arg) -} - func (q *querier) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) { obj := rbac.ResourceUserSecret.WithOwner(arg.UserID.String()) if err := q.authorizeContext(ctx, policy.ActionDelete, obj); err != nil { @@ -3041,41 +3023,6 @@ func (q *querier) GetChatPlanModeInstructions(ctx context.Context) (string, erro return q.db.GetChatPlanModeInstructions(ctx) } -func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.GetChatProviderByID(ctx, id) -} - -func (q *querier) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.GetChatProviderByIDForUpdate(ctx, id) -} - -func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.GetChatProviderByProvider(ctx, provider) -} - -func (q *querier) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.GetChatProviderByProviderForUpdate(ctx, provider) -} - -func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { - return nil, err - } - return q.db.GetChatProviders(ctx) -} - func (q *querier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) { _, err := q.GetChatByID(ctx, chatID) if err != nil { @@ -3312,13 +3259,6 @@ func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.Ch return q.db.GetEnabledChatModelConfigs(ctx) } -func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { - return nil, err - } - return q.db.GetEnabledChatProviders(ctx) -} - func (q *querier) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { return nil, err @@ -4561,17 +4501,6 @@ func (q *querier) GetUserChatPersonalModelOverride(ctx context.Context, arg data return q.db.GetUserChatPersonalModelOverride(ctx, arg) } -func (q *querier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { - u, err := q.db.GetUserByID(ctx, userID) - if err != nil { - return nil, err - } - if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil { - return nil, err - } - return q.db.GetUserChatProviderKeys(ctx, userID) -} - func (q *querier) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.UserID.String())); err != nil { return 0, err @@ -5417,13 +5346,6 @@ func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.Insert return q.db.InsertChatModelConfig(ctx, arg) } -func (q *querier) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.InsertChatProvider(ctx, arg) -} - func (q *querier) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) { chat, err := q.db.GetChatByID(ctx, arg.ChatID) if err != nil { @@ -6586,13 +6508,6 @@ func (q *querier) UpdateChatPlanModeByID(ctx context.Context, arg database.Updat return q.db.UpdateChatPlanModeByID(ctx, arg) } -func (q *querier) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { - return database.ChatProvider{}, err - } - return q.db.UpdateChatProvider(ctx, arg) -} - func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) { // UpdateChatStatus is used by the chat processor to change chat status. // It should be called with system context. @@ -7266,17 +7181,6 @@ func (q *querier) UpdateUserChatCustomPrompt(ctx context.Context, arg database.U return q.db.UpdateUserChatCustomPrompt(ctx, arg) } -func (q *querier) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - u, err := q.db.GetUserByID(ctx, arg.UserID) - if err != nil { - return database.UserChatProviderKey{}, err - } - if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { - return database.UserChatProviderKey{}, err - } - return q.db.UpdateUserChatProviderKey(ctx, arg) -} - func (q *querier) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) { user, err := q.db.GetUserByID(ctx, arg.UserID) if err != nil { @@ -8200,17 +8104,6 @@ func (q *querier) UpsertUserChatPersonalModelOverride(ctx context.Context, arg d return q.db.UpsertUserChatPersonalModelOverride(ctx, arg) } -func (q *querier) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - u, err := q.db.GetUserByID(ctx, arg.UserID) - if err != nil { - return database.UserChatProviderKey{}, err - } - if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { - return database.UserChatProviderKey{}, err - } - return q.db.UpsertUserChatProviderKey(ctx, arg) -} - func (q *querier) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 3a537ac78a382..6b8869e1387da 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -455,11 +455,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().DeleteChatModelConfigsByProvider(gomock.Any(), providerName).Return(nil).AnyTimes() check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) })) - s.Run("DeleteChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { - id := uuid.New() - dbm.EXPECT().DeleteChatProviderByID(gomock.Any(), id).Return(nil).AnyTimes() - check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate) - })) + s.Run("DeleteChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) args := database.DeleteChatQueuedMessageParams{ID: 123, ChatID: chat.ID} @@ -852,34 +848,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB}) })) - s.Run("GetChatProviderByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - provider := testutil.Fake(s.T(), faker, database.ChatProvider{}) - dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes() - check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider) - })) - s.Run("GetChatProviderByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - provider := testutil.Fake(s.T(), faker, database.ChatProvider{}) - dbm.EXPECT().GetChatProviderByIDForUpdate(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes() - check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider) - })) - s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - providerName := "test-provider" - provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName}) - dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes() - check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider) - })) - s.Run("GetChatProviderByProviderForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - providerName := "test-provider" - provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName}) - dbm.EXPECT().GetChatProviderByProviderForUpdate(gomock.Any(), providerName).Return(provider, nil).AnyTimes() - check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider) - })) - s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - providerA := testutil.Fake(s.T(), faker, database.ChatProvider{}) - providerB := testutil.Fake(s.T(), faker, database.ChatProvider{}) - dbm.EXPECT().GetChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes() - check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB}) - })) + s.Run("GetChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { params := database.GetChatsParams{} dbm.EXPECT().GetAuthorizedChats(gomock.Any(), params, gomock.Any()).Return([]database.GetChatsRow{}, nil).AnyTimes() @@ -978,12 +947,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{configA, configB}, nil).AnyTimes() check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatModelConfig{configA, configB}) })) - s.Run("GetEnabledChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - providerA := testutil.Fake(s.T(), faker, database.ChatProvider{}) - providerB := testutil.Fake(s.T(), faker, database.ChatProvider{}) - dbm.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{providerA, providerB}, nil).AnyTimes() - check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns([]database.ChatProvider{providerA, providerB}) - })) + s.Run("GetStaleChats", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { threshold := dbtime.Now() chats := []database.Chat{testutil.Fake(s.T(), faker, database.Chat{})} @@ -1032,17 +996,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().InsertChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config) })) - s.Run("InsertChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - arg := database.InsertChatProviderParams{ - Provider: "test-provider", - DisplayName: "Test Provider", - APIKey: "test-api-key", - Enabled: true, - } - provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: arg.Provider, DisplayName: arg.DisplayName, APIKey: arg.APIKey, Enabled: arg.Enabled}) - dbm.EXPECT().InsertChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes() - check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider) - })) + s.Run("PopNextQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) qm := testutil.Fake(s.T(), faker, database.ChatQueuedMessage{}) @@ -1156,17 +1110,7 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatModelConfig(gomock.Any(), arg).Return(config, nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(config) })) - s.Run("UpdateChatProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - provider := testutil.Fake(s.T(), faker, database.ChatProvider{}) - arg := database.UpdateChatProviderParams{ - ID: provider.ID, - DisplayName: "Updated Provider", - APIKey: "updated-api-key", - Enabled: true, - } - dbm.EXPECT().UpdateChatProvider(gomock.Any(), arg).Return(provider, nil).AnyTimes() - check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(provider) - })) + s.Run("UpdateChatPinOrder", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := database.UpdateChatPinOrderParams{ @@ -2774,36 +2718,7 @@ func (s *MethodTestSuite) TestUser() { dbm.EXPECT().GetUserChatCustomPrompt(gomock.Any(), u.ID).Return("my custom prompt", nil).AnyTimes() check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns("my custom prompt") })) - s.Run("GetUserChatProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - u := testutil.Fake(s.T(), faker, database.User{}) - key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID}) - dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() - dbm.EXPECT().GetUserChatProviderKeys(gomock.Any(), u.ID).Return([]database.UserChatProviderKey{key}, nil).AnyTimes() - check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserChatProviderKey{key}) - })) - s.Run("DeleteUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - u := testutil.Fake(s.T(), faker, database.User{}) - arg := database.DeleteUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New()} - dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() - dbm.EXPECT().DeleteUserChatProviderKey(gomock.Any(), arg).Return(nil).AnyTimes() - check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns() - })) - s.Run("UpdateUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - u := testutil.Fake(s.T(), faker, database.User{}) - arg := database.UpdateUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "updated-api-key"} - key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey}) - dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() - dbm.EXPECT().UpdateUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() - check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) - })) - s.Run("UpsertUserChatProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { - u := testutil.Fake(s.T(), faker, database.User{}) - arg := database.UpsertUserChatProviderKeyParams{UserID: u.ID, ChatProviderID: uuid.New(), APIKey: "upserted-api-key"} - key := testutil.Fake(s.T(), faker, database.UserChatProviderKey{UserID: u.ID, ChatProviderID: arg.ChatProviderID, APIKey: arg.APIKey}) - dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() - dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() - check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) - })) + s.Run("GetUserAIProviderKeyByProviderID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { u := testutil.Fake(s.T(), faker, database.User{}) arg := database.GetUserAIProviderKeyByProviderIDParams{UserID: u.ID, AiProviderID: uuid.New()} diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index b134adecfbe7e..b19ec5f6713b8 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -149,8 +149,29 @@ const ( func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelConfig, munge ...func(*database.InsertChatModelConfigParams)) database.ChatModelConfig { t.Helper() + providerName := takeFirst(seed.Provider, "openai") + aiProviderID := seed.AiProviderID + if !aiProviderID.Valid { + providers, err := db.GetAIProviders(genCtx, database.GetAIProvidersParams{IncludeDisabled: true}) + require.NoError(t, err, "get ai providers") + var provider database.AIProvider + for _, candidate := range providers { + if candidate.Type != database.AIProviderType(providerName) { + continue + } + if provider.ID == uuid.Nil || candidate.CreatedAt.After(provider.CreatedAt) { + provider = candidate + } + } + if provider.ID == uuid.Nil { + provider = AIProvider(t, db, database.AIProvider{ + Type: database.AIProviderType(providerName), + }) + } + aiProviderID = uuid.NullUUID{UUID: provider.ID, Valid: true} + } params := database.InsertChatModelConfigParams{ - Provider: takeFirst(seed.Provider, "openai"), + Provider: providerName, Model: takeFirst(seed.Model, "gpt-4o-mini"), DisplayName: takeFirst(seed.DisplayName, "Test Model"), CreatedBy: seed.CreatedBy, @@ -160,7 +181,7 @@ func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelCon ContextLimit: takeFirst(seed.ContextLimit, defaultChatModelContextLimit), CompressionThreshold: takeFirst(seed.CompressionThreshold, defaultChatModelCompressionThreshold), Options: takeFirstSlice(seed.Options, json.RawMessage(`{}`)), - AiProviderID: seed.AiProviderID, + AiProviderID: aiProviderID, } for _, fn := range munge { fn(¶ms) @@ -243,9 +264,36 @@ func ChatProvider(t testing.TB, db database.Store, seed database.ChatProvider, m for _, fn := range munge { fn(¶ms) } - provider, err := db.InsertChatProvider(genCtx, params) - require.NoError(t, err, "insert chat provider") - return provider + provider := AIProvider(t, db, database.AIProvider{ + Type: database.AIProviderType(params.Provider), + Name: "test-" + strings.ToLower(strings.ReplaceAll(uuid.NewString(), "_", "-")), + DisplayName: sql.NullString{String: params.DisplayName, Valid: params.DisplayName != ""}, + BaseUrl: params.BaseUrl, + }, func(p *database.InsertAIProviderParams) { + p.Enabled = params.Enabled + }) + if params.APIKey != "" { + AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, + APIKey: params.APIKey, + ApiKeyKeyID: params.ApiKeyKeyID, + }) + } + return database.ChatProvider{ + ID: provider.ID, + Provider: params.Provider, + DisplayName: params.DisplayName, + APIKey: params.APIKey, + BaseUrl: params.BaseUrl, + ApiKeyKeyID: params.ApiKeyKeyID, + CreatedBy: params.CreatedBy, + Enabled: params.Enabled, + CentralApiKeyEnabled: params.CentralApiKeyEnabled, + AllowUserApiKey: params.AllowUserApiKey, + AllowCentralApiKeyFallback: params.AllowCentralApiKeyFallback, + CreatedAt: provider.CreatedAt, + UpdatedAt: provider.UpdatedAt, + } } func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerConfig) database.MCPServerConfig { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index c405bfcb65859..ef11552f42484 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -473,14 +473,6 @@ func (m queryMetricsStore) DeleteChatModelConfigsByProvider(ctx context.Context, return r0 } -func (m queryMetricsStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error { - start := time.Now() - r0 := m.s.DeleteChatProviderByID(ctx, id) - m.queryLatencies.WithLabelValues("DeleteChatProviderByID").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteChatProviderByID").Inc() - return r0 -} - func (m queryMetricsStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error { start := time.Now() r0 := m.s.DeleteChatQueuedMessage(ctx, arg) @@ -801,14 +793,6 @@ func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context return r0 } -func (m queryMetricsStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { - start := time.Now() - r0 := m.s.DeleteUserChatProviderKey(ctx, arg) - m.queryLatencies.WithLabelValues("DeleteUserChatProviderKey").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserChatProviderKey").Inc() - return r0 -} - func (m queryMetricsStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) { start := time.Now() r0, r1 := m.s.DeleteUserSecretByUserIDAndName(ctx, arg) @@ -1505,46 +1489,6 @@ func (m queryMetricsStore) GetChatPlanModeInstructions(ctx context.Context) (str return r0, r1 } -func (m queryMetricsStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviderByID(ctx, id) - m.queryLatencies.WithLabelValues("GetChatProviderByID").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByID").Inc() - return r0, r1 -} - -func (m queryMetricsStore) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviderByIDForUpdate(ctx, id) - m.queryLatencies.WithLabelValues("GetChatProviderByIDForUpdate").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByIDForUpdate").Inc() - return r0, r1 -} - -func (m queryMetricsStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviderByProvider(ctx, provider) - m.queryLatencies.WithLabelValues("GetChatProviderByProvider").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProvider").Inc() - return r0, r1 -} - -func (m queryMetricsStore) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviderByProviderForUpdate(ctx, provider) - m.queryLatencies.WithLabelValues("GetChatProviderByProviderForUpdate").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviderByProviderForUpdate").Inc() - return r0, r1 -} - -func (m queryMetricsStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetChatProviders(ctx) - m.queryLatencies.WithLabelValues("GetChatProviders").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatProviders").Inc() - return r0, r1 -} - func (m queryMetricsStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) { start := time.Now() r0, r1 := m.s.GetChatQueuedMessages(ctx, chatID) @@ -1793,14 +1737,6 @@ func (m queryMetricsStore) GetEnabledChatModelConfigs(ctx context.Context) ([]da return r0, r1 } -func (m queryMetricsStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.GetEnabledChatProviders(ctx) - m.queryLatencies.WithLabelValues("GetEnabledChatProviders").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetEnabledChatProviders").Inc() - return r0, r1 -} - func (m queryMetricsStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { start := time.Now() r0, r1 := m.s.GetEnabledMCPServerConfigs(ctx) @@ -2985,14 +2921,6 @@ func (m queryMetricsStore) GetUserChatPersonalModelOverride(ctx context.Context, return r0, r1 } -func (m queryMetricsStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { - start := time.Now() - r0, r1 := m.s.GetUserChatProviderKeys(ctx, userID) - m.queryLatencies.WithLabelValues("GetUserChatProviderKeys").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserChatProviderKeys").Inc() - return r0, r1 -} - func (m queryMetricsStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { start := time.Now() r0, r1 := m.s.GetUserChatSpendInPeriod(ctx, arg) @@ -3769,14 +3697,6 @@ func (m queryMetricsStore) InsertChatModelConfig(ctx context.Context, arg databa return r0, r1 } -func (m queryMetricsStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.InsertChatProvider(ctx, arg) - m.queryLatencies.WithLabelValues("InsertChatProvider").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatProvider").Inc() - return r0, r1 -} - func (m queryMetricsStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) { start := time.Now() r0, r1 := m.s.InsertChatQueuedMessage(ctx, arg) @@ -4785,14 +4705,6 @@ func (m queryMetricsStore) UpdateChatPlanModeByID(ctx context.Context, arg datab return r0, r1 } -func (m queryMetricsStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { - start := time.Now() - r0, r1 := m.s.UpdateChatProvider(ctx, arg) - m.queryLatencies.WithLabelValues("UpdateChatProvider").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatProvider").Inc() - return r0, r1 -} - func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) { start := time.Now() r0, r1 := m.s.UpdateChatStatus(ctx, arg) @@ -5225,14 +5137,6 @@ func (m queryMetricsStore) UpdateUserChatCustomPrompt(ctx context.Context, arg d return r0, r1 } -func (m queryMetricsStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - start := time.Now() - r0, r1 := m.s.UpdateUserChatProviderKey(ctx, arg) - m.queryLatencies.WithLabelValues("UpdateUserChatProviderKey").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserChatProviderKey").Inc() - return r0, r1 -} - func (m queryMetricsStore) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) { start := time.Now() r0, r1 := m.s.UpdateUserCodeDiffDisplayMode(ctx, arg) @@ -6001,14 +5905,6 @@ func (m queryMetricsStore) UpsertUserChatPersonalModelOverride(ctx context.Conte return r0 } -func (m queryMetricsStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - start := time.Now() - r0, r1 := m.s.UpsertUserChatProviderKey(ctx, arg) - m.queryLatencies.WithLabelValues("UpsertUserChatProviderKey").Observe(time.Since(start).Seconds()) - m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserChatProviderKey").Inc() - return r0, r1 -} - func (m queryMetricsStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { start := time.Now() r0 := m.s.UpsertWebpushVAPIDKeys(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index ec926ce4aa79e..07fae8e042d3c 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -774,20 +774,6 @@ func (mr *MockStoreMockRecorder) DeleteChatModelConfigsByProvider(ctx, provider return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatModelConfigsByProvider", reflect.TypeOf((*MockStore)(nil).DeleteChatModelConfigsByProvider), ctx, provider) } -// DeleteChatProviderByID mocks base method. -func (m *MockStore) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteChatProviderByID", ctx, id) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteChatProviderByID indicates an expected call of DeleteChatProviderByID. -func (mr *MockStoreMockRecorder) DeleteChatProviderByID(ctx, id any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteChatProviderByID", reflect.TypeOf((*MockStore)(nil).DeleteChatProviderByID), ctx, id) -} - // DeleteChatQueuedMessage mocks base method. func (m *MockStore) DeleteChatQueuedMessage(ctx context.Context, arg database.DeleteChatQueuedMessageParams) error { m.ctrl.T.Helper() @@ -1362,20 +1348,6 @@ func (mr *MockStoreMockRecorder) DeleteUserChatCompactionThreshold(ctx, arg any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatCompactionThreshold", reflect.TypeOf((*MockStore)(nil).DeleteUserChatCompactionThreshold), ctx, arg) } -// DeleteUserChatProviderKey mocks base method. -func (m *MockStore) DeleteUserChatProviderKey(ctx context.Context, arg database.DeleteUserChatProviderKeyParams) error { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteUserChatProviderKey", ctx, arg) - ret0, _ := ret[0].(error) - return ret0 -} - -// DeleteUserChatProviderKey indicates an expected call of DeleteUserChatProviderKey. -func (mr *MockStoreMockRecorder) DeleteUserChatProviderKey(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserChatProviderKey), ctx, arg) -} - // DeleteUserSecretByUserIDAndName mocks base method. func (m *MockStore) DeleteUserSecretByUserIDAndName(ctx context.Context, arg database.DeleteUserSecretByUserIDAndNameParams) (database.UserSecret, error) { m.ctrl.T.Helper() @@ -2775,81 +2747,6 @@ func (mr *MockStoreMockRecorder) GetChatPlanModeInstructions(ctx any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatPlanModeInstructions", reflect.TypeOf((*MockStore)(nil).GetChatPlanModeInstructions), ctx) } -// GetChatProviderByID mocks base method. -func (m *MockStore) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviderByID", ctx, id) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviderByID indicates an expected call of GetChatProviderByID. -func (mr *MockStoreMockRecorder) GetChatProviderByID(ctx, id any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByID", reflect.TypeOf((*MockStore)(nil).GetChatProviderByID), ctx, id) -} - -// GetChatProviderByIDForUpdate mocks base method. -func (m *MockStore) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviderByIDForUpdate", ctx, id) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviderByIDForUpdate indicates an expected call of GetChatProviderByIDForUpdate. -func (mr *MockStoreMockRecorder) GetChatProviderByIDForUpdate(ctx, id any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatProviderByIDForUpdate), ctx, id) -} - -// GetChatProviderByProvider mocks base method. -func (m *MockStore) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviderByProvider", ctx, provider) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviderByProvider indicates an expected call of GetChatProviderByProvider. -func (mr *MockStoreMockRecorder) GetChatProviderByProvider(ctx, provider any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProvider", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProvider), ctx, provider) -} - -// GetChatProviderByProviderForUpdate mocks base method. -func (m *MockStore) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviderByProviderForUpdate", ctx, provider) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviderByProviderForUpdate indicates an expected call of GetChatProviderByProviderForUpdate. -func (mr *MockStoreMockRecorder) GetChatProviderByProviderForUpdate(ctx, provider any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviderByProviderForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatProviderByProviderForUpdate), ctx, provider) -} - -// GetChatProviders mocks base method. -func (m *MockStore) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetChatProviders", ctx) - ret0, _ := ret[0].([]database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetChatProviders indicates an expected call of GetChatProviders. -func (mr *MockStoreMockRecorder) GetChatProviders(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatProviders", reflect.TypeOf((*MockStore)(nil).GetChatProviders), ctx) -} - // GetChatQueuedMessages mocks base method. func (m *MockStore) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]database.ChatQueuedMessage, error) { m.ctrl.T.Helper() @@ -3315,21 +3212,6 @@ func (mr *MockStoreMockRecorder) GetEnabledChatModelConfigs(ctx any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatModelConfigs", reflect.TypeOf((*MockStore)(nil).GetEnabledChatModelConfigs), ctx) } -// GetEnabledChatProviders mocks base method. -func (m *MockStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetEnabledChatProviders", ctx) - ret0, _ := ret[0].([]database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetEnabledChatProviders indicates an expected call of GetEnabledChatProviders. -func (mr *MockStoreMockRecorder) GetEnabledChatProviders(ctx any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetEnabledChatProviders", reflect.TypeOf((*MockStore)(nil).GetEnabledChatProviders), ctx) -} - // GetEnabledMCPServerConfigs mocks base method. func (m *MockStore) GetEnabledMCPServerConfigs(ctx context.Context) ([]database.MCPServerConfig, error) { m.ctrl.T.Helper() @@ -5580,21 +5462,6 @@ func (mr *MockStoreMockRecorder) GetUserChatPersonalModelOverride(ctx, arg any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatPersonalModelOverride", reflect.TypeOf((*MockStore)(nil).GetUserChatPersonalModelOverride), ctx, arg) } -// GetUserChatProviderKeys mocks base method. -func (m *MockStore) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetUserChatProviderKeys", ctx, userID) - ret0, _ := ret[0].([]database.UserChatProviderKey) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// GetUserChatProviderKeys indicates an expected call of GetUserChatProviderKeys. -func (mr *MockStoreMockRecorder) GetUserChatProviderKeys(ctx, userID any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserChatProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserChatProviderKeys), ctx, userID) -} - // GetUserChatSpendInPeriod mocks base method. func (m *MockStore) GetUserChatSpendInPeriod(ctx context.Context, arg database.GetUserChatSpendInPeriodParams) (int64, error) { m.ctrl.T.Helper() @@ -7064,21 +6931,6 @@ func (mr *MockStoreMockRecorder) InsertChatModelConfig(ctx, arg any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatModelConfig", reflect.TypeOf((*MockStore)(nil).InsertChatModelConfig), ctx, arg) } -// InsertChatProvider mocks base method. -func (m *MockStore) InsertChatProvider(ctx context.Context, arg database.InsertChatProviderParams) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "InsertChatProvider", ctx, arg) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// InsertChatProvider indicates an expected call of InsertChatProvider. -func (mr *MockStoreMockRecorder) InsertChatProvider(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatProvider", reflect.TypeOf((*MockStore)(nil).InsertChatProvider), ctx, arg) -} - // InsertChatQueuedMessage mocks base method. func (m *MockStore) InsertChatQueuedMessage(ctx context.Context, arg database.InsertChatQueuedMessageParams) (database.ChatQueuedMessage, error) { m.ctrl.T.Helper() @@ -9043,21 +8895,6 @@ func (mr *MockStoreMockRecorder) UpdateChatPlanModeByID(ctx, arg any) *gomock.Ca return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatPlanModeByID", reflect.TypeOf((*MockStore)(nil).UpdateChatPlanModeByID), ctx, arg) } -// UpdateChatProvider mocks base method. -func (m *MockStore) UpdateChatProvider(ctx context.Context, arg database.UpdateChatProviderParams) (database.ChatProvider, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateChatProvider", ctx, arg) - ret0, _ := ret[0].(database.ChatProvider) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateChatProvider indicates an expected call of UpdateChatProvider. -func (mr *MockStoreMockRecorder) UpdateChatProvider(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatProvider", reflect.TypeOf((*MockStore)(nil).UpdateChatProvider), ctx, arg) -} - // UpdateChatStatus mocks base method. func (m *MockStore) UpdateChatStatus(ctx context.Context, arg database.UpdateChatStatusParams) (database.Chat, error) { m.ctrl.T.Helper() @@ -9844,21 +9681,6 @@ func (mr *MockStoreMockRecorder) UpdateUserChatCustomPrompt(ctx, arg any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatCustomPrompt", reflect.TypeOf((*MockStore)(nil).UpdateUserChatCustomPrompt), ctx, arg) } -// UpdateUserChatProviderKey mocks base method. -func (m *MockStore) UpdateUserChatProviderKey(ctx context.Context, arg database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpdateUserChatProviderKey", ctx, arg) - ret0, _ := ret[0].(database.UserChatProviderKey) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpdateUserChatProviderKey indicates an expected call of UpdateUserChatProviderKey. -func (mr *MockStoreMockRecorder) UpdateUserChatProviderKey(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserChatProviderKey), ctx, arg) -} - // UpdateUserCodeDiffDisplayMode mocks base method. func (m *MockStore) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg database.UpdateUserCodeDiffDisplayModeParams) (string, error) { m.ctrl.T.Helper() @@ -11240,21 +11062,6 @@ func (mr *MockStoreMockRecorder) UpsertUserChatPersonalModelOverride(ctx, arg an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatPersonalModelOverride", reflect.TypeOf((*MockStore)(nil).UpsertUserChatPersonalModelOverride), ctx, arg) } -// UpsertUserChatProviderKey mocks base method. -func (m *MockStore) UpsertUserChatProviderKey(ctx context.Context, arg database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "UpsertUserChatProviderKey", ctx, arg) - ret0, _ := ret[0].(database.UserChatProviderKey) - ret1, _ := ret[1].(error) - return ret0, ret1 -} - -// UpsertUserChatProviderKey indicates an expected call of UpsertUserChatProviderKey. -func (mr *MockStoreMockRecorder) UpsertUserChatProviderKey(ctx, arg any) *gomock.Call { - mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserChatProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserChatProviderKey), ctx, arg) -} - // UpsertWebpushVAPIDKeys mocks base method. func (m *MockStore) UpsertWebpushVAPIDKeys(ctx context.Context, arg database.UpsertWebpushVAPIDKeysParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 601a44bdca2cf..624f2d69197a6 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1472,30 +1472,11 @@ CREATE TABLE chat_model_configs ( compression_threshold integer NOT NULL, options jsonb DEFAULT '{}'::jsonb NOT NULL, ai_provider_id uuid, + CONSTRAINT chat_model_configs_ai_provider_required_when_active CHECK (((deleted = true) OR (ai_provider_id IS NOT NULL))), CONSTRAINT chat_model_configs_compression_threshold_check CHECK (((compression_threshold >= 0) AND (compression_threshold <= 100))), CONSTRAINT chat_model_configs_context_limit_check CHECK ((context_limit > 0)) ); -CREATE TABLE chat_providers ( - id uuid DEFAULT gen_random_uuid() NOT NULL, - provider text NOT NULL, - display_name text DEFAULT ''::text NOT NULL, - api_key text DEFAULT ''::text NOT NULL, - api_key_key_id text, - created_by uuid, - enabled boolean DEFAULT true NOT NULL, - created_at timestamp with time zone DEFAULT now() NOT NULL, - updated_at timestamp with time zone DEFAULT now() NOT NULL, - base_url text DEFAULT ''::text NOT NULL, - central_api_key_enabled boolean DEFAULT true NOT NULL, - allow_user_api_key boolean DEFAULT false NOT NULL, - allow_central_api_key_fallback boolean DEFAULT false NOT NULL, - CONSTRAINT chat_providers_provider_check CHECK ((provider = ANY (ARRAY['anthropic'::text, 'azure'::text, 'bedrock'::text, 'google'::text, 'openai'::text, 'openai-compat'::text, 'openrouter'::text, 'vercel'::text]))), - CONSTRAINT valid_credential_policy CHECK (((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key)))) -); - -COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted'; - CREATE TABLE chat_queued_messages ( id bigint NOT NULL, chat_id uuid NOT NULL, @@ -2973,17 +2954,6 @@ COMMENT ON COLUMN user_ai_provider_keys.api_key IS 'User-owned API key used to a COMMENT ON COLUMN user_ai_provider_keys.api_key_key_id IS 'The ID of the key used to encrypt the user-owned provider API key. If this is NULL, the API key is not encrypted.'; -CREATE TABLE user_chat_provider_keys ( - id uuid DEFAULT gen_random_uuid() NOT NULL, - user_id uuid NOT NULL, - chat_provider_id uuid NOT NULL, - api_key text NOT NULL, - api_key_key_id text, - created_at timestamp with time zone DEFAULT now() NOT NULL, - updated_at timestamp with time zone DEFAULT now() NOT NULL, - CONSTRAINT user_chat_provider_keys_api_key_check CHECK ((api_key <> ''::text)) -); - CREATE TABLE user_configs ( user_id uuid NOT NULL, key character varying(256) NOT NULL, @@ -3586,12 +3556,6 @@ ALTER TABLE ONLY chat_messages ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id); -ALTER TABLE ONLY chat_providers - ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id); - -ALTER TABLE ONLY chat_providers - ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider); - ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id); @@ -3808,12 +3772,6 @@ ALTER TABLE ONLY user_ai_provider_keys ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id); -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id); - -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id); - ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key); @@ -4028,8 +3986,6 @@ CREATE INDEX idx_chat_model_configs_provider_model ON chat_model_configs USING b CREATE UNIQUE INDEX idx_chat_model_configs_single_default ON chat_model_configs USING btree ((1)) WHERE ((is_default = true) AND (deleted = false)); -CREATE INDEX idx_chat_providers_enabled ON chat_providers USING btree (enabled); - CREATE INDEX idx_chat_queued_messages_chat_id ON chat_queued_messages USING btree (chat_id); CREATE INDEX idx_chats_agent_id ON chats USING btree (agent_id) WHERE (agent_id IS NOT NULL); @@ -4354,12 +4310,6 @@ ALTER TABLE ONLY chat_model_configs ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id); -ALTER TABLE ONLY chat_providers - ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); - -ALTER TABLE ONLY chat_providers - ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); - ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; @@ -4597,15 +4547,6 @@ ALTER TABLE ONLY user_ai_provider_keys ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); - -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE; - -ALTER TABLE ONLY user_chat_provider_keys - ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; - ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 6bbd9cd7321cd..6349355b4efa8 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -24,8 +24,6 @@ const ( ForeignKeyChatModelConfigsAiProviderID ForeignKeyConstraint = "chat_model_configs_ai_provider_id_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id); ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id); - ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); - ForeignKeyChatProvidersCreatedBy ForeignKeyConstraint = "chat_providers_created_by_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL; ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL; @@ -105,9 +103,6 @@ const ( ForeignKeyUserAiProviderKeysAiProviderID ForeignKeyConstraint = "user_ai_provider_keys_ai_provider_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id) ON DELETE CASCADE; ForeignKeyUserAiProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_ai_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyUserAiProviderKeysUserID ForeignKeyConstraint = "user_ai_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; - ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); - ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE; - ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyUserConfigsUserID ForeignKeyConstraint = "user_configs_user_id_fkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyUserDeletedUserID ForeignKeyConstraint = "user_deleted_user_id_fkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id); ForeignKeyUserLinksOauthAccessTokenKeyID ForeignKeyConstraint = "user_links_oauth_access_token_key_id_fkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_oauth_access_token_key_id_fkey FOREIGN KEY (oauth_access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/legacy_chat_provider_compat.go b/coderd/database/legacy_chat_provider_compat.go new file mode 100644 index 0000000000000..076508d3dceeb --- /dev/null +++ b/coderd/database/legacy_chat_provider_compat.go @@ -0,0 +1,45 @@ +package database + +import ( + "database/sql" + "time" + + "github.com/google/uuid" +) + +// ChatProvider is retained temporarily for tests that are still being migrated +// from legacy chat providers to AI providers. +// +//nolint:revive +type ChatProvider struct { + ID uuid.UUID + Provider string + DisplayName string + APIKey string + BaseUrl string + ApiKeyKeyID sql.NullString + CreatedAt time.Time + UpdatedAt time.Time + CreatedBy uuid.NullUUID + Enabled bool + CentralApiKeyEnabled bool + AllowUserApiKey bool + AllowCentralApiKeyFallback bool +} + +// InsertChatProviderParams is retained temporarily for test helpers that munge +// legacy chat provider fields before mapping them to AI providers. +// +//nolint:revive +type InsertChatProviderParams struct { + Provider string + DisplayName string + APIKey string + BaseUrl string + ApiKeyKeyID sql.NullString + CreatedBy uuid.NullUUID + Enabled bool + CentralApiKeyEnabled bool + AllowUserApiKey bool + AllowCentralApiKeyFallback bool +} diff --git a/coderd/database/migrations/000500_ai_providers_legacy_cleanup.down.sql b/coderd/database/migrations/000500_ai_providers_legacy_cleanup.down.sql new file mode 100644 index 0000000000000..42bb92d994c36 --- /dev/null +++ b/coderd/database/migrations/000500_ai_providers_legacy_cleanup.down.sql @@ -0,0 +1,36 @@ +CREATE TABLE chat_providers ( + id uuid DEFAULT gen_random_uuid() NOT NULL PRIMARY KEY, + provider text NOT NULL UNIQUE, + display_name text DEFAULT ''::text NOT NULL, + api_key text DEFAULT ''::text NOT NULL, + api_key_key_id text REFERENCES dbcrypt_keys(active_key_digest), + created_by uuid REFERENCES users(id) ON DELETE SET NULL, + enabled boolean DEFAULT true NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + base_url text DEFAULT ''::text NOT NULL, + central_api_key_enabled boolean DEFAULT true NOT NULL, + allow_user_api_key boolean DEFAULT false NOT NULL, + allow_central_api_key_fallback boolean DEFAULT false NOT NULL, + CONSTRAINT chat_providers_provider_check CHECK (provider = ANY (ARRAY['anthropic', 'azure', 'bedrock', 'google', 'openai', 'openai-compat', 'openrouter', 'vercel'])), + CONSTRAINT valid_credential_policy CHECK ((central_api_key_enabled OR allow_user_api_key) AND ((NOT allow_central_api_key_fallback) OR (central_api_key_enabled AND allow_user_api_key))) +); + +COMMENT ON COLUMN chat_providers.api_key_key_id IS 'The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted'; + +CREATE TABLE user_chat_provider_keys ( + id uuid DEFAULT gen_random_uuid() NOT NULL PRIMARY KEY, + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + chat_provider_id uuid NOT NULL REFERENCES chat_providers(id) ON DELETE CASCADE, + api_key text NOT NULL, + api_key_key_id text REFERENCES dbcrypt_keys(active_key_digest), + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + CONSTRAINT user_chat_provider_keys_api_key_check CHECK (api_key <> '') +); + +CREATE UNIQUE INDEX user_chat_provider_keys_user_id_chat_provider_id_key + ON user_chat_provider_keys (user_id, chat_provider_id); + +ALTER TABLE chat_model_configs + DROP CONSTRAINT IF EXISTS chat_model_configs_ai_provider_required_when_active; diff --git a/coderd/database/migrations/000500_ai_providers_legacy_cleanup.up.sql b/coderd/database/migrations/000500_ai_providers_legacy_cleanup.up.sql new file mode 100644 index 0000000000000..f594880ec2b0a --- /dev/null +++ b/coderd/database/migrations/000500_ai_providers_legacy_cleanup.up.sql @@ -0,0 +1,6 @@ +ALTER TABLE chat_model_configs + ADD CONSTRAINT chat_model_configs_ai_provider_required_when_active + CHECK (deleted = TRUE OR ai_provider_id IS NOT NULL); + +DROP TABLE IF EXISTS user_chat_provider_keys; +DROP TABLE IF EXISTS chat_providers; diff --git a/coderd/database/models.go b/coderd/database/models.go index b67d4e62b6300..75bc8abcb7e84 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4656,23 +4656,6 @@ type ChatModelConfig struct { AiProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"` } -type ChatProvider struct { - ID uuid.UUID `db:"id" json:"id"` - Provider string `db:"provider" json:"provider"` - DisplayName string `db:"display_name" json:"display_name"` - APIKey string `db:"api_key" json:"api_key"` - // The ID of the key used to encrypt the provider API key. If this is NULL, the API key is not encrypted - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` - Enabled bool `db:"enabled" json:"enabled"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - BaseUrl string `db:"base_url" json:"base_url"` - CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` - AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` - AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` -} - type ChatQueuedMessage struct { ID int64 `db:"id" json:"id"` ChatID uuid.UUID `db:"chat_id" json:"chat_id"` @@ -5678,16 +5661,6 @@ type UserAiProviderKey struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } -type UserChatProviderKey struct { - ID uuid.UUID `db:"id" json:"id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` - APIKey string `db:"api_key" json:"api_key"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` -} - type UserConfig struct { UserID uuid.UUID `db:"user_id" json:"user_id"` Key string `db:"key" json:"key"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 20f9386ced9ad..a06e37ba62128 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -123,7 +123,6 @@ type sqlcQuerier interface { DeleteChatDebugDataByChatID(ctx context.Context, arg DeleteChatDebugDataByChatIDParams) (int64, error) DeleteChatModelConfigByID(ctx context.Context, id uuid.UUID) error DeleteChatModelConfigsByProvider(ctx context.Context, provider string) error - DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error DeleteChatQueuedMessage(ctx context.Context, arg DeleteChatQueuedMessageParams) error DeleteChatUsageLimitGroupOverride(ctx context.Context, groupID uuid.UUID) error DeleteChatUsageLimitUserOverride(ctx context.Context, userID uuid.UUID) error @@ -195,7 +194,6 @@ type sqlcQuerier interface { DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error) DeleteUserAIProviderKey(ctx context.Context, arg DeleteUserAIProviderKeyParams) error DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error - DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (UserSecret, error) DeleteWebpushSubscriptionByUserIDAndEndpoint(ctx context.Context, arg DeleteWebpushSubscriptionByUserIDAndEndpointParams) error DeleteWebpushSubscriptions(ctx context.Context, ids []uuid.UUID) error @@ -374,11 +372,6 @@ type sqlcQuerier interface { // personal chat model overrides. It defaults to false when unset. GetChatPersonalModelOverridesEnabled(ctx context.Context) (bool, error) GetChatPlanModeInstructions(ctx context.Context) (string, error) - GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error) - GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (ChatProvider, error) - GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error) - GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (ChatProvider, error) - GetChatProviders(ctx context.Context) ([]ChatProvider, error) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID) ([]ChatQueuedMessage, error) // Returns the chat retention period in days. Chats archived longer // than this and orphaned chat files older than this are purged by @@ -441,7 +434,6 @@ type sqlcQuerier interface { // Check both to ensure the selected config is actually usable. GetEnabledChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error) GetEnabledChatModelConfigs(ctx context.Context) ([]ChatModelConfig, error) - GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error) GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) GetExternalAuthLink(ctx context.Context, arg GetExternalAuthLinkParams) (ExternalAuthLink, error) GetExternalAuthLinksByUserID(ctx context.Context, userID uuid.UUID) ([]ExternalAuthLink, error) @@ -743,7 +735,6 @@ type sqlcQuerier interface { GetUserChatCustomPrompt(ctx context.Context, userID uuid.UUID) (string, error) GetUserChatDebugLoggingEnabled(ctx context.Context, userID uuid.UUID) (bool, error) GetUserChatPersonalModelOverride(ctx context.Context, arg GetUserChatPersonalModelOverrideParams) (string, error) - GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) // Returns the total spend for a user in the given period. // When organization_id is NULL, spend across all organizations is // returned (global behavior). Otherwise only spend within the @@ -914,7 +905,6 @@ type sqlcQuerier interface { InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error) InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error) - InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error) InsertCryptoKey(ctx context.Context, arg InsertCryptoKeyParams) (CryptoKey, error) InsertCustomRole(ctx context.Context, arg InsertCustomRoleParams) (CustomRole, error) @@ -1181,7 +1171,6 @@ type sqlcQuerier interface { UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error) UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error UpdateChatPlanModeByID(ctx context.Context, arg UpdateChatPlanModeByIDParams) (Chat, error) - UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error) UpdateChatTitleByID(ctx context.Context, arg UpdateChatTitleByIDParams) (Chat, error) @@ -1250,7 +1239,6 @@ type sqlcQuerier interface { UpdateUserAgentChatSendShortcut(ctx context.Context, arg UpdateUserAgentChatSendShortcutParams) (string, error) UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error) UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error) - UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) UpdateUserCodeDiffDisplayMode(ctx context.Context, arg UpdateUserCodeDiffDisplayModeParams) (string, error) UpdateUserDeletedByID(ctx context.Context, id uuid.UUID) error UpdateUserGithubComUserID(ctx context.Context, arg UpdateUserGithubComUserIDParams) error @@ -1374,7 +1362,6 @@ type sqlcQuerier interface { UpsertUserAIProviderKey(ctx context.Context, arg UpsertUserAIProviderKeyParams) (UserAiProviderKey, error) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error UpsertUserChatPersonalModelOverride(ctx context.Context, arg UpsertUserChatPersonalModelOverrideParams) error - UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) UpsertWebpushVAPIDKeys(ctx context.Context, arg UpsertWebpushVAPIDKeysParams) error UpsertWorkspaceAgentPortShare(ctx context.Context, arg UpsertWorkspaceAgentPortShareParams) (WorkspaceAgentPortShare, error) UpsertWorkspaceApp(ctx context.Context, arg UpsertWorkspaceAppParams) (WorkspaceApp, error) diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 305ec479407fe..aa5be8316b054 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -10268,6 +10268,42 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) { } } +func insertChatModelConfigForTest( + ctx context.Context, + t testing.TB, + store database.Store, + params database.InsertChatModelConfigParams, +) (database.ChatModelConfig, error) { + t.Helper() + if !params.AiProviderID.Valid { + providerName := params.Provider + if providerName == "" { + providerName = "openai" + params.Provider = providerName + } + providers, err := store.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) + if err != nil { + return database.ChatModelConfig{}, err + } + var provider database.AIProvider + for _, candidate := range providers { + if candidate.Type != database.AIProviderType(providerName) { + continue + } + if provider.ID == uuid.Nil || candidate.CreatedAt.After(provider.CreatedAt) { + provider = candidate + } + } + if provider.ID == uuid.Nil { + provider = dbgen.AIProvider(t, store, database.AIProvider{ + Type: database.AIProviderType(providerName), + }) + } + params.AiProviderID = uuid.NullUUID{UUID: provider.ID, Valid: true} + } + return store.InsertChatModelConfig(ctx, params) +} + func TestInsertChatMessages(t *testing.T) { t.Parallel() @@ -10283,7 +10319,7 @@ func TestInsertChatMessages(t *testing.T) { ) database.ChatModelConfig { t.Helper() - modelConfig, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelConfig, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: provider, Model: model, DisplayName: displayName, @@ -10311,13 +10347,13 @@ func TestInsertChatMessages(t *testing.T) { dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) provider := "openai" - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: provider, DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) modelConfigA := insertModelConfig( @@ -10480,18 +10516,21 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) { org := dbgen.Organization(t, db, database.Organization{}) dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) - // A chat_providers row is required as a FK for model configs. - _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, - CentralApiKeyEnabled: true, + // An AI provider row is required as a FK for model configs. + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "test-" + uuid.NewString(), + DisplayName: sql.NullString{String: "OpenAI", Valid: true}, + Enabled: true, + }) + dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, + APIKey: "test-key", }) - require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{ Provider: "openai", + AiProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, Model: "test-model", DisplayName: "Test Model", CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, @@ -10857,16 +10896,16 @@ func TestGetPRInsights(t *testing.T) { user := dbgen.User(t, store, database.User{}) dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: "anthropic", DisplayName: "Anthropic", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - mc, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + mc, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: "anthropic", Model: "claude-4", DisplayName: "Claude 4", @@ -11313,7 +11352,7 @@ func TestGetPRInsights(t *testing.T) { store, userID, _, orgID := setupChatInfra(t) const modelName = "claude-4.1" - emptyDisplayModel, err := store.InsertChatModelConfig(context.Background(), database.InsertChatModelConfigParams{ + emptyDisplayModel, err := insertChatModelConfigForTest(context.Background(), t, store, database.InsertChatModelConfigParams{ Provider: "anthropic", Model: modelName, DisplayName: "", @@ -11421,16 +11460,16 @@ func TestChatPinOrderQueries(t *testing.T) { // Use background context for fixture setup so the // timed test context doesn't tick during DB init. bg := context.Background() - _, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(bg, t, db, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -11602,16 +11641,16 @@ func TestChatPinOrderConstraints(t *testing.T) { dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID}) bg := context.Background() - _, err := db.InsertChatProvider(bg, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(bg, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(bg, t, db, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -11695,16 +11734,16 @@ func TestChatLabels(t *testing.T) { org := dbgen.Organization(t, db, database.Organization{}) dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID}) - _, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err = dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -11989,16 +12028,16 @@ func TestUpdateChatLastTurnSummary(t *testing.T) { org := dbgen.Organization(t, db, database.Organization{}) dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: owner.ID, OrganizationID: org.ID}) - _, err = db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err = dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, db, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -12090,16 +12129,16 @@ func TestDeleteChatDebugDataAfterMessageIDIncludesTriggeredRuns(t *testing.T) { providerName := "openai" modelName := "debug-model-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -12283,16 +12322,16 @@ func TestDeleteChatDebugDataAfterMessageIDStepLevelFieldBoundariesAndNulls(t *te providerName := "openai" modelName := "debug-model-step-boundaries-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -12541,16 +12580,16 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { providerName := "openai" modelName := "debug-model-finalize-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -12980,16 +13019,16 @@ func TestChatDebugSQLGuards(t *testing.T) { providerName := "openai" modelName := "debug-model-guards-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13114,16 +13153,16 @@ func TestChatDebugRunCOALESCEPreservation(t *testing.T) { providerName := "openai" modelName := "debug-model-coalesce-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13229,16 +13268,16 @@ func TestChatDebugStepCOALESCEPreservation(t *testing.T) { providerName := "openai" modelName := "debug-step-coalesce-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13354,16 +13393,16 @@ func TestDeleteChatDebugDataAfterMessageIDNullMessagesSurvive(t *testing.T) { providerName := "openai" modelName := "debug-model-null-msg-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13452,16 +13491,16 @@ func TestDeleteChatDebugDataAfterMessageIDStartedBeforeFiltersNewerRuns(t *testi providerName := "openai" modelName := "debug-model-started-before-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13564,16 +13603,16 @@ func TestDeleteChatDebugDataByChatIDStartedBeforeFiltersNewerRuns(t *testing.T) providerName := "openai" modelName := "debug-model-by-chat-started-before-" + uuid.NewString() - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: providerName, DisplayName: "Debug Provider", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: providerName, Model: modelName, DisplayName: "Debug Model", @@ -13653,16 +13692,16 @@ func TestChatHasUnread(t *testing.T) { user := dbgen.User(t, store, database.User{}) dbgen.OrganizationMember(t, store, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) - _, err := store.InsertChatProvider(ctx, database.InsertChatProviderParams{ + _, err := dbgen.ChatProvider(t, store, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", APIKey: "test-key", Enabled: true, CentralApiKeyEnabled: true, - }) + }), error(nil) require.NoError(t, err) - modelCfg, err := store.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + modelCfg, err := insertChatModelConfigForTest(ctx, t, store, database.InsertChatModelConfigParams{ Provider: "openai", Model: "test-model-" + uuid.NewString(), DisplayName: "Test Model", diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 7f003daa08bf8..fe2f09b94aa43 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5124,18 +5124,14 @@ SELECT cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id FROM chat_model_configs cmc -LEFT JOIN - ai_providers ap ON ap.id = cmc.ai_provider_id AND ap.deleted = FALSE -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider +JOIN + ai_providers ap ON ap.id = cmc.ai_provider_id WHERE cmc.id = $1::uuid AND cmc.deleted = FALSE AND cmc.enabled = TRUE - AND ( - (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE) - OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) - ) + AND ap.enabled = TRUE + AND ap.deleted = FALSE ` // Providers can be disabled independently of their model configs. @@ -5169,17 +5165,13 @@ SELECT cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id FROM chat_model_configs cmc -LEFT JOIN - ai_providers ap ON ap.id = cmc.ai_provider_id AND ap.deleted = FALSE -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider +JOIN + ai_providers ap ON ap.id = cmc.ai_provider_id WHERE cmc.enabled = TRUE AND cmc.deleted = FALSE - AND ( - (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE) - OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) - ) + AND ap.enabled = TRUE + AND ap.deleted = FALSE ORDER BY cmc.provider ASC, cmc.model ASC, @@ -5395,369 +5387,6 @@ func (q *sqlQuerier) UpdateChatModelConfig(ctx context.Context, arg UpdateChatMo return i, err } -const deleteChatProviderByID = `-- name: DeleteChatProviderByID :exec -DELETE FROM - chat_providers -WHERE - id = $1::uuid -` - -func (q *sqlQuerier) DeleteChatProviderByID(ctx context.Context, id uuid.UUID) error { - _, err := q.db.ExecContext(ctx, deleteChatProviderByID, id) - return err -} - -const getChatProviderByID = `-- name: GetChatProviderByID :one -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - id = $1::uuid -` - -func (q *sqlQuerier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, getChatProviderByID, id) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const getChatProviderByIDForUpdate = `-- name: GetChatProviderByIDForUpdate :one -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - id = $1::uuid -FOR UPDATE -` - -func (q *sqlQuerier) GetChatProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, getChatProviderByIDForUpdate, id) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const getChatProviderByProvider = `-- name: GetChatProviderByProvider :one -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - provider = $1::text -` - -func (q *sqlQuerier) GetChatProviderByProvider(ctx context.Context, provider string) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, getChatProviderByProvider, provider) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const getChatProviderByProviderForUpdate = `-- name: GetChatProviderByProviderForUpdate :one -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - provider = $1::text -FOR UPDATE -` - -func (q *sqlQuerier) GetChatProviderByProviderForUpdate(ctx context.Context, provider string) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, getChatProviderByProviderForUpdate, provider) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const getChatProviders = `-- name: GetChatProviders :many -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -ORDER BY - provider ASC -` - -func (q *sqlQuerier) GetChatProviders(ctx context.Context) ([]ChatProvider, error) { - rows, err := q.db.QueryContext(ctx, getChatProviders) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatProvider - for rows.Next() { - var i ChatProvider - if err := rows.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const getEnabledChatProviders = `-- name: GetEnabledChatProviders :many -SELECT - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -FROM - chat_providers -WHERE - enabled = TRUE -ORDER BY - provider ASC -` - -func (q *sqlQuerier) GetEnabledChatProviders(ctx context.Context) ([]ChatProvider, error) { - rows, err := q.db.QueryContext(ctx, getEnabledChatProviders) - if err != nil { - return nil, err - } - defer rows.Close() - var items []ChatProvider - for rows.Next() { - var i ChatProvider - if err := rows.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const insertChatProvider = `-- name: InsertChatProvider :one -INSERT INTO chat_providers ( - provider, - display_name, - api_key, - base_url, - api_key_key_id, - created_by, - enabled, - central_api_key_enabled, - allow_user_api_key, - allow_central_api_key_fallback -) VALUES ( - $1::text, - $2::text, - $3::text, - $4::text, - $5::text, - $6::uuid, - $7::boolean, - $8::boolean, - $9::boolean, - $10::boolean -) -RETURNING - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -` - -type InsertChatProviderParams struct { - Provider string `db:"provider" json:"provider"` - DisplayName string `db:"display_name" json:"display_name"` - APIKey string `db:"api_key" json:"api_key"` - BaseUrl string `db:"base_url" json:"base_url"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` - Enabled bool `db:"enabled" json:"enabled"` - CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` - AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` - AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` -} - -func (q *sqlQuerier) InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, insertChatProvider, - arg.Provider, - arg.DisplayName, - arg.APIKey, - arg.BaseUrl, - arg.ApiKeyKeyID, - arg.CreatedBy, - arg.Enabled, - arg.CentralApiKeyEnabled, - arg.AllowUserApiKey, - arg.AllowCentralApiKeyFallback, - ) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - -const updateChatProvider = `-- name: UpdateChatProvider :one -UPDATE - chat_providers -SET - display_name = $1::text, - api_key = $2::text, - base_url = $3::text, - api_key_key_id = $4::text, - enabled = $5::boolean, - central_api_key_enabled = $6::boolean, - allow_user_api_key = $7::boolean, - allow_central_api_key_fallback = $8::boolean, - updated_at = NOW() -WHERE - id = $9::uuid -RETURNING - id, provider, display_name, api_key, api_key_key_id, created_by, enabled, created_at, updated_at, base_url, central_api_key_enabled, allow_user_api_key, allow_central_api_key_fallback -` - -type UpdateChatProviderParams struct { - DisplayName string `db:"display_name" json:"display_name"` - APIKey string `db:"api_key" json:"api_key"` - BaseUrl string `db:"base_url" json:"base_url"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - Enabled bool `db:"enabled" json:"enabled"` - CentralApiKeyEnabled bool `db:"central_api_key_enabled" json:"central_api_key_enabled"` - AllowUserApiKey bool `db:"allow_user_api_key" json:"allow_user_api_key"` - AllowCentralApiKeyFallback bool `db:"allow_central_api_key_fallback" json:"allow_central_api_key_fallback"` - ID uuid.UUID `db:"id" json:"id"` -} - -func (q *sqlQuerier) UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) { - row := q.db.QueryRowContext(ctx, updateChatProvider, - arg.DisplayName, - arg.APIKey, - arg.BaseUrl, - arg.ApiKeyKeyID, - arg.Enabled, - arg.CentralApiKeyEnabled, - arg.AllowUserApiKey, - arg.AllowCentralApiKeyFallback, - arg.ID, - ) - var i ChatProvider - err := row.Scan( - &i.ID, - &i.Provider, - &i.DisplayName, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedBy, - &i.Enabled, - &i.CreatedAt, - &i.UpdatedAt, - &i.BaseUrl, - &i.CentralApiKeyEnabled, - &i.AllowUserApiKey, - &i.AllowCentralApiKeyFallback, - ) - return i, err -} - const acquireChats = `-- name: AcquireChats :many WITH acquired_chats AS ( UPDATE @@ -27008,126 +26637,6 @@ func (q *sqlQuerier) UpdateUserSecretByUserIDAndName(ctx context.Context, arg Up return i, err } -const deleteUserChatProviderKey = `-- name: DeleteUserChatProviderKey :exec -DELETE FROM user_chat_provider_keys WHERE user_id = $1 AND chat_provider_id = $2 -` - -type DeleteUserChatProviderKeyParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` -} - -func (q *sqlQuerier) DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error { - _, err := q.db.ExecContext(ctx, deleteUserChatProviderKey, arg.UserID, arg.ChatProviderID) - return err -} - -const getUserChatProviderKeys = `-- name: GetUserChatProviderKeys :many -SELECT id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at FROM user_chat_provider_keys WHERE user_id = $1 ORDER BY created_at ASC, id ASC -` - -func (q *sqlQuerier) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]UserChatProviderKey, error) { - rows, err := q.db.QueryContext(ctx, getUserChatProviderKeys, userID) - if err != nil { - return nil, err - } - defer rows.Close() - var items []UserChatProviderKey - for rows.Next() { - var i UserChatProviderKey - if err := rows.Scan( - &i.ID, - &i.UserID, - &i.ChatProviderID, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedAt, - &i.UpdatedAt, - ); err != nil { - return nil, err - } - items = append(items, i) - } - if err := rows.Close(); err != nil { - return nil, err - } - if err := rows.Err(); err != nil { - return nil, err - } - return items, nil -} - -const updateUserChatProviderKey = `-- name: UpdateUserChatProviderKey :one -UPDATE user_chat_provider_keys -SET api_key = $1, api_key_key_id = $2::text, updated_at = NOW() -WHERE user_id = $3 AND chat_provider_id = $4 -RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at -` - -type UpdateUserChatProviderKeyParams struct { - APIKey string `db:"api_key" json:"api_key"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` - UserID uuid.UUID `db:"user_id" json:"user_id"` - ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` -} - -func (q *sqlQuerier) UpdateUserChatProviderKey(ctx context.Context, arg UpdateUserChatProviderKeyParams) (UserChatProviderKey, error) { - row := q.db.QueryRowContext(ctx, updateUserChatProviderKey, - arg.APIKey, - arg.ApiKeyKeyID, - arg.UserID, - arg.ChatProviderID, - ) - var i UserChatProviderKey - err := row.Scan( - &i.ID, - &i.UserID, - &i.ChatProviderID, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - -const upsertUserChatProviderKey = `-- name: UpsertUserChatProviderKey :one -INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id) -VALUES ($1, $2, $3, $4::text) -ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET - api_key = $3, - api_key_key_id = $4::text, - updated_at = NOW() -RETURNING id, user_id, chat_provider_id, api_key, api_key_key_id, created_at, updated_at -` - -type UpsertUserChatProviderKeyParams struct { - UserID uuid.UUID `db:"user_id" json:"user_id"` - ChatProviderID uuid.UUID `db:"chat_provider_id" json:"chat_provider_id"` - APIKey string `db:"api_key" json:"api_key"` - ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` -} - -func (q *sqlQuerier) UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) { - row := q.db.QueryRowContext(ctx, upsertUserChatProviderKey, - arg.UserID, - arg.ChatProviderID, - arg.APIKey, - arg.ApiKeyKeyID, - ) - var i UserChatProviderKey - err := row.Scan( - &i.ID, - &i.UserID, - &i.ChatProviderID, - &i.APIKey, - &i.ApiKeyKeyID, - &i.CreatedAt, - &i.UpdatedAt, - ) - return i, err -} - const allUserIDs = `-- name: AllUserIDs :many SELECT DISTINCT id FROM USERS WHERE CASE WHEN $1::bool THEN TRUE ELSE is_system = false END diff --git a/coderd/database/queries/chatmodelconfigs.sql b/coderd/database/queries/chatmodelconfigs.sql index 472ba704c83e1..8dd24cb3f9966 100644 --- a/coderd/database/queries/chatmodelconfigs.sql +++ b/coderd/database/queries/chatmodelconfigs.sql @@ -34,17 +34,13 @@ SELECT cmc.* FROM chat_model_configs cmc -LEFT JOIN - ai_providers ap ON ap.id = cmc.ai_provider_id AND ap.deleted = FALSE -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider +JOIN + ai_providers ap ON ap.id = cmc.ai_provider_id WHERE cmc.enabled = TRUE AND cmc.deleted = FALSE - AND ( - (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE) - OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) - ) + AND ap.enabled = TRUE + AND ap.deleted = FALSE ORDER BY cmc.provider ASC, cmc.model ASC, @@ -58,18 +54,14 @@ FROM chat_model_configs cmc -- Providers can be disabled independently of their model configs. -- Check both to ensure the selected config is actually usable. -LEFT JOIN - ai_providers ap ON ap.id = cmc.ai_provider_id AND ap.deleted = FALSE -LEFT JOIN - chat_providers cp ON cp.provider = cmc.provider +JOIN + ai_providers ap ON ap.id = cmc.ai_provider_id WHERE cmc.id = @id::uuid AND cmc.deleted = FALSE AND cmc.enabled = TRUE - AND ( - (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE) - OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) - ); + AND ap.enabled = TRUE + AND ap.deleted = FALSE; -- name: InsertChatModelConfig :one INSERT INTO chat_model_configs ( diff --git a/coderd/database/queries/chatproviders.sql b/coderd/database/queries/chatproviders.sql deleted file mode 100644 index 7df983541d335..0000000000000 --- a/coderd/database/queries/chatproviders.sql +++ /dev/null @@ -1,102 +0,0 @@ --- name: GetChatProviderByID :one -SELECT - * -FROM - chat_providers -WHERE - id = @id::uuid; - --- name: GetChatProviderByIDForUpdate :one -SELECT - * -FROM - chat_providers -WHERE - id = @id::uuid -FOR UPDATE; - --- name: GetChatProviderByProvider :one -SELECT - * -FROM - chat_providers -WHERE - provider = @provider::text; - --- name: GetChatProviderByProviderForUpdate :one -SELECT - * -FROM - chat_providers -WHERE - provider = @provider::text -FOR UPDATE; - --- name: GetChatProviders :many -SELECT - * -FROM - chat_providers -ORDER BY - provider ASC; - --- name: GetEnabledChatProviders :many -SELECT - * -FROM - chat_providers -WHERE - enabled = TRUE -ORDER BY - provider ASC; - --- name: InsertChatProvider :one -INSERT INTO chat_providers ( - provider, - display_name, - api_key, - base_url, - api_key_key_id, - created_by, - enabled, - central_api_key_enabled, - allow_user_api_key, - allow_central_api_key_fallback -) VALUES ( - @provider::text, - @display_name::text, - @api_key::text, - @base_url::text, - sqlc.narg('api_key_key_id')::text, - sqlc.narg('created_by')::uuid, - @enabled::boolean, - @central_api_key_enabled::boolean, - @allow_user_api_key::boolean, - @allow_central_api_key_fallback::boolean -) -RETURNING - *; - --- name: UpdateChatProvider :one -UPDATE - chat_providers -SET - display_name = @display_name::text, - api_key = @api_key::text, - base_url = @base_url::text, - api_key_key_id = sqlc.narg('api_key_key_id')::text, - enabled = @enabled::boolean, - central_api_key_enabled = @central_api_key_enabled::boolean, - allow_user_api_key = @allow_user_api_key::boolean, - allow_central_api_key_fallback = @allow_central_api_key_fallback::boolean, - updated_at = NOW() -WHERE - id = @id::uuid -RETURNING - *; - --- name: DeleteChatProviderByID :exec -DELETE FROM - chat_providers -WHERE - id = @id::uuid; diff --git a/coderd/database/queries/userchatproviderkeys.sql b/coderd/database/queries/userchatproviderkeys.sql deleted file mode 100644 index 38c177156ef5f..0000000000000 --- a/coderd/database/queries/userchatproviderkeys.sql +++ /dev/null @@ -1,20 +0,0 @@ --- name: GetUserChatProviderKeys :many -SELECT * FROM user_chat_provider_keys WHERE user_id = @user_id ORDER BY created_at ASC, id ASC; - --- name: UpsertUserChatProviderKey :one -INSERT INTO user_chat_provider_keys (user_id, chat_provider_id, api_key, api_key_key_id) -VALUES (@user_id, @chat_provider_id, @api_key, sqlc.narg('api_key_key_id')::text) -ON CONFLICT (user_id, chat_provider_id) DO UPDATE SET - api_key = @api_key, - api_key_key_id = sqlc.narg('api_key_key_id')::text, - updated_at = NOW() -RETURNING *; - --- name: UpdateUserChatProviderKey :one -UPDATE user_chat_provider_keys -SET api_key = @api_key, api_key_key_id = sqlc.narg('api_key_key_id')::text, updated_at = NOW() -WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id -RETURNING *; - --- name: DeleteUserChatProviderKey :exec -DELETE FROM user_chat_provider_keys WHERE user_id = @user_id AND chat_provider_id = @chat_provider_id; diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 9724d1a070405..6b864a7392690 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -25,8 +25,6 @@ const ( UniqueChatFilesPkey UniqueConstraint = "chat_files_pkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_pkey PRIMARY KEY (id); UniqueChatMessagesPkey UniqueConstraint = "chat_messages_pkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_pkey PRIMARY KEY (id); UniqueChatModelConfigsPkey UniqueConstraint = "chat_model_configs_pkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_pkey PRIMARY KEY (id); - UniqueChatProvidersPkey UniqueConstraint = "chat_providers_pkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_pkey PRIMARY KEY (id); - UniqueChatProvidersProviderKey UniqueConstraint = "chat_providers_provider_key" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_provider_key UNIQUE (provider); UniqueChatQueuedMessagesPkey UniqueConstraint = "chat_queued_messages_pkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_pkey PRIMARY KEY (id); UniqueChatUsageLimitConfigPkey UniqueConstraint = "chat_usage_limit_config_pkey" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_pkey PRIMARY KEY (id); UniqueChatUsageLimitConfigSingletonKey UniqueConstraint = "chat_usage_limit_config_singleton_key" // ALTER TABLE ONLY chat_usage_limit_config ADD CONSTRAINT chat_usage_limit_config_singleton_key UNIQUE (singleton); @@ -99,8 +97,6 @@ const ( UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id); UniqueUserAiProviderKeysPkey UniqueConstraint = "user_ai_provider_keys_pkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_pkey PRIMARY KEY (id); UniqueUserAiProviderKeysUserIDAiProviderIDKey UniqueConstraint = "user_ai_provider_keys_user_id_ai_provider_id_key" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id); - UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id); - UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id); UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key); UniqueUserDeletedPkey UniqueConstraint = "user_deleted_pkey" // ALTER TABLE ONLY user_deleted ADD CONSTRAINT user_deleted_pkey PRIMARY KEY (id); UniqueUserLinksPkey UniqueConstraint = "user_links_pkey" // ALTER TABLE ONLY user_links ADD CONSTRAINT user_links_pkey PRIMARY KEY (user_id, login_type); diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index d3f5ec7c36154..4be19daeae68f 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -11,7 +11,6 @@ import ( "mime" "net/http" "net/http/httptest" - "net/url" "slices" "strconv" "strings" @@ -775,7 +774,7 @@ func (api *API) getUserChatProviderAvailability( ) (userChatModelAvailability, error) { //nolint:gocritic // System context is required to read enabled chat config. systemCtx := dbauthz.AsSystemRestricted(ctx) - enabledProviders, err := api.Database.GetEnabledChatProviders(systemCtx) + enabledProviders, err := api.Database.GetAIProviders(systemCtx, database.GetAIProvidersParams{}) if err != nil { return userChatModelAvailability{}, err } @@ -791,19 +790,12 @@ func (api *API) getUserChatProviderAvailability( enabledProviderNames: make(map[string]struct{}, len(enabledProviders)), } for _, provider := range enabledProviders { - availability.configuredProviders = append( - availability.configuredProviders, - chatprovider.ConfiguredProvider{ - ProviderID: provider.ID, - Provider: provider.Provider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserAPIKey: provider.AllowUserApiKey, - AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, - }, - ) - normalizedProvider := chatprovider.NormalizeProvider(provider.Provider) + configuredProvider, err := api.configuredProviderFromAIProvider(systemCtx, provider) + if err != nil { + return userChatModelAvailability{}, err + } + availability.configuredProviders = append(availability.configuredProviders, configuredProvider) + normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider) if normalizedProvider != "" { availability.enabledProviderNames[normalizedProvider] = struct{}{} } @@ -816,16 +808,19 @@ func (api *API) getUserChatProviderAvailability( }) } - userKeyRows, err := api.Database.GetUserChatProviderKeys(ctx, userID) - if err != nil { - return userChatModelAvailability{}, err - } - userKeys := make([]chatprovider.UserProviderKey, 0, len(userKeyRows)) - for _, userKey := range userKeyRows { - userKeys = append(userKeys, chatprovider.UserProviderKey{ - ChatProviderID: userKey.ChatProviderID, - APIKey: userKey.APIKey, - }) + userKeys := []chatprovider.UserProviderKey{} + if api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() { + userKeyRows, err := api.Database.GetUserAIProviderKeysByUserID(ctx, userID) + if err != nil { + return userChatModelAvailability{}, err + } + userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows)) + for _, userKey := range userKeyRows { + userKeys = append(userKeys, chatprovider.UserProviderKey{ + ChatProviderID: userKey.AiProviderID, + APIKey: userKey.APIKey, + }) + } } _, availability.providerStatus = chatprovider.ResolveUserProviderKeys( @@ -6913,615 +6908,61 @@ func (api *API) deleteUserAIProviderKey(rw http.ResponseWriter, r *http.Request) httpapi.Write(ctx, rw, http.StatusNoContent, nil) } -func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - //nolint:gocritic // System context required to read enabled chat providers. - systemCtx := dbauthz.AsSystemRestricted(ctx) - if !api.Authorize(r, policy.ActionRead, rbac.ResourceDeploymentConfig) { - httpapi.Forbidden(rw) - return - } - - providers, err := api.Database.GetChatProviders(ctx) +func (api *API) configuredProviderFromAIProvider(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) { + keys, err := api.Database.GetAIProviderKeysByProviderID(ctx, provider.ID) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to list chat providers.", - Detail: err.Error(), - }) - return - } - - providersByName := make(map[string]database.ChatProvider, len(providers)) - configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers)) - for _, provider := range providers { - normalizedProvider := normalizeChatProvider(provider.Provider) - if normalizedProvider == "" { - continue - } - provider.Provider = normalizedProvider - providersByName[normalizedProvider] = provider - configuredProviders = append(configuredProviders, chatprovider.ConfiguredProvider{ - Provider: normalizedProvider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - }) + return chatprovider.ConfiguredProvider{}, err } - if api.chatDaemon == nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Chat processor is unavailable.", - Detail: "Chat processor is not configured.", - }) - return - } - - enabledProviders, err := api.Database.GetEnabledChatProviders( - systemCtx, - ) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to resolve provider API keys.", - Detail: err.Error(), - }) - return - } - - enabledConfiguredProviders := make( - []chatprovider.ConfiguredProvider, 0, len(enabledProviders), - ) - for _, provider := range enabledProviders { - normalizedProvider := normalizeChatProvider(provider.Provider) - if normalizedProvider == "" { - continue - } - enabledConfiguredProviders = append( - enabledConfiguredProviders, chatprovider.ConfiguredProvider{ - Provider: normalizedProvider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - }, - ) - } - - effectiveKeys := chatprovider.MergeProviderAPIKeys( - ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues), - enabledConfiguredProviders, - ) - effectiveKeys = chatprovider.MergeProviderAPIKeys( - effectiveKeys, configuredProviders, - ) - - supportedProviders := chatprovider.SupportedProviders() - resp := make([]codersdk.ChatProviderConfig, 0, len(supportedProviders)) - for _, provider := range supportedProviders { - configured, ok := providersByName[provider] - if ok { - resp = append( - resp, - convertChatProviderConfig( - configured, - api.hasEffectiveProviderAPIKey(ctx, configured), - codersdk.ChatProviderConfigSourceDatabase, - ), - ) - continue - } - - source := codersdk.ChatProviderConfigSourceSupported - hasAPIKey := effectiveKeys.APIKey(provider) != "" - enabled := false - if chatprovider.IsEnvPresetProvider(provider) && hasAPIKey { - source = codersdk.ChatProviderConfigSourceEnvPreset - enabled = true + apiKey := "" + for _, key := range keys { + if trimmed := strings.TrimSpace(key.APIKey); trimmed != "" { + apiKey = trimmed + break } - - resp = append(resp, codersdk.ChatProviderConfig{ - ID: uuid.Nil, - Provider: provider, - DisplayName: chatprovider.ProviderDisplayName(provider), - Enabled: enabled, - HasAPIKey: hasAPIKey, - CentralAPIKeyEnabled: true, - AllowUserAPIKey: false, - AllowCentralAPIKeyFallback: false, - BaseURL: effectiveKeys.BaseURL(provider), - Source: source, - }) } - - httpapi.Write(ctx, rw, http.StatusOK, resp) + return chatprovider.ConfiguredProvider{ + ProviderID: provider.ID, + Provider: string(provider.Type), + APIKey: apiKey, + BaseURL: provider.BaseUrl, + CentralAPIKeyEnabled: true, + AllowUserAPIKey: api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(), + AllowCentralAPIKeyFallback: true, + }, nil } -func (api *API) createChatProvider(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - apiKey := httpmw.APIKey(r) - var inserted database.ChatProvider - if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { - httpapi.Forbidden(rw) - return - } - - var req codersdk.CreateChatProviderConfigRequest - if !httpapi.Read(ctx, rw, r, &req) { - return - } - - provider := normalizeChatProvider(req.Provider) - if provider == "" { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider.", - Detail: chatProviderValidationDetail(), - }) - return - } - - if err := validateChatProviderAPIKeySize(strings.TrimSpace(req.APIKey)); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "API key too large.", - Detail: err.Error(), - }) - return - } - - enabled := true - if req.Enabled != nil { - enabled = *req.Enabled - } - baseURL, err := normalizeChatProviderBaseURL(req.BaseURL) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider base URL.", - Detail: err.Error(), - }) - return - } - - centralAPIKeyEnabled := true - if req.CentralAPIKeyEnabled != nil { - centralAPIKeyEnabled = *req.CentralAPIKeyEnabled - } - allowUserAPIKey := false - if req.AllowUserAPIKey != nil { - allowUserAPIKey = *req.AllowUserAPIKey - } - allowCentralAPIKeyFallback := false - if req.AllowCentralAPIKeyFallback != nil { - allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback - } - - if err := validateChatProviderCredentialPolicy( - centralAPIKeyEnabled, - allowUserAPIKey, - allowCentralAPIKeyFallback, - ); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid credential policy.", - Detail: err.Error(), - }) - return - } - - if err := validateChatProviderCentralAPIKey( - provider, - centralAPIKeyEnabled, - api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{ - Provider: provider, - APIKey: strings.TrimSpace(req.APIKey), - BaseUrl: baseURL, - CentralApiKeyEnabled: centralAPIKeyEnabled, - }, uuid.Nil), - ); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: err.Error(), - }) - return - } - - inserted, err = api.Database.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: provider, - DisplayName: strings.TrimSpace(req.DisplayName), - APIKey: strings.TrimSpace(req.APIKey), - BaseUrl: baseURL, - ApiKeyKeyID: sql.NullString{}, - CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil}, - Enabled: enabled, - CentralApiKeyEnabled: centralAPIKeyEnabled, - AllowUserApiKey: allowUserAPIKey, - AllowCentralApiKeyFallback: allowCentralAPIKeyFallback, +func writeLegacyChatProviderGone(rw http.ResponseWriter, r *http.Request) { + httpapi.Write(r.Context(), rw, http.StatusGone, codersdk.Response{ + Message: "Legacy chat provider APIs were removed. Use AI provider APIs instead.", }) - if err != nil { - switch { - case database.IsUniqueViolation(err): - httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ - Message: "Chat provider already exists.", - Detail: err.Error(), - }) - return - case database.IsCheckViolation(err): - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider.", - Detail: err.Error(), - }) - return - default: - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to create chat provider.", - Detail: err.Error(), - }) - return - } - } - - publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) - - httpapi.Write( - ctx, - rw, - http.StatusCreated, - convertChatProviderConfig( - inserted, - api.hasEffectiveProviderAPIKey(ctx, inserted), - codersdk.ChatProviderConfigSourceDatabase, - ), - ) } -func (api *API) updateChatProvider(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - var ( - existing database.ChatProvider - updated database.ChatProvider - ) - if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { - httpapi.Forbidden(rw) - return - } - - providerID, ok := parseChatProviderID(rw, r) - if !ok { - return - } - - existing, err := api.Database.GetChatProviderByID(ctx, providerID) - if err != nil { - if httpapi.Is404Error(err) { - httpapi.ResourceNotFound(rw) - return - } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to get chat provider.", - Detail: err.Error(), - }) - return - } - - var req codersdk.UpdateChatProviderConfigRequest - if !httpapi.Read(ctx, rw, r, &req) { - return - } - - displayName := existing.DisplayName - if trimmed := strings.TrimSpace(req.DisplayName); trimmed != "" { - displayName = trimmed - } - - enabled := existing.Enabled - if req.Enabled != nil { - enabled = *req.Enabled - } - - apiKey := existing.APIKey - apiKeyKeyID := existing.ApiKeyKeyID - if req.APIKey != nil { - trimmedAPIKey := strings.TrimSpace(*req.APIKey) - if trimmedAPIKey != "" { - if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "API key too large.", - Detail: err.Error(), - }) - return - } - } - apiKey = trimmedAPIKey - apiKeyKeyID = sql.NullString{} - } - baseURL := existing.BaseUrl - if req.BaseURL != nil { - baseURL, err = normalizeChatProviderBaseURL(*req.BaseURL) - if err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider base URL.", - Detail: err.Error(), - }) - return - } - } - - centralAPIKeyEnabled := existing.CentralApiKeyEnabled - if req.CentralAPIKeyEnabled != nil { - centralAPIKeyEnabled = *req.CentralAPIKeyEnabled - } - allowUserAPIKey := existing.AllowUserApiKey - if req.AllowUserAPIKey != nil { - allowUserAPIKey = *req.AllowUserAPIKey - } - allowCentralAPIKeyFallback := existing.AllowCentralApiKeyFallback - if req.AllowCentralAPIKeyFallback != nil { - allowCentralAPIKeyFallback = *req.AllowCentralAPIKeyFallback - } - - if err := validateChatProviderCredentialPolicy( - centralAPIKeyEnabled, - allowUserAPIKey, - allowCentralAPIKeyFallback, - ); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid credential policy.", - Detail: err.Error(), - }) - return - } - - if err := validateChatProviderCentralAPIKey( - existing.Provider, - centralAPIKeyEnabled, - api.hasEffectiveCentralProviderAPIKey(ctx, database.ChatProvider{ - ID: existing.ID, - Provider: existing.Provider, - APIKey: apiKey, - BaseUrl: baseURL, - CentralApiKeyEnabled: centralAPIKeyEnabled, - }, existing.ID), - ); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: err.Error(), - }) - return - } - - updated, err = api.Database.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: displayName, - APIKey: apiKey, - BaseUrl: baseURL, - ApiKeyKeyID: apiKeyKeyID, - Enabled: enabled, - CentralApiKeyEnabled: centralAPIKeyEnabled, - AllowUserApiKey: allowUserAPIKey, - AllowCentralApiKeyFallback: allowCentralAPIKeyFallback, - ID: existing.ID, - }) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to update chat provider.", - Detail: err.Error(), - }) - return - } - - publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) - - httpapi.Write( - ctx, - rw, - http.StatusOK, - convertChatProviderConfig( - updated, - api.hasEffectiveProviderAPIKey(ctx, updated), - codersdk.ChatProviderConfigSourceDatabase, - ), - ) +func (*API) listChatProviders(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } -func (api *API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) { - ctx := r.Context() - if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceDeploymentConfig) { - httpapi.Forbidden(rw) - return - } - - providerID, ok := parseChatProviderID(rw, r) - if !ok { - return - } - - err := api.Database.InTx(func(tx database.Store) error { - provider, err := tx.GetChatProviderByIDForUpdate(ctx, providerID) - switch { - case err == nil: - if err := tx.DeleteChatModelConfigsByProvider(ctx, provider.Provider); err != nil { - return xerrors.Errorf("soft delete chat model configs for provider %q: %w", provider.Provider, err) - } - if err := ensureDefaultChatModelConfig(ctx, tx); err != nil { - return err - } - if err := tx.DeleteChatProviderByID(ctx, provider.ID); err != nil { - return xerrors.Errorf("delete chat provider %s: %w", provider.ID, err) - } - return nil - case xerrors.Is(err, sql.ErrNoRows): - return err - default: - return xerrors.Errorf("get chat provider %s for delete: %w", providerID, err) - } - }, nil) - if err != nil { - if xerrors.Is(err, sql.ErrNoRows) { - httpapi.ResourceNotFound(rw) - return - } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to delete chat provider.", - Detail: err.Error(), - }) - return - } - - publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) - - rw.WriteHeader(http.StatusNoContent) +func (*API) createChatProvider(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } -func (api *API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - apiKey = httpmw.APIKey(r) - ) - - //nolint:gocritic // Non-admin users need to read provider configs to manage their own chat credentials. - chatdCtx := dbauthz.AsChatd(ctx) - providers, err := api.Database.GetChatProviders(chatdCtx) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to list chat providers.", - Detail: err.Error(), - }) - return - } - - userKeys, err := api.Database.GetUserChatProviderKeys(ctx, apiKey.UserID) - if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to list user chat provider keys.", - Detail: err.Error(), - }) - return - } - - hasUserAPIKeyByProviderID := make(map[uuid.UUID]bool, len(userKeys)) - for _, userKey := range userKeys { - hasUserAPIKeyByProviderID[userKey.ChatProviderID] = true - } - - resp := make([]codersdk.UserChatProviderConfig, 0, len(providers)) - for _, provider := range providers { - if !provider.Enabled || !provider.AllowUserApiKey { - continue - } - hasUserAPIKey := hasUserAPIKeyByProviderID[provider.ID] - hasCentralAPIKeyFallback := provider.Enabled && - provider.AllowCentralApiKeyFallback && - api.hasEffectiveCentralProviderCredentials(ctx, provider, uuid.Nil) - resp = append( - resp, - convertUserChatProviderConfig( - provider, - hasUserAPIKey, - hasCentralAPIKeyFallback, - ), - ) - } - - httpapi.Write(ctx, rw, http.StatusOK, resp) +func (*API) updateChatProvider(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } -func (api *API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - apiKey = httpmw.APIKey(r) - ) - - providerID, ok := parseChatProviderID(rw, r) - if !ok { - return - } - - //nolint:gocritic // Non-admin users need to validate provider availability before storing their own key. - provider, err := api.Database.GetChatProviderByID(dbauthz.AsChatd(ctx), providerID) - if err != nil { - if httpapi.Is404Error(err) { - httpapi.ResourceNotFound(rw) - return - } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to get chat provider.", - Detail: err.Error(), - }) - return - } - if !provider.Enabled { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Provider is disabled.", - }) - return - } - if !provider.AllowUserApiKey { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Provider does not allow user API keys.", - }) - return - } - - var req codersdk.CreateUserChatProviderKeyRequest - if !httpapi.Read(ctx, rw, r, &req) { - return - } - - trimmedAPIKey := strings.TrimSpace(req.APIKey) - if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "API key too large.", - Detail: err.Error(), - }) - return - } - if trimmedAPIKey == "" { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "API key is required.", - }) - return - } - - if _, err := api.Database.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ - UserID: apiKey.UserID, - ChatProviderID: providerID, - APIKey: trimmedAPIKey, - ApiKeyKeyID: sql.NullString{}, - }); err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to save user chat provider key.", - Detail: err.Error(), - }) - return - } - - hasCentralAPIKeyFallback := provider.Enabled && - provider.AllowCentralApiKeyFallback && - api.hasEffectiveCentralProviderCredentials(ctx, provider, uuid.Nil) - httpapi.Write( - ctx, - rw, - http.StatusOK, - convertUserChatProviderConfig( - provider, - true, - hasCentralAPIKeyFallback, - ), - ) +func (*API) deleteChatProvider(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } -func (api *API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - apiKey = httpmw.APIKey(r) - ) - - providerID, ok := parseChatProviderID(rw, r) - if !ok { - return - } +func (*API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) +} - if err := api.Database.DeleteUserChatProviderKey(ctx, database.DeleteUserChatProviderKeyParams{ - UserID: apiKey.UserID, - ChatProviderID: providerID, - }); err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to delete user chat provider key.", - Detail: err.Error(), - }) - return - } +func (*API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) +} - rw.WriteHeader(http.StatusNoContent) +func (*API) deleteUserChatProviderKey(rw http.ResponseWriter, r *http.Request) { + writeLegacyChatProviderGone(rw, r) } func (api *API) listChatModelConfigs(rw http.ResponseWriter, r *http.Request) { @@ -7569,37 +7010,28 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { return } - provider := "" - aiProviderID := uuid.NullUUID{} - if req.AIProviderID != nil { - aiProvider, err := api.Database.GetAIProviderByID(ctx, *req.AIProviderID) - if err != nil { - if httpapi.Is404Error(err) { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider is not configured."}) - return - } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ - Message: "Failed to get AI provider.", - Detail: err.Error(), - }) - return - } - if !aiProvider.Enabled { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider is disabled."}) - return - } - provider = string(aiProvider.Type) - aiProviderID = uuid.NullUUID{UUID: aiProvider.ID, Valid: true} - } else { - provider = normalizeChatProvider(req.Provider) - if provider == "" { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider.", - Detail: chatProviderValidationDetail(), - }) + if req.AIProviderID == nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider ID is required."}) + return + } + aiProvider, err := api.Database.GetAIProviderByID(ctx, *req.AIProviderID) + if err != nil { + if httpapi.Is404Error(err) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider is not configured."}) return } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to get AI provider.", + Detail: err.Error(), + }) + return } + if !aiProvider.Enabled { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider is disabled."}) + return + } + provider := string(aiProvider.Type) + aiProviderID := uuid.NullUUID{UUID: aiProvider.ID, Valid: true} model := strings.TrimSpace(req.Model) if model == "" { @@ -7797,16 +7229,6 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { } provider = string(aiProvider.Type) aiProviderID = uuid.NullUUID{UUID: aiProvider.ID, Valid: true} - } else if strings.TrimSpace(req.Provider) != "" { - provider = normalizeChatProvider(req.Provider) - if provider == "" { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid provider.", - Detail: chatProviderValidationDetail(), - }) - return - } - aiProviderID = uuid.NullUUID{} } model := existing.Model @@ -8123,18 +7545,6 @@ func parseChatUsageLimitUserID(rw http.ResponseWriter, r *http.Request) (uuid.UU return userID, true } -func parseChatProviderID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { - providerID, err := uuid.Parse(chi.URLParam(r, "providerConfig")) - if err != nil { - httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{ - Message: "Invalid chat provider ID.", - Detail: err.Error(), - }) - return uuid.Nil, false - } - return providerID, true -} - func parseChatModelConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, bool) { modelConfigID, err := uuid.Parse(chi.URLParam(r, "modelConfig")) if err != nil { @@ -8147,51 +7557,6 @@ func parseChatModelConfigID(rw http.ResponseWriter, r *http.Request) (uuid.UUID, return modelConfigID, true } -func convertChatProviderConfig( - provider database.ChatProvider, - hasAPIKey bool, - source codersdk.ChatProviderConfigSource, -) codersdk.ChatProviderConfig { - displayName := strings.TrimSpace(provider.DisplayName) - if displayName == "" { - displayName = chatprovider.ProviderDisplayName(provider.Provider) - } - - return codersdk.ChatProviderConfig{ - ID: provider.ID, - Provider: provider.Provider, - DisplayName: displayName, - Enabled: provider.Enabled, - HasAPIKey: hasAPIKey, - CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserAPIKey: provider.AllowUserApiKey, - AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, - BaseURL: strings.TrimSpace(provider.BaseUrl), - Source: source, - CreatedAt: provider.CreatedAt, - UpdatedAt: provider.UpdatedAt, - } -} - -func convertUserChatProviderConfig( - provider database.ChatProvider, - hasUserAPIKey bool, - hasCentralAPIKeyFallback bool, -) codersdk.UserChatProviderConfig { - displayName := strings.TrimSpace(provider.DisplayName) - if displayName == "" { - displayName = chatprovider.ProviderDisplayName(provider.Provider) - } - - return codersdk.UserChatProviderConfig{ - ProviderID: provider.ID, - Provider: provider.Provider, - DisplayName: displayName, - HasUserAPIKey: hasUserAPIKey, - HasCentralAPIKeyFallback: hasCentralAPIKeyFallback, - } -} - func convertChatModelConfig(config database.ChatModelConfig) codersdk.ChatModelConfig { var aiProviderID *uuid.UUID if config.AiProviderID.Valid { @@ -8325,103 +7690,11 @@ func isZeroChatModelProviderOptions(options *codersdk.ChatModelProviderOptions) options.Vercel == nil } -func normalizeChatProvider(provider string) string { - return chatprovider.NormalizeProvider(provider) -} - -func normalizeChatProviderBaseURL(raw string) (string, error) { - trimmed := strings.TrimSpace(raw) - if trimmed == "" { - return "", nil - } - - parsed, err := url.Parse(trimmed) - if err != nil { - return "", err - } - if parsed.Scheme == "" || parsed.Host == "" { - return "", xerrors.New("Base URL must be an absolute URL with scheme and host.") - } - if parsed.Scheme != "http" && parsed.Scheme != "https" { - return "", xerrors.New("Base URL scheme must be http or https.") - } - return parsed.String(), nil -} - -func chatProviderValidationDetail() string { - return "Provider must be one of: " + strings.Join(chatprovider.SupportedProviders(), ", ") + "." -} - var ( errChatModelConfigNotFound = xerrors.New("chat model config not found") errChatProviderNotConfigured = xerrors.New("chat provider is not configured") ) -// requireChatProviderForModelConfig takes a FOR UPDATE lock on the provider -// row to serialize model-config writes with deleteChatProvider. Do not swap -// this call for the non-locking provider lookup. -func requireChatProviderForModelConfig( - ctx context.Context, - tx database.Store, - provider string, -) error { - _, err := tx.GetChatProviderByProviderForUpdate(ctx, provider) - switch { - case err == nil: - return nil - case xerrors.Is(err, sql.ErrNoRows): - return errChatProviderNotConfigured - default: - return xerrors.Errorf("get chat provider %q: %w", provider, err) - } -} - -const maxChatProviderAPIKeySize = 10240 // 10 KB - -func validateChatProviderAPIKeySize(apiKey string) error { - if len(apiKey) > maxChatProviderAPIKeySize { - return xerrors.Errorf("API key exceeds maximum size of %d bytes", maxChatProviderAPIKeySize) - } - return nil -} - -//nolint:revive // This helper validates the explicit credential policy tuple. -func validateChatProviderCredentialPolicy( - centralEnabled, allowUserKey, allowFallback bool, -) error { - if !centralEnabled && !allowUserKey { - return xerrors.New( - "At least one credential source must be enabled: central API key or user API key.", - ) - } - if allowFallback && !centralEnabled { - return xerrors.New( - "Central API key fallback requires central API key to be enabled.", - ) - } - if allowFallback && !allowUserKey { - return xerrors.New( - "Central API key fallback requires user API key to be enabled.", - ) - } - return nil -} - -//nolint:revive // This helper validates central-key requirements. -func validateChatProviderCentralAPIKey( - provider string, - centralEnabled bool, - hasCentralAPIKey bool, -) error { - if !centralEnabled || hasCentralAPIKey { - return nil - } - if chatprovider.ProviderAllowsAmbientCredentials(provider) { - return nil - } - return xerrors.New("API key is required when central API key is enabled.") -} - // ChatProviderAPIKeysFromDeploymentValues returns deployment-backed chat // provider API keys. func ChatProviderAPIKeysFromDeploymentValues( @@ -8433,77 +7706,6 @@ func ChatProviderAPIKeysFromDeploymentValues( return chatprovider.ProviderAPIKeys{} } -func (api *API) hasEffectiveProviderAPIKey(ctx context.Context, provider database.ChatProvider) bool { - return api.hasEffectiveCentralProviderAPIKey(ctx, provider, uuid.Nil) -} - -func (api *API) hasEffectiveCentralProviderCredentials( - ctx context.Context, - provider database.ChatProvider, - excludeProviderID uuid.UUID, -) bool { - if api.hasEffectiveCentralProviderAPIKey(ctx, provider, excludeProviderID) { - return true - } - return provider.CentralApiKeyEnabled && - chatprovider.ProviderAllowsAmbientCredentials(provider.Provider) -} - -func (api *API) hasEffectiveCentralProviderAPIKey( - ctx context.Context, - provider database.ChatProvider, - excludeProviderID uuid.UUID, -) bool { - if !provider.CentralApiKeyEnabled { - return false - } - if strings.TrimSpace(provider.APIKey) != "" { - return true - } - deploymentKeys := ChatProviderAPIKeysFromDeploymentValues(api.DeploymentValues) - if deploymentKeys.APIKey(provider.Provider) != "" { - return true - } - if api.chatDaemon == nil { - return false - } - //nolint:gocritic // System context required to read enabled chat providers. - systemCtx := dbauthz.AsSystemRestricted(ctx) - - enabledProviders, err := api.Database.GetEnabledChatProviders( - systemCtx, - ) - if err != nil { - api.Logger.Warn(ctx, "failed to resolve provider API keys", - slog.F("provider", provider.Provider), - slog.Error(err), - ) - return false - } - - enabledConfiguredProviders := make( - []chatprovider.ConfiguredProvider, 0, len(enabledProviders), - ) - for _, configured := range enabledProviders { - if excludeProviderID != uuid.Nil && configured.ID == excludeProviderID { - continue - } - enabledConfiguredProviders = append( - enabledConfiguredProviders, chatprovider.ConfiguredProvider{ - Provider: configured.Provider, - APIKey: configured.APIKey, - BaseURL: configured.BaseUrl, - }, - ) - } - - effectiveKeys := chatprovider.MergeProviderAPIKeys( - deploymentKeys, - enabledConfiguredProviders, - ) - return effectiveKeys.APIKey(provider.Provider) != "" -} - // @Summary Get PR insights // @ID get-pr-insights // @Security CoderSessionToken diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 14860f53eb853..26e3bb6eb675c 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -1616,16 +1616,12 @@ func TestListChatModels(t *testing.T) { client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) - provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "anthropic", - CentralAPIKeyEnabled: ptr.Ref(false), - AllowUserAPIKey: ptr.Ref(true), - }) - require.NoError(t, err) + provider := createAIProviderForTest(t, client, "anthropic", "") contextLimit := int64(4096) - _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "anthropic", + AIProviderID: &provider.ID, Model: "claude-sonnet", ContextLimit: &contextLimit, }) @@ -1645,7 +1641,7 @@ func TestListChatModels(t *testing.T) { require.False(t, anthropicProvider.Available) require.Equal(t, codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, anthropicProvider.UnavailableReason) - _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + _, err = client.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{ APIKey: "user-api-key", }) require.NoError(t, err) @@ -1671,18 +1667,12 @@ func TestListChatModels(t *testing.T) { client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) - provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "google", - APIKey: "central-api-key", - CentralAPIKeyEnabled: ptr.Ref(true), - AllowUserAPIKey: ptr.Ref(true), - AllowCentralAPIKeyFallback: ptr.Ref(true), - }) - require.NoError(t, err) + provider := createAIProviderForTest(t, client, "google", "provider-api-key") contextLimit := int64(4096) - _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "google", + AIProviderID: &provider.ID, Model: "gemini-1.5-pro", ContextLimit: &contextLimit, }) @@ -1701,7 +1691,7 @@ func TestListChatModels(t *testing.T) { require.NotNil(t, googleProvider) require.True(t, googleProvider.Available) - _, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{ + _, err = client.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{ APIKey: "user-api-key", }) require.NoError(t, err) @@ -1729,15 +1719,12 @@ func TestListChatModels(t *testing.T) { client := newChatClientWithDeploymentValues(t, values) _ = coderdtest.CreateFirstUser(t, client.Client) - provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-key", - }) - require.NoError(t, err) + provider := createAIProviderForTest(t, client, "openai", "test-key") contextLimit := int64(4096) - _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", + AIProviderID: &provider.ID, Model: "gpt-4o-mini", ContextLimit: &contextLimit, }) @@ -1751,7 +1738,7 @@ func TestListChatModels(t *testing.T) { require.Equal(t, "gpt-4o-mini", models.Providers[0].Models[0].Model) enabled := false - _, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ + _, err = client.UpdateAIProvider(ctx, provider.ID, codersdk.UpdateAIProviderRequest{ Enabled: &enabled, }) require.NoError(t, err) @@ -2546,6 +2533,7 @@ func TestUserAIProviderKeys(t *testing.T) { func TestListChatProviders(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") t.Run("Success", func(t *testing.T) { t.Parallel() @@ -2615,6 +2603,7 @@ func TestListChatProviders(t *testing.T) { func TestCreateChatProvider(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") t.Run("Success", func(t *testing.T) { t.Parallel() @@ -2915,6 +2904,7 @@ func TestCreateChatProvider(t *testing.T) { func TestUpdateChatProvider(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") t.Run("Success", func(t *testing.T) { t.Parallel() @@ -3207,276 +3197,7 @@ func TestUpdateChatProvider(t *testing.T) { func TestDeleteChatProvider(t *testing.T) { t.Parallel() - - t.Run("Success", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client := newChatClient(t) - _ = coderdtest.CreateFirstUser(t, client.Client) - - provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) - - err = client.DeleteChatProvider(ctx, provider.ID) - require.NoError(t, err) - - providers, err := client.ListChatProviders(ctx) - require.NoError(t, err) - for _, listed := range providers { - require.NotEqual(t, provider.ID, listed.ID) - } - }) - - t.Run("SuccessWithHistoricalChats", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) - firstUser := coderdtest.CreateFirstUser(t, client.Client) - - providerToDelete, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "delete-api-key", - AllowUserAPIKey: ptr.Ref(true), - }) - require.NoError(t, err) - - deleteContextLimit := int64(4096) - deleteIsDefault := true - configToDelete, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: providerToDelete.Provider, - Model: "gpt-4o-delete-provider", - ContextLimit: &deleteContextLimit, - IsDefault: &deleteIsDefault, - }) - require.NoError(t, err) - - keepProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "anthropic", - APIKey: "keep-api-key", - }) - require.NoError(t, err) - - keepContextLimit := int64(8192) - keepConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: keepProvider.Provider, - Model: "claude-keep-provider", - ContextLimit: &keepContextLimit, - }) - require.NoError(t, err) - - chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ - OrganizationID: firstUser.OrganizationID, - ModelConfigID: ptr.Ref(configToDelete.ID), - Content: []codersdk.ChatInputPart{{ - Type: codersdk.ChatInputPartTypeText, - Text: "provider delete history " + t.Name(), - }}, - }) - require.NoError(t, err) - require.Equal(t, configToDelete.ID, chat.LastModelConfigID) - - insertAssistantCostMessage(t, db, chat.ID, configToDelete.ID, 500) - - _, err = client.UpsertUserChatProviderKey(ctx, providerToDelete.ID, codersdk.CreateUserChatProviderKeyRequest{ - APIKey: "user-delete-key", - }) - require.NoError(t, err) - - userKeys, err := db.GetUserChatProviderKeys(dbauthz.AsSystemRestricted(ctx), firstUser.UserID) - require.NoError(t, err) - require.Len(t, userKeys, 1) - require.Equal(t, providerToDelete.ID, userKeys[0].ChatProviderID) - - err = client.DeleteChatProvider(ctx, providerToDelete.ID) - require.NoError(t, err) - - _, err = db.GetChatProviderByID(dbauthz.AsSystemRestricted(ctx), providerToDelete.ID) - require.ErrorIs(t, err, sql.ErrNoRows) - - providers, err := client.ListChatProviders(ctx) - require.NoError(t, err) - foundKeepProvider := false - for _, listed := range providers { - require.NotEqual(t, providerToDelete.ID, listed.ID) - if listed.ID == keepProvider.ID { - foundKeepProvider = true - } - } - require.True(t, foundKeepProvider) - - configs, err := client.ListChatModelConfigs(ctx) - require.NoError(t, err) - foundDeletedConfig := false - foundKeepConfig := false - for _, config := range configs { - if config.ID == configToDelete.ID { - foundDeletedConfig = true - } - if config.ID == keepConfig.ID { - foundKeepConfig = true - require.True(t, config.IsDefault) - } - } - require.False(t, foundDeletedConfig) - require.True(t, foundKeepConfig) - - defaultConfig, err := db.GetDefaultChatModelConfig(dbauthz.AsSystemRestricted(ctx)) - require.NoError(t, err) - require.Equal(t, keepConfig.ID, defaultConfig.ID) - - _, err = db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), configToDelete.ID) - require.ErrorIs(t, err, sql.ErrNoRows) - - gotChat, err := client.GetChat(ctx, chat.ID) - require.NoError(t, err) - require.Equal(t, chat.ID, gotChat.ID) - require.Equal(t, configToDelete.ID, gotChat.LastModelConfigID) - - messages, err := client.GetChatMessages(ctx, chat.ID, nil) - require.NoError(t, err) - foundHistoricalMessage := false - for _, message := range messages.Messages { - if message.ModelConfigID != nil && *message.ModelConfigID == configToDelete.ID { - foundHistoricalMessage = true - break - } - } - require.True(t, foundHistoricalMessage) - - userKeys, err = db.GetUserChatProviderKeys(dbauthz.AsSystemRestricted(ctx), firstUser.UserID) - require.NoError(t, err) - require.Empty(t, userKeys) - }) - - t.Run("SuccessWithHistoricalChatsAndNoReplacementConfig", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) - firstUser := coderdtest.CreateFirstUser(t, client.Client) - - provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "only-provider-api-key", - }) - require.NoError(t, err) - - contextLimit := int64(4096) - isDefault := true - config, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, - Model: "gpt-4o-only-provider", - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) - - chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ - OrganizationID: firstUser.OrganizationID, - ModelConfigID: ptr.Ref(config.ID), - Content: []codersdk.ChatInputPart{{ - Type: codersdk.ChatInputPartTypeText, - Text: "only provider delete history " + t.Name(), - }}, - }) - require.NoError(t, err) - require.Equal(t, config.ID, chat.LastModelConfigID) - - insertAssistantCostMessage(t, db, chat.ID, config.ID, 250) - - err = client.DeleteChatProvider(ctx, provider.ID) - require.NoError(t, err) - - providers, err := client.ListChatProviders(ctx) - require.NoError(t, err) - for _, listed := range providers { - require.NotEqual(t, provider.ID, listed.ID) - } - - _, err = db.GetChatProviderByID(dbauthz.AsSystemRestricted(ctx), provider.ID) - require.ErrorIs(t, err, sql.ErrNoRows) - - _, err = db.GetChatModelConfigByID(dbauthz.AsSystemRestricted(ctx), config.ID) - require.ErrorIs(t, err, sql.ErrNoRows) - - _, err = db.GetDefaultChatModelConfig(dbauthz.AsSystemRestricted(ctx)) - require.ErrorIs(t, err, sql.ErrNoRows) - - configs, err := client.ListChatModelConfigs(ctx) - require.NoError(t, err) - require.Empty(t, configs) - - gotChat, err := client.GetChat(ctx, chat.ID) - require.NoError(t, err) - require.Equal(t, config.ID, gotChat.LastModelConfigID) - - messages, err := client.GetChatMessages(ctx, chat.ID, nil) - require.NoError(t, err) - foundHistoricalMessage := false - for _, message := range messages.Messages { - if message.ModelConfigID != nil && *message.ModelConfigID == config.ID { - foundHistoricalMessage = true - break - } - } - require.True(t, foundHistoricalMessage) - }) - - t.Run("NotFound", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client := newChatClient(t) - _ = coderdtest.CreateFirstUser(t, client.Client) - - err := client.DeleteChatProvider(ctx, uuid.New()) - requireSDKError(t, err, http.StatusNotFound) - }) - - t.Run("InvalidProviderID", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - client := newChatClient(t) - _ = coderdtest.CreateFirstUser(t, client.Client) - - res, err := client.Request( - ctx, - http.MethodDelete, - "/api/experimental/chats/providers/not-a-uuid", - nil, - ) - require.NoError(t, err) - defer res.Body.Close() - - err = codersdk.ReadBodyAsError(res) - sdkErr := requireSDKError(t, err, http.StatusBadRequest) - require.Equal(t, "Invalid chat provider ID.", sdkErr.Message) - }) - - t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - adminClient := newChatClient(t) - firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) - memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) - memberClient := codersdk.NewExperimentalClient(memberClientRaw) - - provider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) - - err = memberClient.DeleteChatProvider(ctx, provider.ID) - requireSDKError(t, err, http.StatusForbidden) - }) + t.Skip("legacy chat provider API removed in favor of AI provider API") } func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) { @@ -3504,6 +3225,7 @@ func TestChatProviderAPIKeysFromDeploymentValues(t *testing.T) { func TestUserChatProviderConfigs(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") requireUserProviderConfig := func(t *testing.T, configs []codersdk.UserChatProviderConfig, provider string) codersdk.UserChatProviderConfig { t.Helper() @@ -3899,6 +3621,7 @@ func TestUserChatProviderConfigs(t *testing.T) { func TestUpsertUserChatProviderKey(t *testing.T) { t.Parallel() + t.Skip("legacy chat provider API removed in favor of AI provider API") t.Run("RejectsTooLargeAPIKey", func(t *testing.T) { t.Parallel() @@ -3978,16 +3701,13 @@ func TestListChatModelConfigs(t *testing.T) { client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) - _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) + aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key") contextLimit := int64(4096) enabled := false disabledConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", + AIProviderID: &aiProvider.ID, Model: "gpt-4o-disabled", DisplayName: "GPT-4o Disabled", Enabled: &enabled, @@ -4024,6 +3744,7 @@ func TestListChatModelConfigs(t *testing.T) { enabled := false _, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: enabledConfig.Provider, + AIProviderID: enabledConfig.AIProviderID, Model: "gpt-4o-disabled", DisplayName: "GPT-4o Disabled", Enabled: &enabled, @@ -4045,15 +3766,12 @@ func TestListChatModelConfigs(t *testing.T) { client, db := newChatClientWithDatabase(t) firstUser := coderdtest.CreateFirstUser(t, client.Client) - _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) + aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key") legacyOptions := json.RawMessage(`{"input_price_per_million_tokens":0.15,"output_price_per_million_tokens":0.6,"cache_read_price_per_million_tokens":0.03,"cache_write_price_per_million_tokens":0.3}`) storedConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ Provider: "openai", + AiProviderID: uuid.NullUUID{UUID: aiProvider.ID, Valid: true}, Model: "gpt-4o-mini-legacy", DisplayName: "GPT-4o Mini Legacy", CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true}, @@ -4114,11 +3832,7 @@ func TestCreateChatModelConfig(t *testing.T) { client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) - _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) + aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key") contextLimit := int64(4096) isDefault := true @@ -4132,6 +3846,7 @@ func TestCreateChatModelConfig(t *testing.T) { } modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", + AIProviderID: &aiProvider.ID, Model: "gpt-4o-mini", ContextLimit: &contextLimit, IsDefault: &isDefault, @@ -4158,15 +3873,12 @@ func TestCreateChatModelConfig(t *testing.T) { client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) - _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) + aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key") contextLimit := int64(4096) - _, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", + AIProviderID: &aiProvider.ID, Model: "gpt-4o-mini", ContextLimit: &contextLimit, ModelConfig: &codersdk.ChatModelCallConfig{ @@ -4190,10 +3902,12 @@ func TestCreateChatModelConfig(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) client := newChatClient(t) _ = coderdtest.CreateFirstUser(t, client.Client) + aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key") _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: "openai", - Model: "gpt-4o-mini", + Provider: "openai", + AIProviderID: &aiProvider.ID, + Model: "gpt-4o-mini", }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) require.Equal(t, "Context limit is required.", sdkErr.Message) @@ -4207,13 +3921,15 @@ func TestCreateChatModelConfig(t *testing.T) { _ = coderdtest.CreateFirstUser(t, client.Client) contextLimit := int64(4096) + missingProviderID := uuid.New() _, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", + AIProviderID: &missingProviderID, Model: "gpt-4o-mini", ContextLimit: &contextLimit, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) - require.Equal(t, "Chat provider is not configured.", sdkErr.Message) + require.Equal(t, "AI provider is not configured.", sdkErr.Message) }) t.Run("WithAIProviderID", func(t *testing.T) { @@ -4291,15 +4007,12 @@ func TestCreateChatModelConfig(t *testing.T) { memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) - _, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) + aiProvider := createAIProviderForTest(t, adminClient, "openai", "test-api-key") contextLimit := int64(4096) - _, err = memberClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + _, err := memberClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", + AIProviderID: &aiProvider.ID, Model: "gpt-4o-mini", ContextLimit: &contextLimit, }) @@ -4390,16 +4103,13 @@ func TestUpdateChatModelConfig(t *testing.T) { memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) memberClient := codersdk.NewExperimentalClient(memberClientRaw) - _, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - APIKey: "test-api-key", - }) - require.NoError(t, err) + aiProvider := createAIProviderForTest(t, adminClient, "openai", "test-api-key") contextLimit := int64(4096) enabled := false modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "openai", + AIProviderID: &aiProvider.ID, Model: "gpt-4o-reenable", DisplayName: "GPT-4o Re-enable", Enabled: &enabled, @@ -4533,11 +4243,12 @@ func TestUpdateChatModelConfig(t *testing.T) { _ = coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) + missingProviderID := uuid.New() _, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ - Provider: "anthropic", + AIProviderID: &missingProviderID, }) sdkErr := requireSDKError(t, err, http.StatusBadRequest) - require.Equal(t, "Chat provider is not configured.", sdkErr.Message) + require.Equal(t, "AI provider is not configured.", sdkErr.Message) }) t.Run("NotFoundWhenTargetRowDisappearsInTx", func(t *testing.T) { @@ -4577,16 +4288,13 @@ func TestUpdateChatModelConfig(t *testing.T) { _ = coderdtest.CreateFirstUser(t, client.Client) defaultConfig := createChatModelConfig(t, client) - _, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "anthropic", - APIKey: "candidate-api-key", - }) - require.NoError(t, err) + aiProvider := createAIProviderForTest(t, client, "anthropic", "candidate-api-key") contextLimit := int64(4096) isDefault := false candidateConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: "anthropic", + AIProviderID: &aiProvider.ID, Model: "claude-3-5-sonnet", ContextLimit: &contextLimit, IsDefault: &isDefault, @@ -10908,6 +10616,31 @@ func TestWatchChatGitAuthz(t *testing.T) { require.Equal(t, http.StatusForbidden, res.StatusCode) } +func createAIProviderForTest( + t testing.TB, + client *codersdk.ExperimentalClient, + provider string, + apiKey string, +) codersdk.AIProvider { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitLong) + enabled := true + aiProvider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderType(provider), + Name: "test-" + provider + "-" + uuid.NewString(), + Enabled: &enabled, + }) + require.NoError(t, err) + if apiKey != "" { + _, err = client.CreateAIProviderKey(ctx, aiProvider.ID, codersdk.CreateAIProviderKeyRequest{ + APIKey: apiKey, + }) + require.NoError(t, err) + } + return aiProvider +} + func createChatModelConfig(t testing.TB, client *codersdk.ExperimentalClient) codersdk.ChatModelConfig { t.Helper() return coderdtest.CreateOpenAICompatChatModelConfig(t, client, "") @@ -10942,10 +10675,12 @@ func createAdditionalChatModelConfig( t.Helper() ctx := testutil.Context(t, testutil.WaitLong) + aiProvider := createAIProviderForTest(t, client, provider, "test-api-key") contextLimit := int64(4096) isDefault := false modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ Provider: provider, + AIProviderID: &aiProvider.ID, Model: model, ContextLimit: &contextLimit, IsDefault: &isDefault, @@ -10976,32 +10711,27 @@ func enableUserChatProviderKey( adminClient *codersdk.ExperimentalClient, userClient *codersdk.ExperimentalClient, providerName string, -) codersdk.ChatProviderConfig { +) codersdk.AIProvider { t.Helper() ctx := testutil.Context(t, testutil.WaitLong) - providers, err := adminClient.ListChatProviders(ctx) + providers, err := adminClient.ListAIProviders(ctx) require.NoError(t, err) - var provider codersdk.ChatProviderConfig + var provider codersdk.AIProvider for _, candidate := range providers { - if candidate.Provider == providerName && candidate.Source == codersdk.ChatProviderConfigSourceDatabase { + if candidate.Type == codersdk.AIProviderType(providerName) { provider = candidate break } } require.NotEqual(t, uuid.Nil, provider.ID) - updated, err := adminClient.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{ - AllowUserAPIKey: ptr.Ref(true), - }) - require.NoError(t, err) - - _, err = userClient.UpsertUserChatProviderKey(ctx, updated.ID, codersdk.CreateUserChatProviderKeyRequest{ + _, err = userClient.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{ APIKey: "test-user-api-key-" + uuid.NewString(), }) require.NoError(t, err) - return updated + return provider } //nolint:tparallel,paralleltest // Subtests share a single coderdtest instance. @@ -11815,13 +11545,20 @@ func TestUserChatPersonalModelOverrides(t *testing.T) { defaultModelConfig := createChatModelConfig(t, adminClient) provider := enableUserChatProviderKey(t, adminClient, memberClient, defaultModelConfig.Provider) - modelConfig := createAdditionalChatModelConfig( - t, - adminClient, - defaultModelConfig.Provider, - "gpt-4o-personal-"+uuid.NewString(), - ) - err := adminClient.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextGeneral, codersdk.UpdateChatModelOverrideRequest{ + modelProvider := createAIProviderForTest(t, adminClient, "anthropic", "") + _, err := memberClient.UpsertUserAIProviderKey(ctx, modelProvider.ID, codersdk.CreateUserAIProviderKeyRequest{ + APIKey: "test-user-api-key-" + uuid.NewString(), + }) + require.NoError(t, err) + contextLimit := int64(4096) + modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "anthropic", + AIProviderID: &modelProvider.ID, + Model: "claude-personal-" + uuid.NewString(), + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + err = adminClient.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextGeneral, codersdk.UpdateChatModelOverrideRequest{ ModelConfigID: modelConfig.ID.String(), }) require.NoError(t, err) @@ -11836,19 +11573,20 @@ func TestUserChatPersonalModelOverrides(t *testing.T) { defaultModelConfig.Provider, "gpt-4o-personal-disabled-"+uuid.NewString(), ) - disabledProvider, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "anthropic", - Enabled: ptr.Ref(false), - CentralAPIKeyEnabled: ptr.Ref(false), - AllowUserAPIKey: ptr.Ref(true), + disabledProvider := createAIProviderForTest(t, adminClient, "google", "test-api-key") + contextLimit = int64(4096) + disabledProviderModelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "google", + AIProviderID: &disabledProvider.ID, + Model: "gemini-personal-disabled-provider-" + uuid.NewString(), + ContextLimit: &contextLimit, + }) + require.NoError(t, err) + enabled := false + disabledProvider, err = adminClient.UpdateAIProvider(ctx, disabledProvider.ID, codersdk.UpdateAIProviderRequest{ + Enabled: &enabled, }) require.NoError(t, err) - disabledProviderModelConfig := createAdditionalChatModelConfig( - t, - adminClient, - "anthropic", - "claude-personal-disabled-provider-"+uuid.NewString(), - ) require.NotEqual(t, uuid.Nil, provider.ID) require.NotEqual(t, uuid.Nil, disabledProvider.ID) @@ -12176,12 +11914,19 @@ func TestCreateChatPersonalModelOverrideRoot(t *testing.T) { firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) defaultModel := createChatModelConfig(t, adminClient) _ = enableUserChatProviderKey(t, adminClient, adminClient, defaultModel.Provider) - overrideModel := createAdditionalChatModelConfig( - t, - adminClient, - defaultModel.Provider, - "gpt-4o-root-personal-"+uuid.NewString(), - ) + overrideProvider := createAIProviderForTest(t, adminClient, "anthropic", "") + _, err := adminClient.UpsertUserAIProviderKey(ctx, overrideProvider.ID, codersdk.CreateUserAIProviderKeyRequest{ + APIKey: "test-user-api-key-" + uuid.NewString(), + }) + require.NoError(t, err) + contextLimit := int64(4096) + overrideModel, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: "anthropic", + AIProviderID: &overrideProvider.ID, + Model: "claude-root-personal-" + uuid.NewString(), + ContextLimit: &contextLimit, + }) + require.NoError(t, err) disabledModel := createDisabledChatModelConfig( t, adminClient, @@ -12226,7 +11971,7 @@ func TestCreateChatPersonalModelOverrideRoot(t *testing.T) { require.NoError(t, err) } - err := adminClient.UpdateChatPersonalModelOverridesAdminSettings(ctx, codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest{ + err = adminClient.UpdateChatPersonalModelOverridesAdminSettings(ctx, codersdk.UpdateChatPersonalModelOverridesAdminSettingsRequest{ AllowUsers: true, }) require.NoError(t, err) diff --git a/coderd/workspaceagents_chat_context_internal_test.go b/coderd/workspaceagents_chat_context_internal_test.go index 5a2c8e25be19a..8439e2139ea96 100644 --- a/coderd/workspaceagents_chat_context_internal_test.go +++ b/coderd/workspaceagents_chat_context_internal_test.go @@ -2,6 +2,7 @@ package coderd import ( "context" + "database/sql" "encoding/json" "testing" "time" @@ -96,17 +97,21 @@ func insertAgentChatTestModelConfig( createdBy := uuid.NullUUID{UUID: userID, Valid: true} - _ = dbgen.ChatProvider(t, db, database.ChatProvider{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-api-key", - CreatedBy: createdBy, + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "test-openai", + DisplayName: sql.NullString{String: "OpenAI", Valid: true}, + }) + dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, + APIKey: "test-api-key", }) return dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ - Provider: "openai", - CreatedBy: createdBy, - UpdatedBy: createdBy, - IsDefault: true, + Provider: "openai", + AiProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + CreatedBy: createdBy, + UpdatedBy: createdBy, + IsDefault: true, }) } diff --git a/coderd/x/chatd/advisor_internal_test.go b/coderd/x/chatd/advisor_internal_test.go index ad81e580b2422..48a937b3d8c08 100644 --- a/coderd/x/chatd/advisor_internal_test.go +++ b/coderd/x/chatd/advisor_internal_test.go @@ -24,8 +24,8 @@ import ( // advisorOverrideStubStore stubs only the database methods that // resolveAdvisorModelOverride exercises. The prod code calls -// GetEnabledChatModelConfigByID so the query joins chat_providers and -// filters both enabled flags atomically; tests simulate that by returning +// GetEnabledChatModelConfigByID so the query joins ai_providers and +// filters both enabled flags atomically. Tests simulate that by returning // configs the stub treats as enabled. type advisorOverrideStubStore struct { database.Store diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 6a0105453e9dd..fedf3aed895a9 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -333,13 +333,13 @@ func (p *Server) resolveAdvisorModelOverride( return fallbackModel, fallbackCallConfig } - // GetEnabledChatModelConfigByID joins on chat_providers.enabled = TRUE + // GetEnabledChatModelConfigByID joins on ai_providers.enabled = TRUE // and chat_model_configs.enabled = TRUE, so it returns sql.ErrNoRows // the moment an admin disables either the model config or its provider. // Using the cached ModelConfigByID here would keep resolving an override - // whose provider was just disabled, and an env or central fallback key - // would let ModelFromConfig succeed, silently routing advisor prompts - // to a provider the admin expects to be off. + // whose provider was just disabled, and a fallback key would let + // ModelFromConfig succeed, silently routing advisor prompts to a provider + // the admin expects to be off. overrideConfig, err := p.db.GetEnabledChatModelConfigByID( ctx, advisorCfg.ModelConfigID, @@ -8144,9 +8144,6 @@ func (p *Server) resolveUserProviderAPIKeys( ownerID uuid.UUID, selectedAIProviderID uuid.UUID, ) (chatprovider.ProviderAPIKeys, error) { - var configuredProviders []chatprovider.ConfiguredProvider - userKeys := []chatprovider.UserProviderKey{} - if selectedAIProviderID != uuid.Nil { provider, err := p.db.GetAIProviderByID(ctx, selectedAIProviderID) if err != nil { @@ -8154,48 +8151,38 @@ func (p *Server) resolveUserProviderAPIKeys( } return p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider) } + providers, err := p.configCache.EnabledProviders(ctx) if err != nil { return chatprovider.ProviderAPIKeys{}, xerrors.Errorf( - "get enabled chat providers: %w", + "get enabled AI providers: %w", err, ) } - configuredProviders = make( + configuredProviders := make( []chatprovider.ConfiguredProvider, 0, len(providers), ) for _, provider := range providers { - configuredProviders = append( - configuredProviders, chatprovider.ConfiguredProvider{ - ProviderID: provider.ID, - Provider: provider.Provider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - CentralAPIKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserAPIKey: provider.AllowUserApiKey, - AllowCentralAPIKeyFallback: provider.AllowCentralApiKeyFallback, - }, - ) - } - allowAnyUserAPIKey := false - for _, provider := range configuredProviders { - if provider.AllowUserAPIKey { - allowAnyUserAPIKey = true - break + configuredProvider, err := p.aiProviderConfig(ctx, provider) + if err != nil { + return chatprovider.ProviderAPIKeys{}, err } + configuredProviders = append(configuredProviders, configuredProvider) } - if allowAnyUserAPIKey { - userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID) + + userKeys := []chatprovider.UserProviderKey{} + if p.allowBYOK { + userKeyRows, err := p.db.GetUserAIProviderKeysByUserID(ctx, ownerID) if err != nil { return chatprovider.ProviderAPIKeys{}, xerrors.Errorf( - "get user chat provider keys: %w", + "get user AI provider keys: %w", err, ) } userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows)) for _, userKey := range userKeyRows { userKeys = append(userKeys, chatprovider.UserProviderKey{ - ChatProviderID: userKey.ChatProviderID, + ChatProviderID: userKey.AiProviderID, APIKey: userKey.APIKey, }) } diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index af4cde4a78525..3d5643239792f 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -792,12 +792,13 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { } db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "openai", - CentralApiKeyEnabled: true, - APIKey: "test-key", - BaseUrl: serverURL, + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ + ID: uuid.New(), + Type: database.AiProviderTypeOpenai, + Enabled: true, + BaseUrl: serverURL, }}, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), gomock.Any()).Return([]database.AIProviderKey{{APIKey: "test-key"}}, nil).AnyTimes() db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) db.EXPECT().GetChatMessagesByChatIDAscPaginated( gomock.Any(), @@ -956,12 +957,13 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t } db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "openai", - CentralApiKeyEnabled: true, - APIKey: "test-key", - BaseUrl: serverURL, + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ + ID: uuid.New(), + Type: database.AiProviderTypeOpenai, + Enabled: true, + BaseUrl: serverURL, }}, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), gomock.Any()).Return([]database.AIProviderKey{{APIKey: "test-key"}}, nil).AnyTimes() db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) db.EXPECT().GetChatMessagesByChatIDAscPaginated( gomock.Any(), @@ -1105,11 +1107,12 @@ func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) { }, } - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "anthropic", - CentralApiKeyEnabled: true, - AllowCentralApiKeyFallback: true, + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ + ID: uuid.New(), + Type: database.AiProviderTypeAnthropic, + Enabled: true, }}, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) require.NoError(t, err) @@ -1146,10 +1149,12 @@ func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKe }, } - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ - Provider: "openai", - CentralApiKeyEnabled: true, + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ + ID: uuid.New(), + Type: database.AiProviderTypeOpenai, + Enabled: true, }}, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) require.NoError(t, err) @@ -3541,7 +3546,7 @@ func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) { db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return( database.ChatModelConfig{}, xerrors.New("no model configured"), ).AnyTimes() - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes() db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return( database.ChatUsageLimitConfig{}, sql.ErrNoRows, @@ -5315,7 +5320,7 @@ func TestAutoPromote_InsertFailureSkipsStatusUpdate(t *testing.T) { return database.ChatModelConfig{}, chatloop.ErrInterrupted }, ).AnyTimes() - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes() db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return( database.ChatUsageLimitConfig{}, sql.ErrNoRows, diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index ed90c601d5770..002e5b799c752 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -287,22 +287,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { ) }) - _, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai-compat", - APIKey: "test-api-key", - BaseURL: openAIURL, - }) - require.NoError(t, err) - - contextLimit := int64(4096) - isDefault := true - _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: "openai-compat", - Model: "gpt-4o-mini", - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) + coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL) // Create a root chat whose first model call will spawn a subagent. chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ @@ -482,22 +467,7 @@ func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) { ) }) - _, err = expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai-compat", - APIKey: "test-api-key", - BaseURL: openAIURL, - }) - require.NoError(t, err) - - contextLimit := int64(4096) - isDefault := true - _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: "openai-compat", - Model: "gpt-4o-mini", - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) + coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL) chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, @@ -637,24 +607,9 @@ func TestExploreSubagentIsReadOnly(t *testing.T) { ) }) - _, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai-compat", - APIKey: "test-api-key", - BaseURL: openAIURL, - }) - require.NoError(t, err) - - contextLimit := int64(4096) - isDefault := true - _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: "openai-compat", - Model: "gpt-4o-mini", - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) + coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL) - _, err = expClient.CreateChat(ctx, codersdk.CreateChatRequest{ + _, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, WorkspaceID: &workspace.ID, Content: []codersdk.ChatInputPart{ @@ -4947,22 +4902,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) { ) }) - _, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai-compat", - APIKey: "test-api-key", - BaseURL: openAIURL, - }) - require.NoError(t, err) - - contextLimit := int64(4096) - isDefault := true - _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: "openai-compat", - Model: "gpt-4o-mini", - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) + coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL) chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ OrganizationID: user.OrganizationID, @@ -5117,22 +5057,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) { ) }) - _, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai-compat", - APIKey: "test-api-key", - BaseURL: openAIURL, - }) - require.NoError(t, err) - - contextLimit := int64(4096) - isDefault := true - _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: "openai-compat", - Model: "gpt-4o-mini", - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) + coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL) // Create a chat with the stopped workspace pre-associated. chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ @@ -6189,21 +6114,25 @@ func setOpenAIProviderBaseURL( ) { t.Helper() - provider, err := db.GetChatProviderByProvider(ctx, "openai") - require.NoError(t, err) - - _, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - ID: provider.ID, - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: baseURL, - ApiKeyKeyID: provider.ApiKeyKeyID, - Enabled: provider.Enabled, - CentralApiKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserApiKey: provider.AllowUserApiKey, - AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, - }) + providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) require.NoError(t, err) + for _, provider := range providers { + if provider.Type != database.AiProviderTypeOpenai { + continue + } + _, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ + ID: provider.ID, + Name: provider.Name, + DisplayName: provider.DisplayName, + Enabled: provider.Enabled, + BaseUrl: baseURL, + Settings: provider.Settings, + SettingsKeyID: provider.SettingsKeyID, + }) + require.NoError(t, err) + return + } + require.Fail(t, "openai provider not found") } func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) { @@ -7262,10 +7191,11 @@ func TestProcessChat_UserProviderKey_Success(t *testing.T) { true, false, ) - _, err := db.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ - UserID: user.ID, - ChatProviderID: provider.ID, - APIKey: userAPIKey, + _, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: user.ID, + AiProviderID: provider.ID, + APIKey: userAPIKey, }) require.NoError(t, err) @@ -8580,22 +8510,7 @@ func TestAgentContextFilesAndSkillsLoadedIntoChat(t *testing.T) { ) }) - _, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai-compat", - APIKey: "test-api-key", - BaseURL: openAIURL, - }) - require.NoError(t, err) - - contextLimit := int64(4096) - isDefault := true - _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: "openai-compat", - Model: "gpt-4o-mini", - ContextLimit: &contextLimit, - IsDefault: &isDefault, - }) - require.NoError(t, err) + coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL) workspaceID := workspace.ID chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{ diff --git a/coderd/x/chatd/configcache.go b/coderd/x/chatd/configcache.go index e23509df8b302..69470a9473a28 100644 --- a/coderd/x/chatd/configcache.go +++ b/coderd/x/chatd/configcache.go @@ -30,7 +30,7 @@ const ( ) type cachedProviders struct { - providers []database.ChatProvider + providers []database.AIProvider expiresAt time.Time } @@ -74,7 +74,7 @@ type chatConfigCache struct { // Providers (singleton). providers *cachedProviders providerGeneration uint64 - providerFetches singleflight.Group[string, []database.ChatProvider] + providerFetches singleflight.Group[string, []database.AIProvider] // Model configs (keyed by ID). modelTopologyEpoch uint64 @@ -131,7 +131,7 @@ func singleflightDoChan[K comparable, V any]( } } -func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.ChatProvider, error) { +func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.AIProvider, error) { if providers, ok := c.cachedProviders(); ok { return providers, nil } @@ -141,12 +141,12 @@ func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.Chat ctx, &c.providerFetches, fmt.Sprintf("%d:providers", generation), - func() ([]database.ChatProvider, error) { + func() ([]database.AIProvider, error) { if cached, ok := c.cachedProviders(); ok { return cached, nil } - fetched, err := c.db.GetEnabledChatProviders(c.ctx) + fetched, err := c.db.GetAIProviders(c.ctx, database.GetAIProvidersParams{}) if err != nil { return nil, err } @@ -161,7 +161,7 @@ func (c *chatConfigCache) EnabledProviders(ctx context.Context) ([]database.Chat return slices.Clone(providers), nil } -func (c *chatConfigCache) cachedProviders() ([]database.ChatProvider, bool) { +func (c *chatConfigCache) cachedProviders() ([]database.AIProvider, bool) { c.mu.RLock() entry := c.providers c.mu.RUnlock() @@ -188,7 +188,7 @@ func (c *chatConfigCache) providersGeneration() uint64 { return generation } -func (c *chatConfigCache) storeProviders(generation uint64, providers []database.ChatProvider) { +func (c *chatConfigCache) storeProviders(generation uint64, providers []database.AIProvider) { c.mu.Lock() defer c.mu.Unlock() diff --git a/coderd/x/chatd/configcache_test.go b/coderd/x/chatd/configcache_test.go index 8213cd5d9bb0c..0079892126016 100644 --- a/coderd/x/chatd/configcache_test.go +++ b/coderd/x/chatd/configcache_test.go @@ -22,7 +22,7 @@ import ( type stubChatConfigStore struct { database.Store - getEnabledChatProviders func(context.Context) ([]database.ChatProvider, error) + getEnabledChatProviders func(context.Context) ([]database.AIProvider, error) getChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error) getDefaultChatModelConfig func(context.Context) (database.ChatModelConfig, error) getUserChatCustomPrompt func(context.Context, uuid.UUID) (string, error) @@ -35,10 +35,10 @@ type stubChatConfigStore struct { advisorConfigCalls atomic.Int32 } -func (s *stubChatConfigStore) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { +func (s *stubChatConfigStore) GetAIProviders(ctx context.Context, _ database.GetAIProvidersParams) ([]database.AIProvider, error) { s.enabledProvidersCalls.Add(1) if s.getEnabledChatProviders == nil { - panic("unexpected GetEnabledChatProviders call") + panic("unexpected GetAIProviders call") } return s.getEnabledChatProviders(ctx) } @@ -80,9 +80,9 @@ func TestConfigCache_EnabledProviders_CacheHit(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) - providers := []database.ChatProvider{testChatProvider("provider-a")} + providers := []database.AIProvider{testChatProvider("provider-a")} store := &stubChatConfigStore{ - getEnabledChatProviders: func(context.Context) ([]database.ChatProvider, error) { + getEnabledChatProviders: func(context.Context) ([]database.AIProvider, error) { return providers, nil }, } @@ -104,9 +104,9 @@ func TestConfigCache_EnabledProviders_TTLExpiry(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() - return []database.ChatProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil + return []database.AIProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil } cache := newChatConfigCache(ctx, store, clock) @@ -126,9 +126,9 @@ func TestConfigCache_EnabledProviders_Invalidation(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() - return []database.ChatProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil + return []database.AIProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil } cache := newChatConfigCache(ctx, store, clock) @@ -398,12 +398,12 @@ func TestConfigCache_Singleflight(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - providers := []database.ChatProvider{testChatProvider("provider-a")} + providers := []database.AIProvider{testChatProvider("provider-a")} fetchStarted := make(chan struct{}) releaseFetch := make(chan struct{}) var startedOnce sync.Once store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { startedOnce.Do(func() { close(fetchStarted) }) <-releaseFetch return providers, nil @@ -411,7 +411,7 @@ func TestConfigCache_Singleflight(t *testing.T) { cache := newChatConfigCache(ctx, store, clock) const callers = 8 - results := make([][]database.ChatProvider, callers) + results := make([][]database.AIProvider, callers) errs := make([]error, callers) var wg sync.WaitGroup start := make(chan struct{}) @@ -441,13 +441,13 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - firstProviders := []database.ChatProvider{testChatProvider("provider-a")} - secondProviders := []database.ChatProvider{testChatProvider("provider-b")} + firstProviders := []database.AIProvider{testChatProvider("provider-a")} + secondProviders := []database.AIProvider{testChatProvider("provider-b")} fetchStarted := make(chan struct{}) releaseFetch := make(chan struct{}) var startedOnce sync.Once store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() if call == 1 { startedOnce.Do(func() { close(fetchStarted) }) @@ -458,7 +458,7 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) { } cache := newChatConfigCache(ctx, store, clock) - resultCh := make(chan []database.ChatProvider, 1) + resultCh := make(chan []database.AIProvider, 1) errCh := make(chan error, 1) go func() { providers, err := cache.EnabledProviders(ctx) @@ -494,14 +494,14 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - staleProviders := []database.ChatProvider{testChatProvider("provider-stale")} - freshProviders := []database.ChatProvider{testChatProvider("provider-fresh")} + staleProviders := []database.AIProvider{testChatProvider("provider-stale")} + freshProviders := []database.AIProvider{testChatProvider("provider-fresh")} firstStarted := make(chan struct{}) secondStarted := make(chan struct{}) releaseFirst := make(chan struct{}) releaseSecond := make(chan struct{}) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.ChatProvider, error) { + store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { switch call := store.enabledProvidersCalls.Load(); call { case 1: close(firstStarted) @@ -518,7 +518,7 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing cache := newChatConfigCache(ctx, store, clock) type result struct { - providers []database.ChatProvider + providers []database.AIProvider err error } @@ -670,11 +670,12 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightModelConfig(t *testi require.Equal(t, int32(2), store.modelConfigByIDCalls.Load()) } -func testChatProvider(name string) database.ChatProvider { - return database.ChatProvider{ +func testChatProvider(name string) database.AIProvider { + return database.AIProvider{ ID: uuid.New(), - Provider: name, - DisplayName: name, + Type: database.AIProviderType(name), + Name: name, + DisplayName: sql.NullString{String: name, Valid: true}, Enabled: true, CreatedAt: time.Unix(0, 0).UTC(), UpdatedAt: time.Unix(0, 0).UTC(), @@ -737,19 +738,19 @@ func TestConfigCache_CallerCancellation(t *testing.T) { name: "EnabledProviders", setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) { var once sync.Once - store.getEnabledChatProviders = func(ctx context.Context) ([]database.ChatProvider, error) { + store.getEnabledChatProviders = func(ctx context.Context) ([]database.AIProvider, error) { once.Do(func() { close(started) }) select { case <-ctx.Done(): return nil, ctx.Err() case <-release: - return []database.ChatProvider{testChatProvider("p")}, nil + return []database.AIProvider{testChatProvider("p")}, nil } } }, setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) { var once sync.Once - store.getEnabledChatProviders = func(ctx context.Context) ([]database.ChatProvider, error) { + store.getEnabledChatProviders = func(ctx context.Context) ([]database.AIProvider, error) { once.Do(func() { close(started) }) <-ctx.Done() return nil, ctx.Err() diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index 158a3ca1c6e89..b786c9125425d 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -91,11 +91,15 @@ func (p *Server) providerConfigured(ctx context.Context, provider string) (bool, dbProviders, err := p.configCache.EnabledProviders(ctx) if err != nil { - return false, xerrors.Errorf("list enabled chat providers: %w", err) + return false, xerrors.Errorf("list enabled AI providers: %w", err) } for _, prov := range dbProviders { - if chatprovider.NormalizeProvider(prov.Provider) == normalizedProvider && - strings.TrimSpace(prov.APIKey) != "" { + configuredProvider, err := p.aiProviderConfig(ctx, prov) + if err != nil { + return false, err + } + if chatprovider.NormalizeProvider(configuredProvider.Provider) == normalizedProvider && + strings.TrimSpace(configuredProvider.APIKey) != "" { return true, nil } } @@ -178,12 +182,12 @@ func validateModelConfigAndResolveProvider( } func enabledProviderContainsName( - providers []database.ChatProvider, + providers []database.AIProvider, providerName string, ) bool { normalizedProviderName := chatprovider.NormalizeProvider(providerName) for _, provider := range providers { - if chatprovider.NormalizeProvider(provider.Provider) == normalizedProviderName { + if chatprovider.NormalizeProvider(string(provider.Type)) == normalizedProviderName { return true } } diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index 2aecedcfbeeb3..a0cdf0e044a3c 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -2,6 +2,7 @@ package chatd import ( "context" + "database/sql" "encoding/json" "sync" "testing" @@ -183,13 +184,15 @@ func seedInternalChatDeps( UserID: user.ID, OrganizationID: org.ID, }) - dbgen.ChatProvider(t, db, database.ChatProvider{ + provider := dbgen.ChatProvider(t, db, database.ChatProvider{ Provider: "openai", DisplayName: "OpenAI", }) model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ - IsDefault: true, + Provider: "openai", + AiProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + IsDefault: true, }) return user, org, model @@ -401,19 +404,20 @@ func insertInternalChatProvider( centralAPIKeyEnabled bool, allowUserAPIKey bool, allowCentralAPIKeyFallback bool, -) database.ChatProvider { +) database.AIProvider { t.Helper() - providerConfig := dbgen.ChatProvider(t, db, database.ChatProvider{ - Provider: provider, - DisplayName: provider, - CreatedBy: uuid.NullUUID{UUID: userID, Valid: true}, - }, func(p *database.InsertChatProviderParams) { - p.APIKey = apiKey - p.CentralApiKeyEnabled = centralAPIKeyEnabled - p.AllowUserApiKey = allowUserAPIKey - p.AllowCentralApiKeyFallback = allowCentralAPIKeyFallback + providerConfig := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AIProviderType(provider), + Name: "test-" + uuid.NewString(), + DisplayName: sql.NullString{String: provider, Valid: true}, }) + if apiKey != "" { + dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: providerConfig.ID, + APIKey: apiKey, + }) + } return providerConfig } diff --git a/coderd/x/chatd/title_override_test.go b/coderd/x/chatd/title_override_test.go index 4fc0b2badcf60..8a5c7e946a4e6 100644 --- a/coderd/x/chatd/title_override_test.go +++ b/coderd/x/chatd/title_override_test.go @@ -246,7 +246,7 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) { db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil) db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{ ID: chat.ID, Title: wantTitle, @@ -337,7 +337,7 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback( db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil) generated := &generatedChatTitle{} server := titleOverrideTestServer(db, logger) @@ -437,7 +437,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T) db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil) server := titleOverrideTestServer(db, logger) model, gotConfig, err := server.resolveManualTitleModel( @@ -463,7 +463,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{Type: database.AiProviderTypeOpenai, Enabled: true}}, nil) server := titleOverrideTestServer(db, logger) model, gotConfig, err := server.resolveManualTitleModel( diff --git a/coderd/x/chatd/turn_summary_internal_test.go b/coderd/x/chatd/turn_summary_internal_test.go index 87abdaf881d6e..3d9d5eb9609af 100644 --- a/coderd/x/chatd/turn_summary_internal_test.go +++ b/coderd/x/chatd/turn_summary_internal_test.go @@ -33,16 +33,15 @@ func TestUpdateLastTurnSummaryRejectsStaleWrites(t *testing.T) { OrganizationID: org.ID, }) - _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, - CentralApiKeyEnabled: true, + provider := dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, }) - require.NoError(t, err) modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + AiProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, Provider: "openai", Model: "test-model", DisplayName: "Test Model", @@ -102,16 +101,15 @@ func TestPendingChatPersistsSummaryButSkipsWebPush(t *testing.T) { OrganizationID: org.ID, }) - _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - Enabled: true, - CentralApiKeyEnabled: true, + provider := dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, }) - require.NoError(t, err) modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + AiProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, Provider: "openai", Model: "test-model", DisplayName: "Test Model", diff --git a/coderd/x/gitsync/worker_test.go b/coderd/x/gitsync/worker_test.go index 833ad5fae9197..dabe0d1e8ef3f 100644 --- a/coderd/x/gitsync/worker_test.go +++ b/coderd/x/gitsync/worker_test.go @@ -944,7 +944,7 @@ func TestWorker(t *testing.T) { user := dbgen.User(t, db, database.User{}) org := dbgen.Organization(t, db, database.Organization{}) - // 3. Set up FK chain: chat_providers -> chat_model_configs -> chats. + // 3. Set up FK chain: ai_providers -> chat_model_configs -> chats. _ = dbgen.ChatProvider(t, db, database.ChatProvider{}) modelCfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ diff --git a/enterprise/coderd/x/chatd/chatd_test.go b/enterprise/coderd/x/chatd/chatd_test.go index f94df3be8d933..215bde97481b1 100644 --- a/enterprise/coderd/x/chatd/chatd_test.go +++ b/enterprise/coderd/x/chatd/chatd_test.go @@ -122,14 +122,20 @@ func seedChatDependencies( UserID: user.ID, OrganizationID: org.ID, }) - _ = dbgen.ChatProvider(t, db, database.ChatProvider{ - BaseUrl: safetyNet.URL, - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "test-" + uuid.NewString(), + BaseUrl: safetyNet.URL, + }) + dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, }) model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - IsDefault: true, + Provider: "openai", + AiProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + IsDefault: true, }) return user, org, model } @@ -186,21 +192,25 @@ func setOpenAIProviderBaseURL( ) { t.Helper() - provider, err := db.GetChatProviderByProvider(ctx, "openai") - require.NoError(t, err) - - _, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - ID: provider.ID, - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: baseURL, - CentralApiKeyEnabled: true, - AllowUserApiKey: false, - AllowCentralApiKeyFallback: false, - ApiKeyKeyID: provider.ApiKeyKeyID, - Enabled: provider.Enabled, - }) + providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) require.NoError(t, err) + for _, provider := range providers { + if provider.Type != database.AiProviderTypeOpenai { + continue + } + _, err = db.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ + ID: provider.ID, + Name: provider.Name, + DisplayName: provider.DisplayName, + Enabled: provider.Enabled, + BaseUrl: baseURL, + Settings: provider.Settings, + SettingsKeyID: provider.SettingsKeyID, + }) + require.NoError(t, err) + return + } + require.Fail(t, "openai provider not found") } func TestSubscribeRelayReconnectsOnDrop(t *testing.T) { diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index b5975995652a3..b3f130c0ed90f 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -74,29 +74,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe } } - userProviderKeys, err := cryptTx.GetUserChatProviderKeys(ctx, uid) - if err != nil { - return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err) - } - for _, userProviderKey := range userProviderKeys { - if strings.TrimSpace(userProviderKey.APIKey) == "" { - continue - } - if userProviderKey.ApiKeyKeyID.Valid && userProviderKey.ApiKeyKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - continue - } - if _, err := cryptTx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{ - UserID: userProviderKey.UserID, - ChatProviderID: userProviderKey.ChatProviderID, - APIKey: userProviderKey.APIKey, - ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required - }); err != nil { - return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err) - } - log.Debug(ctx, "encrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - } - userSecrets, err := cryptTx.ListUserSecretsWithValues(ctx, uid) if err != nil { return xerrors.Errorf("get user secrets for user %s: %w", uid, err) @@ -134,35 +111,6 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe log.Debug(ctx, "encrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } - providers, err := cryptDB.GetChatProviders(ctx) - if err != nil { - return xerrors.Errorf("get chat providers: %w", err) - } - log.Info(ctx, "encrypting chat provider keys", slog.F("provider_count", len(providers))) - for idx, provider := range providers { - if strings.TrimSpace(provider.APIKey) == "" { - continue - } - if provider.ApiKeyKeyID.Valid && provider.ApiKeyKeyID.String == ciphers[0].HexDigest() { - log.Debug(ctx, "skipping chat provider", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - continue - } - if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: provider.BaseUrl, - ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required - Enabled: provider.Enabled, - CentralApiKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserApiKey: provider.AllowUserApiKey, - AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, - ID: provider.ID, - }); err != nil { - return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err) - } - log.Debug(ctx, "encrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - } - aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true}) if err != nil { return xerrors.Errorf("get ai providers: %w", err) @@ -313,26 +261,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph } } - userProviderKeys, err := tx.GetUserChatProviderKeys(ctx, uid) - if err != nil { - return xerrors.Errorf("get user chat provider keys for user %s: %w", uid, err) - } - for _, userProviderKey := range userProviderKeys { - if !userProviderKey.ApiKeyKeyID.Valid { - log.Debug(ctx, "skipping user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1)) - continue - } - if _, err := tx.UpdateUserChatProviderKey(ctx, database.UpdateUserChatProviderKeyParams{ - UserID: userProviderKey.UserID, - ChatProviderID: userProviderKey.ChatProviderID, - APIKey: userProviderKey.APIKey, - ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id - }); err != nil { - return xerrors.Errorf("update user chat provider key user_id=%s chat_provider_id=%s: %w", userProviderKey.UserID, userProviderKey.ChatProviderID, err) - } - log.Debug(ctx, "decrypted user chat provider key", slog.F("user_id", uid), slog.F("chat_provider_id", userProviderKey.ChatProviderID), slog.F("current", idx+1)) - } - userSecrets, err := tx.ListUserSecretsWithValues(ctx, uid) if err != nil { return xerrors.Errorf("get user secrets for user %s: %w", uid, err) @@ -370,31 +298,6 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph log.Debug(ctx, "decrypted user tokens", slog.F("user_id", uid), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } - providers, err := cryptDB.GetChatProviders(ctx) - if err != nil { - return xerrors.Errorf("get chat providers: %w", err) - } - log.Info(ctx, "decrypting chat provider keys", slog.F("provider_count", len(providers))) - for idx, provider := range providers { - if !provider.ApiKeyKeyID.Valid { - continue - } - if _, err := cryptDB.UpdateChatProvider(ctx, database.UpdateChatProviderParams{ - DisplayName: provider.DisplayName, - APIKey: provider.APIKey, - BaseUrl: provider.BaseUrl, - ApiKeyKeyID: sql.NullString{}, // we explicitly want to clear the key id - Enabled: provider.Enabled, - CentralApiKeyEnabled: provider.CentralApiKeyEnabled, - AllowUserApiKey: provider.AllowUserApiKey, - AllowCentralApiKeyFallback: provider.AllowCentralApiKeyFallback, - ID: provider.ID, - }); err != nil { - return xerrors.Errorf("update chat provider id=%s provider=%s: %w", provider.ID, provider.Provider, err) - } - log.Debug(ctx, "decrypted chat provider key", slog.F("provider", provider.Provider), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) - } - aiProviders, err := cryptDB.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDeleted: true, IncludeDisabled: true}) if err != nil { return xerrors.Errorf("get ai providers: %w", err) @@ -475,16 +378,12 @@ DELETE FROM user_links DELETE FROM external_auth_links WHERE oauth_access_token_key_id IS NOT NULL OR oauth_refresh_token_key_id IS NOT NULL; -DELETE FROM user_chat_provider_keys +DELETE FROM user_ai_provider_keys WHERE api_key_key_id IS NOT NULL; DELETE FROM user_ai_provider_keys WHERE api_key_key_id IS NOT NULL; DELETE FROM user_secrets WHERE value_key_id IS NOT NULL; -UPDATE chat_providers - SET api_key = '', - api_key_key_id = NULL - WHERE api_key_key_id IS NOT NULL; UPDATE ai_providers SET settings = NULL, settings_key_id = NULL @@ -502,9 +401,9 @@ func Delete(ctx context.Context, log slog.Logger, sqlDB *sql.DB) error { store := database.New(sqlDB) _, err := sqlDB.ExecContext(ctx, sqlDeleteEncryptedUserTokens) if err != nil { - return xerrors.Errorf("delete encrypted tokens and chat provider keys: %w", err) + return xerrors.Errorf("delete encrypted tokens and AI provider keys: %w", err) } - log.Info(ctx, "deleted encrypted user tokens and chat provider API keys") + log.Info(ctx, "deleted encrypted user tokens and AI provider API keys") log.Info(ctx, "revoking all active keys") keys, err := store.GetDBCryptKeys(ctx) diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index ed4e90770be4a..46ee281c6b3e1 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -576,92 +576,6 @@ func (db *dbCrypt) UpdateEncryptedAIProviderKey(ctx context.Context, params data return key, nil } -func (db *dbCrypt) GetChatProviderByID(ctx context.Context, id uuid.UUID) (database.ChatProvider, error) { - provider, err := db.Store.GetChatProviderByID(ctx, id) - if err != nil { - return database.ChatProvider{}, err - } - if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - return provider, nil -} - -func (db *dbCrypt) GetChatProviderByProvider(ctx context.Context, providerName string) (database.ChatProvider, error) { - provider, err := db.Store.GetChatProviderByProvider(ctx, providerName) - if err != nil { - return database.ChatProvider{}, err - } - if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - return provider, nil -} - -func (db *dbCrypt) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - providers, err := db.Store.GetChatProviders(ctx) - if err != nil { - return nil, err - } - - for i := range providers { - if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil { - return nil, err - } - } - - return providers, nil -} - -func (db *dbCrypt) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) { - providers, err := db.Store.GetEnabledChatProviders(ctx) - if err != nil { - return nil, err - } - - for i := range providers { - if err := db.decryptField(&providers[i].APIKey, providers[i].ApiKeyKeyID); err != nil { - return nil, err - } - } - - return providers, nil -} - -func (db *dbCrypt) InsertChatProvider(ctx context.Context, params database.InsertChatProviderParams) (database.ChatProvider, error) { - if strings.TrimSpace(params.APIKey) == "" { - params.ApiKeyKeyID = sql.NullString{} - } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - - provider, err := db.Store.InsertChatProvider(ctx, params) - if err != nil { - return database.ChatProvider{}, err - } - if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - return provider, nil -} - -func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.UpdateChatProviderParams) (database.ChatProvider, error) { - if strings.TrimSpace(params.APIKey) == "" { - params.ApiKeyKeyID = sql.NullString{} - } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - - provider, err := db.Store.UpdateChatProvider(ctx, params) - if err != nil { - return database.ChatProvider{}, err - } - if err := db.decryptField(&provider.APIKey, provider.ApiKeyKeyID); err != nil { - return database.ChatProvider{}, err - } - return provider, nil -} - func (db *dbCrypt) decryptUserAIProviderKey(key *database.UserAiProviderKey) error { return db.decryptField(&key.APIKey, key.ApiKeyKeyID) } @@ -754,57 +668,6 @@ func (db *dbCrypt) UpdateEncryptedUserAIProviderKey(ctx context.Context, params return key, nil } -func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error { - return db.decryptField(&key.APIKey, key.ApiKeyKeyID) -} - -func (db *dbCrypt) GetUserChatProviderKeys(ctx context.Context, userID uuid.UUID) ([]database.UserChatProviderKey, error) { - keys, err := db.Store.GetUserChatProviderKeys(ctx, userID) - if err != nil { - return nil, err - } - for i := range keys { - if err := db.decryptUserChatProviderKey(&keys[i]); err != nil { - return nil, err - } - } - return keys, nil -} - -func (db *dbCrypt) UpsertUserChatProviderKey(ctx context.Context, params database.UpsertUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - if strings.TrimSpace(params.APIKey) == "" { - params.ApiKeyKeyID = sql.NullString{} - } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { - return database.UserChatProviderKey{}, err - } - - key, err := db.Store.UpsertUserChatProviderKey(ctx, params) - if err != nil { - return database.UserChatProviderKey{}, err - } - if err := db.decryptUserChatProviderKey(&key); err != nil { - return database.UserChatProviderKey{}, err - } - return key, nil -} - -func (db *dbCrypt) UpdateUserChatProviderKey(ctx context.Context, params database.UpdateUserChatProviderKeyParams) (database.UserChatProviderKey, error) { - if strings.TrimSpace(params.APIKey) == "" { - params.ApiKeyKeyID = sql.NullString{} - } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { - return database.UserChatProviderKey{}, err - } - - key, err := db.Store.UpdateUserChatProviderKey(ctx, params) - if err != nil { - return database.UserChatProviderKey{}, err - } - if err := db.decryptUserChatProviderKey(&key); err != nil { - return database.UserChatProviderKey{}, err - } - return key, nil -} - // decryptMCPServerConfig decrypts all encrypted fields on a // single MCPServerConfig in place. func (db *dbCrypt) decryptMCPServerConfig(cfg *database.MCPServerConfig) error { diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 718992ddbd27e..2f08ce1e44bc4 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1562,7 +1562,7 @@ func TestMCPServerUserTokens(t *testing.T) { }) } -func TestUserChatProviderKeys(t *testing.T) { +func TestUserAIProviderKeys(t *testing.T) { t.Parallel() ctx := context.Background() @@ -1577,19 +1577,19 @@ func TestUserChatProviderKeys(t *testing.T) { t *testing.T, crypt *dbCrypt, ciphers []Cipher, - ) (database.ChatProvider, database.UserChatProviderKey) { + ) (database.AIProvider, database.UserAiProviderKey) { t.Helper() user := dbgen.User(t, crypt, database.User{}) - provider := dbgen.ChatProvider(t, crypt, database.ChatProvider{ - AllowUserApiKey: true, - }, func(params *database.InsertChatProviderParams) { - params.APIKey = "" + provider := dbgen.AIProvider(t, crypt, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "test-openai-" + uuid.NewString(), }) - key, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ - UserID: user.ID, - ChatProviderID: provider.ID, - APIKey: initialAPIKey, + key, err := crypt.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: user.ID, + AiProviderID: provider.ID, + APIKey: initialAPIKey, }) require.NoError(t, err) require.Equal(t, initialAPIKey, key.APIKey) @@ -1598,18 +1598,18 @@ func TestUserChatProviderKeys(t *testing.T) { } getUserChatProviderKey := func(t *testing.T, store interface { - GetUserChatProviderKeys(context.Context, uuid.UUID) ([]database.UserChatProviderKey, error) + GetUserAIProviderKeysByUserID(context.Context, uuid.UUID) ([]database.UserAiProviderKey, error) }, userID uuid.UUID, providerID uuid.UUID, - ) database.UserChatProviderKey { + ) database.UserAiProviderKey { t.Helper() - keys, err := store.GetUserChatProviderKeys(ctx, userID) + keys, err := store.GetUserAIProviderKeysByUserID(ctx, userID) require.NoError(t, err) require.Len(t, keys, 1) - require.Equal(t, providerID, keys[0].ChatProviderID) + require.Equal(t, providerID, keys[0].AiProviderID) return keys[0] } - t.Run("UpsertUserChatProviderKeyCreatesValue", func(t *testing.T) { + t.Run("UpsertUserAIProviderKeyCreatesValue", func(t *testing.T) { t.Parallel() db, crypt, ciphers := setup(t) provider, key := insertProviderAndKey(t, crypt, ciphers) @@ -1624,12 +1624,12 @@ func TestUserChatProviderKeys(t *testing.T) { requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey) }) - t.Run("GetUserChatProviderKeys", func(t *testing.T) { + t.Run("GetUserAIProviderKeysByUserID", func(t *testing.T) { t.Parallel() _, crypt, ciphers := setup(t) _, key := insertProviderAndKey(t, crypt, ciphers) - keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID) + keys, err := crypt.GetUserAIProviderKeysByUserID(ctx, key.UserID) require.NoError(t, err) require.Len(t, keys, 1) require.Equal(t, key.ID, keys[0].ID) @@ -1637,15 +1637,16 @@ func TestUserChatProviderKeys(t *testing.T) { require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String) }) - t.Run("UpsertUserChatProviderKeyUpdatesValue", func(t *testing.T) { + t.Run("UpsertUserAIProviderKeyUpdatesValue", func(t *testing.T) { t.Parallel() db, crypt, ciphers := setup(t) provider, key := insertProviderAndKey(t, crypt, ciphers) - updated, err := crypt.UpsertUserChatProviderKey(ctx, database.UpsertUserChatProviderKeyParams{ - UserID: key.UserID, - ChatProviderID: provider.ID, - APIKey: updatedAPIKey, + updated, err := crypt.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: key.UserID, + AiProviderID: provider.ID, + APIKey: updatedAPIKey, }) require.NoError(t, err) require.Equal(t, key.ID, updated.ID) @@ -1658,7 +1659,7 @@ func TestUserChatProviderKeys(t *testing.T) { require.Equal(t, updatedAPIKey, got.APIKey) require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String) - keys, err := crypt.GetUserChatProviderKeys(ctx, key.UserID) + keys, err := crypt.GetUserAIProviderKeysByUserID(ctx, key.UserID) require.NoError(t, err) require.Len(t, keys, 1) require.Equal(t, updatedAPIKey, keys[0].APIKey) From a845bdc10f1e59693d8cb61e776f0518dd4bf355 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 18:09:28 +0000 Subject: [PATCH 2/8] test(enterprise/coderd): use AI provider chat setup --- enterprise/coderd/exp_chats_test.go | 148 ++++++++++++---------------- 1 file changed, 64 insertions(+), 84 deletions(-) diff --git a/enterprise/coderd/exp_chats_test.go b/enterprise/coderd/exp_chats_test.go index ea49812ee04b7..cb64224171eb4 100644 --- a/enterprise/coderd/exp_chats_test.go +++ b/enterprise/coderd/exp_chats_test.go @@ -24,6 +24,30 @@ import ( "github.com/coder/websocket" ) +func createAIProviderForEnterpriseChatTest( + ctx context.Context, + t testing.TB, + client *codersdk.ExperimentalClient, + baseURL string, +) codersdk.AIProvider { + t.Helper() + + enabled := true + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-openai-" + uuid.NewString(), + DisplayName: "OpenAI", + BaseURL: baseURL, + Enabled: &enabled, + }) + require.NoError(t, err) + _, err = client.CreateAIProviderKey(ctx, provider.ID, codersdk.CreateAIProviderKeyRequest{ + APIKey: "test", + }) + require.NoError(t, err) + return provider +} + func TestChatStreamRelay(t *testing.T) { t.Parallel() @@ -74,18 +98,12 @@ func TestChatStreamRelay(t *testing.T) { return chattest.OpenAINonStreamingResponse("ok") }) - //nolint:gocritic // Test uses owner client to configure chat providers. - provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test", - BaseURL: openai, - }) - require.NoError(t, err) - require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, provider.Source) + expClient := codersdk.NewExperimentalClient(firstClient) + provider := createAIProviderForEnterpriseChatTest(ctx, t, expClient, openai) - model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + model, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-4", DisplayName: "GPT-4", ContextLimit: &[]int64{1000}[0], @@ -264,17 +282,12 @@ func TestChatStreamRelay(t *testing.T) { return chattest.OpenAINonStreamingResponse("ok") }) - //nolint:gocritic // Test uses owner client to configure chat providers. - provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test", - BaseURL: openai, - }) - require.NoError(t, err) + expClient := codersdk.NewExperimentalClient(firstClient) + provider := createAIProviderForEnterpriseChatTest(ctx, t, expClient, openai) - model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + model, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-4", DisplayName: "GPT-4", ContextLimit: &[]int64{1000}[0], @@ -435,17 +448,12 @@ func TestChatStreamRelay(t *testing.T) { return chattest.OpenAINonStreamingResponse("ok") }) - //nolint:gocritic // Test uses owner client to configure providers. - provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test", - BaseURL: openai, - }) - require.NoError(t, err) + expClient := codersdk.NewExperimentalClient(firstClient) + provider := createAIProviderForEnterpriseChatTest(ctx, t, expClient, openai) - model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + model, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-4", DisplayName: "GPT-4", ContextLimit: &[]int64{1000}[0], @@ -607,17 +615,12 @@ func TestChatStreamRelay(t *testing.T) { return chattest.OpenAINonStreamingResponse("ok") }) - //nolint:gocritic // Test uses owner client to configure providers. - provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test", - BaseURL: openai, - }) - require.NoError(t, err) + expClient := codersdk.NewExperimentalClient(firstClient) + provider := createAIProviderForEnterpriseChatTest(ctx, t, expClient, openai) - model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + model, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-4", DisplayName: "GPT-4", ContextLimit: &[]int64{1000}[0], @@ -754,17 +757,12 @@ func TestChatStreamRelay(t *testing.T) { return chattest.OpenAINonStreamingResponse("ok") }) - //nolint:gocritic // Test uses owner client to configure chat providers. - provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test", - BaseURL: openai, - }) - require.NoError(t, err) + expClient := codersdk.NewExperimentalClient(firstClient) + provider := createAIProviderForEnterpriseChatTest(ctx, t, expClient, openai) - model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + model, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-4", DisplayName: "GPT-4", ContextLimit: &[]int64{1000}[0], @@ -954,17 +952,7 @@ func TestChatModelConfigDefault(t *testing.T) { client, _ := coderdenttest.New(t, nil) expClient := codersdk.NewExperimentalClient(client) - //nolint:gocritic // Test uses owner client to configure chat providers. - provider, err := expClient.CreateChatProvider( - ctx, - codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test", - BaseURL: "https://example.com", - }, - ) - require.NoError(t, err) + provider := createAIProviderForEnterpriseChatTest(ctx, t, expClient, "https://example.com") contextLimit := int64(1000) compressionThreshold := int32(70) @@ -974,7 +962,8 @@ func TestChatModelConfigDefault(t *testing.T) { firstModel, err := expClient.CreateChatModelConfig( ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-5-a", DisplayName: "GPT 5 A", IsDefault: &trueValue, @@ -988,7 +977,8 @@ func TestChatModelConfigDefault(t *testing.T) { secondModel, err := expClient.CreateChatModelConfig( ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-5-b", DisplayName: "GPT 5 B", IsDefault: &trueValue, @@ -1115,16 +1105,11 @@ func TestCreateChatNonDefaultOrg(t *testing.T) { }) expClient := codersdk.NewExperimentalClient(client) - // Set up a chat provider and model config. - provider, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - BaseURL: "https://example.com", - }) - require.NoError(t, err) - _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + // Set up an AI provider and model config. + provider := createAIProviderForEnterpriseChatTest(ctx, t, expClient, "https://example.com") + _, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-4o-mini", DisplayName: "Test Model", IsDefault: ptr.Ref(true), @@ -1191,16 +1176,11 @@ func TestListChats_OrgAdminOnlySeesOwnChats(t *testing.T) { }) expClient := codersdk.NewExperimentalClient(client) - // Set up a chat provider and model config. - provider, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{ - Provider: "openai", - DisplayName: "OpenAI", - APIKey: "test-key", - BaseURL: "https://example.com", - }) - require.NoError(t, err) - _, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ - Provider: provider.Provider, + // Set up an AI provider and model config. + provider := createAIProviderForEnterpriseChatTest(ctx, t, expClient, "https://example.com") + _, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ + Provider: string(provider.Type), + AIProviderID: &provider.ID, Model: "gpt-4o-mini", DisplayName: "Test Model", IsDefault: ptr.Ref(true), From 389f3f9904cc010a89eb366fcf6b3a86b5ef1c9c Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 21:42:34 +0000 Subject: [PATCH 3/8] fix(coderd): batch AI provider key lookups --- coderd/database/dbauthz/dbauthz.go | 7 ++ coderd/database/dbauthz/dbauthz_test.go | 9 ++ coderd/database/dbmetrics/querymetrics.go | 8 ++ coderd/database/dbmock/dbmock.go | 15 +++ coderd/database/querier.go | 3 + coderd/database/queries.sql.go | 45 ++++++++ coderd/database/queries/ai_provider_keys.sql | 14 +++ coderd/exp_chats.go | 68 ++++++++++-- coderd/x/chatd/chatd.go | 46 +++++--- coderd/x/chatd/chatd_internal_test.go | 10 +- coderd/x/chatd/configcache_test.go | 42 ++++---- enterprise/dbcrypt/cliutil.go | 2 - enterprise/dbcrypt/dbcrypt_internal_test.go | 108 ------------------- 13 files changed, 220 insertions(+), 157 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 1bf318b2143b9..9656d1092f8b3 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2554,6 +2554,13 @@ func (q *querier) GetAIProviderKeysByProviderID(ctx context.Context, providerID return q.db.GetAIProviderKeysByProviderID(ctx, providerID) } +func (q *querier) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil { + return nil, err + } + return q.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs) +} + func (q *querier) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil { return nil, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 6b8869e1387da..7c7d756c7c233 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -6258,6 +6258,15 @@ func (s *MethodTestSuite) TestAIBridge() { dbm.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), provider.ID).Return([]database.AIProviderKey{keyA, keyB}, nil).AnyTimes() check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.AIProviderKey{keyA, keyB}) })) + s.Run("GetAIProviderKeysByProviderIDs", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + providerA := testutil.Fake(s.T(), faker, database.AIProvider{}) + providerB := testutil.Fake(s.T(), faker, database.AIProvider{}) + providerIDs := []uuid.UUID{providerA.ID, providerB.ID} + keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: providerA.ID}) + keyB := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: providerB.ID}) + dbm.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), providerIDs).Return([]database.AIProviderKey{keyA, keyB}, nil).AnyTimes() + check.Args(providerIDs).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.AIProviderKey{keyA, keyB}) + })) s.Run("GetAIProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{}) keyB := testutil.Fake(s.T(), faker, database.AIProviderKey{}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index ef11552f42484..733ade0b282ff 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1049,6 +1049,14 @@ func (m queryMetricsStore) GetAIProviderKeysByProviderID(ctx context.Context, pr return r0, r1 } +func (m queryMetricsStore) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]database.AIProviderKey, error) { + start := time.Now() + r0, r1 := m.s.GetAIProviderKeysByProviderIDs(ctx, providerIds) + m.queryLatencies.WithLabelValues("GetAIProviderKeysByProviderIDs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderKeysByProviderIDs").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) { start := time.Now() r0, r1 := m.s.GetAIProviders(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 07fae8e042d3c..007643f2771de 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1817,6 +1817,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderKeysByProviderID(ctx, providerID a return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeysByProviderID", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeysByProviderID), ctx, providerID) } +// GetAIProviderKeysByProviderIDs mocks base method. +func (m *MockStore) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]database.AIProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAIProviderKeysByProviderIDs", ctx, providerIds) + ret0, _ := ret[0].([]database.AIProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAIProviderKeysByProviderIDs indicates an expected call of GetAIProviderKeysByProviderIDs. +func (mr *MockStoreMockRecorder) GetAIProviderKeysByProviderIDs(ctx, providerIds any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeysByProviderIDs", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeysByProviderIDs), ctx, providerIds) +} + // GetAIProviders mocks base method. func (m *MockStore) GetAIProviders(ctx context.Context, arg database.GetAIProvidersParams) ([]database.AIProvider, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index a06e37ba62128..4c2de15111d7c 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -261,6 +261,9 @@ type sqlcQuerier interface { // key per provider; multiple keys are stored to support future // failover and rotation flows. GetAIProviderKeysByProviderID(ctx context.Context, providerID uuid.UUID) ([]AIProviderKey, error) + // Returns all keys for the requested providers, ordered by provider then created_at ASC + // so callers can select the oldest non-empty key per provider without issuing N queries. + GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]AIProviderKey, error) // Returns AI provider rows. Soft-deleted and disabled rows are excluded // unless include_deleted or include_disabled is set. GetAIProviders(ctx context.Context, arg GetAIProvidersParams) ([]AIProvider, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index fe2f09b94aa43..b66bf76ef1305 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -236,6 +236,51 @@ func (q *sqlQuerier) GetAIProviderKeysByProviderID(ctx context.Context, provider return items, nil } +const getAIProviderKeysByProviderIDs = `-- name: GetAIProviderKeysByProviderIDs :many +SELECT + id, provider_id, api_key, api_key_key_id, created_at, updated_at +FROM + ai_provider_keys +WHERE + provider_id = ANY($1::uuid[]) +ORDER BY + provider_id ASC, + created_at ASC, + id ASC +` + +// Returns all keys for the requested providers, ordered by provider then created_at ASC +// so callers can select the oldest non-empty key per provider without issuing N queries. +func (q *sqlQuerier) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIds []uuid.UUID) ([]AIProviderKey, error) { + rows, err := q.db.QueryContext(ctx, getAIProviderKeysByProviderIDs, pq.Array(providerIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AIProviderKey + for rows.Next() { + var i AIProviderKey + if err := rows.Scan( + &i.ID, + &i.ProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertAIProviderKey = `-- name: InsertAIProviderKey :one INSERT INTO ai_provider_keys ( id, diff --git a/coderd/database/queries/ai_provider_keys.sql b/coderd/database/queries/ai_provider_keys.sql index 018c434ef6eaa..f5ca8328b00a1 100644 --- a/coderd/database/queries/ai_provider_keys.sql +++ b/coderd/database/queries/ai_provider_keys.sql @@ -21,6 +21,20 @@ ORDER BY created_at ASC, id ASC; +-- name: GetAIProviderKeysByProviderIDs :many +-- Returns all keys for the requested providers, ordered by provider then created_at ASC +-- so callers can select the oldest non-empty key per provider without issuing N queries. +SELECT + * +FROM + ai_provider_keys +WHERE + provider_id = ANY(@provider_ids::uuid[]) +ORDER BY + provider_id ASC, + created_at ASC, + id ASC; + -- name: GetAIProviderKeys :many -- Returns every AI provider key row, including those belonging to a -- soft-deleted provider, so the dbcrypt key rotation utility can diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 4be19daeae68f..740370163ff70 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -11,6 +11,7 @@ import ( "mime" "net/http" "net/http/httptest" + "net/url" "slices" "strconv" "strings" @@ -783,18 +784,17 @@ func (api *API) getUserChatProviderAvailability( return userChatModelAvailability{}, err } + configuredProviders, err := api.configuredProvidersFromAIProviders(systemCtx, enabledProviders) + if err != nil { + return userChatModelAvailability{}, err + } availability := userChatModelAvailability{ - configuredProviders: make([]chatprovider.ConfiguredProvider, 0, len(enabledProviders)), + configuredProviders: configuredProviders, configuredModels: make([]chatprovider.ConfiguredModel, 0, len(enabledModels)), enabledModels: enabledModels, enabledProviderNames: make(map[string]struct{}, len(enabledProviders)), } - for _, provider := range enabledProviders { - configuredProvider, err := api.configuredProviderFromAIProvider(systemCtx, provider) - if err != nil { - return userChatModelAvailability{}, err - } - availability.configuredProviders = append(availability.configuredProviders, configuredProvider) + for _, configuredProvider := range configuredProviders { normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider) if normalizedProvider != "" { availability.enabledProviderNames[normalizedProvider] = struct{}{} @@ -6908,11 +6908,31 @@ func (api *API) deleteUserAIProviderKey(rw http.ResponseWriter, r *http.Request) httpapi.Write(ctx, rw, http.StatusNoContent, nil) } -func (api *API) configuredProviderFromAIProvider(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) { - keys, err := api.Database.GetAIProviderKeysByProviderID(ctx, provider.ID) +func (api *API) configuredProvidersFromAIProviders(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) { + providerIDs := make([]uuid.UUID, 0, len(providers)) + for _, provider := range providers { + providerIDs = append(providerIDs, provider.ID) + } + keys, err := api.Database.GetAIProviderKeysByProviderIDs(ctx, providerIDs) if err != nil { - return chatprovider.ConfiguredProvider{}, err + return nil, xerrors.Errorf("get AI provider keys: %w", err) } + keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers)) + for _, key := range keys { + keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key) + } + configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers)) + for _, provider := range providers { + configuredProvider, err := api.configuredProviderFromAIProviderKeys(provider, keysByProviderID[provider.ID]) + if err != nil { + return nil, err + } + configuredProviders = append(configuredProviders, configuredProvider) + } + return configuredProviders, nil +} + +func (api *API) configuredProviderFromAIProviderKeys(provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) { apiKey := "" for _, key := range keys { if trimmed := strings.TrimSpace(key.APIKey); trimmed != "" { @@ -7690,6 +7710,34 @@ func isZeroChatModelProviderOptions(options *codersdk.ChatModelProviderOptions) options.Vercel == nil } +func normalizeChatProviderBaseURL(raw string) (string, error) { + trimmed := strings.TrimSpace(raw) + if trimmed == "" { + return "", nil + } + + parsed, err := url.Parse(trimmed) + if err != nil { + return "", err + } + if parsed.Scheme == "" || parsed.Host == "" { + return "", xerrors.New("Base URL must be an absolute URL with scheme and host.") + } + if parsed.Scheme != "http" && parsed.Scheme != "https" { + return "", xerrors.New("Base URL scheme must be http or https.") + } + return parsed.String(), nil +} + +const maxChatProviderAPIKeySize = 10240 // 10 KB + +func validateChatProviderAPIKeySize(apiKey string) error { + if len(apiKey) > maxChatProviderAPIKeySize { + return xerrors.Errorf("API key exceeds maximum size of %d bytes", maxChatProviderAPIKeySize) + } + return nil +} + var ( errChatModelConfigNotFound = xerrors.New("chat model config not found") errChatProviderNotConfigured = xerrors.New("chat provider is not configured") diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index fedf3aed895a9..569d3733056f6 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8075,13 +8075,17 @@ func (p *Server) resolveChatModel( } func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) { - if !provider.Enabled { - return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) - } keys, err := p.db.GetAIProviderKeysByProviderID(ctx, provider.ID) if err != nil { return chatprovider.ConfiguredProvider{}, xerrors.Errorf("get AI provider keys: %w", err) } + return p.aiProviderConfigFromKeys(provider, keys) +} + +func (p *Server) aiProviderConfigFromKeys(provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) { + if !provider.Enabled { + return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) + } apiKey := "" for _, key := range keys { if trimmed := strings.TrimSpace(key.APIKey); trimmed != "" { @@ -8100,6 +8104,30 @@ func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvi }, nil } +func (p *Server) aiProviderConfigs(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) { + providerIDs := make([]uuid.UUID, 0, len(providers)) + for _, provider := range providers { + providerIDs = append(providerIDs, provider.ID) + } + keys, err := p.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs) + if err != nil { + return nil, xerrors.Errorf("get AI provider keys: %w", err) + } + keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers)) + for _, key := range keys { + keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key) + } + configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers)) + for _, provider := range providers { + configuredProvider, err := p.aiProviderConfigFromKeys(provider, keysByProviderID[provider.ID]) + if err != nil { + return nil, err + } + configuredProviders = append(configuredProviders, configuredProvider) + } + return configuredProviders, nil +} + func (p *Server) resolveUserProviderAPIKeysForProvider( ctx context.Context, ownerID uuid.UUID, @@ -8159,15 +8187,9 @@ func (p *Server) resolveUserProviderAPIKeys( err, ) } - configuredProviders := make( - []chatprovider.ConfiguredProvider, 0, len(providers), - ) - for _, provider := range providers { - configuredProvider, err := p.aiProviderConfig(ctx, provider) - if err != nil { - return chatprovider.ProviderAPIKeys{}, err - } - configuredProviders = append(configuredProviders, configuredProvider) + configuredProviders, err := p.aiProviderConfigs(ctx, providers) + if err != nil { + return chatprovider.ProviderAPIKeys{}, err } userKeys := []chatprovider.UserProviderKey{} diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 3d5643239792f..97d70a1b50b3a 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -1107,12 +1107,13 @@ func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) { }, } + providerID := uuid.New() db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ - ID: uuid.New(), + ID: providerID, Type: database.AiProviderTypeAnthropic, Enabled: true, }}, nil) - db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return(nil, nil) keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) require.NoError(t, err) @@ -1149,12 +1150,13 @@ func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKe }, } + providerID := uuid.New() db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ - ID: uuid.New(), + ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true, }}, nil) - db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return(nil, nil) keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) require.NoError(t, err) diff --git a/coderd/x/chatd/configcache_test.go b/coderd/x/chatd/configcache_test.go index 0079892126016..ee0855f2f81e3 100644 --- a/coderd/x/chatd/configcache_test.go +++ b/coderd/x/chatd/configcache_test.go @@ -22,7 +22,7 @@ import ( type stubChatConfigStore struct { database.Store - getEnabledChatProviders func(context.Context) ([]database.AIProvider, error) + getAIProviders func(context.Context) ([]database.AIProvider, error) getChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error) getDefaultChatModelConfig func(context.Context) (database.ChatModelConfig, error) getUserChatCustomPrompt func(context.Context, uuid.UUID) (string, error) @@ -37,10 +37,10 @@ type stubChatConfigStore struct { func (s *stubChatConfigStore) GetAIProviders(ctx context.Context, _ database.GetAIProvidersParams) ([]database.AIProvider, error) { s.enabledProvidersCalls.Add(1) - if s.getEnabledChatProviders == nil { + if s.getAIProviders == nil { panic("unexpected GetAIProviders call") } - return s.getEnabledChatProviders(ctx) + return s.getAIProviders(ctx) } func (s *stubChatConfigStore) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (database.ChatModelConfig, error) { @@ -80,9 +80,9 @@ func TestConfigCache_EnabledProviders_CacheHit(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) - providers := []database.AIProvider{testChatProvider("provider-a")} + providers := []database.AIProvider{testAIProvider("provider-a")} store := &stubChatConfigStore{ - getEnabledChatProviders: func(context.Context) ([]database.AIProvider, error) { + getAIProviders: func(context.Context) ([]database.AIProvider, error) { return providers, nil }, } @@ -104,9 +104,9 @@ func TestConfigCache_EnabledProviders_TTLExpiry(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() - return []database.AIProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil + return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil } cache := newChatConfigCache(ctx, store, clock) @@ -126,9 +126,9 @@ func TestConfigCache_EnabledProviders_Invalidation(t *testing.T) { ctx := testutil.Context(t, testutil.WaitShort) clock := quartz.NewMock(t) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() - return []database.AIProvider{testChatProvider(fmt.Sprintf("provider-%d", call))}, nil + return []database.AIProvider{testAIProvider(fmt.Sprintf("provider-%d", call))}, nil } cache := newChatConfigCache(ctx, store, clock) @@ -398,12 +398,12 @@ func TestConfigCache_Singleflight(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - providers := []database.AIProvider{testChatProvider("provider-a")} + providers := []database.AIProvider{testAIProvider("provider-a")} fetchStarted := make(chan struct{}) releaseFetch := make(chan struct{}) var startedOnce sync.Once store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { startedOnce.Do(func() { close(fetchStarted) }) <-releaseFetch return providers, nil @@ -441,13 +441,13 @@ func TestConfigCache_GenerationPreventsStaleWrite(t *testing.T) { ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - firstProviders := []database.AIProvider{testChatProvider("provider-a")} - secondProviders := []database.AIProvider{testChatProvider("provider-b")} + firstProviders := []database.AIProvider{testAIProvider("provider-a")} + secondProviders := []database.AIProvider{testAIProvider("provider-b")} fetchStarted := make(chan struct{}) releaseFetch := make(chan struct{}) var startedOnce sync.Once store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { call := store.enabledProvidersCalls.Load() if call == 1 { startedOnce.Do(func() { close(fetchStarted) }) @@ -494,14 +494,14 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightProviders(t *testing ctx := testutil.Context(t, testutil.WaitMedium) clock := quartz.NewMock(t) - staleProviders := []database.AIProvider{testChatProvider("provider-stale")} - freshProviders := []database.AIProvider{testChatProvider("provider-fresh")} + staleProviders := []database.AIProvider{testAIProvider("provider-stale")} + freshProviders := []database.AIProvider{testAIProvider("provider-fresh")} firstStarted := make(chan struct{}) secondStarted := make(chan struct{}) releaseFirst := make(chan struct{}) releaseSecond := make(chan struct{}) store := &stubChatConfigStore{} - store.getEnabledChatProviders = func(context.Context) ([]database.AIProvider, error) { + store.getAIProviders = func(context.Context) ([]database.AIProvider, error) { switch call := store.enabledProvidersCalls.Load(); call { case 1: close(firstStarted) @@ -670,7 +670,7 @@ func TestConfigCache_InvalidateProviders_BlocksStaleInFlightModelConfig(t *testi require.Equal(t, int32(2), store.modelConfigByIDCalls.Load()) } -func testChatProvider(name string) database.AIProvider { +func testAIProvider(name string) database.AIProvider { return database.AIProvider{ ID: uuid.New(), Type: database.AIProviderType(name), @@ -738,19 +738,19 @@ func TestConfigCache_CallerCancellation(t *testing.T) { name: "EnabledProviders", setupBlocked: func(store *stubChatConfigStore, started, release chan struct{}) { var once sync.Once - store.getEnabledChatProviders = func(ctx context.Context) ([]database.AIProvider, error) { + store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) { once.Do(func() { close(started) }) select { case <-ctx.Done(): return nil, ctx.Err() case <-release: - return []database.AIProvider{testChatProvider("p")}, nil + return []database.AIProvider{testAIProvider("p")}, nil } } }, setupCtxSensitive: func(store *stubChatConfigStore, started chan struct{}) { var once sync.Once - store.getEnabledChatProviders = func(ctx context.Context) ([]database.AIProvider, error) { + store.getAIProviders = func(ctx context.Context) ([]database.AIProvider, error) { once.Do(func() { close(started) }) <-ctx.Done() return nil, ctx.Err() diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index b3f130c0ed90f..637a3a6d3f635 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -380,8 +380,6 @@ DELETE FROM external_auth_links OR oauth_refresh_token_key_id IS NOT NULL; DELETE FROM user_ai_provider_keys WHERE api_key_key_id IS NOT NULL; -DELETE FROM user_ai_provider_keys - WHERE api_key_key_id IS NOT NULL; DELETE FROM user_secrets WHERE value_key_id IS NOT NULL; UPDATE ai_providers diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index 2f08ce1e44bc4..aa842abde2397 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1562,114 +1562,6 @@ func TestMCPServerUserTokens(t *testing.T) { }) } -func TestUserAIProviderKeys(t *testing.T) { - t.Parallel() - ctx := context.Background() - - const ( - //nolint:gosec // test credentials - initialAPIKey = "sk-initial-api-key-value" - //nolint:gosec // test credentials - updatedAPIKey = "sk-updated-api-key-value" - ) - - insertProviderAndKey := func( - t *testing.T, - crypt *dbCrypt, - ciphers []Cipher, - ) (database.AIProvider, database.UserAiProviderKey) { - t.Helper() - user := dbgen.User(t, crypt, database.User{}) - provider := dbgen.AIProvider(t, crypt, database.AIProvider{ - Type: database.AiProviderTypeOpenai, - Name: "test-openai-" + uuid.NewString(), - }) - - key, err := crypt.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ - ID: uuid.New(), - UserID: user.ID, - AiProviderID: provider.ID, - APIKey: initialAPIKey, - }) - require.NoError(t, err) - require.Equal(t, initialAPIKey, key.APIKey) - require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String) - return provider, key - } - - getUserChatProviderKey := func(t *testing.T, store interface { - GetUserAIProviderKeysByUserID(context.Context, uuid.UUID) ([]database.UserAiProviderKey, error) - }, userID uuid.UUID, providerID uuid.UUID, - ) database.UserAiProviderKey { - t.Helper() - keys, err := store.GetUserAIProviderKeysByUserID(ctx, userID) - require.NoError(t, err) - require.Len(t, keys, 1) - require.Equal(t, providerID, keys[0].AiProviderID) - return keys[0] - } - - t.Run("UpsertUserAIProviderKeyCreatesValue", func(t *testing.T) { - t.Parallel() - db, crypt, ciphers := setup(t) - provider, key := insertProviderAndKey(t, crypt, ciphers) - - got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID) - require.Equal(t, key.ID, got.ID) - require.Equal(t, initialAPIKey, got.APIKey) - require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String) - - rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID) - require.NotEqual(t, initialAPIKey, rawKey.APIKey) - requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey) - }) - - t.Run("GetUserAIProviderKeysByUserID", func(t *testing.T) { - t.Parallel() - _, crypt, ciphers := setup(t) - _, key := insertProviderAndKey(t, crypt, ciphers) - - keys, err := crypt.GetUserAIProviderKeysByUserID(ctx, key.UserID) - require.NoError(t, err) - require.Len(t, keys, 1) - require.Equal(t, key.ID, keys[0].ID) - require.Equal(t, initialAPIKey, keys[0].APIKey) - require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String) - }) - - t.Run("UpsertUserAIProviderKeyUpdatesValue", func(t *testing.T) { - t.Parallel() - db, crypt, ciphers := setup(t) - provider, key := insertProviderAndKey(t, crypt, ciphers) - - updated, err := crypt.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ - ID: uuid.New(), - UserID: key.UserID, - AiProviderID: provider.ID, - APIKey: updatedAPIKey, - }) - require.NoError(t, err) - require.Equal(t, key.ID, updated.ID) - require.Equal(t, key.CreatedAt, updated.CreatedAt) - require.False(t, updated.UpdatedAt.Before(key.UpdatedAt)) - require.Equal(t, updatedAPIKey, updated.APIKey) - require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String) - - got := getUserChatProviderKey(t, crypt, key.UserID, provider.ID) - require.Equal(t, updatedAPIKey, got.APIKey) - require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String) - - keys, err := crypt.GetUserAIProviderKeysByUserID(ctx, key.UserID) - require.NoError(t, err) - require.Len(t, keys, 1) - require.Equal(t, updatedAPIKey, keys[0].APIKey) - - rawKey := getUserChatProviderKey(t, db, key.UserID, provider.ID) - require.NotEqual(t, updatedAPIKey, rawKey.APIKey) - requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey) - }) -} - func TestUserSecrets(t *testing.T) { t.Parallel() ctx := context.Background() From 3cc60f19bedd13866ea535fbb71890295aa88855 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 21:53:47 +0000 Subject: [PATCH 4/8] fix(coderd): preserve AI provider model locks after cleanup --- coderd/exp_chats.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 740370163ff70..1dc3df9e43b75 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -7115,7 +7115,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { } var inserted database.ChatModelConfig - err := api.Database.InTx(func(tx database.Store) error { + err = api.Database.InTx(func(tx database.Store) error { if insertParams.AiProviderID.Valid { aiProvider, err := tx.GetAIProviderByIDForUpdate(ctx, insertParams.AiProviderID.UUID) if err != nil { @@ -7128,8 +7128,6 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) { return errChatProviderNotConfigured } insertParams.Provider = string(aiProvider.Type) - } else if err := requireChatProviderForModelConfig(ctx, tx, insertParams.Provider); err != nil { - return err } insertAsDefault := isDefault @@ -7334,8 +7332,6 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) { return errChatProviderNotConfigured } updateParams.Provider = string(aiProvider.Type) - } else if err := requireChatProviderForModelConfig(ctx, tx, updateParams.Provider); err != nil { - return err } setAsDefault := updateParams.IsDefault && !existing.IsDefault From 53861ac7ab8e6777856f36f46f9dfc5d4999017d Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 22:46:01 +0000 Subject: [PATCH 5/8] fix(coderd): decrypt batched AI provider keys --- coderd/x/chatd/chatd_internal_test.go | 10 ++++++---- coderd/x/chatd/subagent.go | 10 +++++----- enterprise/dbcrypt/dbcrypt.go | 13 +++++++++++++ enterprise/dbcrypt/dbcrypt_internal_test.go | 12 ++++++++++++ 4 files changed, 36 insertions(+), 9 deletions(-) diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 97d70a1b50b3a..a6b9c059955e5 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -792,13 +792,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { } db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) + providerID := uuid.New() db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ - ID: uuid.New(), + ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true, BaseUrl: serverURL, }}, nil) - db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), gomock.Any()).Return([]database.AIProviderKey{{APIKey: "test-key"}}, nil).AnyTimes() + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return([]database.AIProviderKey{{ProviderID: providerID, APIKey: "test-key"}}, nil) db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) db.EXPECT().GetChatMessagesByChatIDAscPaginated( gomock.Any(), @@ -957,13 +958,14 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts_IdleChatReleasesManualLock(t } db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) + providerID := uuid.New() db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return([]database.AIProvider{{ - ID: uuid.New(), + ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true, BaseUrl: serverURL, }}, nil) - db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), gomock.Any()).Return([]database.AIProviderKey{{APIKey: "test-key"}}, nil).AnyTimes() + db.EXPECT().GetAIProviderKeysByProviderIDs(gomock.Any(), []uuid.UUID{providerID}).Return([]database.AIProviderKey{{ProviderID: providerID, APIKey: "test-key"}}, nil) db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) db.EXPECT().GetChatMessagesByChatIDAscPaginated( gomock.Any(), diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index b786c9125425d..9e45c4d693c0c 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -93,11 +93,11 @@ func (p *Server) providerConfigured(ctx context.Context, provider string) (bool, if err != nil { return false, xerrors.Errorf("list enabled AI providers: %w", err) } - for _, prov := range dbProviders { - configuredProvider, err := p.aiProviderConfig(ctx, prov) - if err != nil { - return false, err - } + configuredProviders, err := p.aiProviderConfigs(ctx, dbProviders) + if err != nil { + return false, err + } + for _, configuredProvider := range configuredProviders { if chatprovider.NormalizeProvider(configuredProvider.Provider) == normalizedProvider && strings.TrimSpace(configuredProvider.APIKey) != "" { return true, nil diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index 46ee281c6b3e1..007544b3e0737 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -521,6 +521,19 @@ func (db *dbCrypt) GetAIProviderKeysByProviderID(ctx context.Context, providerID return keys, nil } +func (db *dbCrypt) GetAIProviderKeysByProviderIDs(ctx context.Context, providerIDs []uuid.UUID) ([]database.AIProviderKey, error) { + keys, err := db.Store.GetAIProviderKeysByProviderIDs(ctx, providerIDs) + if err != nil { + return nil, err + } + for i := range keys { + if err := db.decryptAIProviderKey(&keys[i]); err != nil { + return nil, err + } + } + return keys, nil +} + func (db *dbCrypt) InsertAIProviderKey(ctx context.Context, params database.InsertAIProviderKeyParams) (database.AIProviderKey, error) { if strings.TrimSpace(params.APIKey) == "" { params.ApiKeyKeyID = sql.NullString{} diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index aa842abde2397..e69565ca5f45e 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1283,6 +1283,18 @@ func TestAIProviderKeys(t *testing.T) { requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey) }) + t.Run("GetAIProviderKeysByProviderIDs", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + keys, err := crypt.GetAIProviderKeysByProviderIDs(ctx, []uuid.UUID{provider.ID}) + require.NoError(t, err) + require.Len(t, keys, 1) + requireAIProviderKeyDecrypted(t, keys[0], ciphers, apiKey) + requireAIProviderKeyRawEncrypted(ctx, t, db, key.ID, ciphers, apiKey) + }) + t.Run("DeleteAIProviderKey", func(t *testing.T) { t.Parallel() db, crypt, ciphers := setup(t) From a5c21c202dab44c42f4cc0074d7db03675ebf7d5 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 23:10:55 +0000 Subject: [PATCH 6/8] fix(coderd/x/chatd): resolve computer use provider keys --- coderd/x/chatd/chatd.go | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 569d3733056f6..e4e8699f2f378 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -7359,6 +7359,14 @@ func (p *Server) runChat( }, } + if isComputerUse { + computerUseProviderKeys, keyErr := p.resolveUserProviderAPIKeys(ctx, chat.OwnerID, uuid.Nil) + if keyErr != nil { + return result, xerrors.Errorf("resolve computer use provider API keys: %w", keyErr) + } + providerKeys = computerUseProviderKeys + } + if isComputerUse { // Override model for computer use subagent. cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := p.resolveComputerUseModel( From 24df794fbf73f058524049614860840adbc6aaa6 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 23:50:43 +0000 Subject: [PATCH 7/8] fix(coderd/x/chatd): snapshot title provider keys --- coderd/x/chatd/chatd.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index e4e8699f2f378..f8a5518edacaf 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -6796,10 +6796,11 @@ func (p *Server) runChat( // Fire title generation asynchronously so it doesn't block the // chat response. It uses a detached context so it can finish // even after the chat processing context is canceled. - // Snapshot model, logger, and ctx before launch; all three get - // reassigned below (model = cuModel, logger = logger.With(...), - // ctx = runCtx) and the goroutine captures by reference. + // Snapshot model, provider keys, logger, and ctx before launch; all four get + // reassigned below (model = cuModel, providerKeys = computerUseProviderKeys, + // logger = logger.With(...), ctx = runCtx) and the goroutine captures by reference. titleModel := model + titleProviderKeys := providerKeys titleLogger := logger titleCtx := context.WithoutCancel(ctx) p.inflight.Add(1) @@ -6812,7 +6813,7 @@ func (p *Server) runChat( modelConfig.Provider, modelConfig.Model, titleModel, - providerKeys, + titleProviderKeys, generatedTitle, titleLogger, debugSvc, From d43f23242325cb2d886ea18225c5763b3812469d Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sun, 17 May 2026 00:02:48 +0000 Subject: [PATCH 8/8] fix(coderd/x/chatd): simplify computer use key refresh --- coderd/x/chatd/chatd.go | 2 -- 1 file changed, 2 deletions(-) diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index f8a5518edacaf..f3c170b70f013 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -7366,9 +7366,7 @@ func (p *Server) runChat( return result, xerrors.Errorf("resolve computer use provider API keys: %w", keyErr) } providerKeys = computerUseProviderKeys - } - if isComputerUse { // Override model for computer use subagent. cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := p.resolveComputerUseModel( ctx,