Skip to content
54 changes: 50 additions & 4 deletions cli/aibridged.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@ import (
"github.com/coder/quartz"
)

func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*aibridged.Server, error) {
// newAIBridgeDaemon constructs the in-memory aibridge daemon and wires
// up a subscription that hot-reloads the provider pool from the
// database on every ai_providers change event. The returned unsubscribe
// function tears down the subscription; callers must invoke it
// alongside Server.Close on shutdown.
func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider, cfg codersdk.AIBridgeConfig) (*aibridged.Server, func(), error) {
ctx := context.Background()
coderAPI.Logger.Debug(ctx, "starting in-memory aibridge daemon")

Expand All @@ -37,17 +42,58 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai
// Create pool for reusable stateful [aibridge.RequestBridge] instances (one per user).
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool"), metrics, tracer) // TODO: configurable size.
if err != nil {
return nil, xerrors.Errorf("create request pool: %w", err)
return nil, nil, xerrors.Errorf("create request pool: %w", err)
}

// Subscribe to ai_providers change events so the pool tracks the
// database without a restart. The boot-time `providers` snapshot
// derives from env config and serves as a fallback if the database
// load fails inside the reloader.
reloader := &poolDBReloader{
pool: pool,
db: coderAPI.Database,
cfg: cfg,
logger: logger.Named("provider-loader"),
}
unsubscribe, err := aibridged.SubscribeProviderReload(ctx, coderAPI.Pubsub, reloader, logger.Named("provider-reload"))
if err != nil {
// Pool is still usable with the boot-time snapshot; subscription
// failure is logged but not fatal so the daemon still serves.
logger.Warn(ctx, "subscribe to ai providers change channel", slog.Error(err))
unsubscribe = func() {}
}

// Create daemon.
srv, err := aibridged.New(ctx, pool, func(dialCtx context.Context) (aibridged.DRPCClient, error) {
return coderAPI.CreateInMemoryAIBridgeServer(dialCtx)
}, logger, tracer)
if err != nil {
return nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
unsubscribe()
return nil, nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
}
return srv, unsubscribe, nil
}

// poolDBReloader implements [aibridged.ProviderReloader] by loading
// the live provider set from the database and forwarding it to the
// pool.
type poolDBReloader struct {
pool *aibridged.CachedBridgePool
db database.Store
cfg codersdk.AIBridgeConfig
logger slog.Logger
}

func (r *poolDBReloader) Reload(ctx context.Context) error {
providers, err := BuildProviders(ctx, r.db, r.cfg, r.logger)
if err != nil {
// Keep the previous snapshot in place: dropping all providers
// because the DB read failed would compound the visible failure
// mode beyond the operator's actual misconfiguration.
return xerrors.Errorf("load ai providers from database: %w", err)
}
return srv, nil
r.pool.ReplaceProviders(providers)
return nil
}

// BuildProviders loads every enabled ai_providers row, attaches its
Expand Down
4 changes: 3 additions & 1 deletion cli/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
if err != nil {
return xerrors.Errorf("build AI providers: %w", err)
}
aibridgeDaemon, err = newAIBridgeDaemon(coderAPI, aibridgeProviders)
var unsubscribeProviderReload func()
aibridgeDaemon, unsubscribeProviderReload, err = newAIBridgeDaemon(coderAPI, aibridgeProviders, vals.AI.BridgeConfig)
if err != nil {
return xerrors.Errorf("create aibridged: %w", err)
}
Expand All @@ -1055,6 +1056,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
// daemon does not affect in-flight requests but is needed to
// release pool/recorder resources at shutdown.
defer aibridgeDaemon.Close()
defer unsubscribeProviderReload()
}

if vals.Prometheus.Enable {
Expand Down
19 changes: 19 additions & 0 deletions coderd/ai_providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
Expand Down Expand Up @@ -235,6 +236,7 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) {
aReq.New = row

auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, aiProviderKeyChanges{Added: keys})
api.publishAIProvidersChanged(ctx)

sdk, err := db2sdk.AIProvider(row, keys)
if err != nil {
Expand Down Expand Up @@ -400,6 +402,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) {
}

auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, keyChanges)
api.publishAIProvidersChanged(ctx)

sdk, err := db2sdk.AIProvider(updated, keys)
if err != nil {
Expand Down Expand Up @@ -453,9 +456,25 @@ func (api *API) aiProvidersDelete(rw http.ResponseWriter, r *http.Request) {
return
}

api.publishAIProvidersChanged(ctx)

rw.WriteHeader(http.StatusNoContent)
}

// publishAIProvidersChanged notifies subscribers (aibridged,
// aibridgeproxyd) that the live provider set changed and they should
// refetch from the database. Pubsub failures are logged but not
// propagated: subscribers refresh authoritatively from the DB, so a
// dropped notification only delays convergence.
func (api *API) publishAIProvidersChanged(ctx context.Context) {
Comment thread
dannykopping marked this conversation as resolved.
if api.Pubsub == nil {
return
}
if err := api.Pubsub.Publish(coderpubsub.AIProvidersChangedChannel, nil); err != nil {
api.Logger.Warn(ctx, "publish ai providers changed event", slog.Error(err))
}
}

// errBedrockRejectsAPIKeys is the sentinel returned from inside the
// update transaction when a caller attempts to attach api_keys to a
// Bedrock-typed provider; the outer handler translates it into a 400.
Expand Down
62 changes: 62 additions & 0 deletions coderd/ai_providers_pubsub_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package coderd_test

import (
"context"
"sync/atomic"
"testing"

"github.com/stretchr/testify/require"

"github.com/coder/coder/v2/coderd/coderdtest"
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)

// TestAIProvidersChangedPubsub asserts that the CRUD handlers publish
// on AIProvidersChangedChannel for the operations that affect the
// runtime provider set. Subscribers (aibridged, aibridgeproxyd) depend
// on these notifications to trigger their pool reload.
//
// The handlers publish best-effort and the payload is empty, so we
// assert "at least one event per mutation" via a counter.
func TestAIProvidersChangedPubsub(t *testing.T) {
t.Parallel()

client, _, api := coderdtest.NewWithAPI(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)

var count atomic.Int64
unsubscribe, err := api.Pubsub.Subscribe(coderpubsub.AIProvidersChangedChannel, func(_ context.Context, _ []byte) {
count.Add(1)
})
require.NoError(t, err)
t.Cleanup(unsubscribe)

// Create.
req := codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "pubsub-openai",
Enabled: true,
BaseURL: "https://api.openai.com/v1/",
APIKeys: []string{"k1"},
}
//nolint:gocritic // Owner role is the audience for this endpoint.
created, err := client.CreateAIProvider(ctx, req)
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 1 }, testutil.IntervalFast)

// Update.
newKey := "k2"
_, err = client.UpdateAIProvider(ctx, created.ID.String(), codersdk.UpdateAIProviderRequest{
APIKeys: &[]codersdk.AIProviderKeyMutation{{APIKey: &newKey}},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 2 }, testutil.IntervalFast)

// Delete.
err = client.DeleteAIProvider(ctx, created.ID.String())
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 3 }, testutil.IntervalFast)
}
13 changes: 13 additions & 0 deletions coderd/aibridged/aibridgedmock/poolmock.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

92 changes: 71 additions & 21 deletions coderd/aibridged/pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@ package aibridged
import (
"context"
"net/http"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/dgraph-io/ristretto/v2"
Expand All @@ -26,6 +29,9 @@ const (
// One [*aibridge.RequestBridge] instance is created per given key.
type Pooler interface {
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
// ReplaceProviders swaps the providers used to construct future
// RequestBridge instances and clears the cache.
ReplaceProviders(providers []aibridge.Provider)
Shutdown(ctx context.Context) error
}

Expand All @@ -46,10 +52,12 @@ var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
var _ Pooler = &CachedBridgePool{}

type CachedBridgePool struct {
cache *ristretto.Cache[string, *aibridge.RequestBridge]
providers []aibridge.Provider
logger slog.Logger
options PoolOptions
cache *ristretto.Cache[string, *aibridge.RequestBridge]
// providers is the live provider set used by new RequestBridge instances.
providers atomic.Pointer[[]aibridge.Provider]
providerVersion atomic.Int64
logger slog.Logger
options PoolOptions

singleflight *singleflight.Group[string, *aibridge.RequestBridge]

Expand All @@ -71,32 +79,70 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
if item == nil || item.Value == nil {
return
}

shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*5)
defer shutdownCancel()

// Run the eviction in the background since ristretto blocks sets until a free slot is available.
// Capture the value synchronously: ristretto reuses the
// item slot after OnEvict returns, so reading item.Value
// from the goroutine below races with the caller of
// Clear/Set. The shutdown still runs in the background to
// avoid blocking ristretto's eviction loop.
bridge := item.Value
go func() {
_ = item.Value.Shutdown(shutdownCtx)
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_ = bridge.Shutdown(shutdownCtx)
}()
},
})
if err != nil {
return nil, xerrors.Errorf("create cache: %w", err)
}

return &CachedBridgePool{
cache: cache,
providers: providers,
options: options,
metrics: metrics,
tracer: tracer,
logger: logger,
pool := &CachedBridgePool{
cache: cache,
options: options,
metrics: metrics,
tracer: tracer,
logger: logger,

singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},

shuttingDownCh: make(chan struct{}),
}, nil
}
initial := slices.Clone(providers)
pool.providers.Store(&initial)
return pool, nil
}

// ReplaceProviders swaps the provider snapshot used by future Acquires.
// It is safe to call concurrently with Acquire and is a no-op after
// Shutdown.
func (p *CachedBridgePool) ReplaceProviders(providers []aibridge.Provider) {
select {
Comment thread
dannykopping marked this conversation as resolved.
case <-p.shuttingDownCh:
return
default:
}
snapshot := slices.Clone(providers)
p.providers.Store(&snapshot)
version := time.Now().UnixNano()
Comment thread
dannykopping marked this conversation as resolved.
p.providerVersion.Store(version)
// Clear evicts every cached bridge; OnEvict shuts each one down in
// the background. Wait for buffered writes to drain so a replacement
// immediately followed by an Acquire always sees the cleared cache.
p.cache.Clear()
p.cache.Wait()
p.logger.Info(context.Background(), "request bridge pool reloaded",
slog.F("provider_count", len(snapshot)),
slog.F("provider_version", version),
)
}

// loadProviders returns the current providers snapshot. The returned
// slice must not be mutated.
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
if ptr := p.providers.Load(); ptr != nil {
return *ptr
}
return nil
}

// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
Expand Down Expand Up @@ -140,6 +186,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
}

span.AddEvent("cache_miss")
providerVersion := p.providerVersion.Load()
recorder := aibridge.NewRecorder(p.logger.Named("recorder"), p.tracer, func() (aibridge.Recorder, error) {
client, err := clientFn()
if err != nil {
Expand All @@ -152,7 +199,8 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
// Slow path.
// Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value.
// TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs).
instance, err, _ := p.singleflight.Do(req.InitiatorID.String(), func() (*aibridge.RequestBridge, error) {
singleflightKey := cacheKey + "|" + strconv.FormatInt(providerVersion, 10)
instance, err, _ := p.singleflight.Do(singleflightKey, func() (*aibridge.RequestBridge, error) {
var (
mcpServers mcp.ServerProxier
err error
Expand All @@ -171,12 +219,14 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
}
}

bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.logger, p.metrics, p.tracer)
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
Comment thread
dannykopping marked this conversation as resolved.
Comment thread
dannykopping marked this conversation as resolved.
if err != nil {
return nil, xerrors.Errorf("create new request bridge: %w", err)
}

p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
if p.providerVersion.Load() == providerVersion {
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
}

return bridge, nil
})
Expand Down
Loading
Loading