Skip to content

Commit 79e007c

Browse files
authored
feat: hot-reload aibridged and aibridgeproxyd providers on DB changes (#25673)
Previously the in-process aibridge daemon and the enterprise aibridgeproxy daemon both snapshotted their provider routing once at boot. Any `ai_providers` or `ai_provider_keys` mutation required a restart for either to pick it up. Add an `ai_providers_changed` pubsub channel that the CRUD handlers publish on after Create / Update / Delete. Both daemons subscribe: - **aibridged** rebuilds its `[]aibridge.Provider` snapshot via `BuildProviders` and swaps it into the pool atomically. Inflight requests keep serving against the bridge they already acquired; new acquires build against the new snapshot. Per-provider construction errors stay scoped to the offending row. - **aibridgeproxyd** rebuilds its routing snapshot from `GetAIProviders` and swaps the host→provider map atomically. The MITM listener picks up new providers without restart. DB read for aibridgeproxyd uses the existing `AsAIProviderMetadataReader` subject for routing-only access.
1 parent 6acfe6c commit 79e007c

19 files changed

Lines changed: 1677 additions & 226 deletions

cli/aibridged.go

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@ import (
2424
"github.com/coder/quartz"
2525
)
2626

27-
func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*aibridged.Server, error) {
27+
// newAIBridgeDaemon constructs the in-memory aibridge daemon and wires
28+
// up a subscription that hot-reloads the provider pool from the
29+
// database on every ai_providers change event. The returned unsubscribe
30+
// function tears down the subscription; callers must invoke it
31+
// alongside Server.Close on shutdown.
32+
func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider, cfg codersdk.AIBridgeConfig) (*aibridged.Server, func(), error) {
2833
ctx := context.Background()
2934
coderAPI.Logger.Debug(ctx, "starting in-memory aibridge daemon")
3035

@@ -37,17 +42,58 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai
3742
// Create pool for reusable stateful [aibridge.RequestBridge] instances (one per user).
3843
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool"), metrics, tracer) // TODO: configurable size.
3944
if err != nil {
40-
return nil, xerrors.Errorf("create request pool: %w", err)
45+
return nil, nil, xerrors.Errorf("create request pool: %w", err)
46+
}
47+
48+
// Subscribe to ai_providers change events so the pool tracks the
49+
// database without a restart. The boot-time `providers` snapshot
50+
// derives from env config and serves as a fallback if the database
51+
// load fails inside the reloader.
52+
reloader := &poolDBReloader{
53+
pool: pool,
54+
db: coderAPI.Database,
55+
cfg: cfg,
56+
logger: logger.Named("provider-loader"),
57+
}
58+
unsubscribe, err := aibridged.SubscribeProviderReload(ctx, coderAPI.Pubsub, reloader, logger.Named("provider-reload"))
59+
if err != nil {
60+
// Pool is still usable with the boot-time snapshot; subscription
61+
// failure is logged but not fatal so the daemon still serves.
62+
logger.Warn(ctx, "subscribe to ai providers change channel", slog.Error(err))
63+
unsubscribe = func() {}
4164
}
4265

4366
// Create daemon.
4467
srv, err := aibridged.New(ctx, pool, func(dialCtx context.Context) (aibridged.DRPCClient, error) {
4568
return coderAPI.CreateInMemoryAIBridgeServer(dialCtx)
4669
}, logger, tracer)
4770
if err != nil {
48-
return nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
71+
unsubscribe()
72+
return nil, nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
73+
}
74+
return srv, unsubscribe, nil
75+
}
76+
77+
// poolDBReloader implements [aibridged.ProviderReloader] by loading
78+
// the live provider set from the database and forwarding it to the
79+
// pool.
80+
type poolDBReloader struct {
81+
pool *aibridged.CachedBridgePool
82+
db database.Store
83+
cfg codersdk.AIBridgeConfig
84+
logger slog.Logger
85+
}
86+
87+
func (r *poolDBReloader) Reload(ctx context.Context) error {
88+
providers, err := BuildProviders(ctx, r.db, r.cfg, r.logger)
89+
if err != nil {
90+
// Keep the previous snapshot in place: dropping all providers
91+
// because the DB read failed would compound the visible failure
92+
// mode beyond the operator's actual misconfiguration.
93+
return xerrors.Errorf("load ai providers from database: %w", err)
4994
}
50-
return srv, nil
95+
r.pool.ReplaceProviders(providers)
96+
return nil
5197
}
5298

5399
// BuildProviders loads every enabled ai_providers row, attaches its

cli/server.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1046,7 +1046,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
10461046
if err != nil {
10471047
return xerrors.Errorf("build AI providers: %w", err)
10481048
}
1049-
aibridgeDaemon, err = newAIBridgeDaemon(coderAPI, aibridgeProviders)
1049+
var unsubscribeProviderReload func()
1050+
aibridgeDaemon, unsubscribeProviderReload, err = newAIBridgeDaemon(coderAPI, aibridgeProviders, vals.AI.BridgeConfig)
10501051
if err != nil {
10511052
return xerrors.Errorf("create aibridged: %w", err)
10521053
}
@@ -1055,6 +1056,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
10551056
// daemon does not affect in-flight requests but is needed to
10561057
// release pool/recorder resources at shutdown.
10571058
defer aibridgeDaemon.Close()
1059+
defer unsubscribeProviderReload()
10581060
}
10591061

10601062
if vals.Prometheus.Enable {

coderd/ai_providers.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"github.com/coder/coder/v2/coderd/database/dbtime"
2222
"github.com/coder/coder/v2/coderd/httpapi"
2323
"github.com/coder/coder/v2/coderd/httpmw"
24+
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
2425
"github.com/coder/coder/v2/coderd/util/ptr"
2526
"github.com/coder/coder/v2/codersdk"
2627
)
@@ -235,6 +236,7 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) {
235236
aReq.New = row
236237

237238
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, aiProviderKeyChanges{Added: keys})
239+
api.publishAIProvidersChanged(ctx)
238240

239241
sdk, err := db2sdk.AIProvider(row, keys)
240242
if err != nil {
@@ -400,6 +402,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) {
400402
}
401403

402404
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, keyChanges)
405+
api.publishAIProvidersChanged(ctx)
403406

404407
sdk, err := db2sdk.AIProvider(updated, keys)
405408
if err != nil {
@@ -453,9 +456,25 @@ func (api *API) aiProvidersDelete(rw http.ResponseWriter, r *http.Request) {
453456
return
454457
}
455458

459+
api.publishAIProvidersChanged(ctx)
460+
456461
rw.WriteHeader(http.StatusNoContent)
457462
}
458463

464+
// publishAIProvidersChanged notifies subscribers (aibridged,
465+
// aibridgeproxyd) that the live provider set changed and they should
466+
// refetch from the database. Pubsub failures are logged but not
467+
// propagated: subscribers refresh authoritatively from the DB, so a
468+
// dropped notification only delays convergence.
469+
func (api *API) publishAIProvidersChanged(ctx context.Context) {
470+
if api.Pubsub == nil {
471+
return
472+
}
473+
if err := api.Pubsub.Publish(coderpubsub.AIProvidersChangedChannel, nil); err != nil {
474+
api.Logger.Warn(ctx, "publish ai providers changed event", slog.Error(err))
475+
}
476+
}
477+
459478
// errBedrockRejectsAPIKeys is the sentinel returned from inside the
460479
// update transaction when a caller attempts to attach api_keys to a
461480
// Bedrock-typed provider; the outer handler translates it into a 400.

coderd/ai_providers_pubsub_test.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package coderd_test
2+
3+
import (
4+
"context"
5+
"sync/atomic"
6+
"testing"
7+
8+
"github.com/stretchr/testify/require"
9+
10+
"github.com/coder/coder/v2/coderd/coderdtest"
11+
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
12+
"github.com/coder/coder/v2/codersdk"
13+
"github.com/coder/coder/v2/testutil"
14+
)
15+
16+
// TestAIProvidersChangedPubsub asserts that the CRUD handlers publish
17+
// on AIProvidersChangedChannel for the operations that affect the
18+
// runtime provider set. Subscribers (aibridged, aibridgeproxyd) depend
19+
// on these notifications to trigger their pool reload.
20+
//
21+
// The handlers publish best-effort and the payload is empty, so we
22+
// assert "at least one event per mutation" via a counter.
23+
func TestAIProvidersChangedPubsub(t *testing.T) {
24+
t.Parallel()
25+
26+
client, _, api := coderdtest.NewWithAPI(t, nil)
27+
_ = coderdtest.CreateFirstUser(t, client)
28+
ctx := testutil.Context(t, testutil.WaitLong)
29+
30+
var count atomic.Int64
31+
unsubscribe, err := api.Pubsub.Subscribe(coderpubsub.AIProvidersChangedChannel, func(_ context.Context, _ []byte) {
32+
count.Add(1)
33+
})
34+
require.NoError(t, err)
35+
t.Cleanup(unsubscribe)
36+
37+
// Create.
38+
req := codersdk.CreateAIProviderRequest{
39+
Type: codersdk.AIProviderTypeOpenAI,
40+
Name: "pubsub-openai",
41+
Enabled: true,
42+
BaseURL: "https://api.openai.com/v1/",
43+
APIKeys: []string{"k1"},
44+
}
45+
//nolint:gocritic // Owner role is the audience for this endpoint.
46+
created, err := client.CreateAIProvider(ctx, req)
47+
require.NoError(t, err)
48+
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 1 }, testutil.IntervalFast)
49+
50+
// Update.
51+
newKey := "k2"
52+
_, err = client.UpdateAIProvider(ctx, created.ID.String(), codersdk.UpdateAIProviderRequest{
53+
APIKeys: &[]codersdk.AIProviderKeyMutation{{APIKey: &newKey}},
54+
})
55+
require.NoError(t, err)
56+
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 2 }, testutil.IntervalFast)
57+
58+
// Delete.
59+
err = client.DeleteAIProvider(ctx, created.ID.String())
60+
require.NoError(t, err)
61+
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 3 }, testutil.IntervalFast)
62+
}

coderd/aibridged/aibridgedmock/poolmock.go

Lines changed: 13 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

coderd/aibridged/pool.go

Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@ package aibridged
33
import (
44
"context"
55
"net/http"
6+
"slices"
7+
"strconv"
68
"sync"
9+
"sync/atomic"
710
"time"
811

912
"github.com/dgraph-io/ristretto/v2"
@@ -26,6 +29,9 @@ const (
2629
// One [*aibridge.RequestBridge] instance is created per given key.
2730
type Pooler interface {
2831
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
32+
// ReplaceProviders swaps the providers used to construct future
33+
// RequestBridge instances and clears the cache.
34+
ReplaceProviders(providers []aibridge.Provider)
2935
Shutdown(ctx context.Context) error
3036
}
3137

@@ -46,10 +52,12 @@ var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
4652
var _ Pooler = &CachedBridgePool{}
4753

4854
type CachedBridgePool struct {
49-
cache *ristretto.Cache[string, *aibridge.RequestBridge]
50-
providers []aibridge.Provider
51-
logger slog.Logger
52-
options PoolOptions
55+
cache *ristretto.Cache[string, *aibridge.RequestBridge]
56+
// providers is the live provider set used by new RequestBridge instances.
57+
providers atomic.Pointer[[]aibridge.Provider]
58+
providerVersion atomic.Int64
59+
logger slog.Logger
60+
options PoolOptions
5361

5462
singleflight *singleflight.Group[string, *aibridge.RequestBridge]
5563

@@ -71,32 +79,70 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
7179
if item == nil || item.Value == nil {
7280
return
7381
}
74-
75-
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*5)
76-
defer shutdownCancel()
77-
78-
// Run the eviction in the background since ristretto blocks sets until a free slot is available.
82+
// Capture the value synchronously: ristretto reuses the
83+
// item slot after OnEvict returns, so reading item.Value
84+
// from the goroutine below races with the caller of
85+
// Clear/Set. The shutdown still runs in the background to
86+
// avoid blocking ristretto's eviction loop.
87+
bridge := item.Value
7988
go func() {
80-
_ = item.Value.Shutdown(shutdownCtx)
89+
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
90+
defer cancel()
91+
_ = bridge.Shutdown(shutdownCtx)
8192
}()
8293
},
8394
})
8495
if err != nil {
8596
return nil, xerrors.Errorf("create cache: %w", err)
8697
}
8798

88-
return &CachedBridgePool{
89-
cache: cache,
90-
providers: providers,
91-
options: options,
92-
metrics: metrics,
93-
tracer: tracer,
94-
logger: logger,
99+
pool := &CachedBridgePool{
100+
cache: cache,
101+
options: options,
102+
metrics: metrics,
103+
tracer: tracer,
104+
logger: logger,
95105

96106
singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},
97107

98108
shuttingDownCh: make(chan struct{}),
99-
}, nil
109+
}
110+
initial := slices.Clone(providers)
111+
pool.providers.Store(&initial)
112+
return pool, nil
113+
}
114+
115+
// ReplaceProviders swaps the provider snapshot used by future Acquires.
116+
// It is safe to call concurrently with Acquire and is a no-op after
117+
// Shutdown.
118+
func (p *CachedBridgePool) ReplaceProviders(providers []aibridge.Provider) {
119+
select {
120+
case <-p.shuttingDownCh:
121+
return
122+
default:
123+
}
124+
snapshot := slices.Clone(providers)
125+
p.providers.Store(&snapshot)
126+
version := time.Now().UnixNano()
127+
p.providerVersion.Store(version)
128+
// Clear evicts every cached bridge; OnEvict shuts each one down in
129+
// the background. Wait for buffered writes to drain so a replacement
130+
// immediately followed by an Acquire always sees the cleared cache.
131+
p.cache.Clear()
132+
p.cache.Wait()
133+
p.logger.Info(context.Background(), "request bridge pool reloaded",
134+
slog.F("provider_count", len(snapshot)),
135+
slog.F("provider_version", version),
136+
)
137+
}
138+
139+
// loadProviders returns the current providers snapshot. The returned
140+
// slice must not be mutated.
141+
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
142+
if ptr := p.providers.Load(); ptr != nil {
143+
return *ptr
144+
}
145+
return nil
100146
}
101147

102148
// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
@@ -140,6 +186,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
140186
}
141187

142188
span.AddEvent("cache_miss")
189+
providerVersion := p.providerVersion.Load()
143190
recorder := aibridge.NewRecorder(p.logger.Named("recorder"), p.tracer, func() (aibridge.Recorder, error) {
144191
client, err := clientFn()
145192
if err != nil {
@@ -152,7 +199,8 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
152199
// Slow path.
153200
// Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value.
154201
// TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs).
155-
instance, err, _ := p.singleflight.Do(req.InitiatorID.String(), func() (*aibridge.RequestBridge, error) {
202+
singleflightKey := cacheKey + "|" + strconv.FormatInt(providerVersion, 10)
203+
instance, err, _ := p.singleflight.Do(singleflightKey, func() (*aibridge.RequestBridge, error) {
156204
var (
157205
mcpServers mcp.ServerProxier
158206
err error
@@ -171,12 +219,14 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
171219
}
172220
}
173221

174-
bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.logger, p.metrics, p.tracer)
222+
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
175223
if err != nil {
176224
return nil, xerrors.Errorf("create new request bridge: %w", err)
177225
}
178226

179-
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
227+
if p.providerVersion.Load() == providerVersion {
228+
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
229+
}
180230

181231
return bridge, nil
182232
})

0 commit comments

Comments
 (0)