Skip to content

Commit 168089f

Browse files
committed
refactor: separate aibridge provider and interceptor configs
1 parent 75909f5 commit 168089f

32 files changed

Lines changed: 784 additions & 664 deletions

aibridge/bridge.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,8 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
262262
Client: string(client),
263263
ClientSessionID: sessionID,
264264
CorrelatingToolCallID: interceptor.CorrelatingToolCallID(),
265-
CredentialKind: string(cred.Kind),
266-
CredentialHint: cred.Hint,
265+
CredentialKind: string(cred.Kind()),
266+
CredentialHint: cred.Hint(),
267267
}); err != nil {
268268
span.SetStatus(codes.Error, fmt.Sprintf("failed to record interception: %v", err))
269269
logger.Warn(ctx, "failed to record interception", slog.Error(err))
@@ -277,16 +277,16 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
277277
slog.F("provider", p.Name()),
278278
slog.F("user_agent", r.UserAgent()),
279279
slog.F("streaming", interceptor.Streaming()),
280-
slog.F("credential_kind", string(cred.Kind)),
280+
slog.F("credential_kind", string(cred.Kind())),
281281
)
282282

283283
// Log BYOK credentials. Centralized credentials are set by
284284
// the key failover loop.
285285
credLogFields := []slog.Field{}
286-
if cred.Kind == intercept.CredentialKindBYOK {
286+
if cred.Kind() == intercept.CredentialKindBYOK {
287287
credLogFields = append(credLogFields,
288-
slog.F("credential_hint", cred.Hint),
289-
slog.F("credential_length", cred.Length),
288+
slog.F("credential_hint", cred.Hint()),
289+
slog.F("credential_length", cred.Length()),
290290
)
291291
}
292292
log.Debug(ctx, "interception started", credLogFields...)
@@ -303,8 +303,7 @@ func newInterceptionProcessor(p provider.Provider, cbs *circuitbreaker.ProviderC
303303
})
304304
// For centralized, the hint now reflects the last attempted
305305
// key from the failover loop.
306-
credHint := interceptor.Credential().Hint
307-
credLen := interceptor.Credential().Length
306+
credHint, credLen := cred.Hint(), cred.Length()
308307
if execErr != nil {
309308
if m != nil {
310309
m.InterceptionCount.WithLabelValues(p.Name(), interceptor.Model(), metrics.InterceptionCountStatusFailed, route, r.Method, actor.ID, string(client)).Add(1)

aibridge/config/config.go

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,17 @@ const (
1212
ProviderCopilot = "copilot"
1313
)
1414

15-
// Anthropic carries configuration for an Anthropic provider.
16-
//
17-
// Authentication is mutually exclusive across these three fields,
18-
// set per interception in the provider's CreateInterceptor:
19-
// - KeyPool: centralized requests with automatic key failover.
20-
// - Key: BYOK with X-Api-Key (single attempt, no failover).
21-
// - BYOKBearerToken: BYOK with Authorization Bearer (single
22-
// attempt, no failover).
23-
//
24-
// TODO(ssncferreira): consolidate the three authentication
25-
// fields into a single abstraction per
26-
// https://github.com/coder/aibridge/issues/266.
15+
// Anthropic carries configuration for an Anthropic provider. KeyPool holds the
16+
// centralized keys (with failover); BYOK credentials are resolved per request
17+
// from the incoming headers, not configured here.
2718
type Anthropic struct {
2819
// Name is the provider instance name. If empty, defaults to "anthropic".
2920
Name string
3021
BaseURL string
31-
Key string
3222
KeyPool *keypool.Pool
3323
APIDumpDir string
3424
CircuitBreaker *CircuitBreaker
3525
SendActorHeaders bool
36-
ExtraHeaders map[string]string
37-
// BYOKBearerToken is set in BYOK mode when the user authenticates
38-
// with a access token. When set, the access token is used for upstream
39-
// LLM requests instead of the API key.
40-
BYOKBearerToken string
4126
}
4227

4328
type AWSBedrock struct {
@@ -50,26 +35,17 @@ type AWSBedrock struct {
5035
BaseURL string
5136
}
5237

53-
// OpenAI carries configuration for an OpenAI provider.
54-
//
55-
// Authentication is mutually exclusive across these two fields,
56-
// set per interception in the provider's CreateInterceptor:
57-
// - KeyPool: centralized requests with automatic key failover.
58-
// - Key: BYOK with Authorization Bearer (single attempt, no
59-
// failover).
60-
//
61-
// TODO(ssncferreira): consolidate the authentication fields per
62-
// https://github.com/coder/aibridge/issues/266.
38+
// OpenAI carries configuration for an OpenAI provider. KeyPool holds the
39+
// centralized keys (with failover); BYOK credentials are resolved per request
40+
// from the incoming headers, not configured here.
6341
type OpenAI struct {
6442
// Name is the provider instance name. If empty, defaults to "openai".
6543
Name string
6644
BaseURL string
67-
Key string
6845
KeyPool *keypool.Pool
6946
APIDumpDir string
7047
CircuitBreaker *CircuitBreaker
7148
SendActorHeaders bool
72-
ExtraHeaders map[string]string
7349
}
7450

7551
type Copilot struct {

aibridge/intercept/chatcompletions/base.go

Lines changed: 20 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"go.opentelemetry.io/otel/trace"
1818

1919
"cdr.dev/slog/v3"
20-
"github.com/coder/coder/v2/aibridge/config"
2120
aibcontext "github.com/coder/coder/v2/aibridge/context"
2221
"github.com/coder/coder/v2/aibridge/intercept"
2322
"github.com/coder/coder/v2/aibridge/intercept/apidump"
@@ -29,56 +28,46 @@ import (
2928
)
3029

3130
type interceptionBase struct {
32-
id uuid.UUID
33-
providerName string
34-
req *ChatCompletionNewParamsWrapper
35-
cfg config.OpenAI
31+
id uuid.UUID
32+
req *ChatCompletionNewParamsWrapper
33+
34+
cfg intercept.Config
35+
cred intercept.Credential
3636

3737
// clientHeaders are the original HTTP headers from the client request.
38-
clientHeaders http.Header
39-
authHeaderName string
38+
clientHeaders http.Header
4039

4140
logger slog.Logger
4241
tracer trace.Tracer
4342

44-
recorder recorder.Recorder
45-
mcpProxy mcp.ServerProxier
46-
credential intercept.CredentialInfo
43+
recorder recorder.Recorder
44+
mcpProxy mcp.ServerProxier
4745
}
4846

4947
// newCompletionsService builds the SDK service used for upstream
5048
// calls. BYOK auth is set here. Centralized auth is set
5149
// per-attempt by the failover loop.
5250
func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
53-
// TODO(ssncferreira): validate auth is configured per
54-
// https://github.com/coder/aibridge/issues/266.
55-
5651
var opts []option.RequestOption
57-
// BYOK auth.
58-
if i.cfg.KeyPool == nil {
59-
opts = append(opts, option.WithAPIKey(i.cfg.Key))
52+
// BYOK sets its key here; centralized injects per-attempt in the failover
53+
// loop. The OpenAI SDK presents the key as an Authorization bearer.
54+
if byok, ok := intercept.AsBYOK(i.cred); ok {
55+
opts = append(opts, option.WithAPIKey(byok.Secret))
6056
}
6157
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))
6258

63-
// Add extra headers if configured.
64-
// Some providers require additional headers that are not added by the SDK.
65-
// TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192
66-
for key, value := range i.cfg.ExtraHeaders {
67-
opts = append(opts, option.WithHeader(key, value))
68-
}
69-
7059
// Forward client headers to upstream. This middleware runs after the SDK
7160
// has built the request, and replaces the outgoing headers with the sanitized
7261
// client headers plus provider auth.
7362
if i.clientHeaders != nil {
7463
opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
75-
req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName)
64+
req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.cred.AuthHeader())
7665
return next(req)
7766
}))
7867
}
7968

8069
// Add API dump middleware if configured
81-
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
70+
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.cfg.ProviderName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
8271
opts = append(opts, option.WithMiddleware(mw))
8372
}
8473

@@ -89,8 +78,8 @@ func (i *interceptionBase) ID() uuid.UUID {
8978
return i.id
9079
}
9180

92-
func (i *interceptionBase) Credential() intercept.CredentialInfo {
93-
return i.credential
81+
func (i *interceptionBase) Credential() intercept.Credential {
82+
return i.cred
9483
}
9584

9685
func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
@@ -117,7 +106,7 @@ func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
117106
attribute.String(tracing.RequestPath, r.URL.Path),
118107
attribute.String(tracing.InterceptionID, i.id.String()),
119108
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
120-
attribute.String(tracing.Provider, i.providerName),
109+
attribute.String(tracing.Provider, i.cfg.ProviderName),
121110
attribute.String(tracing.Model, i.Model()),
122111
attribute.Bool(tracing.Streaming, streaming),
123112
}
@@ -219,14 +208,15 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *int
219208
// code. Returns true if the status was a key-specific failover
220209
// trigger so callers can retry with the next key.
221210
func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, err error) bool {
222-
if i.cfg.KeyPool == nil {
211+
centralized, ok := intercept.AsCentralized(i.cred)
212+
if !ok {
223213
return false
224214
}
225215
var apiErr *openai.Error
226216
if !errors.As(err, &apiErr) {
227217
return false
228218
}
229-
return i.cfg.KeyPool.MarkKeyOnStatus(
219+
return centralized.Pool.MarkKeyOnStatus(
230220
ctx, key, apiErr.Response, i.logger,
231221
)
232222
}

aibridge/intercept/chatcompletions/base_internal_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ func TestMarkKeyOnError(t *testing.T) {
141141
key, keyPoolErr := pool.Walker().Next()
142142
require.Nil(t, keyPoolErr)
143143

144-
base := &interceptionBase{cfg: config.OpenAI{KeyPool: pool}, logger: slog.Make()}
144+
base := &interceptionBase{cred: &intercept.Centralized{Pool: pool}, logger: slog.Make()}
145145

146146
got := base.markKeyOnError(context.Background(), key, tc.err)
147147
assert.Equal(t, tc.expectedReturn, got)

aibridge/intercept/chatcompletions/blocking.go

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ import (
1616
"golang.org/x/xerrors"
1717

1818
"cdr.dev/slog/v3"
19-
"github.com/coder/coder/v2/aibridge/config"
2019
aibcontext "github.com/coder/coder/v2/aibridge/context"
2120
"github.com/coder/coder/v2/aibridge/intercept"
2221
"github.com/coder/coder/v2/aibridge/intercept/eventstream"
@@ -33,22 +32,18 @@ type BlockingInterception struct {
3332
func NewBlockingInterceptor(
3433
id uuid.UUID,
3534
req *ChatCompletionNewParamsWrapper,
36-
providerName string,
37-
cfg config.OpenAI,
35+
cfg intercept.Config,
36+
cred intercept.Credential,
3837
clientHeaders http.Header,
39-
authHeaderName string,
4038
tracer trace.Tracer,
41-
cred intercept.CredentialInfo,
4239
) *BlockingInterception {
4340
return &BlockingInterception{interceptionBase: interceptionBase{
44-
id: id,
45-
providerName: providerName,
46-
req: req,
47-
cfg: cfg,
48-
clientHeaders: clientHeaders,
49-
authHeaderName: authHeaderName,
50-
tracer: tracer,
51-
credential: cred,
41+
id: id,
42+
req: req,
43+
cfg: cfg,
44+
cred: cred,
45+
clientHeaders: clientHeaders,
46+
tracer: tracer,
5247
}}
5348
}
5449

@@ -91,7 +86,11 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
9186
// Sum the key attempts across all iterations and record once when the
9287
// interception completes.
9388
var totalKeyAttempts int
94-
defer func() { i.cfg.KeyPool.RecordAttempts(totalKeyAttempts) }()
89+
defer func() {
90+
if centralized, ok := intercept.AsCentralized(i.cred); ok {
91+
centralized.Pool.RecordAttempts(totalKeyAttempts)
92+
}
93+
}()
9594

9695
for {
9796
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
@@ -278,12 +277,15 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
278277
// failover, returning the upstream completion, the number of key attempts
279278
// made for this call, and any error.
280279
func (i *BlockingInterception) newChatCompletion(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (*openai.ChatCompletion, int, error) {
281-
// BYOK: single attempt, no failover.
282-
if i.cfg.KeyPool == nil {
280+
switch i.cred.Kind() {
281+
case intercept.CredentialKindCentralized:
282+
return i.newChatCompletionWithKeyFailover(ctx, svc, opts)
283+
case intercept.CredentialKindBYOK:
283284
completion, err := i.newChatCompletionWithKey(ctx, svc, opts)
284285
return completion, 0, err
286+
default:
287+
return nil, 0, xerrors.New("no credential configured")
285288
}
286-
return i.newChatCompletionWithKeyFailover(ctx, svc, opts)
287289
}
288290

289291
// newChatCompletionWithKey performs a single upstream call.
@@ -300,16 +302,21 @@ func (i *BlockingInterception) newChatCompletionWithKey(ctx context.Context, svc
300302
// trigger failover and are returned to the caller. It returns the upstream
301303
// completion, the number of key attempts made for this call, and any error.
302304
func (i *BlockingInterception) newChatCompletionWithKeyFailover(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (*openai.ChatCompletion, int, error) {
303-
walker := i.cfg.KeyPool.Walker()
305+
centralized, ok := intercept.AsCentralized(i.cred)
306+
if !ok {
307+
// Centralized but pool-less (no centralized keys configured): one attempt.
308+
completion, err := i.newChatCompletionWithKey(ctx, svc, opts)
309+
return completion, 0, err
310+
}
311+
walker := centralized.Pool.Walker()
304312
for {
305313
key, keyPoolErr := walker.Next()
306314
if keyPoolErr != nil {
307315
return nil, walker.Attempts(), keyPoolErr
308316
}
309-
// Record the key in use so the hint reflects the last attempted key.
310-
i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value())
317+
centralized.SetKey(key.Value())
311318
i.logger.Debug(ctx, "using centralized api key",
312-
slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length))
319+
slog.F("credential_hint", i.cred.Hint()), slog.F("credential_length", i.cred.Length()))
313320

314321
requestOpts := append([]option.RequestOption{}, opts...)
315322
requestOpts = append(requestOpts,

0 commit comments

Comments
 (0)