Skip to content

Commit c350d1f

Browse files
committed
refactor(aibridge): apply key pool failover follow-ups
1 parent e6514e3 commit c350d1f

10 files changed

Lines changed: 286 additions & 255 deletions

File tree

aibridge/intercept/keyfailover_test.go

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -62,15 +62,6 @@ type interceptorCase struct {
6262
newInterceptor func(t *testing.T, streaming bool, upstreamURL string, reqBody []byte, pool *keypool.Pool, byokKey string) intercept.Interceptor
6363
}
6464

65-
// keyFromHeader reads the API key an upstream request carried in the named auth
66-
// header.
67-
func keyFromHeader(name string, h http.Header) string {
68-
if name == "Authorization" {
69-
return utils.ExtractBearerToken(h.Get(name))
70-
}
71-
return h.Get(name)
72-
}
73-
7465
// interceptorCases is the set of interceptors the failover tests run against,
7566
// one entry per supported API.
7667
var interceptorCases = []interceptorCase{
@@ -365,7 +356,7 @@ func TestInterception_KeyFailover(t *testing.T) {
365356

366357
var seenKeys []string
367358
for _, r := range upstream.ReceivedRequests() {
368-
seenKeys = append(seenKeys, keyFromHeader(ic.authHeader, r.Header))
359+
seenKeys = append(seenKeys, testutil.KeyFromHeader(ic.authHeader, r.Header))
369360
}
370361
assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys")
371362

@@ -540,7 +531,7 @@ func TestInterception_AgenticLoopFailover(t *testing.T) {
540531

541532
var seenKeys []string
542533
for _, r := range upstream.ReceivedRequests() {
543-
seenKeys = append(seenKeys, keyFromHeader(ic.authHeader, r.Header))
534+
seenKeys = append(seenKeys, testutil.KeyFromHeader(ic.authHeader, r.Header))
544535
}
545536
assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys")
546537

aibridge/intercept/messages/base.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,9 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
580580
// ResponseErrorFromKeyPool translates a *keypool.Error into
581581
// a developer-facing ResponseError shaped for the Anthropic API.
582582
func ResponseErrorFromKeyPool(keyPoolErr *keypool.Error) *ResponseError {
583+
if keyPoolErr == nil {
584+
return nil
585+
}
583586
switch keyPoolErr.Kind {
584587
case keypool.ErrorKindPermanent:
585588
return newResponseError(

aibridge/intercept/messages/base_internal_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,10 @@ func TestResponseErrorFromKeyPool(t *testing.T) {
10411041
expectedStatus int
10421042
expectedRetryAfter time.Duration
10431043
}{
1044+
{
1045+
name: "nil_returns_nil",
1046+
keyPoolErr: nil,
1047+
},
10441048
{
10451049
// Rate-limited with no cooldown: 429, no Retry-After.
10461050
name: "rate_limited_zero_retry_after",
@@ -1067,6 +1071,10 @@ func TestResponseErrorFromKeyPool(t *testing.T) {
10671071
t.Run(tc.name, func(t *testing.T) {
10681072
t.Parallel()
10691073
got := ResponseErrorFromKeyPool(tc.keyPoolErr)
1074+
if tc.keyPoolErr == nil {
1075+
assert.Nil(t, got)
1076+
return
1077+
}
10701078
require.NotNil(t, got)
10711079
assert.Equal(t, tc.expectedStatus, got.StatusCode)
10721080
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)

aibridge/intercept/messages/blocking.go

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,8 +112,14 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
112112

113113
for {
114114
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
115+
116+
// Rebuilt per iteration: i.reqPayload mutates when an agentic
117+
// continuation appends tool results, so withBody must reflect
118+
// the latest payload on every upstream call.
119+
callOpts := []option.RequestOption{i.withBody()}
120+
115121
var keyAttempts int
116-
resp, keyAttempts, err = i.newMessage(ctx, svc)
122+
resp, keyAttempts, err = i.newMessage(ctx, svc, callOpts)
117123
totalKeyAttempts += keyAttempts
118124
if err != nil {
119125
if eventstream.IsConnError(err) {
@@ -352,24 +358,23 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
352358
// newMessage routes by credential type, returning the upstream message, the
353359
// number of key attempts made for this call, and any error. Centralized
354360
// credentials fail over across the key pool, while BYOK makes a single attempt.
355-
func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, int, error) {
361+
func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService, opts []option.RequestOption) (*anthropic.Message, int, error) {
356362
switch i.cred.Kind() {
357363
case intercept.CredentialKindCentralized:
358-
return i.newMessageWithKeyFailover(ctx, svc)
364+
return i.newMessageWithKeyFailover(ctx, svc, opts)
359365
case intercept.CredentialKindBYOK:
360-
msg, err := i.newMessageWithKey(ctx, svc)
366+
msg, err := i.newMessageWithKey(ctx, svc, opts...)
361367
return msg, 0, err
362368
default:
363369
return nil, 0, xerrors.New("no credential configured")
364370
}
365371
}
366372

367373
// newMessageWithKey performs a single upstream call.
368-
func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthropic.MessageService, extraOpts ...option.RequestOption) (_ *anthropic.Message, outErr error) {
374+
func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthropic.MessageService, opts ...option.RequestOption) (_ *anthropic.Message, outErr error) {
369375
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
370376
defer tracing.EndSpanErr(span, &outErr)
371377

372-
opts := append([]option.RequestOption{i.withBody()}, extraOpts...)
373378
return svc.New(ctx, anthropic.MessageNewParams{}, opts...)
374379
}
375380

@@ -378,12 +383,12 @@ func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthro
378383
// 429 and permanent on 401/403. Errors that aren't key-specific don't trigger
379384
// failover and are returned to the caller. It returns the upstream message,
380385
// the number of key attempts made for this call, and any error.
381-
func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, int, error) {
386+
func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, svc anthropic.MessageService, opts []option.RequestOption) (*anthropic.Message, int, error) {
382387
centralized, ok := intercept.AsCentralized(i.cred)
383388
if !ok {
384389
// Centralized but pool-less: Bedrock, which signs via AWS. A single
385390
// attempt with no failover.
386-
msg, err := i.newMessageWithKey(ctx, svc)
391+
msg, err := i.newMessageWithKey(ctx, svc, opts...)
387392
return msg, 0, err
388393
}
389394
walker := centralized.Pool.Walker()
@@ -396,12 +401,14 @@ func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, sv
396401
i.logger.Debug(ctx, "using centralized api key",
397402
slog.F("credential_hint", i.cred.Hint()), slog.F("credential_length", i.cred.Length()))
398403

399-
msg, err := i.newMessageWithKey(ctx, svc,
404+
requestOpts := append([]option.RequestOption{}, opts...)
405+
requestOpts = append(requestOpts,
400406
option.WithAPIKey(key.Value()),
401407
// Disable SDK retries because the failover loop
402408
// handles retries via key rotation.
403409
option.WithMaxRetries(0),
404410
)
411+
msg, err := i.newMessageWithKey(ctx, svc, requestOpts...)
405412
// Key-specific failure: try the next key.
406413
if i.markKeyOnError(ctx, key, err) {
407414
continue

aibridge/intercept/messages/streaming.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ newStream:
179179
walker = centralized.Pool.Walker()
180180
}
181181

182-
var streamOpts []option.RequestOption
182+
streamOpts := []option.RequestOption{i.withBody()}
183183
var currentPoolKey *keypool.Key
184184
if walker != nil {
185185
key, keyPoolErr := walker.Next()
@@ -696,10 +696,9 @@ func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte
696696
}
697697

698698
// newStream traces svc.NewStreaming() call.
699-
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, extraOpts ...option.RequestOption) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
699+
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, opts ...option.RequestOption) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
700700
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
701701
defer span.End()
702702

703-
opts := append([]option.RequestOption{i.withBody()}, extraOpts...)
704703
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, opts...)
705704
}

aibridge/intercept/openai_errors.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ func (e *ResponseError) ToResponse() *http.Response {
7373
// ResponseErrorFromKeyPool translates a *keypool.Error into
7474
// a developer-facing ResponseError shaped for the OpenAI API.
7575
func ResponseErrorFromKeyPool(keyPoolErr *keypool.Error) *ResponseError {
76+
if keyPoolErr == nil {
77+
return nil
78+
}
7679
switch keyPoolErr.Kind {
7780
case keypool.ErrorKindPermanent:
7881
return NewResponseError(

aibridge/intercept/openai_errors_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ func TestResponseErrorFromKeyPool(t *testing.T) {
2121
expectedStatus int
2222
expectedRetryAfter time.Duration
2323
}{
24+
{
25+
name: "nil_returns_nil",
26+
keyPoolErr: nil,
27+
},
2428
{
2529
// Rate-limited with no cooldown: 429, no Retry-After.
2630
name: "rate_limited_zero_retry_after",
@@ -47,6 +51,10 @@ func TestResponseErrorFromKeyPool(t *testing.T) {
4751
t.Run(tc.name, func(t *testing.T) {
4852
t.Parallel()
4953
got := intercept.ResponseErrorFromKeyPool(tc.keyPoolErr)
54+
if tc.keyPoolErr == nil {
55+
assert.Nil(t, got)
56+
return
57+
}
5058
require.NotNil(t, got)
5159
assert.Equal(t, tc.expectedStatus, got.StatusCode)
5260
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)

0 commit comments

Comments
 (0)