Skip to content

Commit c4adf90

Browse files
committed
feat: support multiple provider configs per provider family
Adds multi-provider-config support to the chat system, allowing administrators to define multiple separate credentials and endpoints for the same provider family (e.g., separate OpenAI instances for different environments). Key changes: - Migration 000460: drops unique constraint on chat_providers.provider, adds nullable UUID FK (provider_config_id) on chat_model_configs with ON DELETE SET NULL, and backfills existing models. - Oldest-enabled-config precedence for deterministic fallback. - Auto-bind on model creation to oldest enabled config. - Application-level soft-delete of bound models on provider deletion to preserve customer tracking and chat history. - Per-config effective API key reporting in provider listings. - Frontend: config dropdown in ModelForm, family-level availability aggregation, multi-config provider display. - Comprehensive tests for migration backfills, precedence, binding, soft-delete, and non-admin visibility filtering.
1 parent 1c4a9ed commit c4adf90

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+4783
-2258
lines changed

cli/agent.go

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"strings"
1818
"time"
1919

20-
"github.com/google/uuid"
2120
"github.com/prometheus/client_golang/prometheus"
2221
"golang.org/x/xerrors"
2322
"gopkg.in/natefinch/lumberjack.v2"
@@ -273,14 +272,11 @@ func workspaceAgent() *serpent.Command {
273272
logger.Info(ctx, "agent devcontainer detection not enabled")
274273
}
275274

276-
reinitCtx, reinitCancel := context.WithCancel(ctx)
277-
defer reinitCancel()
278-
reinitEvents := agentsdk.WaitForReinitLoop(reinitCtx, logger, client)
275+
reinitEvents := agentsdk.WaitForReinitLoop(ctx, logger, client)
279276

280277
var (
281-
lastOwnerID uuid.UUID
282-
lastErr error
283-
mustExit bool
278+
lastErr error
279+
mustExit bool
284280
)
285281
for {
286282
prometheusRegistry := prometheus.NewRegistry()
@@ -347,32 +343,9 @@ func workspaceAgent() *serpent.Command {
347343
case <-ctx.Done():
348344
logger.Info(ctx, "agent shutting down", slog.Error(context.Cause(ctx)))
349345
mustExit = true
350-
case event, ok := <-reinitEvents:
351-
switch {
352-
case !ok:
353-
// Channel closed — the reinit loop exited
354-
// (terminal 409 or context expired). Keep
355-
// running the current agent until the parent
356-
// context is canceled.
357-
logger.Info(ctx, "reinit channel closed, running without reinit capability")
358-
reinitEvents = nil
359-
<-ctx.Done()
360-
mustExit = true
361-
case event.OwnerID != uuid.Nil && event.OwnerID == lastOwnerID:
362-
// Duplicate reinit for same owner — already
363-
// reinitialized. Cancel the reinit loop
364-
// goroutine and keep the current agent.
365-
logger.Info(ctx, "skipping redundant reinit, owner unchanged",
366-
slog.F("owner_id", event.OwnerID))
367-
reinitCancel()
368-
reinitEvents = nil
369-
<-ctx.Done()
370-
mustExit = true
371-
default:
372-
lastOwnerID = event.OwnerID
373-
logger.Info(ctx, "agent received instruction to reinitialize",
374-
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
375-
}
346+
case event := <-reinitEvents:
347+
logger.Info(ctx, "agent received instruction to reinitialize",
348+
slog.F("workspace_id", event.WorkspaceID), slog.F("reason", event.Reason))
376349
}
377350

378351
lastErr = agnt.Close()

coderd/apidoc/docs.go

Lines changed: 2 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/apidoc/swagger.json

Lines changed: 2 additions & 21 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/chatproviders.go

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package database
2+
3+
import (
4+
"cmp"
5+
"slices"
6+
)
7+
8+
// ChatProvidersByFamilyPrecedence selects the effective provider config for
9+
// each family. The oldest config wins so runtime behavior matches the single-
10+
// row lookup used elsewhere in the API.
11+
func ChatProvidersByFamilyPrecedence(providers []ChatProvider) []ChatProvider {
12+
providers = slices.Clone(providers)
13+
slices.SortStableFunc(providers, func(a, b ChatProvider) int {
14+
if byProvider := cmp.Compare(a.Provider, b.Provider); byProvider != 0 {
15+
return byProvider
16+
}
17+
if byCreatedAt := a.CreatedAt.Compare(b.CreatedAt); byCreatedAt != 0 {
18+
return byCreatedAt
19+
}
20+
return cmp.Compare(a.ID.String(), b.ID.String())
21+
})
22+
selected := providers[:0]
23+
for _, provider := range providers {
24+
if len(selected) > 0 && selected[len(selected)-1].Provider == provider.Provider {
25+
continue
26+
}
27+
selected = append(selected, provider)
28+
}
29+
return selected
30+
}
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
//nolint:testpackage // Internal test for unexported ChatProvidersByFamilyPrecedence helper.
2+
package database
3+
4+
import (
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestChatProvidersByFamilyPrecedence(t *testing.T) {
13+
t.Parallel()
14+
15+
provider := func(id, family string, createdAt time.Time) ChatProvider {
16+
return ChatProvider{
17+
ID: uuid.MustParse(id),
18+
Provider: family,
19+
CreatedAt: createdAt,
20+
}
21+
}
22+
23+
timeAt := func(day int) time.Time {
24+
return time.Date(2024, time.January, day, 12, 0, 0, 0, time.UTC)
25+
}
26+
27+
cases := []struct {
28+
name string
29+
input []ChatProvider
30+
want []ChatProvider
31+
}{
32+
{
33+
name: "empty input",
34+
input: nil,
35+
want: nil,
36+
},
37+
{
38+
name: "single row",
39+
input: []ChatProvider{
40+
provider("00000000-0000-0000-0000-000000000001", "openai", timeAt(2)),
41+
},
42+
want: []ChatProvider{
43+
provider("00000000-0000-0000-0000-000000000001", "openai", timeAt(2)),
44+
},
45+
},
46+
{
47+
name: "all same family keeps oldest",
48+
input: []ChatProvider{
49+
provider("00000000-0000-0000-0000-000000000003", "openai", timeAt(3)),
50+
provider("00000000-0000-0000-0000-000000000001", "openai", timeAt(1)),
51+
provider("00000000-0000-0000-0000-000000000002", "openai", timeAt(2)),
52+
},
53+
want: []ChatProvider{
54+
provider("00000000-0000-0000-0000-000000000001", "openai", timeAt(1)),
55+
},
56+
},
57+
{
58+
name: "multiple families keep oldest per family",
59+
input: []ChatProvider{
60+
provider("00000000-0000-0000-0000-000000000006", "openai", timeAt(4)),
61+
provider("00000000-0000-0000-0000-000000000004", "anthropic", timeAt(2)),
62+
provider("00000000-0000-0000-0000-000000000008", "zeta", timeAt(3)),
63+
provider("00000000-0000-0000-0000-000000000005", "openai", timeAt(1)),
64+
provider("00000000-0000-0000-0000-000000000007", "anthropic", timeAt(5)),
65+
},
66+
want: []ChatProvider{
67+
provider("00000000-0000-0000-0000-000000000004", "anthropic", timeAt(2)),
68+
provider("00000000-0000-0000-0000-000000000005", "openai", timeAt(1)),
69+
provider("00000000-0000-0000-0000-000000000008", "zeta", timeAt(3)),
70+
},
71+
},
72+
{
73+
name: "same provider and timestamp breaks ties by id",
74+
input: []ChatProvider{
75+
provider("00000000-0000-0000-0000-00000000000b", "openai", timeAt(1)),
76+
provider("00000000-0000-0000-0000-00000000000a", "openai", timeAt(1)),
77+
},
78+
want: []ChatProvider{
79+
provider("00000000-0000-0000-0000-00000000000a", "openai", timeAt(1)),
80+
},
81+
},
82+
}
83+
84+
for _, tc := range cases {
85+
t.Run(tc.name, func(t *testing.T) {
86+
t.Parallel()
87+
88+
got := ChatProvidersByFamilyPrecedence(tc.input)
89+
90+
require.Equal(t, tc.want, got)
91+
})
92+
}
93+
}

coderd/database/dbauthz/dbauthz.go

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1731,6 +1731,13 @@ func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLog
17311731
return q.db.CountAuthorizedAuditLogs(ctx, arg, prep)
17321732
}
17331733

1734+
func (q *querier) CountChatProvidersByProviderExcludingID(ctx context.Context, arg database.CountChatProvidersByProviderExcludingIDParams) (int32, error) {
1735+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
1736+
return 0, err
1737+
}
1738+
return q.db.CountChatProvidersByProviderExcludingID(ctx, arg)
1739+
}
1740+
17341741
func (q *querier) CountConnectionLogs(ctx context.Context, arg database.CountConnectionLogsParams) (int64, error) {
17351742
// Just like the actual query, shortcut if the user is an owner.
17361743
err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceConnectionLog)
@@ -2671,13 +2678,6 @@ func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (databa
26712678
return q.db.GetChatProviderByID(ctx, id)
26722679
}
26732680

2674-
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
2675-
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
2676-
return database.ChatProvider{}, err
2677-
}
2678-
return q.db.GetChatProviderByProvider(ctx, provider)
2679-
}
2680-
26812681
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
26822682
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
26832683
return nil, err
@@ -2875,6 +2875,13 @@ func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.Ch
28752875
return q.db.GetEnabledChatModelConfigs(ctx)
28762876
}
28772877

2878+
func (q *querier) GetEnabledChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
2879+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
2880+
return database.ChatProvider{}, err
2881+
}
2882+
return q.db.GetEnabledChatProviderByProvider(ctx, provider)
2883+
}
2884+
28782885
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
28792886
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
28802887
return nil, err
@@ -5641,6 +5648,13 @@ func (q *querier) SelectUsageEventsForPublishing(ctx context.Context, arg time.T
56415648
return q.db.SelectUsageEventsForPublishing(ctx, arg)
56425649
}
56435650

5651+
func (q *querier) SoftDeleteBoundChatModelConfigsByProviderConfigID(ctx context.Context, providerConfigID uuid.UUID) (int64, error) {
5652+
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
5653+
return 0, err
5654+
}
5655+
return q.db.SoftDeleteBoundChatModelConfigsByProviderConfigID(ctx, providerConfigID)
5656+
}
5657+
56445658
func (q *querier) SoftDeleteChatMessageByID(ctx context.Context, id int64) error {
56455659
msg, err := q.db.GetChatMessageByID(ctx, id)
56465660
if err != nil {
@@ -5667,6 +5681,13 @@ func (q *querier) SoftDeleteChatMessagesAfterID(ctx context.Context, arg databas
56675681
return q.db.SoftDeleteChatMessagesAfterID(ctx, arg)
56685682
}
56695683

5684+
func (q *querier) SoftDeleteUnboundChatModelConfigsByProvider(ctx context.Context, provider string) (int64, error) {
5685+
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
5686+
return 0, err
5687+
}
5688+
return q.db.SoftDeleteUnboundChatModelConfigsByProvider(ctx, provider)
5689+
}
5690+
56705691
func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) {
56715692
return q.db.TryAcquireLock(ctx, id)
56725693
}

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -644,12 +644,6 @@ func (s *MethodTestSuite) TestChats() {
644644
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
645645
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
646646
}))
647-
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
648-
providerName := "test-provider"
649-
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
650-
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
651-
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
652-
}))
653647
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
654648
providerA := testutil.Fake(s.T(), faker, database.ChatProvider{})
655649
providerB := testutil.Fake(s.T(), faker, database.ChatProvider{})

0 commit comments

Comments
 (0)