From 641a866b2574d6238630a587afee0c97a9d4dd71 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 15:10:20 +0000 Subject: [PATCH 1/7] feat(codersdk): add AI provider chat APIs --- codersdk/chats.go | 238 +++++++++++++++++++++++++++++++++ site/src/api/typesGenerated.ts | 113 ++++++++++++++++ 2 files changed, 351 insertions(+) diff --git a/codersdk/chats.go b/codersdk/chats.go index 2baaf87e12724..6ed5b986bde12 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -1102,6 +1102,87 @@ type UpdateChatProviderConfigRequest struct { AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"` } +// AIProviderType identifies the provider implementation family. +type AIProviderType string + +const ( + AIProviderTypeAnthropic AIProviderType = "anthropic" + AIProviderTypeAzure AIProviderType = "azure" + AIProviderTypeBedrock AIProviderType = "bedrock" + AIProviderTypeGoogle AIProviderType = "google" + AIProviderTypeOpenAI AIProviderType = "openai" + AIProviderTypeOpenAICompat AIProviderType = "openai-compat" + AIProviderTypeOpenRouter AIProviderType = "openrouter" + AIProviderTypeVercel AIProviderType = "vercel" +) + +// AIProvider is an admin-managed provider configuration used by Agents. +type AIProvider struct { + ID uuid.UUID `json:"id" format:"uuid"` + Type AIProviderType `json:"type"` + Name string `json:"name"` + DisplayName string `json:"display_name"` + Enabled bool `json:"enabled"` + Deleted bool `json:"deleted"` + BaseURL string `json:"base_url"` + CreatedAt time.Time `json:"created_at" format:"date-time"` + UpdatedAt time.Time `json:"updated_at" format:"date-time"` +} + +// AIProviderSummary is provider metadata embedded in other API responses. +type AIProviderSummary struct { + ID uuid.UUID `json:"id" format:"uuid"` + Type AIProviderType `json:"type"` + Name string `json:"name"` + DisplayName string `json:"display_name"` + Enabled bool `json:"enabled"` + Deleted bool `json:"deleted"` +} + +// CreateAIProviderRequest creates an AI provider. +type CreateAIProviderRequest struct { + Type AIProviderType `json:"type"` + Name string `json:"name"` + DisplayName string `json:"display_name,omitempty"` + BaseURL string `json:"base_url,omitempty"` + Enabled *bool `json:"enabled,omitempty"` +} + +// UpdateAIProviderRequest updates an AI provider. +type UpdateAIProviderRequest struct { + Name *string `json:"name,omitempty"` + DisplayName *string `json:"display_name,omitempty"` + BaseURL *string `json:"base_url,omitempty"` + Enabled *bool `json:"enabled,omitempty"` +} + +// AIProviderKey is a provider-scoped key summary. API keys are never returned. +type AIProviderKey struct { + ID uuid.UUID `json:"id" format:"uuid"` + ProviderID uuid.UUID `json:"provider_id" format:"uuid"` + CreatedAt time.Time `json:"created_at" format:"date-time"` + UpdatedAt time.Time `json:"updated_at" format:"date-time"` +} + +// CreateAIProviderKeyRequest creates a provider-scoped AI provider key. +type CreateAIProviderKeyRequest struct { + APIKey string `json:"api_key"` +} + +// UserAIProviderKeyConfig is a provider summary from the current user's +// perspective. It reports key presence but never returns key material. +type UserAIProviderKeyConfig struct { + Provider AIProviderSummary `json:"provider"` + HasUserAPIKey bool `json:"has_user_api_key"` + BYOKEnabled bool `json:"byok_enabled"` +} + +// CreateUserAIProviderKeyRequest creates or replaces a user's API key +// for an AI provider. +type CreateUserAIProviderKeyRequest struct { + APIKey string `json:"api_key"` +} + // UserChatProviderConfig is a summary of a provider that allows // user-supplied keys, as seen from the current user's perspective. type UserChatProviderConfig struct { @@ -2071,6 +2152,163 @@ func (c *ExperimentalClient) DeleteChatProvider(ctx context.Context, providerID return nil } +// ListAIProviders returns admin-managed AI providers. +func (c *ExperimentalClient) ListAIProviders(ctx context.Context) ([]AIProvider, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/ai-providers", nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, ReadBodyAsError(res) + } + + var providers []AIProvider + return providers, json.NewDecoder(res.Body).Decode(&providers) +} + +// CreateAIProvider creates an admin-managed AI provider. +func (c *ExperimentalClient) CreateAIProvider(ctx context.Context, req CreateAIProviderRequest) (AIProvider, error) { + res, err := c.Request(ctx, http.MethodPost, "/api/experimental/chats/ai-providers", req) + if err != nil { + return AIProvider{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return AIProvider{}, ReadBodyAsError(res) + } + + var provider AIProvider + return provider, json.NewDecoder(res.Body).Decode(&provider) +} + +// GetAIProvider returns an admin-managed AI provider. +func (c *ExperimentalClient) GetAIProvider(ctx context.Context, providerID uuid.UUID) (AIProvider, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/ai-providers/%s", providerID), nil) + if err != nil { + return AIProvider{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AIProvider{}, ReadBodyAsError(res) + } + + var provider AIProvider + return provider, json.NewDecoder(res.Body).Decode(&provider) +} + +// UpdateAIProvider updates an admin-managed AI provider. +func (c *ExperimentalClient) UpdateAIProvider(ctx context.Context, providerID uuid.UUID, req UpdateAIProviderRequest) (AIProvider, error) { + res, err := c.Request(ctx, http.MethodPatch, fmt.Sprintf("/api/experimental/chats/ai-providers/%s", providerID), req) + if err != nil { + return AIProvider{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AIProvider{}, ReadBodyAsError(res) + } + + var provider AIProvider + return provider, json.NewDecoder(res.Body).Decode(&provider) +} + +// DeleteAIProvider deletes an admin-managed AI provider. +func (c *ExperimentalClient) DeleteAIProvider(ctx context.Context, providerID uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/ai-providers/%s", providerID), nil) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + +// ListAIProviderKeys returns provider-scoped AI provider key summaries. +func (c *ExperimentalClient) ListAIProviderKeys(ctx context.Context, providerID uuid.UUID) ([]AIProviderKey, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/ai-providers/%s/keys", providerID), nil) + if err != nil { + return nil, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, ReadBodyAsError(res) + } + + var keys []AIProviderKey + return keys, json.NewDecoder(res.Body).Decode(&keys) +} + +// CreateAIProviderKey creates a provider-scoped AI provider key. +func (c *ExperimentalClient) CreateAIProviderKey(ctx context.Context, providerID uuid.UUID, req CreateAIProviderKeyRequest) (AIProviderKey, error) { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/ai-providers/%s/keys", providerID), req) + if err != nil { + return AIProviderKey{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusCreated { + return AIProviderKey{}, ReadBodyAsError(res) + } + + var key AIProviderKey + return key, json.NewDecoder(res.Body).Decode(&key) +} + +// DeleteAIProviderKey deletes a provider-scoped AI provider key. +func (c *ExperimentalClient) DeleteAIProviderKey(ctx context.Context, providerID uuid.UUID, keyID uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/ai-providers/%s/keys/%s", providerID, keyID), nil) + if err != nil { + return err + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + +// ListUserAIProviderKeyConfigs returns user-scoped AI provider key configs. +func (c *ExperimentalClient) ListUserAIProviderKeyConfigs(ctx context.Context) ([]UserAIProviderKeyConfig, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/user-ai-provider-keys", nil) + if err != nil { + return nil, xerrors.Errorf("list user AI provider key configs: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return nil, ReadBodyAsError(res) + } + var configs []UserAIProviderKeyConfig + return configs, json.NewDecoder(res.Body).Decode(&configs) +} + +// UpsertUserAIProviderKey creates or replaces a user API key for an AI provider. +func (c *ExperimentalClient) UpsertUserAIProviderKey(ctx context.Context, providerID uuid.UUID, req CreateUserAIProviderKeyRequest) (UserAIProviderKeyConfig, error) { + res, err := c.Request(ctx, http.MethodPut, fmt.Sprintf("/api/experimental/chats/user-ai-provider-keys/%s", providerID), req) + if err != nil { + return UserAIProviderKeyConfig{}, xerrors.Errorf("upsert user AI provider key: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return UserAIProviderKeyConfig{}, ReadBodyAsError(res) + } + var config UserAIProviderKeyConfig + return config, json.NewDecoder(res.Body).Decode(&config) +} + +// DeleteUserAIProviderKey deletes a user API key for an AI provider. +func (c *ExperimentalClient) DeleteUserAIProviderKey(ctx context.Context, providerID uuid.UUID) error { + res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("/api/experimental/chats/user-ai-provider-keys/%s", providerID), nil) + if err != nil { + return xerrors.Errorf("delete user AI provider key: %w", err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusNoContent { + return ReadBodyAsError(res) + } + return nil +} + // ListUserChatProviderConfigs returns user-scoped chat provider configs. func (c *ExperimentalClient) ListUserChatProviderConfigs(ctx context.Context) ([]UserChatProviderConfig, error) { res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/user-provider-configs", nil) diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 9b033b058e83f..644e5b6d2dd09 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -298,6 +298,22 @@ export interface AIConfig { readonly chat?: ChatConfig; } +// From codersdk/chats.go +/** + * AIProvider is an admin-managed provider configuration used by Agents. + */ +export interface AIProvider { + readonly id: string; + readonly type: AIProviderType; + readonly name: string; + readonly display_name: string; + readonly enabled: boolean; + readonly deleted: boolean; + readonly base_url: string; + readonly created_at: string; + readonly updated_at: string; +} + // From codersdk/deployment.go /** * AIProviderConfig represents a single AI provider instance, @@ -327,6 +343,52 @@ export interface AIProviderConfig { readonly bedrock_small_fast_model?: string; } +// From codersdk/chats.go +/** + * AIProviderKey is a provider-scoped key summary. API keys are never returned. + */ +export interface AIProviderKey { + readonly id: string; + readonly provider_id: string; + readonly created_at: string; + readonly updated_at: string; +} + +// From codersdk/chats.go +/** + * AIProviderSummary is provider metadata embedded in other API responses. + */ +export interface AIProviderSummary { + readonly id: string; + readonly type: AIProviderType; + readonly name: string; + readonly display_name: string; + readonly enabled: boolean; + readonly deleted: boolean; +} + +// From codersdk/chats.go +export type AIProviderType = + | "anthropic" + | "azure" + | "bedrock" + | "google" + | "openai" + | "openai-compat" + | "openrouter" + | "vercel"; + +export const AIProviderTypes: AIProviderType[] = [ + "anthropic", + "azure", + "bedrock", + "google", + "openai", + "openai-compat", + "openrouter", + "vercel", +]; + // From codersdk/allowlist.go /** * APIAllowListTarget represents a single allow-list entry using the canonical @@ -2978,6 +3040,26 @@ export interface ConvertLoginRequest { readonly password: string; } +// From codersdk/chats.go +/** + * CreateAIProviderKeyRequest creates a provider-scoped AI provider key. + */ +export interface CreateAIProviderKeyRequest { + readonly api_key: string; +} + +// From codersdk/chats.go +/** + * CreateAIProviderRequest creates an AI provider. + */ +export interface CreateAIProviderRequest { + readonly type: AIProviderType; + readonly name: string; + readonly display_name?: string; + readonly base_url?: string; + readonly enabled?: boolean; +} + // From codersdk/chats.go /** * CreateChatMessageRequest is the request to add a message to a chat. @@ -3341,6 +3423,15 @@ export interface CreateTokenRequest { readonly allow_list?: readonly APIAllowListTarget[]; } +// From codersdk/chats.go +/** + * CreateUserAIProviderKeyRequest creates or replaces a user's API key + * for an AI provider. + */ +export interface CreateUserAIProviderKeyRequest { + readonly api_key: string; +} + // From codersdk/chats.go /** * CreateUserChatProviderKeyRequest creates or replaces a user's API key @@ -8044,6 +8135,17 @@ export interface TransitionStats { readonly P95: number | null; } +// From codersdk/chats.go +/** + * UpdateAIProviderRequest updates an AI provider. + */ +export interface UpdateAIProviderRequest { + readonly name?: string; + readonly display_name?: string; + readonly base_url?: string; + readonly enabled?: boolean; +} + // From codersdk/templates.go export interface UpdateActiveTemplateVersion { readonly id: string; @@ -8750,6 +8852,17 @@ export interface User extends ReducedUser { readonly has_ai_seat: boolean; } +// From codersdk/chats.go +/** + * UserAIProviderKeyConfig is a provider summary from the current user's + * perspective. It reports key presence but never returns key material. + */ +export interface UserAIProviderKeyConfig { + readonly provider: AIProviderSummary; + readonly has_user_api_key: boolean; + readonly byok_enabled: boolean; +} + // From codersdk/insights.go /** * UserActivity shows the session time for a user. From 18955924f5174823112b87db40d3ac7a40a7d66a Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 15:21:00 +0000 Subject: [PATCH 2/7] feat(coderd): add AI provider chat routes --- coderd/coderd.go | 21 + coderd/database/dbauthz/dbauthz_test.go | 1 + coderd/database/queries.sql.go | 15 +- coderd/database/queries/ai_providers.sql | 1 + coderd/exp_chats.go | 437 ++++++++++++++++++++ enterprise/dbcrypt/dbcrypt_internal_test.go | 2 + 6 files changed, 471 insertions(+), 6 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index dd67488264bf3..1ee5fef0f3b94 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1250,6 +1250,20 @@ func New(options *Options) *API { r.Delete("/", api.deleteChatProvider) }) }) + r.Route("/ai-providers", func(r chi.Router) { + r.Get("/", api.listAIProviders) + r.Post("/", api.createAIProvider) + r.Route("/{aiProvider}", func(r chi.Router) { + r.Get("/", api.readAIProvider) + r.Patch("/", api.updateAIProvider) + r.Delete("/", api.deleteAIProvider) + r.Route("/keys", func(r chi.Router) { + r.Get("/", api.listAIProviderKeys) + r.Post("/", api.createAIProviderKey) + r.Delete("/{aiProviderKey}", api.deleteAIProviderKey) + }) + }) + }) // TODO(cian): place under /api/experimental/chats/config r.Route("/model-configs", func(r chi.Router) { r.Get("/", api.listChatModelConfigs) @@ -1279,6 +1293,13 @@ func New(options *Options) *API { r.Delete("/", api.deleteUserChatProviderKey) }) }) + r.Route("/user-ai-provider-keys", func(r chi.Router) { + r.Get("/", api.listUserAIProviderKeyConfigs) + r.Route("/{aiProvider}", func(r chi.Router) { + r.Put("/", api.upsertUserAIProviderKey) + r.Delete("/", api.deleteUserAIProviderKey) + }) + }) r.Route("/{chat}", func(r chi.Router) { r.Use(httpmw.ExtractChatParam(options.Database)) r.Get("/", api.getChat) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index b1ac192d0ab2f..f3cd456a32db7 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -6305,6 +6305,7 @@ func (s *MethodTestSuite) TestAIBridge() { provider := testutil.Fake(s.T(), faker, database.AIProvider{}) arg := database.UpdateAIProviderParams{ ID: provider.ID, + Name: provider.Name, Enabled: true, BaseUrl: "https://api.example.com/", } diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 8386edf8f1c33..0cd7d478ebf84 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -515,19 +515,21 @@ const updateAIProvider = `-- name: UpdateAIProvider :one UPDATE ai_providers SET - display_name = $1::text, - enabled = $2::boolean, - base_url = $3::text, - settings = $4::text, - settings_key_id = $5::text, + name = $1::text, + display_name = $2::text, + enabled = $3::boolean, + base_url = $4::text, + settings = $5::text, + settings_key_id = $6::text, updated_at = NOW() WHERE - id = $6::uuid AND deleted = FALSE + id = $7::uuid AND deleted = FALSE RETURNING id, type, name, display_name, enabled, deleted, base_url, settings, settings_key_id, created_at, updated_at ` type UpdateAIProviderParams struct { + Name string `db:"name" json:"name"` DisplayName sql.NullString `db:"display_name" json:"display_name"` Enabled bool `db:"enabled" json:"enabled"` BaseUrl string `db:"base_url" json:"base_url"` @@ -538,6 +540,7 @@ type UpdateAIProviderParams struct { func (q *sqlQuerier) UpdateAIProvider(ctx context.Context, arg UpdateAIProviderParams) (AIProvider, error) { row := q.db.QueryRowContext(ctx, updateAIProvider, + arg.Name, arg.DisplayName, arg.Enabled, arg.BaseUrl, diff --git a/coderd/database/queries/ai_providers.sql b/coderd/database/queries/ai_providers.sql index 9c2302861ef74..41a8dfd717cd6 100644 --- a/coderd/database/queries/ai_providers.sql +++ b/coderd/database/queries/ai_providers.sql @@ -54,6 +54,7 @@ RETURNING UPDATE ai_providers SET + name = @name::text, display_name = sqlc.narg('display_name')::text, enabled = @enabled::boolean, base_url = @base_url::text, diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index ea27c3c2ea581..838b7e06db4d4 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -6370,6 +6370,443 @@ func convertChatMessages(messages []database.ChatMessage) []codersdk.ChatMessage return result } +func parseAIProviderID(r *http.Request) (uuid.UUID, error) { + return uuid.Parse(chi.URLParam(r, "aiProvider")) +} + +func parseAIProviderKeyID(r *http.Request) (uuid.UUID, error) { + return uuid.Parse(chi.URLParam(r, "aiProviderKey")) +} + +func aiProviderDisplayName(displayName sql.NullString, name string) string { + if displayName.Valid { + return displayName.String + } + return name +} + +func convertAIProvider(provider database.AIProvider) codersdk.AIProvider { + return codersdk.AIProvider{ + ID: provider.ID, + Type: codersdk.AIProviderType(provider.Type), + Name: provider.Name, + DisplayName: aiProviderDisplayName(provider.DisplayName, provider.Name), + Enabled: provider.Enabled, + Deleted: provider.Deleted, + BaseURL: provider.BaseUrl, + CreatedAt: provider.CreatedAt, + UpdatedAt: provider.UpdatedAt, + } +} + +func convertAIProviderSummary(provider database.AIProvider) codersdk.AIProviderSummary { + return codersdk.AIProviderSummary{ + ID: provider.ID, + Type: codersdk.AIProviderType(provider.Type), + Name: provider.Name, + DisplayName: aiProviderDisplayName(provider.DisplayName, provider.Name), + Enabled: provider.Enabled, + Deleted: provider.Deleted, + } +} + +func convertAIProviderKey(key database.AIProviderKey) codersdk.AIProviderKey { + return codersdk.AIProviderKey{ + ID: key.ID, + ProviderID: key.ProviderID, + CreatedAt: key.CreatedAt, + UpdatedAt: key.UpdatedAt, + } +} + +func validAIProviderName(name string) bool { + if name == "" || strings.HasPrefix(name, "agents-") { + return false + } + lastHyphen := true + for _, r := range name { + switch { + case r >= 'a' && r <= 'z': + lastHyphen = false + case r >= '0' && r <= '9': + lastHyphen = false + case r == '-': + if lastHyphen { + return false + } + lastHyphen = true + default: + return false + } + } + return !lastHyphen +} + +func writeInvalidAIProviderName(ctx context.Context, rw http.ResponseWriter) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid AI provider name.", + Detail: "Names must use lowercase letters, numbers, and single hyphens, and cannot use the reserved agents- prefix.", + }) +} + +func (api *API) listAIProviders(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionRead, rbac.ResourceAIProvider) { + httpapi.Forbidden(rw) + return + } + providers, err := api.Database.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI providers."}) + return + } + resp := make([]codersdk.AIProvider, 0, len(providers)) + for _, provider := range providers { + resp = append(resp, convertAIProvider(provider)) + } + httpapi.Write(ctx, rw, http.StatusOK, resp) +} + +func (api *API) createAIProvider(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionCreate, rbac.ResourceAIProvider) { + httpapi.Forbidden(rw) + return + } + var req codersdk.CreateAIProviderRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + if !validAIProviderName(req.Name) { + writeInvalidAIProviderName(ctx, rw) + return + } + typ := database.AIProviderType(req.Type) + if !typ.Valid() { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider type."}) + return + } + enabled := true + if req.Enabled != nil { + enabled = *req.Enabled + } + provider, err := api.Database.InsertAIProvider(ctx, database.InsertAIProviderParams{ + ID: uuid.New(), + Type: typ, + Name: req.Name, + DisplayName: sql.NullString{String: req.DisplayName, Valid: req.DisplayName != ""}, + Enabled: enabled, + BaseUrl: req.BaseURL, + Settings: sql.NullString{}, + SettingsKeyID: sql.NullString{}, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to create AI provider."}) + return + } + publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) + httpapi.Write(ctx, rw, http.StatusCreated, convertAIProvider(provider)) +} + +func (api *API) readAIProvider(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionRead, rbac.ResourceAIProvider) { + httpapi.Forbidden(rw) + return + } + providerID, err := parseAIProviderID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) + return + } + provider, err := api.Database.GetAIProviderByID(ctx, providerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) + return + } + httpapi.Write(ctx, rw, http.StatusOK, convertAIProvider(provider)) +} + +func (api *API) updateAIProvider(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionUpdate, rbac.ResourceAIProvider) { + httpapi.Forbidden(rw) + return + } + providerID, err := parseAIProviderID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) + return + } + current, err := api.Database.GetAIProviderByID(ctx, providerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) + return + } + var req codersdk.UpdateAIProviderRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + name := current.Name + if req.Name != nil { + name = *req.Name + if !validAIProviderName(name) { + writeInvalidAIProviderName(ctx, rw) + return + } + } + displayName := current.DisplayName + if req.DisplayName != nil { + displayName = sql.NullString{String: *req.DisplayName, Valid: *req.DisplayName != ""} + } + baseURL := current.BaseUrl + if req.BaseURL != nil { + baseURL = *req.BaseURL + } + enabled := current.Enabled + if req.Enabled != nil { + enabled = *req.Enabled + } + provider, err := api.Database.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ + ID: providerID, + Name: name, + DisplayName: displayName, + Enabled: enabled, + BaseUrl: baseURL, + Settings: current.Settings, + SettingsKeyID: current.SettingsKeyID, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update AI provider."}) + return + } + publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) + httpapi.Write(ctx, rw, http.StatusOK, convertAIProvider(provider)) +} + +func (api *API) deleteAIProvider(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionDelete, rbac.ResourceAIProvider) { + httpapi.Forbidden(rw) + return + } + providerID, err := parseAIProviderID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) + return + } + if err := api.Database.DeleteAIProviderByID(ctx, providerID); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete AI provider."}) + return + } + publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) + httpapi.Write(ctx, rw, http.StatusNoContent, nil) +} + +func (api *API) listAIProviderKeys(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionRead, rbac.ResourceAIProvider) { + httpapi.Forbidden(rw) + return + } + providerID, err := parseAIProviderID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) + return + } + keys, err := api.Database.GetAIProviderKeysByProviderID(ctx, providerID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."}) + return + } + resp := make([]codersdk.AIProviderKey, 0, len(keys)) + for _, key := range keys { + resp = append(resp, convertAIProviderKey(key)) + } + httpapi.Write(ctx, rw, http.StatusOK, resp) +} + +func (api *API) createAIProviderKey(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionCreate, rbac.ResourceAIProvider) { + httpapi.Forbidden(rw) + return + } + providerID, err := parseAIProviderID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) + return + } + var req codersdk.CreateAIProviderKeyRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + if strings.TrimSpace(req.APIKey) == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."}) + return + } + now := time.Now() + key, err := api.Database.InsertAIProviderKey(ctx, database.InsertAIProviderKeyParams{ + ID: uuid.New(), + ProviderID: providerID, + APIKey: req.APIKey, + ApiKeyKeyID: sql.NullString{}, + CreatedAt: now, + UpdatedAt: now, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to create AI provider key."}) + return + } + publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) + httpapi.Write(ctx, rw, http.StatusCreated, convertAIProviderKey(key)) +} + +func (api *API) deleteAIProviderKey(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionDelete, rbac.ResourceAIProvider) { + httpapi.Forbidden(rw) + return + } + providerID, err := parseAIProviderID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) + return + } + keyID, err := parseAIProviderKeyID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider key ID."}) + return + } + key, err := api.Database.GetAIProviderKeyByID(ctx, keyID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider key not found."}) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider key."}) + return + } + if key.ProviderID != providerID { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider key not found."}) + return + } + if err := api.Database.DeleteAIProviderKey(ctx, keyID); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete AI provider key."}) + return + } + publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) + httpapi.Write(ctx, rw, http.StatusNoContent, nil) +} + +func (api *API) listUserAIProviderKeyConfigs(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + //nolint:gocritic // Users can list limited enabled provider metadata without AI provider admin permissions. + providers, err := api.Database.GetAIProviders(dbauthz.AsSystemRestricted(ctx), database.GetAIProvidersParams{IncludeDisabled: true}) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI providers."}) + return + } + keys, err := api.Database.GetUserAIProviderKeysByUserID(ctx, apiKey.UserID) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list user AI provider keys."}) + return + } + keysByProviderID := make(map[uuid.UUID]struct{}, len(keys)) + for _, key := range keys { + keysByProviderID[key.AiProviderID] = struct{}{} + } + byokEnabled := api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() + configs := make([]codersdk.UserAIProviderKeyConfig, 0, len(providers)) + for _, provider := range providers { + _, hasKey := keysByProviderID[provider.ID] + if provider.Deleted || (!provider.Enabled && !hasKey) { + continue + } + configs = append(configs, codersdk.UserAIProviderKeyConfig{ + Provider: convertAIProviderSummary(provider), + HasUserAPIKey: hasKey, + BYOKEnabled: byokEnabled, + }) + } + httpapi.Write(ctx, rw, http.StatusOK, configs) +} + +func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() { + httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{Message: "BYOK is disabled."}) + return + } + apiKey := httpmw.APIKey(r) + providerID, err := parseAIProviderID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) + return + } + //nolint:gocritic // Users can attach their own key to an enabled provider without AI provider admin permissions. + provider, err := api.Database.GetAIProviderByID(dbauthz.AsSystemRestricted(ctx), providerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) + return + } + var req codersdk.CreateUserAIProviderKeyRequest + if !httpapi.Read(ctx, rw, r, &req) { + return + } + if strings.TrimSpace(req.APIKey) == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."}) + return + } + now := time.Now() + _, err = api.Database.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: apiKey.UserID, + AiProviderID: providerID, + APIKey: req.APIKey, + ApiKeyKeyID: sql.NullString{}, + CreatedAt: now, + UpdatedAt: now, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update user AI provider key."}) + return + } + httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserAIProviderKeyConfig{ + Provider: convertAIProviderSummary(provider), + HasUserAPIKey: true, + BYOKEnabled: true, + }) +} + +func (api *API) deleteUserAIProviderKey(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + providerID, err := parseAIProviderID(r) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) + return + } + if err := api.Database.DeleteUserAIProviderKey(ctx, database.DeleteUserAIProviderKeyParams{UserID: apiKey.UserID, AiProviderID: providerID}); err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete user AI provider key."}) + return + } + httpapi.Write(ctx, rw, http.StatusNoContent, nil) +} + func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) { ctx := r.Context() //nolint:gocritic // System context required to read enabled chat providers. diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index fea3a4eeb65bd..0ebf93e46df76 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1195,6 +1195,7 @@ func TestAIProviders(t *testing.T) { const newSettings = `{"_type":"bedrock","_version":1,"region":"us-east-1","model":"anthropic.claude-sonnet-4-5-20250929-v1:0","access_key":"AKIA-test","access_key_secret":"test-secret"}` updated, err := crypt.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ ID: provider.ID, + Name: provider.Name, DisplayName: provider.DisplayName, Enabled: provider.Enabled, BaseUrl: provider.BaseUrl, @@ -1211,6 +1212,7 @@ func TestAIProviders(t *testing.T) { provider := insertProvider(t, crypt, ciphers) updated, err := crypt.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ ID: provider.ID, + Name: provider.Name, DisplayName: provider.DisplayName, Enabled: provider.Enabled, BaseUrl: provider.BaseUrl, From e82b66fb9661e4ebf149c7372e6af2c42283ef12 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 18:31:40 +0000 Subject: [PATCH 3/7] test(coderd): cover AI provider chat APIs --- coderd/exp_chats.go | 31 +++++- coderd/exp_chats_test.go | 223 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 250 insertions(+), 4 deletions(-) diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 838b7e06db4d4..2620e4691d404 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -6501,7 +6501,14 @@ func (api *API) createAIProvider(rw http.ResponseWriter, r *http.Request) { SettingsKeyID: sql.NullString{}, }) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to create AI provider."}) + switch { + case database.IsUniqueViolation(err): + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{Message: "AI provider already exists."}) + case database.IsCheckViolation(err): + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider."}) + default: + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to create AI provider."}) + } return } publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) @@ -6585,7 +6592,14 @@ func (api *API) updateAIProvider(rw http.ResponseWriter, r *http.Request) { SettingsKeyID: current.SettingsKeyID, }) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update AI provider."}) + switch { + case database.IsUniqueViolation(err): + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{Message: "AI provider already exists."}) + case database.IsCheckViolation(err): + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider."}) + default: + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update AI provider."}) + } return } publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) @@ -6645,6 +6659,15 @@ func (api *API) createAIProviderKey(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) return } + _, err = api.Database.GetAIProviderByID(ctx, providerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) + return + } var req codersdk.CreateAIProviderKeyRequest if !httpapi.Read(ctx, rw, r, &req) { return @@ -6653,7 +6676,7 @@ func (api *API) createAIProviderKey(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."}) return } - now := time.Now() + now := api.Clock.Now() key, err := api.Database.InsertAIProviderKey(ctx, database.InsertAIProviderKeyParams{ ID: uuid.New(), ProviderID: providerID, @@ -6771,7 +6794,7 @@ func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."}) return } - now := time.Now() + now := api.Clock.Now() _, err = api.Database.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ ID: uuid.New(), UserID: apiKey.UserID, diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 77ad68acb6aeb..3a826ffe708cc 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -2076,6 +2076,229 @@ func TestWatchChats(t *testing.T) { }) } +func TestAIProviderCRUD(t *testing.T) { + t.Parallel() + + t.Run("AdminLifecycle", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + name := "test-openai-" + uuid.NewString() + + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: name, + DisplayName: "OpenAI Test", + BaseURL: "https://api.openai.example.com/v1", + }) + require.NoError(t, err) + require.NotEqual(t, uuid.Nil, provider.ID) + require.Equal(t, codersdk.AIProviderTypeOpenAI, provider.Type) + require.Equal(t, name, provider.Name) + require.Equal(t, "OpenAI Test", provider.DisplayName) + require.True(t, provider.Enabled) + + got, err := client.GetAIProvider(ctx, provider.ID) + require.NoError(t, err) + require.Equal(t, provider.ID, got.ID) + + providers, err := client.ListAIProviders(ctx) + require.NoError(t, err) + require.Contains(t, providers, provider) + + updatedName := "test-anthropic-" + uuid.NewString() + disabled := false + updated, err := client.UpdateAIProvider(ctx, provider.ID, codersdk.UpdateAIProviderRequest{ + Name: &updatedName, + DisplayName: ptr.Ref("Anthropic Test"), + BaseURL: ptr.Ref("https://api.anthropic.example.com/v1"), + Enabled: &disabled, + }) + require.NoError(t, err) + require.Equal(t, updatedName, updated.Name) + require.Equal(t, "Anthropic Test", updated.DisplayName) + require.False(t, updated.Enabled) + + key, err := client.CreateAIProviderKey(ctx, provider.ID, codersdk.CreateAIProviderKeyRequest{APIKey: "test-api-key"}) + require.NoError(t, err) + require.Equal(t, provider.ID, key.ProviderID) + + keys, err := client.ListAIProviderKeys(ctx, provider.ID) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, key.ID, keys[0].ID) + + require.NoError(t, client.DeleteAIProviderKey(ctx, provider.ID, key.ID)) + keys, err = client.ListAIProviderKeys(ctx, provider.ID) + require.NoError(t, err) + require.Empty(t, keys) + + require.NoError(t, client.DeleteAIProvider(ctx, provider.ID)) + _, err = client.GetAIProvider(ctx, provider.ID) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("CreateValidation", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + name := "test-openai-" + uuid.NewString() + + _, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: name, + }) + require.NoError(t, err) + + _, err = client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: name, + }) + sdkErr := requireSDKError(t, err, http.StatusConflict) + require.Equal(t, "AI provider already exists.", sdkErr.Message) + + _, err = client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "InvalidName", + }) + sdkErr = requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid AI provider name.", sdkErr.Message) + + _, err = client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderType("not-a-provider"), + Name: "test-invalid-" + uuid.NewString(), + }) + sdkErr = requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid AI provider type.", sdkErr.Message) + }) + + t.Run("CreateKeyForMissingProvider", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + _, err := client.CreateAIProviderKey(ctx, uuid.New(), codersdk.CreateAIProviderKeyRequest{APIKey: "test-api-key"}) + sdkErr := requireSDKError(t, err, http.StatusNotFound) + require.Equal(t, "AI provider not found.", sdkErr.Message) + }) + + t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + _, err := memberClient.ListAIProviders(ctx) + requireSDKError(t, err, http.StatusForbidden) + _, err = memberClient.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-member-" + uuid.NewString(), + }) + requireSDKError(t, err, http.StatusForbidden) + }) +} + +func TestUserAIProviderKeys(t *testing.T) { + t.Parallel() + + t.Run("SelfServiceLifecycle", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + provider, err := adminClient.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-user-key-" + uuid.NewString(), + }) + require.NoError(t, err) + + configs, err := memberClient.ListUserAIProviderKeyConfigs(ctx) + require.NoError(t, err) + var cfg *codersdk.UserAIProviderKeyConfig + for i := range configs { + if configs[i].Provider.ID == provider.ID { + cfg = &configs[i] + break + } + } + require.NotNil(t, cfg) + require.False(t, cfg.HasUserAPIKey) + require.True(t, cfg.BYOKEnabled) + + cfgValue, err := memberClient.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"}) + require.NoError(t, err) + require.Equal(t, provider.ID, cfgValue.Provider.ID) + require.True(t, cfgValue.HasUserAPIKey) + require.True(t, cfgValue.BYOKEnabled) + + configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx) + require.NoError(t, err) + cfg = nil + for i := range configs { + if configs[i].Provider.ID == provider.ID { + cfg = &configs[i] + break + } + } + require.NotNil(t, cfg) + require.True(t, cfg.HasUserAPIKey) + + require.NoError(t, memberClient.DeleteUserAIProviderKey(ctx, provider.ID)) + configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx) + require.NoError(t, err) + for _, listed := range configs { + if listed.Provider.ID == provider.ID { + require.False(t, listed.HasUserAPIKey) + return + } + } + t.Fatal("provider config not found") + }) + + t.Run("BYOKDisabledRejectsUpsertAndAllowsDelete", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + values := chatDeploymentValues(t) + values.AI.BridgeConfig.AllowBYOK = serpent.Bool(false) + client := newChatClientWithDeploymentValues(t, values) + _ = coderdtest.CreateFirstUser(t, client.Client) + + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-byok-disabled-" + uuid.NewString(), + }) + require.NoError(t, err) + + _, err = client.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"}) + sdkErr := requireSDKError(t, err, http.StatusForbidden) + require.Equal(t, "BYOK is disabled.", sdkErr.Message) + + configs, err := client.ListUserAIProviderKeyConfigs(ctx) + require.NoError(t, err) + for _, cfg := range configs { + if cfg.Provider.ID == provider.ID { + require.False(t, cfg.BYOKEnabled) + break + } + } + require.NoError(t, client.DeleteUserAIProviderKey(ctx, provider.ID)) + }) +} + func TestListChatProviders(t *testing.T) { t.Parallel() From a735065158738d33e973e9941df41c9bc5ca8940 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 19:38:21 +0000 Subject: [PATCH 4/7] fix(coderd): harden AI provider API validation --- coderd/database/dbauthz/dbauthz.go | 7 + coderd/database/dbmetrics/querymetrics.go | 8 ++ coderd/database/dbmock/dbmock.go | 15 +++ coderd/database/querier.go | 1 + coderd/database/queries.sql.go | 29 ++++ coderd/database/queries/ai_providers.sql | 9 ++ coderd/exp_chats.go | 156 ++++++++++++++-------- coderd/exp_chats_test.go | 126 +++++++++++++++++ 8 files changed, 299 insertions(+), 52 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 5d4e24b689e9c..e1bd85106c9c0 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2532,6 +2532,13 @@ func (q *querier) GetAIProviderByID(ctx context.Context, id uuid.UUID) (database return q.db.GetAIProviderByID(ctx, id) } +func (q *querier) GetAIProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.AIProvider, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil { + return database.AIProvider{}, err + } + return q.db.GetAIProviderByIDForUpdate(ctx, id) +} + func (q *querier) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil { return database.AIProvider{}, err diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 33385b1aec811..c405bfcb65859 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1025,6 +1025,14 @@ func (m queryMetricsStore) GetAIProviderByID(ctx context.Context, id uuid.UUID) return r0, r1 } +func (m queryMetricsStore) GetAIProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.AIProvider, error) { + start := time.Now() + r0, r1 := m.s.GetAIProviderByIDForUpdate(ctx, id) + m.queryLatencies.WithLabelValues("GetAIProviderByIDForUpdate").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderByIDForUpdate").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) { start := time.Now() r0, r1 := m.s.GetAIProviderByName(ctx, name) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index e7c2e96a413fd..ec926ce4aa79e 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1770,6 +1770,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderByID(ctx, id any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderByID", reflect.TypeOf((*MockStore)(nil).GetAIProviderByID), ctx, id) } +// GetAIProviderByIDForUpdate mocks base method. +func (m *MockStore) GetAIProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.AIProvider, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetAIProviderByIDForUpdate", ctx, id) + ret0, _ := ret[0].(database.AIProvider) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetAIProviderByIDForUpdate indicates an expected call of GetAIProviderByIDForUpdate. +func (mr *MockStoreMockRecorder) GetAIProviderByIDForUpdate(ctx, id any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetAIProviderByIDForUpdate), ctx, id) +} + // GetAIProviderByName mocks base method. func (m *MockStore) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index e1b0d5989685c..9489e0aaf38ea 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -251,6 +251,7 @@ type sqlcQuerier interface { GetAIBridgeUserPromptsByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeUserPrompt, error) GetAIModelPriceByProviderModel(ctx context.Context, arg GetAIModelPriceByProviderModelParams) (AiModelPrice, error) GetAIProviderByID(ctx context.Context, id uuid.UUID) (AIProvider, error) + GetAIProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (AIProvider, error) GetAIProviderByName(ctx context.Context, name string) (AIProvider, error) GetAIProviderKeyByID(ctx context.Context, id uuid.UUID) (AIProviderKey, error) // Returns every AI provider key row, including those belonging to a diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 0cd7d478ebf84..b444adf15cfb4 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -366,6 +366,35 @@ func (q *sqlQuerier) GetAIProviderByID(ctx context.Context, id uuid.UUID) (AIPro return i, err } +const getAIProviderByIDForUpdate = `-- name: GetAIProviderByIDForUpdate :one +SELECT + id, type, name, display_name, enabled, deleted, base_url, settings, settings_key_id, created_at, updated_at +FROM + ai_providers +WHERE + id = $1::uuid AND deleted = FALSE +FOR UPDATE +` + +func (q *sqlQuerier) GetAIProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (AIProvider, error) { + row := q.db.QueryRowContext(ctx, getAIProviderByIDForUpdate, id) + var i AIProvider + err := row.Scan( + &i.ID, + &i.Type, + &i.Name, + &i.DisplayName, + &i.Enabled, + &i.Deleted, + &i.BaseUrl, + &i.Settings, + &i.SettingsKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getAIProviderByName = `-- name: GetAIProviderByName :one SELECT id, type, name, display_name, enabled, deleted, base_url, settings, settings_key_id, created_at, updated_at diff --git a/coderd/database/queries/ai_providers.sql b/coderd/database/queries/ai_providers.sql index 41a8dfd717cd6..da02a4a90913c 100644 --- a/coderd/database/queries/ai_providers.sql +++ b/coderd/database/queries/ai_providers.sql @@ -6,6 +6,15 @@ FROM WHERE id = @id::uuid AND deleted = FALSE; +-- name: GetAIProviderByIDForUpdate :one +SELECT + * +FROM + ai_providers +WHERE + id = @id::uuid AND deleted = FALSE +FOR UPDATE; + -- name: GetAIProviderByName :one SELECT * diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 2620e4691d404..7452b87ca87b6 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -6477,6 +6477,8 @@ func (api *API) createAIProvider(rw http.ResponseWriter, r *http.Request) { if !httpapi.Read(ctx, rw, r, &req) { return } + req.Name = strings.TrimSpace(req.Name) + req.DisplayName = strings.TrimSpace(req.DisplayName) if !validAIProviderName(req.Name) { writeInvalidAIProviderName(ctx, rw) return @@ -6490,13 +6492,21 @@ func (api *API) createAIProvider(rw http.ResponseWriter, r *http.Request) { if req.Enabled != nil { enabled = *req.Enabled } + baseURL, err := normalizeChatProviderBaseURL(req.BaseURL) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid provider base URL.", + Detail: err.Error(), + }) + return + } provider, err := api.Database.InsertAIProvider(ctx, database.InsertAIProviderParams{ ID: uuid.New(), Type: typ, Name: req.Name, DisplayName: sql.NullString{String: req.DisplayName, Valid: req.DisplayName != ""}, Enabled: enabled, - BaseUrl: req.BaseURL, + BaseUrl: baseURL, Settings: sql.NullString{}, SettingsKeyID: sql.NullString{}, }) @@ -6549,57 +6559,67 @@ func (api *API) updateAIProvider(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) return } - current, err := api.Database.GetAIProviderByID(ctx, providerID) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) - return - } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) - return - } var req codersdk.UpdateAIProviderRequest if !httpapi.Read(ctx, rw, r, &req) { return } - name := current.Name - if req.Name != nil { - name = *req.Name - if !validAIProviderName(name) { - writeInvalidAIProviderName(ctx, rw) - return + var provider database.AIProvider + if err := api.Database.InTx(func(tx database.Store) error { + current, err := tx.GetAIProviderByIDForUpdate(ctx, providerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return httperror.NewResponseError(http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) + } + return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) } - } - displayName := current.DisplayName - if req.DisplayName != nil { - displayName = sql.NullString{String: *req.DisplayName, Valid: *req.DisplayName != ""} - } - baseURL := current.BaseUrl - if req.BaseURL != nil { - baseURL = *req.BaseURL - } - enabled := current.Enabled - if req.Enabled != nil { - enabled = *req.Enabled - } - provider, err := api.Database.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ - ID: providerID, - Name: name, - DisplayName: displayName, - Enabled: enabled, - BaseUrl: baseURL, - Settings: current.Settings, - SettingsKeyID: current.SettingsKeyID, - }) - if err != nil { - switch { - case database.IsUniqueViolation(err): - httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{Message: "AI provider already exists."}) - case database.IsCheckViolation(err): - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider."}) - default: - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update AI provider."}) + name := current.Name + if req.Name != nil { + name = strings.TrimSpace(*req.Name) + if !validAIProviderName(name) { + return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider name."}) + } + } + displayName := current.DisplayName + if req.DisplayName != nil { + trimmed := strings.TrimSpace(*req.DisplayName) + displayName = sql.NullString{String: trimmed, Valid: trimmed != ""} + } + baseURL := current.BaseUrl + if req.BaseURL != nil { + baseURL, err = normalizeChatProviderBaseURL(*req.BaseURL) + if err != nil { + return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{ + Message: "Invalid provider base URL.", + Detail: err.Error(), + }) + } + } + enabled := current.Enabled + if req.Enabled != nil { + enabled = *req.Enabled + } + provider, err = tx.UpdateAIProvider(ctx, database.UpdateAIProviderParams{ + ID: providerID, + Name: name, + DisplayName: displayName, + Enabled: enabled, + BaseUrl: baseURL, + Settings: current.Settings, + SettingsKeyID: current.SettingsKeyID, + }) + if err != nil { + switch { + case database.IsUniqueViolation(err): + return httperror.NewResponseError(http.StatusConflict, codersdk.Response{Message: "AI provider already exists."}) + case database.IsCheckViolation(err): + return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider."}) + default: + return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{Message: "Failed to update AI provider."}) + } } + return nil + }, nil); err != nil { + httperror.WriteResponseError(ctx, rw, err) return } publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) @@ -6617,8 +6637,20 @@ func (api *API) deleteAIProvider(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."}) return } - if err := api.Database.DeleteAIProviderByID(ctx, providerID); err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete AI provider."}) + if err := api.Database.InTx(func(tx database.Store) error { + _, err := tx.GetAIProviderByIDForUpdate(ctx, providerID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return httperror.NewResponseError(http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) + } + return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) + } + if err := tx.DeleteAIProviderByID(ctx, providerID); err != nil { + return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete AI provider."}) + } + return nil + }, nil); err != nil { + httperror.WriteResponseError(ctx, rw, err) return } publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) @@ -6672,7 +6704,15 @@ func (api *API) createAIProviderKey(rw http.ResponseWriter, r *http.Request) { if !httpapi.Read(ctx, rw, r, &req) { return } - if strings.TrimSpace(req.APIKey) == "" { + trimmedAPIKey := strings.TrimSpace(req.APIKey) + if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "API key too large.", + Detail: err.Error(), + }) + return + } + if trimmedAPIKey == "" { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."}) return } @@ -6680,7 +6720,7 @@ func (api *API) createAIProviderKey(rw http.ResponseWriter, r *http.Request) { key, err := api.Database.InsertAIProviderKey(ctx, database.InsertAIProviderKeyParams{ ID: uuid.New(), ProviderID: providerID, - APIKey: req.APIKey, + APIKey: trimmedAPIKey, ApiKeyKeyID: sql.NullString{}, CreatedAt: now, UpdatedAt: now, @@ -6786,11 +6826,23 @@ func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) return } + if !provider.Enabled { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider is disabled."}) + return + } var req codersdk.CreateUserAIProviderKeyRequest if !httpapi.Read(ctx, rw, r, &req) { return } - if strings.TrimSpace(req.APIKey) == "" { + trimmedAPIKey := strings.TrimSpace(req.APIKey) + if err := validateChatProviderAPIKeySize(trimmedAPIKey); err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "API key too large.", + Detail: err.Error(), + }) + return + } + if trimmedAPIKey == "" { httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."}) return } @@ -6799,7 +6851,7 @@ func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) ID: uuid.New(), UserID: apiKey.UserID, AiProviderID: providerID, - APIKey: req.APIKey, + APIKey: trimmedAPIKey, ApiKeyKeyID: sql.NullString{}, CreatedAt: now, UpdatedAt: now, diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 3a826ffe708cc..cd88df4c3636e 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -2176,6 +2176,90 @@ func TestAIProviderCRUD(t *testing.T) { require.Equal(t, "Invalid AI provider type.", sdkErr.Message) }) + t.Run("CreateNormalizesInput", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + name := "test-normalized-" + uuid.NewString() + + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: " " + name + " ", + DisplayName: " Normalized Provider ", + BaseURL: " https://api.openai.example.com/v1 ", + }) + require.NoError(t, err) + require.Equal(t, name, provider.Name) + require.Equal(t, "Normalized Provider", provider.DisplayName) + require.Equal(t, "https://api.openai.example.com/v1", provider.BaseURL) + }) + + t.Run("CreateRejectsInvalidBaseURL", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + _, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-invalid-base-url-" + uuid.NewString(), + BaseURL: "file:///tmp/model", + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid provider base URL.", sdkErr.Message) + }) + + t.Run("UpdateRejectsInvalidBaseURL", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-update-base-url-" + uuid.NewString(), + }) + require.NoError(t, err) + + _, err = client.UpdateAIProvider(ctx, provider.ID, codersdk.UpdateAIProviderRequest{ + BaseURL: ptr.Ref("localhost:11434"), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid provider base URL.", sdkErr.Message) + }) + + t.Run("DeleteMissingProvider", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + err := client.DeleteAIProvider(ctx, uuid.New()) + sdkErr := requireSDKError(t, err, http.StatusNotFound) + require.Equal(t, "AI provider not found.", sdkErr.Message) + }) + + t.Run("CreateKeyRejectsLargeAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-large-key-" + uuid.NewString(), + }) + require.NoError(t, err) + + _, err = client.CreateAIProviderKey(ctx, provider.ID, codersdk.CreateAIProviderKeyRequest{APIKey: strings.Repeat("x", 10241)}) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key too large.", sdkErr.Message) + }) + t.Run("CreateKeyForMissingProvider", func(t *testing.T) { t.Parallel() @@ -2268,6 +2352,48 @@ func TestUserAIProviderKeys(t *testing.T) { t.Fatal("provider config not found") }) + t.Run("RejectsDisabledProvider", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + disabled := false + provider, err := adminClient.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-disabled-user-key-" + uuid.NewString(), + Enabled: &disabled, + }) + require.NoError(t, err) + + _, err = memberClient.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"}) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "AI provider is disabled.", sdkErr.Message) + }) + + t.Run("RejectsLargeAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + provider, err := adminClient.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-large-user-key-" + uuid.NewString(), + }) + require.NoError(t, err) + + _, err = memberClient.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: strings.Repeat("x", 10241)}) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key too large.", sdkErr.Message) + }) + t.Run("BYOKDisabledRejectsUpsertAndAllowsDelete", func(t *testing.T) { t.Parallel() From 74a5622c1d7447ec9cf164273f51a7fef69e4832 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 20:43:50 +0000 Subject: [PATCH 5/7] test(coderd/database/dbauthz): cover AI provider row lock authz --- coderd/database/dbauthz/dbauthz_test.go | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index f3cd456a32db7..8458541e522fa 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -6277,6 +6277,11 @@ func (s *MethodTestSuite) TestAIBridge() { dbm.EXPECT().GetAIProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes() check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(provider) })) + s.Run("GetAIProviderByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + provider := testutil.Fake(s.T(), faker, database.AIProvider{}) + dbm.EXPECT().GetAIProviderByIDForUpdate(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes() + check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(provider) + })) s.Run("GetAIProviderByName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { provider := testutil.Fake(s.T(), faker, database.AIProvider{}) dbm.EXPECT().GetAIProviderByName(gomock.Any(), provider.Name).Return(provider, nil).AnyTimes() From a22b9cf79efe80f04db5a8fb4e7cadb9c59da535 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 21:14:40 +0000 Subject: [PATCH 6/7] feat: add AI provider chat routes --- coderd/database/dbauthz/dbauthz.go | 2 +- coderd/database/dbauthz/dbauthz_test.go | 2 +- coderd/exp_chats.go | 69 ++++++++++---- coderd/exp_chats_test.go | 119 ++++++++++++++++++++++++ 4 files changed, 171 insertions(+), 21 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index e1bd85106c9c0..44e3a24f570a2 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2533,7 +2533,7 @@ func (q *querier) GetAIProviderByID(ctx context.Context, id uuid.UUID) (database } func (q *querier) GetAIProviderByIDForUpdate(ctx context.Context, id uuid.UUID) (database.AIProvider, error) { - if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil { + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceAIProvider); err != nil { return database.AIProvider{}, err } return q.db.GetAIProviderByIDForUpdate(ctx, id) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 8458541e522fa..3a537ac78a382 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -6280,7 +6280,7 @@ func (s *MethodTestSuite) TestAIBridge() { s.Run("GetAIProviderByIDForUpdate", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { provider := testutil.Fake(s.T(), faker, database.AIProvider{}) dbm.EXPECT().GetAIProviderByIDForUpdate(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes() - check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(provider) + check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionUpdate).Returns(provider) })) s.Run("GetAIProviderByName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { provider := testutil.Fake(s.T(), faker, database.AIProvider{}) diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 7452b87ca87b6..6057a4c14d7dc 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -6442,11 +6442,33 @@ func validAIProviderName(name string) bool { return !lastHyphen } -func writeInvalidAIProviderName(ctx context.Context, rw http.ResponseWriter) { - httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ +func invalidAIProviderNameResponse() codersdk.Response { + return codersdk.Response{ Message: "Invalid AI provider name.", Detail: "Names must use lowercase letters, numbers, and single hyphens, and cannot use the reserved agents- prefix.", - }) + } +} + +func writeInvalidAIProviderName(ctx context.Context, rw http.ResponseWriter) { + httpapi.Write(ctx, rw, http.StatusBadRequest, invalidAIProviderNameResponse()) +} + +func (api *API) writeAIProviderInternalServerError(ctx context.Context, rw http.ResponseWriter, logMessage string, response codersdk.Response, err error, fields ...slog.Field) { + fields = append(fields, slog.Error(err)) + api.Logger.Error(ctx, logMessage, fields...) + httpapi.Write(ctx, rw, http.StatusInternalServerError, response) +} + +func (api *API) writeAIProviderResponseError(ctx context.Context, rw http.ResponseWriter, err error, fields ...slog.Field) { + if responseErr, ok := httperror.IsResponder(err); ok { + status, resp := responseErr.Response() + if status >= http.StatusInternalServerError { + api.Logger.Error(ctx, resp.Message, append(fields, slog.Error(err))...) + } + httpapi.Write(ctx, rw, status, resp) + return + } + api.writeAIProviderInternalServerError(ctx, rw, "AI provider handler failed", codersdk.Response{Message: "Internal server error", Detail: err.Error()}, err, fields...) } func (api *API) listAIProviders(rw http.ResponseWriter, r *http.Request) { @@ -6457,7 +6479,7 @@ func (api *API) listAIProviders(rw http.ResponseWriter, r *http.Request) { } providers, err := api.Database.GetAIProviders(ctx, database.GetAIProvidersParams{IncludeDisabled: true}) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI providers."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to list AI providers", codersdk.Response{Message: "Failed to list AI providers."}, err) return } resp := make([]codersdk.AIProvider, 0, len(providers)) @@ -6517,7 +6539,7 @@ func (api *API) createAIProvider(rw http.ResponseWriter, r *http.Request) { case database.IsCheckViolation(err): httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider."}) default: - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to create AI provider."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to create AI provider", codersdk.Response{Message: "Failed to create AI provider."}, err) } return } @@ -6542,7 +6564,7 @@ func (api *API) readAIProvider(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) return } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to get AI provider", codersdk.Response{Message: "Failed to get AI provider."}, err, slog.F("ai_provider_id", providerID)) return } httpapi.Write(ctx, rw, http.StatusOK, convertAIProvider(provider)) @@ -6576,7 +6598,7 @@ func (api *API) updateAIProvider(rw http.ResponseWriter, r *http.Request) { if req.Name != nil { name = strings.TrimSpace(*req.Name) if !validAIProviderName(name) { - return httperror.NewResponseError(http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider name."}) + return httperror.NewResponseError(http.StatusBadRequest, invalidAIProviderNameResponse()) } } displayName := current.DisplayName @@ -6619,7 +6641,7 @@ func (api *API) updateAIProvider(rw http.ResponseWriter, r *http.Request) { } return nil }, nil); err != nil { - httperror.WriteResponseError(ctx, rw, err) + api.writeAIProviderResponseError(ctx, rw, err, slog.F("ai_provider_id", providerID)) return } publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) @@ -6645,12 +6667,21 @@ func (api *API) deleteAIProvider(rw http.ResponseWriter, r *http.Request) { } return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) } + keys, err := tx.GetAIProviderKeysByProviderID(ctx, providerID) + if err != nil { + return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."}) + } + for _, key := range keys { + if err := tx.DeleteAIProviderKey(ctx, key.ID); err != nil { + return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete AI provider key."}) + } + } if err := tx.DeleteAIProviderByID(ctx, providerID); err != nil { return httperror.NewResponseError(http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete AI provider."}) } return nil }, nil); err != nil { - httperror.WriteResponseError(ctx, rw, err) + api.writeAIProviderResponseError(ctx, rw, err, slog.F("ai_provider_id", providerID)) return } publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) @@ -6670,7 +6701,7 @@ func (api *API) listAIProviderKeys(rw http.ResponseWriter, r *http.Request) { } keys, err := api.Database.GetAIProviderKeysByProviderID(ctx, providerID) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to list AI provider keys", codersdk.Response{Message: "Failed to list AI provider keys."}, err, slog.F("ai_provider_id", providerID)) return } resp := make([]codersdk.AIProviderKey, 0, len(keys)) @@ -6697,7 +6728,7 @@ func (api *API) createAIProviderKey(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) return } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to get AI provider", codersdk.Response{Message: "Failed to get AI provider."}, err, slog.F("ai_provider_id", providerID)) return } var req codersdk.CreateAIProviderKeyRequest @@ -6726,7 +6757,7 @@ func (api *API) createAIProviderKey(rw http.ResponseWriter, r *http.Request) { UpdatedAt: now, }) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to create AI provider key."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to create AI provider key", codersdk.Response{Message: "Failed to create AI provider key."}, err, slog.F("ai_provider_id", providerID)) return } publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) @@ -6755,7 +6786,7 @@ func (api *API) deleteAIProviderKey(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider key not found."}) return } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider key."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to get AI provider key", codersdk.Response{Message: "Failed to get AI provider key."}, err, slog.F("ai_provider_id", providerID), slog.F("ai_provider_key_id", keyID)) return } if key.ProviderID != providerID { @@ -6763,7 +6794,7 @@ func (api *API) deleteAIProviderKey(rw http.ResponseWriter, r *http.Request) { return } if err := api.Database.DeleteAIProviderKey(ctx, keyID); err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete AI provider key."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to delete AI provider key", codersdk.Response{Message: "Failed to delete AI provider key."}, err, slog.F("ai_provider_id", providerID), slog.F("ai_provider_key_id", keyID)) return } publishChatConfigEvent(api.Logger, api.Pubsub, pubsub.ChatConfigEventProviders, uuid.Nil) @@ -6776,12 +6807,12 @@ func (api *API) listUserAIProviderKeyConfigs(rw http.ResponseWriter, r *http.Req //nolint:gocritic // Users can list limited enabled provider metadata without AI provider admin permissions. providers, err := api.Database.GetAIProviders(dbauthz.AsSystemRestricted(ctx), database.GetAIProvidersParams{IncludeDisabled: true}) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI providers."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to list user AI provider configs", codersdk.Response{Message: "Failed to list AI providers."}, err, slog.F("user_id", apiKey.UserID)) return } keys, err := api.Database.GetUserAIProviderKeysByUserID(ctx, apiKey.UserID) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list user AI provider keys."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to list user AI provider keys", codersdk.Response{Message: "Failed to list user AI provider keys."}, err, slog.F("user_id", apiKey.UserID)) return } keysByProviderID := make(map[uuid.UUID]struct{}, len(keys)) @@ -6823,7 +6854,7 @@ func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."}) return } - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to get AI provider", codersdk.Response{Message: "Failed to get AI provider."}, err, slog.F("ai_provider_id", providerID)) return } if !provider.Enabled { @@ -6857,7 +6888,7 @@ func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) UpdatedAt: now, }) if err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update user AI provider key."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to update user AI provider key", codersdk.Response{Message: "Failed to update user AI provider key."}, err, slog.F("user_id", apiKey.UserID), slog.F("ai_provider_id", providerID)) return } httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserAIProviderKeyConfig{ @@ -6876,7 +6907,7 @@ func (api *API) deleteUserAIProviderKey(rw http.ResponseWriter, r *http.Request) return } if err := api.Database.DeleteUserAIProviderKey(ctx, database.DeleteUserAIProviderKeyParams{UserID: apiKey.UserID, AiProviderID: providerID}); err != nil { - httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete user AI provider key."}) + api.writeAIProviderInternalServerError(ctx, rw, "failed to delete user AI provider key", codersdk.Response{Message: "Failed to delete user AI provider key."}, err, slog.F("user_id", apiKey.UserID), slog.F("ai_provider_id", providerID)) return } httpapi.Write(ctx, rw, http.StatusNoContent, nil) diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index cd88df4c3636e..3558f322f48b9 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -2119,6 +2119,7 @@ func TestAIProviderCRUD(t *testing.T) { require.NoError(t, err) require.Equal(t, updatedName, updated.Name) require.Equal(t, "Anthropic Test", updated.DisplayName) + require.Equal(t, "https://api.anthropic.example.com/v1", updated.BaseURL) require.False(t, updated.Enabled) key, err := client.CreateAIProviderKey(ctx, provider.ID, codersdk.CreateAIProviderKeyRequest{APIKey: "test-api-key"}) @@ -2212,6 +2213,50 @@ func TestAIProviderCRUD(t *testing.T) { require.Equal(t, "Invalid provider base URL.", sdkErr.Message) }) + t.Run("UpdateRejectsInvalidName", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-update-invalid-name-" + uuid.NewString(), + }) + require.NoError(t, err) + + _, err = client.UpdateAIProvider(ctx, provider.ID, codersdk.UpdateAIProviderRequest{ + Name: ptr.Ref("not a valid provider name"), + }) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "Invalid AI provider name.", sdkErr.Message) + require.NotEmpty(t, sdkErr.Detail) + }) + + t.Run("UpdateRejectsDuplicateName", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-update-duplicate-name-" + uuid.NewString(), + }) + require.NoError(t, err) + other, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-update-duplicate-other-" + uuid.NewString(), + }) + require.NoError(t, err) + + _, err = client.UpdateAIProvider(ctx, provider.ID, codersdk.UpdateAIProviderRequest{ + Name: ptr.Ref(other.Name), + }) + sdkErr := requireSDKError(t, err, http.StatusConflict) + require.Equal(t, "AI provider already exists.", sdkErr.Message) + }) + t.Run("UpdateRejectsInvalidBaseURL", func(t *testing.T) { t.Parallel() @@ -2260,6 +2305,23 @@ func TestAIProviderCRUD(t *testing.T) { require.Equal(t, "API key too large.", sdkErr.Message) }) + t.Run("CreateKeyRejectsEmptyAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-empty-key-" + uuid.NewString(), + }) + require.NoError(t, err) + + _, err = client.CreateAIProviderKey(ctx, provider.ID, codersdk.CreateAIProviderKeyRequest{APIKey: " "}) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key is required.", sdkErr.Message) + }) + t.Run("CreateKeyForMissingProvider", func(t *testing.T) { t.Parallel() @@ -2272,6 +2334,26 @@ func TestAIProviderCRUD(t *testing.T) { require.Equal(t, "AI provider not found.", sdkErr.Message) }) + t.Run("DeleteProviderDeletesProviderKeys", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-delete-keys-" + uuid.NewString(), + }) + require.NoError(t, err) + _, err = client.CreateAIProviderKey(ctx, provider.ID, codersdk.CreateAIProviderKeyRequest{APIKey: "test-api-key"}) + require.NoError(t, err) + + require.NoError(t, client.DeleteAIProvider(ctx, provider.ID)) + keys, err := client.ListAIProviderKeys(ctx, provider.ID) + require.NoError(t, err) + require.Empty(t, keys) + }) + t.Run("ForbiddenForOrganizationMember", func(t *testing.T) { t.Parallel() @@ -2340,6 +2422,23 @@ func TestUserAIProviderKeys(t *testing.T) { require.NotNil(t, cfg) require.True(t, cfg.HasUserAPIKey) + cfgValue, err = memberClient.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "replacement-user-api-key"}) + require.NoError(t, err) + require.Equal(t, provider.ID, cfgValue.Provider.ID) + require.True(t, cfgValue.HasUserAPIKey) + + configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx) + require.NoError(t, err) + cfg = nil + for i := range configs { + if configs[i].Provider.ID == provider.ID { + cfg = &configs[i] + break + } + } + require.NotNil(t, cfg) + require.True(t, cfg.HasUserAPIKey) + require.NoError(t, memberClient.DeleteUserAIProviderKey(ctx, provider.ID)) configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx) require.NoError(t, err) @@ -2394,6 +2493,26 @@ func TestUserAIProviderKeys(t *testing.T) { require.Equal(t, "API key too large.", sdkErr.Message) }) + t.Run("RejectsEmptyAPIKey", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + adminClient := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, adminClient.Client) + memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID) + memberClient := codersdk.NewExperimentalClient(memberClientRaw) + + provider, err := adminClient.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "test-empty-user-key-" + uuid.NewString(), + }) + require.NoError(t, err) + + _, err = memberClient.UpsertUserAIProviderKey(ctx, provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: " "}) + sdkErr := requireSDKError(t, err, http.StatusBadRequest) + require.Equal(t, "API key is required.", sdkErr.Message) + }) + t.Run("BYOKDisabledRejectsUpsertAndAllowsDelete", func(t *testing.T) { t.Parallel() From b8b1cb6eecc8dd44d429b24df890e653f07a309b Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Sat, 16 May 2026 21:46:50 +0000 Subject: [PATCH 7/7] fix(coderd): include AI providers in enabled model queries --- coderd/database/queries.sql.go | 18 ++++++++++++++---- coderd/database/queries/chatmodelconfigs.sql | 18 ++++++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2157340746ad8..769f6aace4e34 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -5124,13 +5124,18 @@ SELECT cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id FROM chat_model_configs cmc -JOIN +LEFT JOIN + ai_providers ap ON ap.id = cmc.ai_provider_id AND ap.deleted = FALSE +LEFT JOIN chat_providers cp ON cp.provider = cmc.provider WHERE cmc.id = $1::uuid AND cmc.deleted = FALSE AND cmc.enabled = TRUE - AND cp.enabled = TRUE + AND ( + (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE) + OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) + ) ` // Providers can be disabled independently of their model configs. @@ -5164,12 +5169,17 @@ SELECT cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id FROM chat_model_configs cmc -JOIN +LEFT JOIN + ai_providers ap ON ap.id = cmc.ai_provider_id AND ap.deleted = FALSE +LEFT JOIN chat_providers cp ON cp.provider = cmc.provider WHERE cmc.enabled = TRUE AND cmc.deleted = FALSE - AND cp.enabled = TRUE + AND ( + (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE) + OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) + ) ORDER BY cmc.provider ASC, cmc.model ASC, diff --git a/coderd/database/queries/chatmodelconfigs.sql b/coderd/database/queries/chatmodelconfigs.sql index d129760c3dcaf..5222f64199064 100644 --- a/coderd/database/queries/chatmodelconfigs.sql +++ b/coderd/database/queries/chatmodelconfigs.sql @@ -34,12 +34,17 @@ SELECT cmc.* FROM chat_model_configs cmc -JOIN +LEFT JOIN + ai_providers ap ON ap.id = cmc.ai_provider_id AND ap.deleted = FALSE +LEFT JOIN chat_providers cp ON cp.provider = cmc.provider WHERE cmc.enabled = TRUE AND cmc.deleted = FALSE - AND cp.enabled = TRUE + AND ( + (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE) + OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) + ) ORDER BY cmc.provider ASC, cmc.model ASC, @@ -53,13 +58,18 @@ FROM chat_model_configs cmc -- Providers can be disabled independently of their model configs. -- Check both to ensure the selected config is actually usable. -JOIN +LEFT JOIN + ai_providers ap ON ap.id = cmc.ai_provider_id AND ap.deleted = FALSE +LEFT JOIN chat_providers cp ON cp.provider = cmc.provider WHERE cmc.id = @id::uuid AND cmc.deleted = FALSE AND cmc.enabled = TRUE - AND cp.enabled = TRUE; + AND ( + (cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE) + OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE) + ); -- name: InsertChatModelConfig :one INSERT INTO chat_model_configs (