Skip to content

Commit 2821e52

Browse files
committed
fix(coderd): address review findings for multi-config providers
1 parent 2667354 commit 2821e52

12 files changed

Lines changed: 121 additions & 69 deletions

File tree

coderd/database/dbauthz/dbauthz.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2640,13 +2640,6 @@ func (q *querier) GetChatProviderByID(ctx context.Context, id uuid.UUID) (databa
26402640
return q.db.GetChatProviderByID(ctx, id)
26412641
}
26422642

2643-
func (q *querier) GetChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
2644-
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
2645-
return database.ChatProvider{}, err
2646-
}
2647-
return q.db.GetChatProviderByProvider(ctx, provider)
2648-
}
2649-
26502643
func (q *querier) GetChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
26512644
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
26522645
return nil, err
@@ -2820,6 +2813,13 @@ func (q *querier) GetEnabledChatModelConfigs(ctx context.Context) ([]database.Ch
28202813
return q.db.GetEnabledChatModelConfigs(ctx)
28212814
}
28222815

2816+
func (q *querier) GetEnabledChatProviderByProvider(ctx context.Context, provider string) (database.ChatProvider, error) {
2817+
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
2818+
return database.ChatProvider{}, err
2819+
}
2820+
return q.db.GetEnabledChatProviderByProvider(ctx, provider)
2821+
}
2822+
28232823
func (q *querier) GetEnabledChatProviders(ctx context.Context) ([]database.ChatProvider, error) {
28242824
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
28252825
return nil, err

coderd/database/dbauthz/dbauthz_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -617,10 +617,10 @@ func (s *MethodTestSuite) TestChats() {
617617
dbm.EXPECT().GetChatProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
618618
check.Args(provider.ID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
619619
}))
620-
s.Run("GetChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
620+
s.Run("GetEnabledChatProviderByProvider", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
621621
providerName := "test-provider"
622622
provider := testutil.Fake(s.T(), faker, database.ChatProvider{Provider: providerName})
623-
dbm.EXPECT().GetChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
623+
dbm.EXPECT().GetEnabledChatProviderByProvider(gomock.Any(), providerName).Return(provider, nil).AnyTimes()
624624
check.Args(providerName).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(provider)
625625
}))
626626
s.Run("GetChatProviders", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {

coderd/database/dbmetrics/querymetrics.go

Lines changed: 8 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/dbmock/dbmock.go

Lines changed: 15 additions & 15 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/querier.go

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries.sql.go

Lines changed: 28 additions & 28 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/database/queries/chatproviders.sql

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ FROM
66
WHERE
77
id = @id::uuid;
88

9-
-- name: GetChatProviderByProvider :one
9+
-- name: GetEnabledChatProviderByProvider :one
1010
SELECT
1111
*
1212
FROM

coderd/exp_chats.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3653,15 +3653,15 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
36533653
}
36543654

36553655
providersByName := make(map[string]database.ChatProvider, len(providers))
3656-
configuredFamilies := make(map[string]bool, len(providers))
3656+
configuredFamilies := make(map[string]struct{}, len(providers))
36573657
for i := range providers {
36583658
normalizedProvider := normalizeChatProvider(providers[i].Provider)
36593659
if normalizedProvider == "" {
36603660
continue
36613661
}
36623662
providers[i].Provider = normalizedProvider
36633663
providersByName[normalizedProvider] = providers[i]
3664-
configuredFamilies[normalizedProvider] = true
3664+
configuredFamilies[normalizedProvider] = struct{}{}
36653665
}
36663666
configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providersByName))
36673667
for _, provider := range providersByName {
@@ -3727,7 +3727,7 @@ func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
37273727
)
37283728
}
37293729
for _, provider := range supportedProviders {
3730-
if configuredFamilies[provider] {
3730+
if _, ok := configuredFamilies[provider]; ok {
37313731
continue
37323732
}
37333733

coderd/exp_chats_test.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1117,6 +1117,58 @@ func TestUpdateChatProvider(t *testing.T) {
11171117
require.Equal(t, "Only one enabled provider config per provider family is allowed.", sdkErr.Message)
11181118
})
11191119

1120+
t.Run("FlipEnabled", func(t *testing.T) {
1121+
t.Parallel()
1122+
1123+
ctx := testutil.Context(t, testutil.WaitLong)
1124+
client := newChatClient(t)
1125+
_ = coderdtest.CreateFirstUser(t, client.Client)
1126+
1127+
first, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
1128+
Provider: "openai",
1129+
DisplayName: "OpenAI Primary",
1130+
APIKey: "key-1",
1131+
})
1132+
require.NoError(t, err)
1133+
1134+
disabled := false
1135+
second, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
1136+
Provider: "openai",
1137+
DisplayName: "OpenAI Secondary",
1138+
APIKey: "key-2",
1139+
Enabled: &disabled,
1140+
})
1141+
require.NoError(t, err)
1142+
1143+
firstDisabled, err := client.UpdateChatProvider(ctx, first.ID, codersdk.UpdateChatProviderConfigRequest{
1144+
DisplayName: first.DisplayName,
1145+
Enabled: &disabled,
1146+
})
1147+
require.NoError(t, err)
1148+
require.False(t, firstDisabled.Enabled)
1149+
1150+
enabled := true
1151+
secondEnabled, err := client.UpdateChatProvider(ctx, second.ID, codersdk.UpdateChatProviderConfigRequest{
1152+
DisplayName: second.DisplayName,
1153+
Enabled: &enabled,
1154+
})
1155+
require.NoError(t, err)
1156+
require.True(t, secondEnabled.Enabled)
1157+
1158+
providers, err := client.ListChatProviders(ctx)
1159+
require.NoError(t, err)
1160+
1161+
providerByID := make(map[uuid.UUID]codersdk.ChatProviderConfig, len(providers))
1162+
for _, provider := range providers {
1163+
providerByID[provider.ID] = provider
1164+
}
1165+
1166+
require.Contains(t, providerByID, first.ID)
1167+
require.Contains(t, providerByID, second.ID)
1168+
require.False(t, providerByID[first.ID].Enabled)
1169+
require.True(t, providerByID[second.ID].Enabled)
1170+
})
1171+
11201172
t.Run("NotFound", func(t *testing.T) {
11211173
t.Parallel()
11221174

coderd/x/chatd/chatd_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2548,7 +2548,7 @@ func setOpenAIProviderBaseURL(
25482548
) {
25492549
t.Helper()
25502550

2551-
provider, err := db.GetChatProviderByProvider(ctx, "openai")
2551+
provider, err := db.GetEnabledChatProviderByProvider(ctx, "openai")
25522552
require.NoError(t, err)
25532553

25542554
_, err = db.UpdateChatProvider(ctx, database.UpdateChatProviderParams{

0 commit comments

Comments
 (0)