Skip to content

Commit f122ca9

Browse files
committed
feat(enterprise/aibridged): hot-reload provider pool and keys from DB on pubsub
Switches the in-memory aibridged daemon from a static, env-derived provider list to a database-backed list that hot-reloads via pubsub. After this PR: - aibridged loads providers from ai_providers at startup (system actor, dbauthz-gated) and joins them with ai_provider_keys to pick the operator-preferred primary key (first by created_at). - Non-Bedrock providers with zero ai_provider_keys are skipped with a warning; Bedrock providers always have zero keys and authenticate via the encrypted settings blob (AWS access key + secret). - The CRUD handlers from the previous PR publish on 'ai_providers_changed' after every successful Insert/Update/ SoftDelete of a provider AND after every Insert/Delete of a key, because key changes alone affect the runtime pool. - Each replica subscribes to that channel and triggers aibridged.Server.Reload, which atomically swaps the providers slice on the pool and clears the cached RequestBridge instances. - In-flight requests continue against their existing RequestBridge until completion; the cache's OnEvict shutdown closes MCP connections in the background after a 5-second grace period. The proxy daemon is intentionally NOT reloaded yet to keep this PR focused; it still receives the boot-time provider snapshot. A follow-up will introduce a Pooler interface for the proxy and mirror this pattern. Pool changes: - CachedBridgePool stores providers via atomic.Pointer[[]Provider] instead of a fixed slice. - New Reload(providers) method on the Pooler interface that atomically swaps the snapshot, calls cache.Clear, and waits for buffered writes to drain so a subsequent Acquire always sees the new set. Tests: - TestPoolReload covers the happy path: build a pool, acquire a bridge, Reload, ensure the next Acquire targets the new provider set. - TestPoolReloadAfterShutdown ensures Reload is a no-op post-Close so a stale subscriber notification cannot resurrect a torn-down pool. - TestAIProvidersPubsubPublish exercises the producer side: each of Insert/Update/Delete on a provider emits a notification on AIBridgeProvidersChangedChannel. - TestAIProviderKeysPubsubPublish does the same for the keys sub-resource (Insert and Delete).
1 parent 9cb8faa commit f122ca9

9 files changed

Lines changed: 493 additions & 13 deletions

File tree

coderd/ai_providers.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) {
235235
aReq.New = row
236236

237237
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, aiProviderKeyChanges{Added: keys})
238+
publishAIProvidersChanged(ctx, api.Pubsub, api.Logger)
238239

239240
sdk, err := db2sdk.AIProvider(row, keys)
240241
if err != nil {
@@ -400,6 +401,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) {
400401
}
401402

402403
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, keyChanges)
404+
publishAIProvidersChanged(ctx, api.Pubsub, api.Logger)
403405

404406
sdk, err := db2sdk.AIProvider(updated, keys)
405407
if err != nil {
@@ -453,6 +455,8 @@ func (api *API) aiProvidersDelete(rw http.ResponseWriter, r *http.Request) {
453455
return
454456
}
455457

458+
publishAIProvidersChanged(ctx, api.Pubsub, api.Logger)
459+
456460
rw.WriteHeader(http.StatusNoContent)
457461
}
458462

coderd/ai_providers_pubsub.go

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package coderd
2+
3+
import (
4+
"context"
5+
6+
"cdr.dev/slog/v3"
7+
"github.com/coder/coder/v2/coderd/database/pubsub"
8+
)
9+
10+
// AIProvidersChangedChannel is the pubsub channel published whenever an
11+
// ai_providers or ai_provider_keys row is inserted, updated, or
12+
// soft-deleted via the API. Subscribers (currently aibridged in every
13+
// replica, but the channel is provider-generic) refresh their cached
14+
// state by re-querying the database.
15+
//
16+
// Messages have no payload; receivers re-read the rows themselves.
17+
// This keeps the channel agnostic to dbcrypt-key changes and avoids
18+
// bus traffic carrying secrets.
19+
const AIProvidersChangedChannel = "ai_providers_changed"
20+
21+
// publishAIProvidersChanged publishes a notification on the providers-
22+
// changed channel. Errors are logged but never returned to callers; a
23+
// missed notification only delays consumers catching up to the new
24+
// state, and the next mutation will retry.
25+
func publishAIProvidersChanged(ctx context.Context, ps pubsub.Pubsub, logger slog.Logger) {
26+
if ps == nil {
27+
return
28+
}
29+
if err := ps.Publish(AIProvidersChangedChannel, nil); err != nil {
30+
logger.Warn(ctx, "failed to publish ai_providers_changed",
31+
slog.Error(err))
32+
}
33+
}

coderd/ai_providers_pubsub_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
package coderd_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/stretchr/testify/require"
8+
9+
"github.com/coder/coder/v2/coderd"
10+
"github.com/coder/coder/v2/coderd/coderdtest"
11+
"github.com/coder/coder/v2/coderd/database/dbtestutil"
12+
"github.com/coder/coder/v2/codersdk"
13+
"github.com/coder/coder/v2/testutil"
14+
)
15+
16+
// TestAIProvidersPubsubPublish verifies that mutating an AI provider
17+
// publishes on the AIProvidersChangedChannel so each replica's
18+
// RequestBridge pool can invalidate.
19+
func TestAIProvidersPubsubPublish(t *testing.T) {
20+
t.Parallel()
21+
22+
db, ps := dbtestutil.NewDB(t)
23+
client := coderdtest.New(t, &coderdtest.Options{
24+
Database: db,
25+
Pubsub: ps,
26+
})
27+
_ = coderdtest.CreateFirstUser(t, client)
28+
ctx := testutil.Context(t, testutil.WaitLong)
29+
30+
notified := make(chan struct{}, 4)
31+
cancel, err := ps.Subscribe(coderd.AIProvidersChangedChannel, func(_ context.Context, _ []byte) {
32+
select {
33+
case notified <- struct{}{}:
34+
default:
35+
}
36+
})
37+
require.NoError(t, err)
38+
t.Cleanup(cancel)
39+
40+
// Create publishes.
41+
//nolint:gocritic // Owner role is the audience for this endpoint.
42+
created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
43+
Type: codersdk.AIProviderTypeOpenAI,
44+
Name: "pubsub-test",
45+
Enabled: true,
46+
BaseURL: "https://api.openai.com/v1",
47+
})
48+
require.NoError(t, err)
49+
select {
50+
case <-notified:
51+
case <-ctx.Done():
52+
t.Fatalf("timed out waiting for pubsub notify after create")
53+
}
54+
55+
// Update publishes.
56+
display := "Renamed"
57+
//nolint:gocritic // Owner role is the audience for this endpoint.
58+
_, err = client.UpdateAIProvider(ctx, created.Name, codersdk.UpdateAIProviderRequest{
59+
DisplayName: &display,
60+
})
61+
require.NoError(t, err)
62+
select {
63+
case <-notified:
64+
case <-ctx.Done():
65+
t.Fatalf("timed out waiting for pubsub notify after update")
66+
}
67+
68+
// Delete publishes.
69+
//nolint:gocritic // Owner role is the audience for this endpoint.
70+
require.NoError(t, client.DeleteAIProvider(ctx, created.Name))
71+
select {
72+
case <-notified:
73+
case <-ctx.Done():
74+
t.Fatalf("timed out waiting for pubsub notify after delete")
75+
}
76+
}

enterprise/aibridged/aibridged.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"golang.org/x/xerrors"
1313

1414
"cdr.dev/slog/v3"
15+
"github.com/coder/coder/v2/aibridge"
1516
"github.com/coder/coder/v2/codersdk"
1617
"github.com/coder/retry"
1718
)
@@ -154,6 +155,22 @@ func (s *Server) GetRequestHandler(ctx context.Context, req Request) (http.Handl
154155
return reqBridge, nil
155156
}
156157

158+
// Reload swaps the providers used to construct future RequestBridge
159+
// instances and invalidates the existing cache. It is the entry
160+
// point that the CLI subscribes to ai_providers_changed pubsub
161+
// events on.
162+
//
163+
// Reload is safe to call concurrently with serving requests; existing
164+
// in-flight requests continue against their previously-cached
165+
// bridge until completion, while subsequent requests get a freshly-
166+
// built bridge using the new provider set.
167+
func (s *Server) Reload(providers []aibridge.Provider) {
168+
if s.requestBridgePool == nil {
169+
return
170+
}
171+
s.requestBridgePool.Reload(providers)
172+
}
173+
157174
// isShutdown returns whether the Server is shutdown or not.
158175
func (s *Server) isShutdown() bool {
159176
select {

enterprise/aibridged/aibridgedmock/poolmock.go

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enterprise/aibridged/pool.go

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"net/http"
66
"sync"
7+
"sync/atomic"
78
"time"
89

910
"github.com/dgraph-io/ristretto/v2"
@@ -26,6 +27,7 @@ const (
2627
// One [*aibridge.RequestBridge] instance is created per given key.
2728
type Pooler interface {
2829
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
30+
Reload(providers []aibridge.Provider)
2931
Shutdown(ctx context.Context) error
3032
}
3133

@@ -46,8 +48,12 @@ var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
4648
var _ Pooler = &CachedBridgePool{}
4749

4850
type CachedBridgePool struct {
49-
cache *ristretto.Cache[string, *aibridge.RequestBridge]
50-
providers []aibridge.Provider
51+
cache *ristretto.Cache[string, *aibridge.RequestBridge]
52+
// providers holds an atomic slice of aibridge.Provider. Use
53+
// loadProviders() to read and Reload() to swap. The atomic
54+
// indirection lets us hot-swap the live provider set in
55+
// response to configuration changes without locking the cache.
56+
providers atomic.Pointer[[]aibridge.Provider]
5157
logger slog.Logger
5258
options PoolOptions
5359

@@ -85,18 +91,20 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
8591
return nil, xerrors.Errorf("create cache: %w", err)
8692
}
8793

88-
return &CachedBridgePool{
89-
cache: cache,
90-
providers: providers,
91-
options: options,
92-
metrics: metrics,
93-
tracer: tracer,
94-
logger: logger,
94+
pool := &CachedBridgePool{
95+
cache: cache,
96+
options: options,
97+
metrics: metrics,
98+
tracer: tracer,
99+
logger: logger,
95100

96101
singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},
97102

98103
shuttingDownCh: make(chan struct{}),
99-
}, nil
104+
}
105+
copied := append([]aibridge.Provider(nil), providers...)
106+
pool.providers.Store(&copied)
107+
return pool, nil
100108
}
101109

102110
// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
@@ -171,7 +179,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
171179
}
172180
}
173181

174-
bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.logger, p.metrics, p.tracer)
182+
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
175183
if err != nil {
176184
return nil, xerrors.Errorf("create new request bridge: %w", err)
177185
}
@@ -184,6 +192,42 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
184192
return instance, err
185193
}
186194

195+
// Reload swaps the providers used to construct future RequestBridge
196+
// instances and clears the cache. Existing in-flight requests
197+
// continue against their previously-cached bridge until completion;
198+
// the next Acquire returns a freshly-built bridge using the new
199+
// providers slice.
200+
//
201+
// Reload is safe to call concurrently with Acquire.
202+
func (p *CachedBridgePool) Reload(providers []aibridge.Provider) {
203+
select {
204+
case <-p.shuttingDownCh:
205+
return
206+
default:
207+
}
208+
copied := append([]aibridge.Provider(nil), providers...)
209+
p.providers.Store(&copied)
210+
// Clear evicts every cached bridge; OnEvict will gracefully
211+
// shut each one down in the background. Wait for buffered
212+
// writes to drain so a Reload immediately followed by an
213+
// Acquire always sees the cleared cache.
214+
p.cache.Clear()
215+
p.cache.Wait()
216+
p.logger.Info(context.Background(), "request bridge pool reloaded",
217+
slog.F("provider_count", len(copied)),
218+
)
219+
}
220+
221+
// loadProviders returns the current providers slice. The returned
222+
// slice must not be mutated.
223+
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
224+
ptr := p.providers.Load()
225+
if ptr == nil {
226+
return nil
227+
}
228+
return *ptr
229+
}
230+
187231
func (p *CachedBridgePool) CacheMetrics() PoolMetrics {
188232
if p.cache == nil {
189233
return nil

0 commit comments

Comments
 (0)