Skip to content

Commit fe2d5e5

Browse files
committed
feat(enterprise/aibridged): hot-reload provider pool and keys from DB on pubsub
Switches the in-memory aibridged daemon from a static, env-derived provider list to a database-backed list that hot-reloads via pubsub. After this PR: - aibridged loads providers from ai_providers at startup (system actor, dbauthz-gated) and joins them with ai_provider_keys to pick the operator-preferred primary key (first by created_at). - Non-Bedrock providers with zero ai_provider_keys are skipped with a warning; Bedrock providers always have zero keys and authenticate via the encrypted settings blob (AWS access key + secret). - The CRUD handlers from the previous PR publish on 'ai_providers_changed' after every successful Insert/Update/ SoftDelete of a provider AND after every Insert/Delete of a key, because key changes alone affect the runtime pool. - Each replica subscribes to that channel and triggers aibridged.Server.Reload, which atomically swaps the providers slice on the pool and clears the cached RequestBridge instances. - In-flight requests continue against their existing RequestBridge until completion; the cache's OnEvict shutdown closes MCP connections in the background after a 5-second grace period. The proxy daemon is intentionally NOT reloaded yet to keep this PR focused; it still receives the boot-time provider snapshot. A follow-up will introduce a Pooler interface for the proxy and mirror this pattern. Pool changes: - CachedBridgePool stores providers via atomic.Pointer[[]Provider] instead of a fixed slice. - New Reload(providers) method on the Pooler interface that atomically swaps the snapshot, calls cache.Clear, and waits for buffered writes to drain so a subsequent Acquire always sees the new set. Tests: - TestPoolReload covers the happy path: build a pool, acquire a bridge, Reload, ensure the next Acquire targets the new provider set. - TestPoolReloadAfterShutdown ensures Reload is a no-op post-Close so a stale subscriber notification cannot resurrect a torn-down pool. - TestAIProvidersPubsubPublish exercises the producer side: each of Insert/Update/Delete on a provider emits a notification on AIBridgeProvidersChangedChannel. - TestAIProviderKeysPubsubPublish does the same for the keys sub-resource (Insert and Delete).
1 parent eb0f556 commit fe2d5e5

9 files changed

Lines changed: 604 additions & 14 deletions

File tree

enterprise/aibridged/aibridged.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"golang.org/x/xerrors"
1313

1414
"cdr.dev/slog/v3"
15+
"github.com/coder/coder/v2/aibridge"
1516
"github.com/coder/coder/v2/codersdk"
1617
"github.com/coder/retry"
1718
)
@@ -154,6 +155,22 @@ func (s *Server) GetRequestHandler(ctx context.Context, req Request) (http.Handl
154155
return reqBridge, nil
155156
}
156157

158+
// Reload swaps the providers used to construct future RequestBridge
159+
// instances and invalidates the existing cache. It is the entry
160+
// point that the CLI subscribes to ai_providers_changed pubsub
161+
// events on.
162+
//
163+
// Reload is safe to call concurrently with serving requests; existing
164+
// in-flight requests continue against their previously-cached
165+
// bridge until completion, while subsequent requests get a freshly-
166+
// built bridge using the new provider set.
167+
func (s *Server) Reload(providers []aibridge.Provider) {
168+
if s.requestBridgePool == nil {
169+
return
170+
}
171+
s.requestBridgePool.Reload(providers)
172+
}
173+
157174
// isShutdown returns whether the Server is shutdown or not.
158175
func (s *Server) isShutdown() bool {
159176
select {

enterprise/aibridged/aibridgedmock/poolmock.go

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

enterprise/aibridged/pool.go

Lines changed: 55 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"net/http"
66
"sync"
7+
"sync/atomic"
78
"time"
89

910
"github.com/dgraph-io/ristretto/v2"
@@ -26,6 +27,7 @@ const (
2627
// One [*aibridge.RequestBridge] instance is created per given key.
2728
type Pooler interface {
2829
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
30+
Reload(providers []aibridge.Provider)
2931
Shutdown(ctx context.Context) error
3032
}
3133

@@ -46,8 +48,12 @@ var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
4648
var _ Pooler = &CachedBridgePool{}
4749

4850
type CachedBridgePool struct {
49-
cache *ristretto.Cache[string, *aibridge.RequestBridge]
50-
providers []aibridge.Provider
51+
cache *ristretto.Cache[string, *aibridge.RequestBridge]
52+
// providers holds an atomic slice of aibridge.Provider. Use
53+
// loadProviders() to read and Reload() to swap. The atomic
54+
// indirection lets us hot-swap the live provider set in
55+
// response to configuration changes without locking the cache.
56+
providers atomic.Pointer[[]aibridge.Provider]
5157
logger slog.Logger
5258
options PoolOptions
5359

@@ -85,18 +91,20 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
8591
return nil, xerrors.Errorf("create cache: %w", err)
8692
}
8793

88-
return &CachedBridgePool{
89-
cache: cache,
90-
providers: providers,
91-
options: options,
92-
metrics: metrics,
93-
tracer: tracer,
94-
logger: logger,
94+
pool := &CachedBridgePool{
95+
cache: cache,
96+
options: options,
97+
metrics: metrics,
98+
tracer: tracer,
99+
logger: logger,
95100

96101
singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},
97102

98103
shuttingDownCh: make(chan struct{}),
99-
}, nil
104+
}
105+
copied := append([]aibridge.Provider(nil), providers...)
106+
pool.providers.Store(&copied)
107+
return pool, nil
100108
}
101109

102110
// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
@@ -171,7 +179,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
171179
}
172180
}
173181

174-
bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.logger, p.metrics, p.tracer)
182+
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
175183
if err != nil {
176184
return nil, xerrors.Errorf("create new request bridge: %w", err)
177185
}
@@ -184,6 +192,42 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
184192
return instance, err
185193
}
186194

195+
// Reload swaps the providers used to construct future RequestBridge
196+
// instances and clears the cache. Existing in-flight requests
197+
// continue against their previously-cached bridge until completion;
198+
// the next Acquire returns a freshly-built bridge using the new
199+
// providers slice.
200+
//
201+
// Reload is safe to call concurrently with Acquire.
202+
func (p *CachedBridgePool) Reload(providers []aibridge.Provider) {
203+
select {
204+
case <-p.shuttingDownCh:
205+
return
206+
default:
207+
}
208+
copied := append([]aibridge.Provider(nil), providers...)
209+
p.providers.Store(&copied)
210+
// Clear evicts every cached bridge; OnEvict will gracefully
211+
// shut each one down in the background. Wait for buffered
212+
// writes to drain so a Reload immediately followed by an
213+
// Acquire always sees the cleared cache.
214+
p.cache.Clear()
215+
p.cache.Wait()
216+
p.logger.Info(context.Background(), "request bridge pool reloaded",
217+
slog.F("provider_count", len(copied)),
218+
)
219+
}
220+
221+
// loadProviders returns the current providers slice. The returned
222+
// slice must not be mutated.
223+
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
224+
ptr := p.providers.Load()
225+
if ptr == nil {
226+
return nil
227+
}
228+
return *ptr
229+
}
230+
187231
func (p *CachedBridgePool) CacheMetrics() PoolMetrics {
188232
if p.cache == nil {
189233
return nil
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
package aibridged_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
"time"
7+
8+
"github.com/google/uuid"
9+
"github.com/stretchr/testify/require"
10+
"go.uber.org/mock/gomock"
11+
12+
"cdr.dev/slog/v3/sloggers/slogtest"
13+
"github.com/coder/coder/v2/aibridge"
14+
"github.com/coder/coder/v2/aibridge/mcpmock"
15+
"github.com/coder/coder/v2/enterprise/aibridged"
16+
mock "github.com/coder/coder/v2/enterprise/aibridged/aibridgedmock"
17+
)
18+
19+
// TestPoolReload exercises CachedBridgePool.Reload, ensuring that
20+
// the cache is cleared after a hot-swap so the next Acquire builds a
21+
// fresh RequestBridge against the new providers.
22+
func TestPoolReload(t *testing.T) {
23+
t.Parallel()
24+
25+
logger := slogtest.Make(t, nil)
26+
ctrl := gomock.NewController(t)
27+
client := mock.NewMockDRPCClient(ctrl)
28+
mcpProxy := mcpmock.NewMockServerProxier(ctrl)
29+
clientFn := func() (aibridged.DRPCClient, error) {
30+
return client, nil
31+
}
32+
33+
mcpProxy.EXPECT().Init(gomock.Any()).AnyTimes().Return(nil)
34+
mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil)
35+
36+
opts := aibridged.PoolOptions{MaxItems: 8, TTL: time.Minute}
37+
pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{
38+
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{
39+
Name: "openai",
40+
BaseURL: "https://api.openai.com/v1",
41+
Key: "sk-old",
42+
}),
43+
}, logger, nil, testTracer)
44+
require.NoError(t, err)
45+
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
46+
47+
id, apiKeyID := uuid.New(), uuid.New()
48+
req := aibridged.Request{InitiatorID: id, APIKeyID: apiKeyID.String()}
49+
50+
// Prime the cache.
51+
_, err = pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy))
52+
require.NoError(t, err)
53+
54+
cm := pool.CacheMetrics()
55+
require.EqualValues(t, 1, cm.Misses())
56+
require.EqualValues(t, 1, cm.KeysAdded())
57+
58+
// Reload with a new provider set. ristretto.Cache.Clear()
59+
// resets metrics to zero alongside emptying the cache.
60+
pool.Reload([]aibridge.Provider{
61+
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{
62+
Name: "openai",
63+
BaseURL: "https://api.openai.com/v1",
64+
Key: "sk-new",
65+
}),
66+
})
67+
68+
cm = pool.CacheMetrics()
69+
require.EqualValues(t, 0, cm.KeysAdded(), "expected metrics to be reset after Reload")
70+
71+
// After Reload, the next Acquire should be a miss because the
72+
// cache was cleared, and a new RequestBridge gets built.
73+
_, err = pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy))
74+
require.NoError(t, err)
75+
76+
cm = pool.CacheMetrics()
77+
require.EqualValues(t, 1, cm.Misses(), "expected one cache miss after reload")
78+
require.EqualValues(t, 1, cm.KeysAdded(), "expected new key added after reload")
79+
80+
// Wait briefly for ristretto's eviction goroutines (spawned by
81+
// the OnEvict callback during Reload) to settle so gomock's
82+
// teardown does not race with their Shutdown calls. The Shutdown
83+
// expectation is set with AnyTimes() so the assertion does not
84+
// require an exact count, but ctrl.Finish does need to see the
85+
// call complete.
86+
time.Sleep(100 * time.Millisecond)
87+
}
88+
89+
// TestPoolReloadAfterShutdown verifies Reload is a no-op when the
90+
// pool has already been shut down.
91+
func TestPoolReloadAfterShutdown(t *testing.T) {
92+
t.Parallel()
93+
94+
logger := slogtest.Make(t, nil)
95+
pool, err := aibridged.NewCachedBridgePool(
96+
aibridged.DefaultPoolOptions, nil, logger, nil, testTracer,
97+
)
98+
require.NoError(t, err)
99+
require.NoError(t, pool.Shutdown(context.Background()))
100+
101+
// Should not panic or hang.
102+
pool.Reload([]aibridge.Provider{
103+
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{
104+
Name: "openai",
105+
BaseURL: "https://api.openai.com/v1",
106+
Key: "sk",
107+
}),
108+
})
109+
}

0 commit comments

Comments
 (0)