Skip to content

Commit 64e4c6a

Browse files
committed
feat: prepare schema for multiple provider configs per family
Drop the unique constraint on chat_providers.provider and add provider_config_id to chat_model_configs. This prepares the database for multiple provider configs per family while keeping one-per-family enforcement at the application layer. Schema changes: - Drop unique constraint on chat_providers.provider - Add provider_config_id column to chat_model_configs with backfill - Add indexes for provider_config_id and provider lookups Backend changes: - Support provider_config_id binding in model config CRUD - Add provider-enablement filtering to GetEnabledChatModelConfigs - Soft-delete model configs on provider deletion (with updated_by) - Add application-level uniqueness guard for provider creation - Add ensureDefaultChatModelConfig call to deleteChatProvider - Align has_api_key / has_effective_api_key semantics - Update dbauthz, dbmetrics, dbmock, dbcrypt wiring Runtime and frontend changes for full multi-provider support will follow in a separate PR.
1 parent da5395a commit 64e4c6a

25 files changed

Lines changed: 2857 additions & 423 deletions

coderd/database/chatproviders.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package database
2+
3+
import (
4+
"cmp"
5+
"slices"
6+
)
7+
8+
// ChatProvidersByFamilyPrecedence selects the effective provider config for
9+
// each family. The oldest config wins so runtime behavior matches the single-
10+
// row lookup used elsewhere in the API.
11+
func ChatProvidersByFamilyPrecedence(providers []ChatProvider) []ChatProvider {
12+
providers = slices.Clone(providers)
13+
slices.SortStableFunc(providers, func(a, b ChatProvider) int {
14+
if byProvider := cmp.Compare(a.Provider, b.Provider); byProvider != 0 {
15+
return byProvider
16+
}
17+
if byCreatedAt := a.CreatedAt.Compare(b.CreatedAt); byCreatedAt != 0 {
18+
return byCreatedAt
19+
}
20+
return cmp.Compare(a.ID.String(), b.ID.String())
21+
})
22+
selected := providers[:0]
23+
for _, provider := range providers {
24+
if len(selected) > 0 && selected[len(selected)-1].Provider == provider.Provider {
25+
continue
26+
}
27+
selected = append(selected, provider)
28+
}
29+
return selected
30+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//nolint:testpackage // Internal test for unexported ChatProvidersByFamilyPrecedence helper.
2+
package database
3+
4+
import (
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestChatProvidersByFamilyPrecedence(t *testing.T) {
13+
t.Parallel()
14+
15+
provider := func(id, family string, createdAt time.Time) ChatProvider {
16+
return ChatProvider{
17+
ID: uuid.MustParse(id),
18+
Provider: family,
19+
CreatedAt: createdAt,
20+
}
21+
}
22+
23+
timeAt := func(day int) time.Time {
24+
return time.Date(2024, time.January, day, 12, 0, 0, 0, time.UTC)
25+
}
26+
27+
cases := []struct {
28+
name string
29+
input []ChatProvider
30+
want []ChatProvider
31+
}{
32+
{
33+
name: "empty input",
34+
input: nil,
35+
want: nil,
36+
},
37+
{
38+
name: "single row",
39+
input: []ChatProvider{
40+
provider("00000000-0000-0000-0000-000000000001", "openai", timeAt(2)),
41+
},
42+
want: []ChatProvider{
43+
provider("00000000-0000-0000-0000-000000000001", "openai", timeAt(2)),
44+
},
45+
},
46+
{
47+
name: "all same family keeps oldest",
48+
input: []ChatProvider{
49+
provider("00000000-0000-0000-0000-000000000003", "openai", timeAt(3)),
50+
provider("00000000-0000-0000-0000-000000000001", "openai", timeAt(1)),
51+
provider("00000000-0000-0000-0000-000000000002", "openai", timeAt(2)),
52+
},
53+
want: []ChatProvider{
54+
provider("00000000-0000-0000-0000-000000000001", "openai", timeAt(1)),
55+
},
56+
},
57+
{
58+
name: "multiple families keep oldest per family",
59+
input: []ChatProvider{
60+
provider("00000000-0000-0000-0000-000000000006", "openai", timeAt(4)),
61+
provider("00000000-0000-0000-0000-000000000004", "anthropic", timeAt(2)),
62+
provider("00000000-0000-0000-0000-000000000008", "zeta", timeAt(3)),
63+
provider("00000000-0000-0000-0000-000000000005", "openai", timeAt(1)),
64+
provider("00000000-0000-0000-0000-000000000007", "anthropic", timeAt(5)),
65+
},
66+
want: []ChatProvider{
67+
provider("00000000-0000-0000-0000-000000000004", "anthropic", timeAt(2)),
68+
provider("00000000-0000-0000-0000-000000000005", "openai", timeAt(1)),
69+
provider("00000000-0000-0000-0000-000000000008", "zeta", timeAt(3)),
70+
},
71+
},
72+
{
73+
name: "same provider and timestamp breaks ties by id",
74+
input: []ChatProvider{
75+
provider("00000000-0000-0000-0000-00000000000b", "openai", timeAt(1)),
76+
provider("00000000-0000-0000-0000-00000000000a", "openai", timeAt(1)),
77+
},
78+
want: []ChatProvider{
79+
provider("00000000-0000-0000-0000-00000000000a", "openai", timeAt(1)),
80+
},
81+
},
82+
}
83+
84+
for _, tc := range cases {
85+
t.Run(tc.name, func(t *testing.T) {
86+
t.Parallel()
87+
88+
got := ChatProvidersByFamilyPrecedence(tc.input)
89+
90+
require.Equal(t, tc.want, got)
91+
})
92+
}
93+
}

coderd/database/dbauthz/dbauthz.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1738,6 +1738,13 @@ func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLog
17381738
return q.db.CountAuthorizedAuditLogs(ctx, arg, prep)
17391739
}
17401740

1741+
func (q *querier) CountChatProvidersByProviderExcludingID(ctx context.Context, arg database.CountChatProvidersByProviderExcludingIDParams) (int32, error) {
1742+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
1743+
return 0, err
1744+
}
1745+
return q.db.CountChatProvidersByProviderExcludingID(ctx, arg)
1746+
}
1747+
17411748
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
17421749
// Just like the actual query, shortcut if the user is an owner.
17431750
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
@@ -2904,6 +2911,13 @@ func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.Ch
29042911
return q.db.GetEnabledChatModelConfigs(ctx)
29052912
}
29062913

2914+
func (q *querier) GetEnabledChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
2915+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
2916+
return database.ChatProvider{}, err
2917+
}
2918+
return q.db.GetEnabledChatProviderByProvider(ctx, provider)
2919+
}
2920+
29072921
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
29082922
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
29092923
return nil, err
@@ -5678,6 +5692,13 @@ func (q *querier) SelectUsageEventsForPublishing(ctx context.Context, arg time.T
56785692
return q.db.SelectUsageEventsForPublishing(ctx, arg)
56795693
}
56805694

5695+
func (q *querier) SoftDeleteBoundChatModelConfigsByProviderConfigID(ctx context.Context, providerConfigID database.SoftDeleteBoundChatModelConfigsByProviderConfigIDParams) (int64, error) {
5696+
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
5697+
return 0, err
5698+
}
5699+
return q.db.SoftDeleteBoundChatModelConfigsByProviderConfigID(ctx, providerConfigID)
5700+
}
5701+
56815702
func (q *querier) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
56825703
msg, err := q.db.GetChatMessageByID(ctx, id)
56835704
if err != nil {
@@ -5704,6 +5725,13 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
57045725
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
57055726
}
57065727

5728+
func (q *querier) SoftDeleteUnboundChatModelConfigsByProvider(ctx context.Context, provider database.SoftDeleteUnboundChatModelConfigsByProviderParams) (int64, error) {
5729+
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
5730+
return 0, err
5731+
}
5732+
return q.db.SoftDeleteUnboundChatModelConfigsByProvider(ctx, provider)
5733+
}
5734+
57075735
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
57085736
return q.db.TryAcquireLock(ctx, id)
57095737
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 46 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,14 @@ func (s *MethodTestSuite) TestConnectionLogs() {
365365
dbm.EXPECT().CountConnectionLogs(gomock.Any(), database.CountConnectionLogsParams{}).Return(int64(0), nil).AnyTimes()
366366
check.Args(database.CountConnectionLogsParams{}, emptyPreparedAuthorized{}).Asserts(rbac.ResourceConnectionLog, policy.ActionRead)
367367
}))
368+
s.Run("DeleteOldChatFiles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
369+
dbm.EXPECT().DeleteOldChatFiles(gomock.Any(), database.DeleteOldChatFilesParams{}).Return(int64(0), nil).AnyTimes()
370+
check.Args(database.DeleteOldChatFilesParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
371+
}))
372+
s.Run("DeleteOldChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
373+
dbm.EXPECT().DeleteOldChats(gomock.Any(), database.DeleteOldChatsParams{}).Return(int64(0), nil).AnyTimes()
374+
check.Args(database.DeleteOldChatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
375+
}))
368376
s.Run("DeleteOldConnectionLogs", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
369377
dbm.EXPECT().DeleteOldConnectionLogs(gomock.Any(), database.DeleteOldConnectionLogsParams{}).Return(int64(0), nil).AnyTimes()
370378
check.Args(database.DeleteOldConnectionLogsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
@@ -423,6 +431,14 @@ func (s *MethodTestSuite) TestChats() {
423431
dbm.EXPECT().UnpinChatByID(gomock.Any(), chat.ID).Return(nil).AnyTimes()
424432
check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns()
425433
}))
434+
s.Run("SoftDeleteBoundChatModelConfigsByProviderConfigID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
435+
arg := database.SoftDeleteBoundChatModelConfigsByProviderConfigIDParams{
436+
UpdatedBy: uuid.NullUUID{UUID: uuid.New(), Valid: true},
437+
ProviderConfigID: uuid.New(),
438+
}
439+
dbm.EXPECT().SoftDeleteBoundChatModelConfigsByProviderConfigID(gomock.Any(), arg).Return(int64(2), nil).AnyTimes()
440+
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(int64(2))
441+
}))
426442
s.Run("SoftDeleteChatMessagesAfterID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
427443
chat := testutil.Fake(s.T(), faker, database.Chat{})
428444
arg := database.SoftDeleteChatMessagesAfterIDParams{
@@ -444,6 +460,14 @@ func (s *MethodTestSuite) TestChats() {
444460
dbm.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), msg.ID).Return(nil).AnyTimes()
445461
check.Args(msg.ID).Asserts(chat, policy.ActionUpdate).Returns()
446462
}))
463+
s.Run("SoftDeleteUnboundChatModelConfigsByProvider", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
464+
arg := database.SoftDeleteUnboundChatModelConfigsByProviderParams{
465+
UpdatedBy: uuid.NullUUID{UUID: uuid.New(), Valid: true},
466+
Provider: "openai",
467+
}
468+
dbm.EXPECT().SoftDeleteUnboundChatModelConfigsByProvider(gomock.Any(), arg).Return(int64(1), nil).AnyTimes()
469+
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns(int64(1))
470+
}))
447471
s.Run("DeleteChatModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
448472
id := uuid.New()
449473
dbm.EXPECT().DeleteChatModelConfigByID(gomock.Any(), id).Return(nil).AnyTimes()
@@ -553,6 +577,14 @@ func (s *MethodTestSuite) TestChats() {
553577
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
554578
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
555579
}))
580+
s.Run("CountChatProvidersByProviderExcludingID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
581+
arg := database.CountChatProvidersByProviderExcludingIDParams{
582+
Provider: "openai",
583+
ID: uuid.New(),
584+
}
585+
dbm.EXPECT().CountChatProvidersByProviderExcludingID(gomock.Any(), arg).Return(int32(3), nil).AnyTimes()
586+
check.Args(arg).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int32(3))
587+
}))
556588
s.Run("CountEnabledModelsWithoutPricing", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
557589
dbm.EXPECT().CountEnabledModelsWithoutPricing(gomock.Any()).Return(int64(3), nil).AnyTimes()
558590
check.Args().Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(int64(3))
@@ -600,22 +632,6 @@ func (s *MethodTestSuite) TestChats() {
600632
dbm.EXPECT().GetChatFileMetadataByChatID(gomock.Any(), file.ID).Return(rows, nil).AnyTimes()
601633
check.Args(file.ID).Asserts(rbac.ResourceChat.WithOwner(file.OwnerID.String()).InOrg(file.OrganizationID).WithID(file.ID), policy.ActionRead).Returns(rows)
602634
}))
603-
s.Run("DeleteOldChatFiles", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
604-
dbm.EXPECT().DeleteOldChatFiles(gomock.Any(), database.DeleteOldChatFilesParams{}).Return(int64(0), nil).AnyTimes()
605-
check.Args(database.DeleteOldChatFilesParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
606-
}))
607-
s.Run("DeleteOldChats", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
608-
dbm.EXPECT().DeleteOldChats(gomock.Any(), database.DeleteOldChatsParams{}).Return(int64(0), nil).AnyTimes()
609-
check.Args(database.DeleteOldChatsParams{}).Asserts(rbac.ResourceSystem, policy.ActionDelete)
610-
}))
611-
s.Run("GetChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
612-
dbm.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes()
613-
check.Args().Asserts()
614-
}))
615-
s.Run("UpsertChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
616-
dbm.EXPECT().UpsertChatRetentionDays(gomock.Any(), int32(30)).Return(nil).AnyTimes()
617-
check.Args(int32(30)).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
618-
}))
619635
s.Run("GetChatMessageByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
620636
chat := testutil.Fake(s.T(), faker, database.Chat{})
621637
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
@@ -683,12 +699,22 @@ func (s *MethodTestSuite) TestChats() {
683699
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
684700
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
685701
}))
702+
s.Run("GetEnabledChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
703+
providerName := "test-provider"
704+
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
705+
dbm.EXPECT().GetEnabledChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
706+
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
707+
}))
686708
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
687709
providerName := "test-provider"
688710
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
689711
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
690712
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
691713
}))
714+
s.Run("GetChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
715+
dbm.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes()
716+
check.Args().Asserts().Returns(int32(30))
717+
}))
692718
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
693719
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
694720
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})
@@ -1006,6 +1032,10 @@ func (s *MethodTestSuite) TestChats() {
10061032
dbm.EXPECT().UpsertChatIncludeDefaultSystemPrompt(gomock.Any(), false).Return(nil).AnyTimes()
10071033
check.Args(false).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
10081034
}))
1035+
s.Run("UpsertChatRetentionDays", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
1036+
dbm.EXPECT().UpsertChatRetentionDays(gomock.Any(), int32(30)).Return(nil).AnyTimes()
1037+
check.Args(int32(30)).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)
1038+
}))
10091039
s.Run("UpsertChatSystemPrompt", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
10101040
dbm.EXPECT().UpsertChatSystemPrompt(gomock.Any(), "").Return(nil).AnyTimes()
10111041
check.Args("").Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate)

coderd/database/dbmetrics/querymetrics.go

Lines changed: 32 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)