From f32ae0f1dee3fc8a1f87cd2e2b8534fac53322ed Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Thu, 30 Apr 2026 09:14:51 +0000 Subject: [PATCH 01/14] feat: add automatic key failover for AI Bridge Anthropic --- aibridge/config/config.go | 19 +- aibridge/intercept/messages/base.go | 111 +++- aibridge/intercept/messages/base_test.go | 205 +++++++ aibridge/intercept/messages/blocking.go | 48 +- aibridge/intercept/messages/blocking_test.go | 452 ++++++++++++++ aibridge/intercept/messages/streaming.go | 149 ++++- aibridge/intercept/messages/streaming_test.go | 552 ++++++++++++++++++ .../integrationtest/keypool_failover_test.go | 128 ++++ aibridge/keypool/headers.go | 37 ++ aibridge/keypool/headers_test.go | 110 ++++ aibridge/keypool/keymark.go | 48 ++ aibridge/keypool/keymark_test.go | 125 ++++ aibridge/keypool/keypool.go | 89 ++- aibridge/keypool/keypool_test.go | 183 +++++- aibridge/provider/anthropic.go | 63 +- aibridge/provider/anthropic_test.go | 66 +++ enterprise/cli/aibridged.go | 35 +- 17 files changed, 2301 insertions(+), 119 deletions(-) create mode 100644 aibridge/intercept/messages/blocking_test.go create mode 100644 aibridge/intercept/messages/streaming_test.go create mode 100644 aibridge/internal/integrationtest/keypool_failover_test.go create mode 100644 aibridge/keypool/headers.go create mode 100644 aibridge/keypool/headers_test.go create mode 100644 aibridge/keypool/keymark.go create mode 100644 aibridge/keypool/keymark_test.go diff --git a/aibridge/config/config.go b/aibridge/config/config.go index 48f29bb3f5188..676e891c2baaa 100644 --- a/aibridge/config/config.go +++ b/aibridge/config/config.go @@ -1,6 +1,10 @@ package config -import "time" +import ( + "time" + + "github.com/coder/coder/v2/aibridge/keypool" +) const ( ProviderAnthropic = "anthropic" @@ -8,11 +12,24 @@ const ( ProviderCopilot = "copilot" ) +// Anthropic carries configuration for an Anthropic provider. +// +// Authentication is mutually exclusive across these three fields, +// set per interception in the provider's CreateInterceptor: +// - KeyPool: centralized requests with automatic key failover. +// - Key: BYOK with X-Api-Key (single attempt, no failover). +// - BYOKBearerToken: BYOK with Authorization Bearer (single +// attempt, no failover). +// +// TODO(ssncferreira): consolidate the three authentication +// fields into a single abstraction per +// https://github.com/coder/aibridge/issues/266. type Anthropic struct { // Name is the provider instance name. If empty, defaults to "anthropic". Name string BaseURL string Key string + KeyPool *keypool.Pool APIDumpDir string CircuitBreaker *CircuitBreaker SendActorHeaders bool diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index c5d053768e829..387228220ef58 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -5,7 +5,9 @@ import ( "encoding/json" "errors" "fmt" + "math" "net/http" + "strconv" "strings" "time" @@ -26,6 +28,7 @@ import ( aibcontext "github.com/coder/coder/v2/aibridge/context" "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/intercept/apidump" + "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/recorder" "github.com/coder/coder/v2/aibridge/tracing" @@ -202,19 +205,25 @@ func (i *interceptionBase) isSmallFastModel() bool { return strings.Contains(i.reqPayload.model(), "haiku") } +// newMessagesService builds the SDK service used for upstream +// calls. BYOK auth is set here. Centralized auth is set +// per-attempt by the failover loop. func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { - // BYOK with access token uses Authorization: Bearer. - // Otherwise use X-Api-Key (centralized or BYOK with personal API key). - if i.cfg.BYOKBearerToken != "" { - i.logger.Debug(ctx, "using byok access token auth", - slog.F("bearer_hint", utils.MaskSecret(i.cfg.BYOKBearerToken)), - ) - opts = append(opts, option.WithAuthToken(i.cfg.BYOKBearerToken)) - } else { - i.logger.Debug(ctx, "using api key auth", - slog.F("api_key_hint", utils.MaskSecret(i.cfg.Key)), - ) - opts = append(opts, option.WithAPIKey(i.cfg.Key)) + // BYOK auth. + if i.cfg.KeyPool == nil { + if i.cfg.BYOKBearerToken != "" { + // BYOK Bearer: Authorization header. + i.logger.Debug(ctx, "using byok access token auth", + slog.F("bearer_hint", utils.MaskSecret(i.cfg.BYOKBearerToken)), + ) + opts = append(opts, option.WithAuthToken(i.cfg.BYOKBearerToken)) + } else { + // BYOK X-Api-Key. + i.logger.Debug(ctx, "using api key auth", + slog.F("api_key_hint", utils.MaskSecret(i.cfg.Key)), + ) + opts = append(opts, option.WithAPIKey(i.cfg.Key)) + } } opts = append(opts, option.WithBaseURL(i.cfg.BaseURL)) @@ -427,6 +436,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *res } w.Header().Set("Content-Type", "application/json") + // Set Retry-After when a cooldown is configured. + if antErr.RetryAfter > 0 { + w.Header().Set("Retry-After", strconv.Itoa(int(math.Ceil(antErr.RetryAfter.Seconds())))) + } w.WriteHeader(antErr.StatusCode) out, err := json.Marshal(antErr) @@ -503,6 +516,53 @@ func accumulateUsage(dest, src any) { } } +// For centralized requests, markKeyOnError extracts an +// Anthropic SDK error from err and marks the key based on +// its status code. Returns true if the status was a key-specific +// failover trigger so callers can retry with the next key. +func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, err error) bool { + if i.cfg.KeyPool == nil { + return false + } + var apiErr *anthropic.Error + if !errors.As(err, &apiErr) { + return false + } + return keypool.MarkKeyOnStatus( + ctx, key, apiErr.StatusCode, apiErr.Response, + i.logger, i.providerName, + ) +} + +// For centralized requests, mapExhaustionError translates a +// keypool exhaustion error into a developer-facing responseError +// shaped for the Anthropic API. Returns nil if err is not an +// exhaustion error. +func (i *interceptionBase) mapExhaustionError(err error) *responseError { + if i.cfg.KeyPool == nil { + return nil + } + var transient *keypool.TransientExhaustionError + switch { + case errors.As(err, &transient): + return newErrorResponse( + "all configured keys are rate-limited", + string(constant.ValueOf[constant.RateLimitError]()), + http.StatusTooManyRequests, + transient.RetryAfter, + ) + case errors.Is(err, keypool.ErrPermanentExhaustion): + return newErrorResponse( + "all configured keys failed authentication", + string(constant.ValueOf[constant.APIError]()), + http.StatusBadGateway, + 0, + ) + default: + return nil + } +} + func getErrorResponse(err error) *responseError { var apierr *anthropic.Error if !errors.As(err, &apierr) { @@ -510,7 +570,7 @@ func getErrorResponse(err error) *responseError { } msg := apierr.Error() - typ := string(constant.ValueOf[constant.APIError]()) + errType := string(constant.ValueOf[constant.APIError]()) var detail *anthropic.APIErrorObject if field, ok := apierr.JSON.ExtraFields["error"]; ok { @@ -518,19 +578,10 @@ func getErrorResponse(err error) *responseError { } if detail != nil { msg = detail.Message - typ = string(detail.Type) + errType = string(detail.Type) } - return &responseError{ - ErrorResponse: &anthropic.ErrorResponse{ - Error: anthropic.ErrorObjectUnion{ - Message: msg, - Type: typ, - }, - Type: constant.ValueOf[constant.Error](), - }, - StatusCode: apierr.StatusCode, - } + return newErrorResponse(msg, errType, apierr.StatusCode, 0) } var _ error = &responseError{} @@ -538,17 +589,21 @@ var _ error = &responseError{} type responseError struct { *anthropic.ErrorResponse - StatusCode int `json:"-"` + StatusCode int `json:"-"` + RetryAfter time.Duration `json:"-"` } -func newErrorResponse(msg error) *responseError { +func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *responseError { return &responseError{ ErrorResponse: &shared.ErrorResponse{ Error: shared.ErrorObjectUnion{ - Message: msg.Error(), - Type: "error", + Message: msg, + Type: errType, }, + Type: constant.ValueOf[constant.Error](), }, + StatusCode: status, + RetryAfter: retryAfter, } } diff --git a/aibridge/intercept/messages/base_test.go b/aibridge/intercept/messages/base_test.go index 148c77c3fa7b6..9d3c388de8084 100644 --- a/aibridge/intercept/messages/base_test.go +++ b/aibridge/intercept/messages/base_test.go @@ -3,18 +3,24 @@ package messages //nolint:testpackage // tests unexported internals import ( "context" "net/http" + "net/http/httptest" "testing" + "time" "github.com/anthropics/anthropic-sdk-go" "github.com/anthropics/anthropic-sdk-go/shared/constant" mcpgo "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/tidwall/gjson" + "golang.org/x/xerrors" "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/utils" + "github.com/coder/quartz" ) func TestScanForCorrelatingToolCallID(t *testing.T) { @@ -991,3 +997,202 @@ func TestFilterBedrockBetaFlags(t *testing.T) { }) } } + +func TestMapExhaustionError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + nilKeyPool bool + err error + expectedNil bool + expectedStatus int + expectedRetryAfter time.Duration + }{ + { + // BYOK or no centralized pool: never maps. + name: "nil_keypool_returns_nil", + nilKeyPool: true, + err: &keypool.TransientExhaustionError{}, + expectedNil: true, + }, + { + // Transient with valid keys present: 429, no Retry-After. + name: "transient_zero_retry_after", + err: &keypool.TransientExhaustionError{}, + expectedStatus: http.StatusTooManyRequests, + expectedRetryAfter: 0, + }, + { + // Transient with cooldown: 429, Retry-After set. + name: "transient_with_retry_after", + err: &keypool.TransientExhaustionError{RetryAfter: 5 * time.Second}, + expectedStatus: http.StatusTooManyRequests, + expectedRetryAfter: 5 * time.Second, + }, + { + // Permanent: 502 api_error. + name: "permanent_returns_502", + err: keypool.ErrPermanentExhaustion, + expectedStatus: http.StatusBadGateway, + }, + { + // Anything else: not a pool-exhaustion error. + name: "non_pool_exhaustion_error_returns_nil", + err: xerrors.New("some other error"), + expectedNil: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + cfg := config.Anthropic{} + if !tc.nilKeyPool { + pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t)) + require.NoError(t, err) + cfg.KeyPool = pool + } + base := &interceptionBase{cfg: cfg} + + got := base.mapExhaustionError(tc.err) + if tc.expectedNil { + require.Nil(t, got) + return + } + require.NotNil(t, got) + assert.Equal(t, tc.expectedStatus, got.StatusCode) + assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter) + }) + } +} + +func TestMarkKeyOnError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err error + expectedReturn bool + expectedState keypool.KeyState + }{ + { + // Not an *anthropic.Error: no status code to act on. + name: "non_api_error_returns_false", + err: xerrors.New("network failure"), + expectedReturn: false, + expectedState: keypool.KeyStateValid, + }, + { + // Rate-limited: temporary cooldown. + name: "429_marks_temporary", + err: &anthropic.Error{StatusCode: http.StatusTooManyRequests}, + expectedReturn: true, + expectedState: keypool.KeyStateTemporary, + }, + { + // Auth failure: mark permanent. + name: "401_marks_permanent", + err: &anthropic.Error{StatusCode: http.StatusUnauthorized}, + expectedReturn: true, + expectedState: keypool.KeyStatePermanent, + }, + { + // Auth forbidden: mark permanent. + name: "403_marks_permanent", + err: &anthropic.Error{StatusCode: http.StatusForbidden}, + expectedReturn: true, + expectedState: keypool.KeyStatePermanent, + }, + { + // Server errors are not key-specific. + name: "500_does_not_mark", + err: &anthropic.Error{StatusCode: http.StatusInternalServerError}, + expectedReturn: false, + expectedState: keypool.KeyStateValid, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t)) + require.NoError(t, err) + key, err := pool.Walker().Next() + require.NoError(t, err) + + base := &interceptionBase{cfg: config.Anthropic{KeyPool: pool}, logger: slog.Make()} + + got := base.markKeyOnError(context.Background(), key, tc.err) + assert.Equal(t, tc.expectedReturn, got) + assert.Equal(t, tc.expectedState, key.State()) + }) + } +} + +func TestWriteUpstreamError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + respErr *responseError + expectStatus int + // Empty string means the header should be absent. + expectRetryAfter string + // Substring expected in the marshaled body. Empty means no body check. + expectBodyContains string + }{ + { + // Standard error: status and JSON body written. + name: "writes_status_and_body", + respErr: newErrorResponse("upstream failed", "api_error", http.StatusBadGateway, 0), + expectStatus: http.StatusBadGateway, + expectBodyContains: `"upstream failed"`, + }, + { + // Whole-second retryAfter: emitted as integer seconds. + name: "retry_after_in_seconds", + respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second), + expectStatus: http.StatusTooManyRequests, + expectRetryAfter: "60", + }, + { + // 500ms rounds up to Retry-After: 1. + name: "retry_after_500ms_rounds_up_to_one", + respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond), + expectStatus: http.StatusTooManyRequests, + expectRetryAfter: "1", + }, + { + // 200ms rounds up to Retry-After: 1. + name: "retry_after_200ms_rounds_up_to_one", + respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond), + expectStatus: http.StatusTooManyRequests, + expectRetryAfter: "1", + }, + { + // Negative retryAfter: header omitted. + name: "negative_retry_after_omits_header", + respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second), + expectStatus: http.StatusTooManyRequests, + expectRetryAfter: "", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + base := &interceptionBase{logger: slog.Make()} + + w := httptest.NewRecorder() + base.writeUpstreamError(w, tc.respErr) + + assert.Equal(t, tc.expectStatus, w.Code, "status code") + assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type header") + assert.Equal(t, tc.expectRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + if tc.expectBodyContains != "" { + assert.Contains(t, w.Body.String(), tc.expectBodyContains, "response body") + } + }) + } +} diff --git a/aibridge/intercept/messages/blocking.go b/aibridge/intercept/messages/blocking.go index 610f93457841a..7251d0c07c98a 100644 --- a/aibridge/intercept/messages/blocking.go +++ b/aibridge/intercept/messages/blocking.go @@ -112,6 +112,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req return xerrors.Errorf("upstream connection closed: %w", err) } + // The failover loop may return a keypool exhaustion + // error. Check before the SDK-error path. + if keyErr := i.mapExhaustionError(err); keyErr != nil { + i.writeUpstreamError(w, keyErr) + return xerrors.Errorf("key pool exhausted: %w", err) + } + if antErr := getErrorResponse(err); antErr != nil { i.writeUpstreamError(w, antErr) return xerrors.Errorf("anthropic API error: %w", err) @@ -338,5 +345,44 @@ func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.Mes ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer tracing.EndSpanErr(span, &outErr) - return svc.New(ctx, anthropic.MessageNewParams{}, i.withBody()) + // BYOK: single attempt, no failover. + if i.cfg.KeyPool == nil { + return svc.New(ctx, anthropic.MessageNewParams{}, i.withBody()) + } + return i.newMessageWithKeyFailover(ctx, svc) +} + +// newMessageWithKeyFailover walks the centralized key pool, +// trying each key until one succeeds or the pool is exhausted. +// Keys are marked temporary on 429 and permanent on 401/403. +// Errors that aren't key-specific don't trigger failover and +// are returned to the caller. +func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) { + // TODO(ssncferreira): update the interception's credential + // hint with the actually-used key (the successful key on + // success, the last tried key on failure) in the upstack PR. + walker := i.cfg.KeyPool.Walker() + for { + key, err := walker.Next() + if err != nil { + return nil, err + } + + msg, err := svc.New(ctx, anthropic.MessageNewParams{}, + i.withBody(), + option.WithAPIKey(key.Value()), + // Disable SDK retries because the failover loop + // handles retries via key rotation. + option.WithMaxRetries(0), + ) + if err == nil { + return msg, nil + } + // Mark the key based on the upstream response. + if !i.markKeyOnError(ctx, key, err) { + // Not a key-specific failure: return without + // trying another key. + return nil, err + } + } } diff --git a/aibridge/intercept/messages/blocking_test.go b/aibridge/intercept/messages/blocking_test.go new file mode 100644 index 0000000000000..51d77fb0d0e15 --- /dev/null +++ b/aibridge/intercept/messages/blocking_test.go @@ -0,0 +1,452 @@ +package messages //nolint:testpackage // tests unexported internals + +import ( + "io" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/internal/testutil" + "github.com/coder/coder/v2/aibridge/keypool" + "github.com/coder/coder/v2/aibridge/mcp" + "github.com/coder/quartz" +) + +// Common request and Anthropic-shaped response bodies. +const ( + requestBody = `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` + successBody = `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}` + toolUseBody = `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_01","name":"test_tool","input":{}}],"model":"claude-opus-4-5","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}` + rateLimitBody = `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}` + authErrorBody = `{"type":"error","error":{"type":"authentication_error","message":"invalid key"}}` + serverErrorBody = `{"type":"error","error":{"type":"api_error","message":"server error"}}` +) + +type upstreamResponse struct { + statusCode int + body string + headers map[string]string +} + +func TestBlockingInterception_KeyFailover(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + // Centralized pool keys. Empty when byokKey is set. + keys []string + // BYOK key. Empty when keys is set. + byokKey string + // Scripted upstream responses keyed by X-Api-Key. + responses map[string]upstreamResponse + expectedRequestCount int32 + expectedStatusCode int + expectedRetryAfter string + // Expected key states after the request, by index in keys. + expectedKeyStates []keypool.KeyState + }{ + { + // Given: 1 valid key returning 200. + // Then: 1 request, 200 response, key remains valid. + name: "single_valid_key", + keys: []string{"k0"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusOK, body: successBody}, + }, + expectedRequestCount: 1, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + }, + { + // Given: 2 keys; key-0 returns 429, key-1 returns 200. + // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. + name: "failover_after_429", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + body: rateLimitBody, + }, + "k1": {statusCode: http.StatusOK, body: successBody}, + }, + expectedRequestCount: 2, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateValid, + }, + }, + { + // Given: 2 keys; key-0 returns 401, key-1 returns 200. + // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. + name: "failover_after_401", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1": {statusCode: http.StatusOK, body: successBody}, + }, + expectedRequestCount: 2, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStatePermanent, + keypool.KeyStateValid, + }, + }, + { + // Given: 2 keys; key-0 returns 403, key-1 returns 200. + // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. + name: "failover_after_403", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusForbidden, body: authErrorBody}, + "k1": {statusCode: http.StatusOK, body: successBody}, + }, + expectedRequestCount: 2, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStatePermanent, + keypool.KeyStateValid, + }, + }, + { + // Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s. + // Then: 3 requests, 429 response with smallest Retry-After, + // all keys temporary. + name: "all_keys_rate_limited", + keys: []string{"k0", "k1", "k2"}, + responses: map[string]upstreamResponse{ + "k0": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + body: rateLimitBody, + }, + "k1": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "3"}, + body: rateLimitBody, + }, + "k2": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "10"}, + body: rateLimitBody, + }, + }, + expectedRequestCount: 3, + expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "3", + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateTemporary, + keypool.KeyStateTemporary, + }, + }, + { + // Given: 2 keys; both return 401. + // Then: 2 requests, 502 api_error response, both keys permanent. + name: "all_keys_unauthorized", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + }, + expectedRequestCount: 2, + expectedStatusCode: http.StatusBadGateway, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStatePermanent, + keypool.KeyStatePermanent, + }, + }, + { + // Given: 2 keys; key-0 returns 500. + // Then: 1 request, 500 response, both keys remain valid. + name: "server_error_no_failover", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, + }, + expectedRequestCount: 1, + expectedStatusCode: http.StatusInternalServerError, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateValid, + keypool.KeyStateValid, + }, + }, + { + // Given: BYOK with a single key returning 429. + // Then: 1 request, 429 response, no failover. + name: "byok_no_failover", + byokKey: "user-byok", + responses: map[string]upstreamResponse{ + "user-byok": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{ + "Retry-After": "5", + // BYOK doesn't set MaxRetries(0); + // suppress SDK retries to test a + // single attempt. + "x-should-retry": "false", + }, + body: rateLimitBody, + }, + }, + expectedRequestCount: 1, + expectedStatusCode: http.StatusTooManyRequests, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Mock upstream: counts requests and returns + // scripted responses keyed by X-Api-Key. An unmapped + // key falls through to 500 so misconfigured cases + // surface via the status assertion. + var requestCount atomic.Int32 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + _, _ = io.Copy(io.Discard, r.Body) + resp, ok := tc.responses[r.Header.Get("X-Api-Key")] + if !ok { + resp = upstreamResponse{statusCode: http.StatusInternalServerError} + } + w.Header().Set("Content-Type", "application/json") + for hk, hv := range resp.headers { + w.Header().Set(hk, hv) + } + w.WriteHeader(resp.statusCode) + _, _ = w.Write([]byte(resp.body)) + })) + t.Cleanup(upstream.Close) + + cfg := config.Anthropic{BaseURL: upstream.URL + "/"} + var pool *keypool.Pool + if len(tc.keys) > 0 { + var err error + pool, err = keypool.New(tc.keys, quartz.NewMock(t)) + require.NoError(t, err) + cfg.KeyPool = pool + } else if tc.byokKey != "" { + cfg.Key = tc.byokKey + } + + payload, err := NewRequestPayload([]byte(requestBody)) + require.NoError(t, err) + + interceptor := NewBlockingInterceptor( + uuid.New(), + payload, + config.ProviderAnthropic, + cfg, + nil, + http.Header{}, + "X-Api-Key", + otel.Tracer("blocking_test"), + intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + ) + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + w := httptest.NewRecorder() + err = interceptor.ProcessRequest(w, req) + if tc.expectedStatusCode == http.StatusOK { + require.NoError(t, err) + } else { + require.Error(t, err) + } + + assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") + assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") + assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + if pool != nil { + assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + } + }) + } +} + +// TestBlockingInterception_AgenticLoopFailover covers the +// scenarios that span an agentic-loop continuation: the initial +// client request and the subsequent tool-call continuation can +// each fail over independently. Each iteration gets its own +// walker. +func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + // Scripted upstream responses consumed in order of + // upstream request. + responses []upstreamResponse + expectedRequestCount int32 + expectedSeenKeys []string + expectedStatusCode int + expectedKeyStates []keypool.KeyState + }{ + { + // Given: 2 keys; both upstream calls succeed on key-0. + // Then: 2 requests, 200 response, both keys remain valid. + name: "happy_path", + responses: []upstreamResponse{ + {statusCode: http.StatusOK, body: toolUseBody}, + {statusCode: http.StatusOK, body: successBody}, + }, + expectedRequestCount: 2, + expectedSeenKeys: []string{"k0", "k0"}, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateValid, + keypool.KeyStateValid, + }, + }, + { + // Given: 2 keys; key-0 succeeds initially, then 429s + // during the agentic continuation, key-1 succeeds. + // Then: 3 requests, 200 response, key-0 temporary, + // key-1 valid. + name: "agentic_failover_to_k1", + responses: []upstreamResponse{ + {statusCode: http.StatusOK, body: toolUseBody}, + { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + body: rateLimitBody, + }, + {statusCode: http.StatusOK, body: successBody}, + }, + expectedRequestCount: 3, + expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateValid, + }, + }, + { + // Given: 2 keys; key-0 succeeds initially, then both + // keys 429 during the agentic continuation. + // Then: 3 requests, 429 response with smallest + // Retry-After, both keys temporary. + name: "agentic_all_keys_fail", + responses: []upstreamResponse{ + {statusCode: http.StatusOK, body: toolUseBody}, + { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + body: rateLimitBody, + }, + { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "3"}, + body: rateLimitBody, + }, + }, + expectedRequestCount: 3, + expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedStatusCode: http.StatusTooManyRequests, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateTemporary, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var requestCount atomic.Int32 + var seenKeysMu sync.Mutex + var seenKeys []string + + // Mock upstream: returns scripted responses in order, + // records each request's X-Api-Key for assertions. + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + idx := int(requestCount.Add(1)) - 1 + seenKeysMu.Lock() + seenKeys = append(seenKeys, r.Header.Get("X-Api-Key")) + seenKeysMu.Unlock() + _, _ = io.Copy(io.Discard, r.Body) + + if idx >= len(tc.responses) { + w.WriteHeader(http.StatusInternalServerError) + return + } + resp := tc.responses[idx] + w.Header().Set("Content-Type", "application/json") + for hk, hv := range resp.headers { + w.Header().Set(hk, hv) + } + w.WriteHeader(resp.statusCode) + _, _ = w.Write([]byte(resp.body)) + })) + t.Cleanup(upstream.Close) + + pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + require.NoError(t, err) + + cfg := config.Anthropic{ + BaseURL: upstream.URL + "/", + KeyPool: pool, + } + + payload, err := NewRequestPayload([]byte(requestBody)) + require.NoError(t, err) + + interceptor := NewBlockingInterceptor( + uuid.New(), + payload, + config.ProviderAnthropic, + cfg, + nil, + http.Header{}, + "X-Api-Key", + otel.Tracer("blocking_test"), + intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + ) + + // Mock proxy with a tool the upstream's tool_use + // response will reference. + proxy := &mockServerProxier{ + tools: []*mcp.Tool{ + { + Client: stubToolCaller{}, + ID: "test_tool", + Name: "test_tool", + ServerName: "coder", + Logger: slog.Make(), + }, + }, + } + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + w := httptest.NewRecorder() + err = interceptor.ProcessRequest(w, req) + if tc.expectedStatusCode == http.StatusOK { + require.NoError(t, err) + } else { + require.Error(t, err) + } + + assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") + assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") + + seenKeysMu.Lock() + defer seenKeysMu.Unlock() + assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") + assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + }) + } +} diff --git a/aibridge/intercept/messages/streaming.go b/aibridge/intercept/messages/streaming.go index 881e62dad599d..0f3b605e21f51 100644 --- a/aibridge/intercept/messages/streaming.go +++ b/aibridge/intercept/messages/streaming.go @@ -25,6 +25,7 @@ import ( aibcontext "github.com/coder/coder/v2/aibridge/context" "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/intercept/eventstream" + "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/recorder" "github.com/coder/coder/v2/aibridge/tracing" @@ -161,14 +162,66 @@ newStream: break } - stream := i.newStream(streamCtx, svc) + // Per-iteration walker. An iteration is either an agentic + // continuation (sending a tool result back in a new + // stream) or a failover retry (previous key marked, try + // the next one). + var walker *keypool.Walker + if i.cfg.KeyPool != nil { + walker = i.cfg.KeyPool.Walker() + } + + var streamOpts []option.RequestOption + var currentKey *keypool.Key + if walker != nil { + key, err := walker.Next() + if err != nil { + // Pool exhausted in this iteration. Relay the + // error to the client: as an SSE event if events + // have already been sent, or by direct write + // otherwise. + if respErr := i.mapExhaustionError(err); respErr != nil { + interceptionErr = respErr + if events.IsStreaming() { + payload, mErr := i.marshal(respErr) + if mErr != nil { + logger.Warn(ctx, "failed to marshal exhaustion error", slog.Error(mErr)) + } else if sErr := events.Send(streamCtx, payload); sErr != nil { + logger.Warn(ctx, "failed to relay exhaustion error", slog.Error(sErr)) + } + } else { + i.writeUpstreamError(w, respErr) + } + } + break + } + currentKey = key + streamOpts = append(streamOpts, + option.WithAPIKey(key.Value()), + // Disable SDK retries because the failover + // loop handles retries via key rotation. + option.WithMaxRetries(0), + ) + } + + stream := i.newStream(streamCtx, svc, streamOpts...) var message anthropic.Message var lastToolName string pendingToolCalls := make(map[string]string) + // iterationStarted is per-iteration (reset on every + // newStream loop): true once the upstream call has + // produced any events for this iteration. While false, + // a key-specific failure can still fail over to the + // next key. Distinct from events.IsStreaming(), which + // is stream-wide and stays true once iteration 1 has + // sent any event downstream. + var iterationStarted bool + for stream.Next() { + iterationStarted = true event := stream.Current() if err := message.Accumulate(event); err != nil { logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON())) @@ -478,40 +531,44 @@ newStream: promptFound = false //nolint:ineffassign // reset to prevent double-recording across newStream iterations } - if events.IsStreaming() { - // Check if the stream encountered any errors. - if streamErr := stream.Err(); streamErr != nil { - if eventstream.IsUnrecoverableError(streamErr) { - logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) - // We can't reflect an error back if there's a connection error or the request context was canceled. - } else if antErr := getErrorResponse(streamErr); antErr != nil { - logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr)) - interceptionErr = antErr - } else { - logger.Warn(ctx, "unknown stream error", slog.Error(streamErr)) - // Unfortunately, the Anthropic SDK does not support parsing errors received in the stream - // into known types (i.e. [shared.OverloadedError]). - // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 - // All it does is wrap the payload in an error - which is all we can return, currently. - interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr)) - } - } else if lastErr != nil { - // Otherwise check if any logical errors occurred during processing. - logger.Warn(ctx, "stream processing failed", slog.Error(lastErr)) - interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr)) - } - - if interceptionErr != nil { - payload, err := i.marshal(interceptionErr) + if iterationStarted { + // Mid-stream error or logical error: events have + // already streamed for this iteration, so the + // error is relayed as an SSE event. + if respErr := i.mapStreamError(ctx, logger, stream.Err(), lastErr); respErr != nil { + interceptionErr = respErr + payload, err := i.marshal(respErr) if err != nil { - logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", fmt.Sprintf("%+v", interceptionErr))) + logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", fmt.Sprintf("%+v", respErr))) } else if err := events.Send(streamCtx, payload); err != nil { logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) } } } else { - // Stream has not started yet; write to response if present. - i.writeUpstreamError(w, getErrorResponse(stream.Err())) + // Pre-stream failure of this iteration. For + // centralized requests, mark the key and retry with + // the next one. + if currentKey != nil && i.markKeyOnError(ctx, currentKey, stream.Err()) { + continue newStream + } + // Non-key error: relay it. + respErr := getErrorResponse(stream.Err()) + if respErr != nil { + interceptionErr = respErr + if events.IsStreaming() { + // Prior iterations have streamed, so the SSE + // connection is open: inject as an SSE event. + payload, mErr := i.marshal(respErr) + if mErr != nil { + logger.Warn(ctx, "failed to marshal error", slog.Error(mErr)) + } else if sErr := events.Send(streamCtx, payload); sErr != nil { + logger.Warn(ctx, "failed to relay error", slog.Error(sErr)) + } + } else { + // No events streamed yet, write the response directly. + i.writeUpstreamError(w, respErr) + } + } } shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30) @@ -534,6 +591,35 @@ newStream: return interceptionErr } +// mapStreamError converts a mid-stream upstream error or +// processing error into a relayable responseError. Returns nil +// when the error is unrecoverable, in which case nothing can be +// relayed back. +func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *responseError { + if streamErr != nil { + if eventstream.IsUnrecoverableError(streamErr) { + logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) + // We can't reflect an error back if there's a connection error or the request context was canceled. + return nil + } + if antErr := getErrorResponse(streamErr); antErr != nil { + logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr)) + return antErr + } + logger.Warn(ctx, "unknown stream error", slog.Error(streamErr)) + // Unfortunately, the Anthropic SDK does not support parsing errors received in the stream + // into known types (i.e. [shared.OverloadedError]). + // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 + // All it does is wrap the payload in an error - which is all we can return, currently. + return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), 0, 0) + } + if lastErr != nil { + logger.Warn(ctx, "stream processing failed", slog.Error(lastErr)) + return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), 0, 0) + } + return nil +} + func (i *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) { sj, err := sjson.Set(event.RawJSON(), "message.id", i.ID().String()) if err != nil { @@ -585,9 +671,10 @@ func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte } // newStream traces svc.NewStreaming() call. -func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] { +func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, extraOpts ...option.RequestOption) *ssestream.Stream[anthropic.MessageStreamEventUnion] { _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) defer span.End() - return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, i.withBody()) + opts := append([]option.RequestOption{i.withBody()}, extraOpts...) + return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, opts...) } diff --git a/aibridge/intercept/messages/streaming_test.go b/aibridge/intercept/messages/streaming_test.go new file mode 100644 index 0000000000000..bbcbd9f187053 --- /dev/null +++ b/aibridge/intercept/messages/streaming_test.go @@ -0,0 +1,552 @@ +package messages //nolint:testpackage // tests unexported internals + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + + "github.com/google/uuid" + mcplib "github.com/mark3labs/mcp-go/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/internal/testutil" + "github.com/coder/coder/v2/aibridge/keypool" + "github.com/coder/coder/v2/aibridge/mcp" + "github.com/coder/quartz" +) + +// Anthropic-shaped SSE body for a successful streaming response. +const streamingSuccessBody = `event: message_start +data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":5}} + +event: message_stop +data: {"type":"message_stop"} +` + +func TestStreamingInterception_KeyFailover(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + // Centralized pool keys. Empty when byokKey is set. + keys []string + // BYOK key. Empty when keys is set. + byokKey string + // Scripted upstream responses keyed by X-Api-Key. + responses map[string]upstreamResponse + expectedRequestCount int32 + expectedStatusCode int + expectedRetryAfter string + // Expected key states after the request, by index in keys. + expectedKeyStates []keypool.KeyState + }{ + { + // Given: 1 valid key returning a successful stream. + // Then: 1 request, 200 response, key remains valid. + name: "single_valid_key", + keys: []string{"k0"}, + responses: map[string]upstreamResponse{ + "k0": { + statusCode: http.StatusOK, + headers: map[string]string{"Content-Type": "text/event-stream"}, + body: streamingSuccessBody, + }, + }, + expectedRequestCount: 1, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + }, + { + // Given: 2 keys; key-0 returns 429 pre-stream, key-1 + // streams successfully. + // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. + name: "failover_after_429", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + body: rateLimitBody, + }, + "k1": { + statusCode: http.StatusOK, + headers: map[string]string{"Content-Type": "text/event-stream"}, + body: streamingSuccessBody, + }, + }, + expectedRequestCount: 2, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateValid, + }, + }, + { + // Given: 2 keys; key-0 returns 401 pre-stream, key-1 + // streams successfully. + // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. + name: "failover_after_401", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1": { + statusCode: http.StatusOK, + headers: map[string]string{"Content-Type": "text/event-stream"}, + body: streamingSuccessBody, + }, + }, + expectedRequestCount: 2, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStatePermanent, + keypool.KeyStateValid, + }, + }, + { + // Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams. + // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. + name: "failover_after_403", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusForbidden, body: authErrorBody}, + "k1": { + statusCode: http.StatusOK, + headers: map[string]string{"Content-Type": "text/event-stream"}, + body: streamingSuccessBody, + }, + }, + expectedRequestCount: 2, + expectedStatusCode: http.StatusOK, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStatePermanent, + keypool.KeyStateValid, + }, + }, + { + // Given: 3 keys; all return 429 pre-stream with + // cooldowns 5s, 3s, 10s. + // Then: 3 requests, 429 response with smallest + // Retry-After, all keys temporary. + name: "all_keys_rate_limited", + keys: []string{"k0", "k1", "k2"}, + responses: map[string]upstreamResponse{ + "k0": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + body: rateLimitBody, + }, + "k1": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "3"}, + body: rateLimitBody, + }, + "k2": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "10"}, + body: rateLimitBody, + }, + }, + expectedRequestCount: 3, + expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "3", + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateTemporary, + keypool.KeyStateTemporary, + }, + }, + { + // Given: 2 keys; both return 401 pre-stream. + // Then: 2 requests, 502 api_error response, both keys permanent. + name: "all_keys_unauthorized", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + "k1": {statusCode: http.StatusUnauthorized, body: authErrorBody}, + }, + expectedRequestCount: 2, + expectedStatusCode: http.StatusBadGateway, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStatePermanent, + keypool.KeyStatePermanent, + }, + }, + { + // Given: 2 keys; key-0 returns 500 pre-stream. + // Then: 1 request, 500 response, both keys remain valid. + name: "server_error_no_failover", + keys: []string{"k0", "k1"}, + responses: map[string]upstreamResponse{ + "k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, + }, + expectedRequestCount: 1, + expectedStatusCode: http.StatusInternalServerError, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateValid, + keypool.KeyStateValid, + }, + }, + { + // Given: BYOK with a single key returning 429. + // Then: 1 request, 429 response, no failover. + name: "byok_no_failover", + byokKey: "user-byok", + responses: map[string]upstreamResponse{ + "user-byok": { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{ + "Retry-After": "5", + // BYOK doesn't set MaxRetries(0); + // suppress SDK retries to test a + // single attempt. + "x-should-retry": "false", + }, + body: rateLimitBody, + }, + }, + expectedRequestCount: 1, + expectedStatusCode: http.StatusTooManyRequests, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Mock upstream: counts requests and returns + // scripted responses keyed by X-Api-Key. An unmapped + // key falls through to 500 so misconfigured cases + // surface via the status assertion. + var requestCount atomic.Int32 + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + _, _ = io.Copy(io.Discard, r.Body) + resp, ok := tc.responses[r.Header.Get("X-Api-Key")] + if !ok { + resp = upstreamResponse{statusCode: http.StatusInternalServerError} + } + for hk, hv := range resp.headers { + w.Header().Set(hk, hv) + } + w.WriteHeader(resp.statusCode) + _, _ = w.Write([]byte(resp.body)) + })) + t.Cleanup(upstream.Close) + + cfg := config.Anthropic{BaseURL: upstream.URL + "/"} + var pool *keypool.Pool + if len(tc.keys) > 0 { + var err error + pool, err = keypool.New(tc.keys, quartz.NewMock(t)) + require.NoError(t, err) + cfg.KeyPool = pool + } else if tc.byokKey != "" { + cfg.Key = tc.byokKey + } + + payload, err := NewRequestPayload([]byte(requestBody)) + require.NoError(t, err) + + interceptor := NewStreamingInterceptor( + uuid.New(), + payload, + config.ProviderAnthropic, + cfg, + nil, + http.Header{}, + "X-Api-Key", + otel.Tracer("streaming_test"), + intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + ) + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + w := httptest.NewRecorder() + err = interceptor.ProcessRequest(w, req) + if tc.expectedStatusCode == http.StatusOK { + require.NoError(t, err) + } else { + require.Error(t, err) + } + + assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") + assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") + assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + // No prior iteration streamed, so errors must be a + // direct HTTP response, not an SSE event. + assert.NotContains(t, w.Body.String(), "event: error", "error must not be relayed as an SSE event") + if pool != nil { + assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + } + }) + } +} + +// SSE bodies covering an agentic-continuation flow. +const ( + // First response: a tool_use block referencing the injected + // "test_tool". Triggers the agentic continuation loop. + toolUseStreamBody = `event: message_start +data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"test_tool","input":{}}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{}"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":5}} + +event: message_stop +data: {"type":"message_stop"} + +` + + // Second response (after the tool result is sent back): + // a plain text completion that ends the loop. + textStreamBody = `event: message_start +data: {"type":"message_start","message":{"id":"msg_02","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":15,"output_tokens":1}}} + +event: content_block_start +data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} + +event: content_block_delta +data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"done"}} + +event: content_block_stop +data: {"type":"content_block_stop","index":0} + +event: message_delta +data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":3}} + +event: message_stop +data: {"type":"message_stop"} + +` +) + +// stubToolCaller is a minimal mcp.ToolCaller that returns a fixed +// text result, so the agentic continuation can proceed. +type stubToolCaller struct{} + +func (stubToolCaller) CallTool(_ context.Context, _ mcplib.CallToolRequest) (*mcplib.CallToolResult, error) { + return mcplib.NewToolResultText("tool result"), nil +} + +// TestStreamingInterception_AgenticLoopFailover covers the +// scenarios that span an agentic-loop continuation: the initial +// client request and the subsequent tool-call continuation can +// each fail over independently. Each iteration gets its own +// walker. +func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { + t.Parallel() + + sseHeaders := map[string]string{"Content-Type": "text/event-stream"} + + tests := []struct { + name string + // Scripted upstream responses consumed in order of + // upstream request. + responses []upstreamResponse + expectedRequestCount int32 + expectedSeenKeys []string + // Substring expected in the response body. Either a + // success marker (e.g. "done") or an error marker + // (e.g. "rate_limit_error"). + expectedBodyContains string + // True when the error must be relayed as an SSE event. + expectErrorAsSSEEvent bool + // True when ProcessRequest is expected to return an + // error (e.g. all keys exhausted). + expectedErr bool + expectedKeyStates []keypool.KeyState + }{ + { + // Given: 2 keys; both upstream calls succeed on key-0. + // Then: 2 requests, success body, both keys remain valid. + name: "happy_path", + responses: []upstreamResponse{ + {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, + {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, + }, + expectedRequestCount: 2, + expectedSeenKeys: []string{"k0", "k0"}, + expectedBodyContains: "done", + expectErrorAsSSEEvent: false, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateValid, + keypool.KeyStateValid, + }, + }, + { + // Given: 2 keys; key-0 succeeds initially, then 429s + // during the agentic continuation, key-1 succeeds. + // Then: 3 requests, success body, key-0 temporary, + // key-1 valid. + name: "agentic_failover_to_k1", + responses: []upstreamResponse{ + {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, + { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + body: rateLimitBody, + }, + {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, + }, + expectedRequestCount: 3, + expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedBodyContains: "done", + expectErrorAsSSEEvent: false, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateValid, + }, + }, + { + // Given: 2 keys; key-0 succeeds initially, then both + // keys 429 during the agentic continuation. + // Then: 3 requests, error injected as SSE event, both + // keys temporary. + name: "agentic_all_keys_fail", + responses: []upstreamResponse{ + {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, + { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + body: rateLimitBody, + }, + { + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "3"}, + body: rateLimitBody, + }, + }, + expectedRequestCount: 3, + expectedSeenKeys: []string{"k0", "k0", "k1"}, + expectedBodyContains: "all configured keys are rate-limited", + expectErrorAsSSEEvent: true, + expectedErr: true, + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateTemporary, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var requestCount atomic.Int32 + var seenKeysMu sync.Mutex + var seenKeys []string + + // Mock upstream: returns scripted responses in order, + // records each request's X-Api-Key for assertions. + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + idx := int(requestCount.Add(1)) - 1 + seenKeysMu.Lock() + seenKeys = append(seenKeys, r.Header.Get("X-Api-Key")) + seenKeysMu.Unlock() + _, _ = io.Copy(io.Discard, r.Body) + + if idx >= len(tc.responses) { + w.WriteHeader(http.StatusInternalServerError) + return + } + resp := tc.responses[idx] + for hk, hv := range resp.headers { + w.Header().Set(hk, hv) + } + w.WriteHeader(resp.statusCode) + _, _ = w.Write([]byte(resp.body)) + })) + t.Cleanup(upstream.Close) + + pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + require.NoError(t, err) + + cfg := config.Anthropic{ + BaseURL: upstream.URL + "/", + KeyPool: pool, + } + + payload, err := NewRequestPayload([]byte(requestBody)) + require.NoError(t, err) + + interceptor := NewStreamingInterceptor( + uuid.New(), + payload, + config.ProviderAnthropic, + cfg, + nil, + http.Header{}, + "X-Api-Key", + otel.Tracer("streaming_test"), + intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), + ) + + // Mock proxy with a tool the upstream's tool_use event + // will reference. The stub caller returns a fixed + // text result. + proxy := &mockServerProxier{ + tools: []*mcp.Tool{ + { + Client: stubToolCaller{}, + ID: "test_tool", + Name: "test_tool", + ServerName: "coder", + Logger: slog.Make(), + }, + }, + } + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy) + + req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) + w := httptest.NewRecorder() + err = interceptor.ProcessRequest(w, req) + if tc.expectedErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") + body := w.Body.String() + assert.Contains(t, body, tc.expectedBodyContains, "response body") + if tc.expectErrorAsSSEEvent { + assert.Contains(t, body, "event: error", "error must be relayed as an SSE event") + } + + seenKeysMu.Lock() + defer seenKeysMu.Unlock() + assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") + assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + }) + } +} diff --git a/aibridge/internal/integrationtest/keypool_failover_test.go b/aibridge/internal/integrationtest/keypool_failover_test.go new file mode 100644 index 0000000000000..bab2552a28af2 --- /dev/null +++ b/aibridge/internal/integrationtest/keypool_failover_test.go @@ -0,0 +1,128 @@ +package integrationtest //nolint:testpackage // tests unexported internals + +import ( + "fmt" + "io" + "net/http" + "net/http/httptest" + "sync" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" + + "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/fixtures" + "github.com/coder/coder/v2/aibridge/keypool" + "github.com/coder/coder/v2/aibridge/provider" + "github.com/coder/quartz" +) + +// TestAnthropic_KeyFailover verifies that a pool's key state +// persists across distinct client requests: a key marked +// temporary on request 1 is still skipped on request 2 without +// a wasted upstream attempt. +func TestAnthropic_KeyFailover(t *testing.T) { + t.Parallel() + + fix := fixtures.Parse(t, fixtures.AntSimple) + + tests := []struct { + name string + streaming bool + successBody []byte + successCType string + }{ + { + name: "blocking", + streaming: false, + successBody: fix.NonStreaming(), + successCType: "application/json", + }, + { + name: "streaming", + streaming: true, + successBody: fix.Streaming(), + successCType: "text/event-stream", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + require.NoError(t, err) + + var requestCount atomic.Int32 + var seenKeysMu sync.Mutex + var seenKeys []string + + // Mock upstream: k0 always returns 429, k1 returns + // the per-test success body. + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + requestCount.Add(1) + key := r.Header.Get("X-Api-Key") + seenKeysMu.Lock() + seenKeys = append(seenKeys, key) + seenKeysMu.Unlock() + _, _ = io.Copy(io.Discard, r.Body) + + switch key { + case "k0": + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Retry-After", "60") + w.WriteHeader(http.StatusTooManyRequests) + _, _ = fmt.Fprint(w, `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`) + case "k1": + w.Header().Set("Content-Type", tc.successCType) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(tc.successBody) + default: + w.WriteHeader(http.StatusInternalServerError) + } + })) + t.Cleanup(upstream.Close) + + bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL, + withCustomProvider(provider.NewAnthropic(config.Anthropic{ + BaseURL: upstream.URL, + KeyPool: pool, + }, nil)), + ) + + requestBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming) + require.NoError(t, err) + + // Request 1: walker starts at k0, fails over to k1 + // after 429. + resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, requestBody) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusOK, resp.StatusCode) + + // Request 2: walker skips the now-temporary k0 and + // goes straight to k1 (1 upstream call, not 2). + resp, err = bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, requestBody) + require.NoError(t, err) + _, _ = io.Copy(io.Discard, resp.Body) + require.NoError(t, resp.Body.Close()) + require.Equal(t, http.StatusOK, resp.StatusCode) + + seenKeysMu.Lock() + defer seenKeysMu.Unlock() + // Request 1: 2 calls (k0 then k1). Request 2: 1 call (k1). + assert.Equal(t, int32(3), requestCount.Load(), "upstream request count") + assert.Equal(t, []string{"k0", "k1", "k1"}, seenKeys, "seen keys") + + // Pool state persists: k0 temporary, k1 valid. + assert.Equal(t, []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateValid, + }, pool.PoolState(), "key states") + }) + } +} diff --git a/aibridge/keypool/headers.go b/aibridge/keypool/headers.go new file mode 100644 index 0000000000000..a7626433672d6 --- /dev/null +++ b/aibridge/keypool/headers.go @@ -0,0 +1,37 @@ +package keypool + +import ( + "net/http" + "strconv" + "strings" + "time" +) + +// ParseRetryAfter extracts the cooldown duration from response +// headers. It prefers the OpenAI-specific "retry-after-ms" +// header (milliseconds) over the standard "Retry-After" header +// (seconds). Returns zero if neither header is present or +// parseable. The HTTP-date form of "Retry-After" is not parsed. +func ParseRetryAfter(resp *http.Response) time.Duration { + if resp == nil { + return 0 + } + + // OpenAI convention: millisecond precision. + if val := resp.Header.Get("retry-after-ms"); val != "" { + ms, err := strconv.ParseFloat(strings.TrimSpace(val), 64) + if err == nil && ms > 0 { + return time.Duration(ms * float64(time.Millisecond)) + } + } + + // Standard header: seconds. + if val := resp.Header.Get("Retry-After"); val != "" { + seconds, err := strconv.Atoi(strings.TrimSpace(val)) + if err == nil && seconds > 0 { + return time.Duration(seconds) * time.Second + } + } + + return 0 +} diff --git a/aibridge/keypool/headers_test.go b/aibridge/keypool/headers_test.go new file mode 100644 index 0000000000000..853450c68a383 --- /dev/null +++ b/aibridge/keypool/headers_test.go @@ -0,0 +1,110 @@ +package keypool_test + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/coder/coder/v2/aibridge/keypool" +) + +func TestParseRetryAfter(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + headers map[string]string + nilResponse bool + expected time.Duration + }{ + // nil response. + { + name: "nil_response", + nilResponse: true, + expected: 0, + }, + // No headers set. + { + name: "no_headers", + headers: nil, + expected: 0, + }, + // retry-after-ms (OpenAI, preferred). + { + name: "openai_retry_after_ms", + headers: map[string]string{"retry-after-ms": "2500"}, + expected: 2500 * time.Millisecond, + }, + { + name: "whitespace_trimmed_ms", + headers: map[string]string{"retry-after-ms": " 1500 "}, + expected: 1500 * time.Millisecond, + }, + { + name: "negative_ms_returns_zero", + headers: map[string]string{"retry-after-ms": "-100"}, + expected: 0, + }, + // Retry-After (standard, seconds). + { + name: "standard_retry_after_seconds", + headers: map[string]string{"Retry-After": "60"}, + expected: 60 * time.Second, + }, + { + name: "whitespace_trimmed_seconds", + headers: map[string]string{"Retry-After": " 30 "}, + expected: 30 * time.Second, + }, + { + name: "zero_seconds_returns_zero", + headers: map[string]string{"Retry-After": "0"}, + expected: 0, + }, + { + name: "negative_seconds_returns_zero", + headers: map[string]string{"Retry-After": "-5"}, + expected: 0, + }, + // Both headers set: precedence and fallback. + { + name: "prefers_retry_after_ms_over_standard", + headers: map[string]string{ + "retry-after-ms": "1500", + "Retry-After": "30", + }, + expected: 1500 * time.Millisecond, + }, + { + name: "falls_back_to_standard_when_ms_invalid", + headers: map[string]string{"retry-after-ms": "invalid", "Retry-After": "10"}, + expected: 10 * time.Second, + }, + { + name: "zero_ms_falls_back_to_standard", + headers: map[string]string{"retry-after-ms": "0", "Retry-After": "5"}, + expected: 5 * time.Second, + }, + { + name: "zero_ms_and_zero_seconds_return_zero", + headers: map[string]string{"retry-after-ms": "0", "Retry-After": "0"}, + expected: 0, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + var resp *http.Response + if !tc.nilResponse { + resp = &http.Response{Header: make(http.Header)} + for key, val := range tc.headers { + resp.Header.Set(key, val) + } + } + assert.Equal(t, tc.expected, keypool.ParseRetryAfter(resp)) + }) + } +} diff --git a/aibridge/keypool/keymark.go b/aibridge/keypool/keymark.go new file mode 100644 index 0000000000000..8ce6e50813086 --- /dev/null +++ b/aibridge/keypool/keymark.go @@ -0,0 +1,48 @@ +package keypool + +import ( + "context" + "net/http" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/aibridge/utils" +) + +// MarkKeyOnStatus marks key based on a key-specific HTTP +// status code (429 for temporary, 401 or 403 for permanent). +// Returns true if the status was a key-specific failover +// trigger so callers can retry with the next key. +func MarkKeyOnStatus( + ctx context.Context, + key *Key, + statusCode int, + resp *http.Response, + logger slog.Logger, + providerName string, +) bool { + switch statusCode { + case http.StatusTooManyRequests: + cooldown := ParseRetryAfter(resp) + if key.MarkTemporary(cooldown) { + logger.Warn(ctx, "key marked temporary", + slog.F("provider", providerName), + slog.F("api_key_hint", utils.MaskSecret(key.Value())), + slog.F("status", statusCode), + slog.F("cooldown", cooldown)) + } + return true + case http.StatusUnauthorized, http.StatusForbidden: + if key.MarkPermanent() { + logger.Error(ctx, "key marked permanent", + slog.F("provider", providerName), + slog.F("api_key_hint", utils.MaskSecret(key.Value())), + slog.F("status", statusCode)) + } + return true + default: + logger.Debug(ctx, "status is not a key failover trigger", + slog.F("provider", providerName), + slog.F("status", statusCode)) + return false + } +} diff --git a/aibridge/keypool/keymark_test.go b/aibridge/keypool/keymark_test.go new file mode 100644 index 0000000000000..07072228b53c8 --- /dev/null +++ b/aibridge/keypool/keymark_test.go @@ -0,0 +1,125 @@ +package keypool_test + +import ( + "context" + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/aibridge/keypool" + "github.com/coder/quartz" +) + +func TestMarkKeyOnStatus(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + statusCode int + headers map[string]string + expectedReturn bool + expectedState keypool.KeyState + expectedCooldown time.Duration + }{ + { + // 429 with standard Retry-After header (seconds). + name: "429_with_retry_after_seconds", + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"Retry-After": "5"}, + expectedReturn: true, + expectedState: keypool.KeyStateTemporary, + expectedCooldown: 5 * time.Second, + }, + { + // 429 with retry-after-ms header (milliseconds). + name: "429_with_retry_after_ms", + statusCode: http.StatusTooManyRequests, + headers: map[string]string{"retry-after-ms": "1500"}, + expectedReturn: true, + expectedState: keypool.KeyStateTemporary, + expectedCooldown: 1500 * time.Millisecond, + }, + { + // 429 without headers falls back to default cooldown. + name: "429_no_headers_uses_default", + statusCode: http.StatusTooManyRequests, + expectedReturn: true, + expectedState: keypool.KeyStateTemporary, + expectedCooldown: 60 * time.Second, + }, + { + name: "401_marks_permanent", + statusCode: http.StatusUnauthorized, + expectedReturn: true, + expectedState: keypool.KeyStatePermanent, + }, + { + name: "403_marks_permanent", + statusCode: http.StatusForbidden, + expectedReturn: true, + expectedState: keypool.KeyStatePermanent, + }, + { + name: "200_does_not_mark", + statusCode: http.StatusOK, + expectedReturn: false, + expectedState: keypool.KeyStateValid, + }, + { + name: "500_does_not_mark", + statusCode: http.StatusInternalServerError, + expectedReturn: false, + expectedState: keypool.KeyStateValid, + }, + { + // 529 is the Anthropic overloaded status, handled by + // the circuit breaker, not key failover. + name: "529_does_not_mark", + statusCode: 529, + expectedReturn: false, + expectedState: keypool.KeyStateValid, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + clk := quartz.NewMock(t) + pool, err := keypool.New([]string{"key-0"}, clk) + require.NoError(t, err) + key, err := pool.Walker().Next() + require.NoError(t, err) + + resp := &http.Response{Header: make(http.Header)} + for k, v := range tc.headers { + resp.Header.Set(k, v) + } + + got := keypool.MarkKeyOnStatus( + context.Background(), + key, + tc.statusCode, + resp, + // 401 and 403 cases legitimately log at error + // level when marking a key permanent. + slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + "test", + ) + + assert.Equal(t, tc.expectedReturn, got) + assert.Equal(t, tc.expectedState, key.State()) + + // Verify cooldown was set to the expected duration: + // advancing by exactly that amount returns the key + // to valid. + if tc.expectedCooldown > 0 { + clk.Advance(tc.expectedCooldown) + assert.Equal(t, keypool.KeyStateValid, key.State()) + } + }) + } +} diff --git a/aibridge/keypool/keypool.go b/aibridge/keypool/keypool.go index 02fd980027c40..e6221209c4b16 100644 --- a/aibridge/keypool/keypool.go +++ b/aibridge/keypool/keypool.go @@ -1,6 +1,7 @@ package keypool import ( + "fmt" "sync" "time" @@ -15,11 +16,24 @@ var ( // ErrDuplicateKey is returned when the input contains // duplicate key values. ErrDuplicateKey = xerrors.New("duplicate key") - // ErrAllKeysExhausted is returned when the walker has visited - // every key in the pool and none are available. - ErrAllKeysExhausted = xerrors.New("all keys exhausted") ) +// ErrPermanentExhaustion is returned when every key in the +// pool has been permanently marked unavailable. +var ErrPermanentExhaustion = xerrors.New("all keys permanently unavailable") + +// TransientExhaustionError is returned when no key is currently +// available but at least one will recover. RetryAfter is the +// soonest remaining cooldown across the pool, or 0 if a key +// just became valid mid-walk. +type TransientExhaustionError struct { + RetryAfter time.Duration +} + +func (e *TransientExhaustionError) Error() string { + return fmt.Sprintf("all keys exhausted (retry after %s)", e.RetryAfter) +} + // KeyState represents the current state of a key in the pool. type KeyState int @@ -101,6 +115,23 @@ func (k *Key) State() KeyState { return KeyStateValid } +// remainingCooldown returns the duration until the key's +// cooldown expires for temporary keys. Returns 0 for keys +// that are valid (no active cooldown) or permanent. +func (k *Key) remainingCooldown() time.Duration { + k.mu.RLock() + defer k.mu.RUnlock() + + if k.permanent { + return 0 + } + now := k.clock.Now() + if now.Before(k.cooldownUntil) { + return k.cooldownUntil.Sub(now) + } + return 0 +} + // MarkTemporary marks the key as temporarily unavailable with // the specified cooldown duration. Returns true if this call // transitions the key to temporary. @@ -146,6 +177,47 @@ func (k *Key) MarkPermanent() bool { return true } +// exhaustionError returns ErrPermanentExhaustion if every key +// is permanently unavailable, or *TransientExhaustionError if +// at least one key is temporarily unavailable. When multiple +// keys are temporary, the smallest remaining cooldown is used +// as the retry-after. +func (p *Pool) exhaustionError() error { + var retryAfter time.Duration + var hasCooldown bool + for i := range p.keys { + switch p.keys[i].State() { + // Recoverable now: signal transient with zero retry-after. + case KeyStateValid: + return &TransientExhaustionError{} + // Recoverable later: track soonest remaining cooldown. + case KeyStateTemporary: + cooldown := p.keys[i].remainingCooldown() + if !hasCooldown || cooldown < retryAfter { + retryAfter = cooldown + hasCooldown = true + } + // Permanent: keep walking to confirm error type. + default: + } + } + if hasCooldown { + return &TransientExhaustionError{RetryAfter: retryAfter} + } + return ErrPermanentExhaustion +} + +// PoolState returns a snapshot of each key's state in the pool's +// original order. The result reflects the state at call time and +// is not updated after. Use Walker for the failover iteration path. +func (p *Pool) PoolState() []KeyState { + states := make([]KeyState, len(p.keys)) + for i := range p.keys { + states[i] = p.keys[i].State() + } + return states +} + // Walker traverses a Pool for a single request. Each request // creates its own walker so that it can independently iterate // through keys without interfering with other requests. @@ -162,14 +234,15 @@ func (p *Pool) Walker() *Walker { return &Walker{pool: p, pos: 0} } -// Next returns a Key handle for the next available key. This is -// a read-only operation; it does not modify the pool state. +// Next returns a Key handle for the next available key without +// modifying the pool state. // -// Returns ErrAllKeysExhausted when no more keys are available. +// Returns *TransientExhaustionError or ErrPermanentExhaustion +// when no more keys are available. func (w *Walker) Next() (*Key, error) { pool := w.pool if pool == nil { - return nil, ErrAllKeysExhausted + return nil, ErrPermanentExhaustion } for i := w.pos; i < len(pool.keys); i++ { @@ -183,5 +256,5 @@ func (w *Walker) Next() (*Key, error) { } // No keys available. - return nil, ErrAllKeysExhausted + return nil, pool.exhaustionError() } diff --git a/aibridge/keypool/keypool_test.go b/aibridge/keypool/keypool_test.go index 7fa9790bc4f10..bb544b7345248 100644 --- a/aibridge/keypool/keypool_test.go +++ b/aibridge/keypool/keypool_test.go @@ -1,6 +1,8 @@ package keypool_test import ( + "errors" + "sync" "testing" "time" @@ -49,7 +51,8 @@ func TestNewKeyPool(t *testing.T) { // No more keys available. _, err = walker.Next() - require.ErrorIs(t, err, keypool.ErrAllKeysExhausted) + var transient *keypool.TransientExhaustionError + require.ErrorAs(t, err, &transient, "expected transient exhaustion: walker returned all valid keys, none marked permanent") }) } } @@ -282,19 +285,21 @@ func TestWalkerNext(t *testing.T) { t.Parallel() tests := []struct { - name string - keys []string - setup func(t *testing.T, pool *keypool.Pool) - advance time.Duration - expectValid []string + name string + keys []string + setup func(t *testing.T, pool *keypool.Pool) + advance time.Duration + expectedValid []string + expectedErr error }{ { // Given: key-0: valid, key-1: valid, key-2: valid. // Then: key-0: valid, key-1: valid, key-2: valid. - name: "all_keys_valid", - keys: []string{"key-0", "key-1", "key-2"}, - setup: func(_ *testing.T, _ *keypool.Pool) {}, - expectValid: []string{"key-0", "key-1", "key-2"}, + name: "all_keys_valid", + keys: []string{"key-0", "key-1", "key-2"}, + setup: func(_ *testing.T, _ *keypool.Pool) {}, + expectedValid: []string{"key-0", "key-1", "key-2"}, + expectedErr: &keypool.TransientExhaustionError{}, }, { // Given: key-0: temporary, key-1: valid, key-2: valid. @@ -306,7 +311,8 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key.MarkTemporary(60 * time.Second) }, - expectValid: []string{"key-1", "key-2"}, + expectedValid: []string{"key-1", "key-2"}, + expectedErr: &keypool.TransientExhaustionError{}, }, { // Given: key-0: permanent, key-1: permanent, key-2: valid. @@ -322,7 +328,8 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key1.MarkPermanent() }, - expectValid: []string{"key-2"}, + expectedValid: []string{"key-2"}, + expectedErr: &keypool.TransientExhaustionError{}, }, { // Given: key-0: temporary (30s), key-1: valid. @@ -335,8 +342,9 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key.MarkTemporary(30 * time.Second) }, - advance: 35 * time.Second, - expectValid: []string{"key-0", "key-1"}, + advance: 35 * time.Second, + expectedValid: []string{"key-0", "key-1"}, + expectedErr: &keypool.TransientExhaustionError{}, }, { // Given: key-0: temporary (zero, default 60s), key-1: valid. @@ -349,8 +357,9 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key.MarkTemporary(0) }, - advance: 50 * time.Second, - expectValid: []string{"key-1"}, + advance: 50 * time.Second, + expectedValid: []string{"key-1"}, + expectedErr: &keypool.TransientExhaustionError{}, }, { // Given: key-0: temporary (zero, default 60s), key-1: valid. @@ -363,8 +372,9 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key.MarkTemporary(0) }, - advance: 65 * time.Second, - expectValid: []string{"key-0", "key-1"}, + advance: 65 * time.Second, + expectedValid: []string{"key-0", "key-1"}, + expectedErr: &keypool.TransientExhaustionError{}, }, { // Given: key-0: temporary (negative, default 60s), key-1: valid. @@ -377,13 +387,14 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key.MarkTemporary(-10 * time.Second) }, - advance: 65 * time.Second, - expectValid: []string{"key-0", "key-1"}, + advance: 65 * time.Second, + expectedValid: []string{"key-0", "key-1"}, + expectedErr: &keypool.TransientExhaustionError{}, }, { // Given: key-0: temporary (60s), then marked again with shorter cooldown (10s). // When: 15s pass (past 10s, but not 60s). - // Then: key-0: temporary. + // Then: key-0: temporary, 45s remaining. name: "shorter_cooldown_preserves_longer_not_expired", keys: []string{"key-0"}, setup: func(t *testing.T, pool *keypool.Pool) { @@ -392,8 +403,9 @@ func TestWalkerNext(t *testing.T) { key.MarkTemporary(60 * time.Second) key.MarkTemporary(10 * time.Second) }, - advance: 15 * time.Second, - expectValid: []string{}, + advance: 15 * time.Second, + expectedValid: []string{}, + expectedErr: &keypool.TransientExhaustionError{RetryAfter: 45 * time.Second}, }, { // Given: key-0: temporary (60s), then marked again with shorter cooldown (10s). @@ -407,8 +419,30 @@ func TestWalkerNext(t *testing.T) { key.MarkTemporary(60 * time.Second) key.MarkTemporary(10 * time.Second) }, - advance: 65 * time.Second, - expectValid: []string{"key-0"}, + advance: 65 * time.Second, + expectedValid: []string{"key-0"}, + expectedErr: &keypool.TransientExhaustionError{}, + }, + { + // Given: key-0: temporary (60s), key-1: temporary (10s), key-2: temporary (30s). + // Then: key-0: temporary, key-1: temporary, key-2: temporary. + // Smallest remaining cooldown is reported on exhaustion. + name: "smallest_cooldown_across_temporary_keys", + keys: []string{"key-0", "key-1", "key-2"}, + setup: func(t *testing.T, pool *keypool.Pool) { + walker := pool.Walker() + key0, err := walker.Next() + require.NoError(t, err) + key0.MarkTemporary(60 * time.Second) + key1, err := walker.Next() + require.NoError(t, err) + key1.MarkTemporary(10 * time.Second) + key2, err := walker.Next() + require.NoError(t, err) + key2.MarkTemporary(30 * time.Second) + }, + expectedValid: []string{}, + expectedErr: &keypool.TransientExhaustionError{RetryAfter: 10 * time.Second}, }, { // Given: key-0: temporary, key-1: temporary. @@ -424,7 +458,8 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key1.MarkTemporary(60 * time.Second) }, - expectValid: []string{}, + expectedValid: []string{}, + expectedErr: &keypool.TransientExhaustionError{RetryAfter: 60 * time.Second}, }, { // Given: key-0: permanent, key-1: permanent. @@ -440,7 +475,8 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key1.MarkPermanent() }, - expectValid: []string{}, + expectedValid: []string{}, + expectedErr: keypool.ErrPermanentExhaustion, }, { // Given: key-0: permanent, key-1: temporary, key-2: permanent. @@ -459,7 +495,8 @@ func TestWalkerNext(t *testing.T) { require.NoError(t, err) key2.MarkPermanent() }, - expectValid: []string{}, + expectedValid: []string{}, + expectedErr: &keypool.TransientExhaustionError{RetryAfter: 60 * time.Second}, }, } @@ -478,7 +515,7 @@ func TestWalkerNext(t *testing.T) { } walker := pool.Walker() - for _, expectedKey := range tc.expectValid { + for _, expectedKey := range tc.expectedValid { key, err := walker.Next() require.NoError(t, err) assert.Equal(t, expectedKey, key.Value()) @@ -486,7 +523,93 @@ func TestWalkerNext(t *testing.T) { // After all expected keys, the walker should be exhausted. _, err = walker.Next() - require.ErrorIs(t, err, keypool.ErrAllKeysExhausted) + var wantTransient *keypool.TransientExhaustionError + if errors.As(tc.expectedErr, &wantTransient) { + var got *keypool.TransientExhaustionError + require.ErrorAs(t, err, &got) + assert.Equal(t, wantTransient.RetryAfter, got.RetryAfter) + } else { + require.ErrorIs(t, err, tc.expectedErr) + } + }) + } +} + +// TestKeyConcurrent exercises the documented concurrent-safety +// contract by hammering a single key with concurrent Mark calls +// and asserting the resulting state honors the pool's invariants. +func TestKeyConcurrent(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + // run is called concurrently from numGoroutines, each + // with its own index. + run func(idx int, key *keypool.Key) + // verify asserts the final state. May advance the clock. + verify func(t *testing.T, key *keypool.Key, clk *quartz.Mock) + }{ + { + // Half of the goroutines mark the key as temporary + // with 60s, the other half with 10s. The longer + // cooldown must win regardless of ordering. + name: "longer_cooldown_wins", + run: func(idx int, key *keypool.Key) { + if idx%2 == 0 { + key.MarkTemporary(60 * time.Second) + } else { + key.MarkTemporary(10 * time.Second) + } + }, + verify: func(t *testing.T, key *keypool.Key, clk *quartz.Mock) { + // At 50s the 60s cooldown is still active. + clk.Advance(50 * time.Second) + assert.Equal(t, keypool.KeyStateTemporary, key.State()) + // At 65s the 60s cooldown has expired. + clk.Advance(15 * time.Second) + assert.Equal(t, keypool.KeyStateValid, key.State()) + }, + }, + { + // Half of the goroutines mark the key as permanent, + // the other half mark it as temporary. Permanent is + // terminal: any permanent call wins. + name: "permanent_wins_over_temporary", + run: func(idx int, key *keypool.Key) { + if idx%2 == 0 { + key.MarkPermanent() + } else { + key.MarkTemporary(60 * time.Second) + } + }, + verify: func(t *testing.T, key *keypool.Key, _ *quartz.Mock) { + assert.Equal(t, keypool.KeyStatePermanent, key.State()) + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + clk := quartz.NewMock(t) + pool, err := keypool.New([]string{"key-0"}, clk) + require.NoError(t, err) + key, err := pool.Walker().Next() + require.NoError(t, err) + + const numGoroutines = 10 + var wg sync.WaitGroup + for r := range numGoroutines { + wg.Add(1) + go func(r int) { + defer wg.Done() + tc.run(r, key) + }(r) + } + wg.Wait() + + tc.verify(t, key, clk) }) } } diff --git a/aibridge/provider/anthropic.go b/aibridge/provider/anthropic.go index 269c669a16b5f..48f376fbce9f3 100644 --- a/aibridge/provider/anthropic.go +++ b/aibridge/provider/anthropic.go @@ -15,8 +15,10 @@ import ( "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/intercept/messages" + "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/tracing" "github.com/coder/coder/v2/aibridge/utils" + "github.com/coder/quartz" ) // anthropicForwardHeaders lists headers from incoming requests that should be @@ -55,6 +57,24 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi if cfg.BaseURL == "" { cfg.BaseURL = "https://api.anthropic.com/" } + // Resolve centralized key configuration into KeyPool. + // Precedence: + // 1. cfg.KeyPool (explicit, highest priority). + // 2. cfg.Key (legacy single key). + // After this block cfg.Key is empty so it can only carry a + // BYOK X-Api-Key set per interception in CreateInterceptor. + // TODO(ssncferreira): simplify auth field resolution per + // https://github.com/coder/aibridge/issues/266. + if cfg.KeyPool == nil && cfg.Key != "" { + // keypool.New only fails on empty or duplicate keys, + // neither possible with a single non-empty key. + pool, err := keypool.New([]string{cfg.Key}, quartz.NewReal()) + if err != nil { + panic(fmt.Sprintf("anthropic provider: build single-key pool: %s", err)) + } + cfg.KeyPool = pool + } + cfg.Key = "" if cfg.CircuitBreaker != nil { cfg.CircuitBreaker.IsFailure = anthropicIsFailure cfg.CircuitBreaker.OpenErrorResponse = anthropicOpenErrorResponse @@ -119,29 +139,41 @@ func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tr // Any Coder-specific authentication has already been stripped. // // In centralized mode neither Authorization nor X-Api-Key is - // present, so cfg keeps the centralized key unchanged. + // present, so cfg keeps the KeyPool from provider construction + // and the failover loop walks it. // - // In BYOK mode the user's LLM credentials survive intact. - // If X-Api-Key is present the user has a personal API key; - // overwrite the centralized key with it. If Authorization is - // present the user authenticated directly with provider; - // set BYOKBearerToken and clear the centralized key. - // When both are present, X-Api-Key takes priority to match - // claude-code behavior. + // In BYOK mode the user's LLM credentials survive intact and + // failover is disabled by clearing cfg.KeyPool. If X-Api-Key is + // present the user has a personal API key, populate cfg.Key. + // If Authorization is present the user authenticated directly + // with the provider, populate cfg.BYOKBearerToken. When both + // are present, X-Api-Key takes priority to match claude-code + // behavior. + // + // TODO(ssncferreira): consolidate auth field handling per + // https://github.com/coder/aibridge/issues/266. credKind := intercept.CredentialKindCentralized - credSecret := cfg.Key + var credSecret string authHeaderName := p.AuthHeader() if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" { cfg.Key = apiKey + cfg.KeyPool = nil authHeaderName = "X-Api-Key" credKind = intercept.CredentialKindBYOK credSecret = apiKey } else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" { cfg.BYOKBearerToken = token - cfg.Key = "" + cfg.KeyPool = nil authHeaderName = "Authorization" credKind = intercept.CredentialKindBYOK credSecret = token + } else if cfg.KeyPool != nil { + // Centralized: use the first key as a placeholder hint. + // TODO(ssncferreira): record the actually-used key in + // the interception record to reflect failover. + if k, err := cfg.KeyPool.Walker().Next(); err == nil { + credSecret = k.Value() + } } cred := intercept.NewCredentialInfo(credKind, credSecret) @@ -175,7 +207,16 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) { return } - headers.Set(p.AuthHeader(), p.cfg.Key) + // Centralized: pull a single key from the pool. No failover + // or exhaustion handling here. + // TODO(ssncferreira): replace with RoundTripper-based auth + // in the upstack passthrough PR. + if p.cfg.KeyPool == nil { + return + } + if key, err := p.cfg.KeyPool.Walker().Next(); err == nil { + headers.Set(p.AuthHeader(), key.Value()) + } } func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker { diff --git a/aibridge/provider/anthropic_test.go b/aibridge/provider/anthropic_test.go index a4ea0c21e2433..7ebf4495d9b3d 100644 --- a/aibridge/provider/anthropic_test.go +++ b/aibridge/provider/anthropic_test.go @@ -13,6 +13,8 @@ import ( "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/internal/testutil" + "github.com/coder/coder/v2/aibridge/keypool" + "github.com/coder/quartz" ) func TestAnthropic_TypeAndName(t *testing.T) { @@ -49,6 +51,70 @@ func TestAnthropic_TypeAndName(t *testing.T) { } } +func TestNewAnthropic_KeyResolution(t *testing.T) { + t.Parallel() + + pool, err := keypool.New([]string{"pool-key-0", "pool-key-1"}, quartz.NewMock(t)) + require.NoError(t, err) + + tests := []struct { + name string + cfg config.Anthropic + expectedKeys []string + }{ + { + // Legacy single-key path: NewAnthropic builds a + // pool containing just that key. + name: "key_creates_keypool", + cfg: config.Anthropic{Key: "legacy-key"}, + expectedKeys: []string{"legacy-key"}, + }, + { + // Caller supplies the pool directly. + name: "keypool_passed_directly", + cfg: config.Anthropic{KeyPool: pool}, + expectedKeys: []string{"pool-key-0", "pool-key-1"}, + }, + { + // Both set: KeyPool wins, Key is ignored. + name: "keypool_takes_precedence_over_key", + cfg: config.Anthropic{Key: "legacy-key", KeyPool: pool}, + expectedKeys: []string{"pool-key-0", "pool-key-1"}, + }, + { + // Neither set: no centralized auth available. BYOK + // auth is set per-request in CreateInterceptor. + name: "neither_set_no_centralized_auth", + cfg: config.Anthropic{}, + expectedKeys: nil, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + p := NewAnthropic(tc.cfg, nil) + + if tc.expectedKeys == nil { + assert.Nil(t, p.cfg.KeyPool, "expected no KeyPool") + return + } + + require.NotNil(t, p.cfg.KeyPool) + walker := p.cfg.KeyPool.Walker() + var got []string + for { + key, err := walker.Next() + if err != nil { + break + } + got = append(got, key.Value()) + } + assert.Equal(t, tc.expectedKeys, got) + }) + } +} + func TestAnthropic_CreateInterceptor(t *testing.T) { t.Parallel() diff --git a/enterprise/cli/aibridged.go b/enterprise/cli/aibridged.go index 8c02db7d55bf2..87f32f5d54686 100644 --- a/enterprise/cli/aibridged.go +++ b/enterprise/cli/aibridged.go @@ -10,10 +10,12 @@ import ( "github.com/coder/coder/v2/aibridge" "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/coderd/tracing" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/enterprise/aibridged" "github.com/coder/coder/v2/enterprise/coderd" + "github.com/coder/quartz" ) func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*aibridged.Server, error) { @@ -88,16 +90,24 @@ func buildProviders(cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) { } // Add legacy Anthropic provider if configured. Bedrock credentials - // alone are sufficient — an Anthropic API key is not required when + // alone are sufficient, an Anthropic API key is not required when // using AWS Bedrock. if cfg.LegacyAnthropic.Key.String() != "" || getBedrockConfig(cfg.LegacyBedrock) != nil { if _, conflict := usedNames[aibridge.ProviderAnthropic]; conflict { return nil, xerrors.Errorf("legacy CODER_AIBRIDGE_ANTHROPIC_KEY conflicts with indexed provider named %q; remove one or the other", aibridge.ProviderAnthropic) } + var pool *keypool.Pool + if key := cfg.LegacyAnthropic.Key.String(); key != "" { + var err error + pool, err = keypool.New([]string{key}, quartz.NewReal()) + if err != nil { + return nil, xerrors.Errorf("create legacy anthropic key pool: %w", err) + } + } providers = append(providers, aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{ Name: aibridge.ProviderAnthropic, BaseURL: cfg.LegacyAnthropic.BaseURL.String(), - Key: cfg.LegacyAnthropic.Key.String(), + KeyPool: pool, CircuitBreaker: cbConfig, SendActorHeaders: cfg.SendActorHeaders.Value(), }, getBedrockConfig(cfg.LegacyBedrock))) @@ -110,14 +120,13 @@ func buildProviders(cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) { if name == "" { name = p.Type } - // Currently, only the first key is used, if any. - // TODO(ssncferreira): pass a keypool.Pool instead. - var key string - if len(p.Keys) > 0 { - key = p.Keys[0] - } switch p.Type { case aibridge.ProviderOpenAI: + // TODO(ssncferreira): pass a keypool.Pool instead. + var key string + if len(p.Keys) > 0 { + key = p.Keys[0] + } providers = append(providers, aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{ Name: name, BaseURL: p.BaseURL, @@ -127,10 +136,18 @@ func buildProviders(cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) { SendActorHeaders: cfg.SendActorHeaders.Value(), })) case aibridge.ProviderAnthropic: + var pool *keypool.Pool + if len(p.Keys) > 0 { + var err error + pool, err = keypool.New(p.Keys, quartz.NewReal()) + if err != nil { + return nil, xerrors.Errorf("create anthropic key pool for provider %q: %w", name, err) + } + } providers = append(providers, aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{ Name: name, BaseURL: p.BaseURL, - Key: key, + KeyPool: pool, APIDumpDir: p.DumpDir, CircuitBreaker: cbConfig, SendActorHeaders: cfg.SendActorHeaders.Value(), From 669e1af0b43e0c1c41a6386ec278f4da39496dae Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 11:49:42 +0000 Subject: [PATCH 02/14] fix: avoid silent 200 on unknown pre-stream errors --- aibridge/intercept/messages/streaming.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aibridge/intercept/messages/streaming.go b/aibridge/intercept/messages/streaming.go index 0f3b605e21f51..4f0bf20a4b41c 100644 --- a/aibridge/intercept/messages/streaming.go +++ b/aibridge/intercept/messages/streaming.go @@ -551,8 +551,11 @@ newStream: if currentKey != nil && i.markKeyOnError(ctx, currentKey, stream.Err()) { continue newStream } - // Non-key error: relay it. - respErr := getErrorResponse(stream.Err()) + // Non-key error: relay it. Use mapStreamError so that + // unknown upstream errors (TCP reset, DNS failure, TLS + // error, deadline exceeded) are wrapped in a generic + // response instead of producing a silent HTTP 200. + respErr := i.mapStreamError(ctx, logger, stream.Err(), lastErr) if respErr != nil { interceptionErr = respErr if events.IsStreaming() { From e2d7ae64c01ca9faacd75df96a85a5612e0145b3 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 17:05:05 +0000 Subject: [PATCH 03/14] fix: set 502 status on unknown stream error responses --- aibridge/intercept/messages/streaming.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aibridge/intercept/messages/streaming.go b/aibridge/intercept/messages/streaming.go index 4f0bf20a4b41c..eada7a0a613ac 100644 --- a/aibridge/intercept/messages/streaming.go +++ b/aibridge/intercept/messages/streaming.go @@ -614,11 +614,11 @@ func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Lo // into known types (i.e. [shared.OverloadedError]). // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 // All it does is wrap the payload in an error - which is all we can return, currently. - return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), 0, 0) + return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0) } if lastErr != nil { logger.Warn(ctx, "stream processing failed", slog.Error(lastErr)) - return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), 0, 0) + return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0) } return nil } From 51b7b532ad69f78b9e44cca7b3d1aebc050a5fec Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 17:23:32 +0000 Subject: [PATCH 04/14] fix: read key state and cooldown under one lock --- aibridge/keypool/keypool.go | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/aibridge/keypool/keypool.go b/aibridge/keypool/keypool.go index e6221209c4b16..1fefd691693d9 100644 --- a/aibridge/keypool/keypool.go +++ b/aibridge/keypool/keypool.go @@ -115,21 +115,20 @@ func (k *Key) State() KeyState { return KeyStateValid } -// remainingCooldown returns the duration until the key's -// cooldown expires for temporary keys. Returns 0 for keys -// that are valid (no active cooldown) or permanent. -func (k *Key) remainingCooldown() time.Duration { +// stateAndCooldown returns the key's state and remaining +// cooldown as a single atomic snapshot. +func (k *Key) stateAndCooldown() (KeyState, time.Duration) { k.mu.RLock() defer k.mu.RUnlock() if k.permanent { - return 0 + return KeyStatePermanent, 0 } now := k.clock.Now() if now.Before(k.cooldownUntil) { - return k.cooldownUntil.Sub(now) + return KeyStateTemporary, k.cooldownUntil.Sub(now) } - return 0 + return KeyStateValid, 0 } // MarkTemporary marks the key as temporarily unavailable with @@ -186,13 +185,13 @@ func (p *Pool) exhaustionError() error { var retryAfter time.Duration var hasCooldown bool for i := range p.keys { - switch p.keys[i].State() { + state, cooldown := p.keys[i].stateAndCooldown() + switch state { // Recoverable now: signal transient with zero retry-after. case KeyStateValid: return &TransientExhaustionError{} // Recoverable later: track soonest remaining cooldown. case KeyStateTemporary: - cooldown := p.keys[i].remainingCooldown() if !hasCooldown || cooldown < retryAfter { retryAfter = cooldown hasCooldown = true From 45bdb2318248537eb6670dd4e79c8ee490f37c43 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 17:34:40 +0000 Subject: [PATCH 05/14] test: assert Retry-After and error envelope shape --- aibridge/intercept/messages/base_test.go | 1 + aibridge/intercept/messages/blocking_test.go | 3 +++ 2 files changed, 4 insertions(+) diff --git a/aibridge/intercept/messages/base_test.go b/aibridge/intercept/messages/base_test.go index 9d3c388de8084..c296ad95d3a12 100644 --- a/aibridge/intercept/messages/base_test.go +++ b/aibridge/intercept/messages/base_test.go @@ -1190,6 +1190,7 @@ func TestWriteUpstreamError(t *testing.T) { assert.Equal(t, tc.expectStatus, w.Code, "status code") assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type header") assert.Equal(t, tc.expectRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + assert.Contains(t, w.Body.String(), `"type":"error"`, "outer error envelope") if tc.expectBodyContains != "" { assert.Contains(t, w.Body.String(), tc.expectBodyContains, "response body") } diff --git a/aibridge/intercept/messages/blocking_test.go b/aibridge/intercept/messages/blocking_test.go index 51d77fb0d0e15..b0fd608281024 100644 --- a/aibridge/intercept/messages/blocking_test.go +++ b/aibridge/intercept/messages/blocking_test.go @@ -292,6 +292,7 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { expectedRequestCount int32 expectedSeenKeys []string expectedStatusCode int + expectedRetryAfter string expectedKeyStates []keypool.KeyState }{ { @@ -355,6 +356,7 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { expectedRequestCount: 3, expectedSeenKeys: []string{"k0", "k0", "k1"}, expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "3", expectedKeyStates: []keypool.KeyState{ keypool.KeyStateTemporary, keypool.KeyStateTemporary, @@ -442,6 +444,7 @@ func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") + assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") seenKeysMu.Lock() defer seenKeysMu.Unlock() From f3dbcc0b434f9c22ff75557440801acce0a24bfc Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 17:51:40 +0000 Subject: [PATCH 06/14] fix: propagate upstream Retry-After on BYOK errors --- aibridge/intercept/messages/base.go | 2 +- aibridge/intercept/messages/blocking_test.go | 4 +++- aibridge/intercept/messages/streaming_test.go | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index 387228220ef58..b5f4c8c34223b 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -581,7 +581,7 @@ func getErrorResponse(err error) *responseError { errType = string(detail.Type) } - return newErrorResponse(msg, errType, apierr.StatusCode, 0) + return newErrorResponse(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response)) } var _ error = &responseError{} diff --git a/aibridge/intercept/messages/blocking_test.go b/aibridge/intercept/messages/blocking_test.go index b0fd608281024..475ad607f42e2 100644 --- a/aibridge/intercept/messages/blocking_test.go +++ b/aibridge/intercept/messages/blocking_test.go @@ -184,7 +184,8 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { }, { // Given: BYOK with a single key returning 429. - // Then: 1 request, 429 response, no failover. + // Then: 1 request, 429 response, no failover, upstream + // Retry-After propagated to the client. name: "byok_no_failover", byokKey: "user-byok", responses: map[string]upstreamResponse{ @@ -202,6 +203,7 @@ func TestBlockingInterception_KeyFailover(t *testing.T) { }, expectedRequestCount: 1, expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "5", }, } diff --git a/aibridge/intercept/messages/streaming_test.go b/aibridge/intercept/messages/streaming_test.go index bbcbd9f187053..6f573b380846a 100644 --- a/aibridge/intercept/messages/streaming_test.go +++ b/aibridge/intercept/messages/streaming_test.go @@ -209,7 +209,8 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { }, { // Given: BYOK with a single key returning 429. - // Then: 1 request, 429 response, no failover. + // Then: 1 request, 429 response, no failover, upstream + // Retry-After propagated to the client. name: "byok_no_failover", byokKey: "user-byok", responses: map[string]upstreamResponse{ @@ -227,6 +228,7 @@ func TestStreamingInterception_KeyFailover(t *testing.T) { }, expectedRequestCount: 1, expectedStatusCode: http.StatusTooManyRequests, + expectedRetryAfter: "5", }, } From d443e50de9438dcaf145beebaaaf3c8ba7fdd56e Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 17:56:14 +0000 Subject: [PATCH 07/14] fix: log actual cooldown applied to temporary keys --- aibridge/keypool/keymark.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aibridge/keypool/keymark.go b/aibridge/keypool/keymark.go index 8ce6e50813086..5b5fd37985734 100644 --- a/aibridge/keypool/keymark.go +++ b/aibridge/keypool/keymark.go @@ -23,6 +23,9 @@ func MarkKeyOnStatus( switch statusCode { case http.StatusTooManyRequests: cooldown := ParseRetryAfter(resp) + if cooldown <= 0 { + cooldown = defaultCooldown + } if key.MarkTemporary(cooldown) { logger.Warn(ctx, "key marked temporary", slog.F("provider", providerName), From 20def6044d81dac56dbe0fab0202a8e41f959d8d Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 18:13:16 +0000 Subject: [PATCH 08/14] chore: flag missing auth validation --- aibridge/intercept/messages/base.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index b5f4c8c34223b..08ee1213d4f62 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -209,6 +209,9 @@ func (i *interceptionBase) isSmallFastModel() bool { // calls. BYOK auth is set here. Centralized auth is set // per-attempt by the failover loop. func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) { + // TODO(ssncferreira): validate auth is configured per + // https://github.com/coder/aibridge/issues/266. + // BYOK auth. if i.cfg.KeyPool == nil { if i.cfg.BYOKBearerToken != "" { From d5062ce885cfc9c01a423915f9f7ea7bca482370 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 18:41:48 +0000 Subject: [PATCH 09/14] fix: emit one upstream span per attempt in blocking --- aibridge/intercept/messages/blocking.go | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/aibridge/intercept/messages/blocking.go b/aibridge/intercept/messages/blocking.go index 7251d0c07c98a..9deb8ca43b15a 100644 --- a/aibridge/intercept/messages/blocking.go +++ b/aibridge/intercept/messages/blocking.go @@ -341,17 +341,25 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req return nil } -func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (_ *anthropic.Message, outErr error) { - ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) - defer tracing.EndSpanErr(span, &outErr) - +// newMessage routes between BYOK (single attempt) and centralized +// failover. +func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) { // BYOK: single attempt, no failover. if i.cfg.KeyPool == nil { - return svc.New(ctx, anthropic.MessageNewParams{}, i.withBody()) + return i.newMessageWithKey(ctx, svc) } return i.newMessageWithKeyFailover(ctx, svc) } +// newMessageWithKey performs a single upstream call. +func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthropic.MessageService, extraOpts ...option.RequestOption) (_ *anthropic.Message, outErr error) { + _, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...)) + defer tracing.EndSpanErr(span, &outErr) + + opts := append([]option.RequestOption{i.withBody()}, extraOpts...) + return svc.New(ctx, anthropic.MessageNewParams{}, opts...) +} + // newMessageWithKeyFailover walks the centralized key pool, // trying each key until one succeeds or the pool is exhausted. // Keys are marked temporary on 429 and permanent on 401/403. @@ -368,8 +376,7 @@ func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, sv return nil, err } - msg, err := svc.New(ctx, anthropic.MessageNewParams{}, - i.withBody(), + msg, err := i.newMessageWithKey(ctx, svc, option.WithAPIKey(key.Value()), // Disable SDK retries because the failover loop // handles retries via key rotation. From 7966cabceae87f1cbb6d80aecf0023579bc1b770 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 18:51:44 +0000 Subject: [PATCH 10/14] refactor: simplify blocking failover control flow --- aibridge/intercept/messages/blocking.go | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/aibridge/intercept/messages/blocking.go b/aibridge/intercept/messages/blocking.go index 9deb8ca43b15a..ad1a79e8838e5 100644 --- a/aibridge/intercept/messages/blocking.go +++ b/aibridge/intercept/messages/blocking.go @@ -382,14 +382,12 @@ func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, sv // handles retries via key rotation. option.WithMaxRetries(0), ) - if err == nil { - return msg, nil - } - // Mark the key based on the upstream response. - if !i.markKeyOnError(ctx, key, err) { - // Not a key-specific failure: return without - // trying another key. - return nil, err + // Key-specific failure: try the next key. + if i.markKeyOnError(ctx, key, err) { + continue } + // Either success (msg, nil) or a non-key error (nil, err): + // nothing to retry, return as-is. + return msg, err } } From 410391809d2c0ce14f0d224b69afc99cdbbb0915 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 19:06:33 +0000 Subject: [PATCH 11/14] fix: record unrecoverable stream errors --- aibridge/intercept/messages/streaming.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/aibridge/intercept/messages/streaming.go b/aibridge/intercept/messages/streaming.go index eada7a0a613ac..47ebeebf24e77 100644 --- a/aibridge/intercept/messages/streaming.go +++ b/aibridge/intercept/messages/streaming.go @@ -535,7 +535,8 @@ newStream: // Mid-stream error or logical error: events have // already streamed for this iteration, so the // error is relayed as an SSE event. - if respErr := i.mapStreamError(ctx, logger, stream.Err(), lastErr); respErr != nil { + streamErr := stream.Err() + if respErr := i.mapStreamError(ctx, logger, streamErr, lastErr); respErr != nil { interceptionErr = respErr payload, err := i.marshal(respErr) if err != nil { @@ -543,6 +544,11 @@ newStream: } else if err := events.Send(streamCtx, payload); err != nil { logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload)) } + } else if streamErr != nil { + // Unrecoverable (e.g., broken pipe, context + // canceled): can't relay to the client, but record + // the error so it isn't silently swallowed. + interceptionErr = streamErr } } else { // Pre-stream failure of this iteration. For From 83b1d159d19094947b91ac87a69c0dba56286d09 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Tue, 5 May 2026 19:22:07 +0000 Subject: [PATCH 12/14] refactor: simplify MarkKeyOnStatus signature and clarify PoolState use --- aibridge/intercept/messages/base.go | 2 +- aibridge/intercept/messages/base_test.go | 8 ++++---- aibridge/keypool/keymark.go | 11 +++++++---- aibridge/keypool/keymark_test.go | 6 ++++-- aibridge/keypool/keypool.go | 4 ++-- 5 files changed, 18 insertions(+), 13 deletions(-) diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index 08ee1213d4f62..e13b1ee606ef7 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -532,7 +532,7 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, return false } return keypool.MarkKeyOnStatus( - ctx, key, apiErr.StatusCode, apiErr.Response, + ctx, key, apiErr.Response, i.logger, i.providerName, ) } diff --git a/aibridge/intercept/messages/base_test.go b/aibridge/intercept/messages/base_test.go index c296ad95d3a12..dd03815be0223 100644 --- a/aibridge/intercept/messages/base_test.go +++ b/aibridge/intercept/messages/base_test.go @@ -1086,28 +1086,28 @@ func TestMarkKeyOnError(t *testing.T) { { // Rate-limited: temporary cooldown. name: "429_marks_temporary", - err: &anthropic.Error{StatusCode: http.StatusTooManyRequests}, + err: &anthropic.Error{StatusCode: http.StatusTooManyRequests, Response: &http.Response{StatusCode: http.StatusTooManyRequests}}, expectedReturn: true, expectedState: keypool.KeyStateTemporary, }, { // Auth failure: mark permanent. name: "401_marks_permanent", - err: &anthropic.Error{StatusCode: http.StatusUnauthorized}, + err: &anthropic.Error{StatusCode: http.StatusUnauthorized, Response: &http.Response{StatusCode: http.StatusUnauthorized}}, expectedReturn: true, expectedState: keypool.KeyStatePermanent, }, { // Auth forbidden: mark permanent. name: "403_marks_permanent", - err: &anthropic.Error{StatusCode: http.StatusForbidden}, + err: &anthropic.Error{StatusCode: http.StatusForbidden, Response: &http.Response{StatusCode: http.StatusForbidden}}, expectedReturn: true, expectedState: keypool.KeyStatePermanent, }, { // Server errors are not key-specific. name: "500_does_not_mark", - err: &anthropic.Error{StatusCode: http.StatusInternalServerError}, + err: &anthropic.Error{StatusCode: http.StatusInternalServerError, Response: &http.Response{StatusCode: http.StatusInternalServerError}}, expectedReturn: false, expectedState: keypool.KeyStateValid, }, diff --git a/aibridge/keypool/keymark.go b/aibridge/keypool/keymark.go index 5b5fd37985734..0f97543d7bd1a 100644 --- a/aibridge/keypool/keymark.go +++ b/aibridge/keypool/keymark.go @@ -9,17 +9,20 @@ import ( ) // MarkKeyOnStatus marks key based on a key-specific HTTP -// status code (429 for temporary, 401 or 403 for permanent). -// Returns true if the status was a key-specific failover -// trigger so callers can retry with the next key. +// status code from resp (429 for temporary, 401 or 403 for +// permanent). Returns true if the status was a key-specific +// failover trigger so callers can retry with the next key. func MarkKeyOnStatus( ctx context.Context, key *Key, - statusCode int, resp *http.Response, logger slog.Logger, providerName string, ) bool { + if resp == nil { + return false + } + statusCode := resp.StatusCode switch statusCode { case http.StatusTooManyRequests: cooldown := ParseRetryAfter(resp) diff --git a/aibridge/keypool/keymark_test.go b/aibridge/keypool/keymark_test.go index 07072228b53c8..9e4631409d2c7 100644 --- a/aibridge/keypool/keymark_test.go +++ b/aibridge/keypool/keymark_test.go @@ -94,7 +94,10 @@ func TestMarkKeyOnStatus(t *testing.T) { key, err := pool.Walker().Next() require.NoError(t, err) - resp := &http.Response{Header: make(http.Header)} + resp := &http.Response{ + StatusCode: tc.statusCode, + Header: make(http.Header), + } for k, v := range tc.headers { resp.Header.Set(k, v) } @@ -102,7 +105,6 @@ func TestMarkKeyOnStatus(t *testing.T) { got := keypool.MarkKeyOnStatus( context.Background(), key, - tc.statusCode, resp, // 401 and 403 cases legitimately log at error // level when marking a key permanent. diff --git a/aibridge/keypool/keypool.go b/aibridge/keypool/keypool.go index 1fefd691693d9..48a50d2a0efe8 100644 --- a/aibridge/keypool/keypool.go +++ b/aibridge/keypool/keypool.go @@ -207,8 +207,8 @@ func (p *Pool) exhaustionError() error { } // PoolState returns a snapshot of each key's state in the pool's -// original order. The result reflects the state at call time and -// is not updated after. Use Walker for the failover iteration path. +// original order, used by tests and other diagnostic callers. Use +// Walker for the failover iteration path. func (p *Pool) PoolState() []KeyState { states := make([]KeyState, len(p.keys)) for i := range p.keys { From 07054416931232e2e60a92c7661296a1b35e62f1 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Thu, 7 May 2026 10:34:24 +0000 Subject: [PATCH 13/14] refactor: rename keypool exhaustion errors and helpers --- aibridge/intercept/messages/base.go | 16 +++++------- aibridge/intercept/messages/base_test.go | 26 ++++--------------- aibridge/intercept/messages/blocking.go | 2 +- aibridge/intercept/messages/streaming.go | 22 ++++++++-------- aibridge/keypool/keypool.go | 28 ++++++++++----------- aibridge/keypool/keypool_test.go | 32 ++++++++++++------------ 6 files changed, 52 insertions(+), 74 deletions(-) diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index e13b1ee606ef7..14c276370e2a5 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -537,15 +537,11 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, ) } -// For centralized requests, mapExhaustionError translates a -// keypool exhaustion error into a developer-facing responseError -// shaped for the Anthropic API. Returns nil if err is not an -// exhaustion error. -func (i *interceptionBase) mapExhaustionError(err error) *responseError { - if i.cfg.KeyPool == nil { - return nil - } - var transient *keypool.TransientExhaustionError +// processKeyPoolError translates a keypool exhaustion error +// into a developer-facing responseError shaped for the Anthropic +// API. Returns nil if err is not an exhaustion error. +func processKeyPoolError(err error) *responseError { + var transient *keypool.TransientKeyPoolError switch { case errors.As(err, &transient): return newErrorResponse( @@ -554,7 +550,7 @@ func (i *interceptionBase) mapExhaustionError(err error) *responseError { http.StatusTooManyRequests, transient.RetryAfter, ) - case errors.Is(err, keypool.ErrPermanentExhaustion): + case errors.Is(err, keypool.ErrPermanentKeyPool): return newErrorResponse( "all configured keys failed authentication", string(constant.ValueOf[constant.APIError]()), diff --git a/aibridge/intercept/messages/base_test.go b/aibridge/intercept/messages/base_test.go index dd03815be0223..a6accceeaa00b 100644 --- a/aibridge/intercept/messages/base_test.go +++ b/aibridge/intercept/messages/base_test.go @@ -998,42 +998,34 @@ func TestFilterBedrockBetaFlags(t *testing.T) { } } -func TestMapExhaustionError(t *testing.T) { +func TestProcessKeyPoolError(t *testing.T) { t.Parallel() tests := []struct { name string - nilKeyPool bool err error expectedNil bool expectedStatus int expectedRetryAfter time.Duration }{ - { - // BYOK or no centralized pool: never maps. - name: "nil_keypool_returns_nil", - nilKeyPool: true, - err: &keypool.TransientExhaustionError{}, - expectedNil: true, - }, { // Transient with valid keys present: 429, no Retry-After. name: "transient_zero_retry_after", - err: &keypool.TransientExhaustionError{}, + err: &keypool.TransientKeyPoolError{}, expectedStatus: http.StatusTooManyRequests, expectedRetryAfter: 0, }, { // Transient with cooldown: 429, Retry-After set. name: "transient_with_retry_after", - err: &keypool.TransientExhaustionError{RetryAfter: 5 * time.Second}, + err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second}, expectedStatus: http.StatusTooManyRequests, expectedRetryAfter: 5 * time.Second, }, { // Permanent: 502 api_error. name: "permanent_returns_502", - err: keypool.ErrPermanentExhaustion, + err: keypool.ErrPermanentKeyPool, expectedStatus: http.StatusBadGateway, }, { @@ -1047,15 +1039,7 @@ func TestMapExhaustionError(t *testing.T) { for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - cfg := config.Anthropic{} - if !tc.nilKeyPool { - pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t)) - require.NoError(t, err) - cfg.KeyPool = pool - } - base := &interceptionBase{cfg: cfg} - - got := base.mapExhaustionError(tc.err) + got := processKeyPoolError(tc.err) if tc.expectedNil { require.Nil(t, got) return diff --git a/aibridge/intercept/messages/blocking.go b/aibridge/intercept/messages/blocking.go index ad1a79e8838e5..7956acf088d7d 100644 --- a/aibridge/intercept/messages/blocking.go +++ b/aibridge/intercept/messages/blocking.go @@ -114,7 +114,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req // The failover loop may return a keypool exhaustion // error. Check before the SDK-error path. - if keyErr := i.mapExhaustionError(err); keyErr != nil { + if keyErr := processKeyPoolError(err); keyErr != nil { i.writeUpstreamError(w, keyErr) return xerrors.Errorf("key pool exhausted: %w", err) } diff --git a/aibridge/intercept/messages/streaming.go b/aibridge/intercept/messages/streaming.go index 47ebeebf24e77..6143da2710f0a 100644 --- a/aibridge/intercept/messages/streaming.go +++ b/aibridge/intercept/messages/streaming.go @@ -175,23 +175,21 @@ newStream: var currentKey *keypool.Key if walker != nil { key, err := walker.Next() - if err != nil { + if respErr := processKeyPoolError(err); respErr != nil { // Pool exhausted in this iteration. Relay the // error to the client: as an SSE event if events // have already been sent, or by direct write // otherwise. - if respErr := i.mapExhaustionError(err); respErr != nil { - interceptionErr = respErr - if events.IsStreaming() { - payload, mErr := i.marshal(respErr) - if mErr != nil { - logger.Warn(ctx, "failed to marshal exhaustion error", slog.Error(mErr)) - } else if sErr := events.Send(streamCtx, payload); sErr != nil { - logger.Warn(ctx, "failed to relay exhaustion error", slog.Error(sErr)) - } - } else { - i.writeUpstreamError(w, respErr) + interceptionErr = respErr + if events.IsStreaming() { + payload, mErr := i.marshal(respErr) + if mErr != nil { + logger.Warn(ctx, "failed to marshal exhaustion error", slog.Error(mErr)) + } else if sErr := events.Send(streamCtx, payload); sErr != nil { + logger.Warn(ctx, "failed to relay exhaustion error", slog.Error(sErr)) } + } else { + i.writeUpstreamError(w, respErr) } break } diff --git a/aibridge/keypool/keypool.go b/aibridge/keypool/keypool.go index 48a50d2a0efe8..a2791f031deee 100644 --- a/aibridge/keypool/keypool.go +++ b/aibridge/keypool/keypool.go @@ -18,19 +18,19 @@ var ( ErrDuplicateKey = xerrors.New("duplicate key") ) -// ErrPermanentExhaustion is returned when every key in the +// ErrPermanentKeyPool is returned when every key in the // pool has been permanently marked unavailable. -var ErrPermanentExhaustion = xerrors.New("all keys permanently unavailable") +var ErrPermanentKeyPool = xerrors.New("all keys permanently unavailable") -// TransientExhaustionError is returned when no key is currently +// TransientKeyPoolError is returned when no key is currently // available but at least one will recover. RetryAfter is the // soonest remaining cooldown across the pool, or 0 if a key // just became valid mid-walk. -type TransientExhaustionError struct { +type TransientKeyPoolError struct { RetryAfter time.Duration } -func (e *TransientExhaustionError) Error() string { +func (e *TransientKeyPoolError) Error() string { return fmt.Sprintf("all keys exhausted (retry after %s)", e.RetryAfter) } @@ -176,12 +176,12 @@ func (k *Key) MarkPermanent() bool { return true } -// exhaustionError returns ErrPermanentExhaustion if every key -// is permanently unavailable, or *TransientExhaustionError if +// keyPoolError returns ErrPermanentKeyPool if every key +// is permanently unavailable, or *TransientKeyPoolError if // at least one key is temporarily unavailable. When multiple // keys are temporary, the smallest remaining cooldown is used // as the retry-after. -func (p *Pool) exhaustionError() error { +func (p *Pool) keyPoolError() error { var retryAfter time.Duration var hasCooldown bool for i := range p.keys { @@ -189,7 +189,7 @@ func (p *Pool) exhaustionError() error { switch state { // Recoverable now: signal transient with zero retry-after. case KeyStateValid: - return &TransientExhaustionError{} + return &TransientKeyPoolError{} // Recoverable later: track soonest remaining cooldown. case KeyStateTemporary: if !hasCooldown || cooldown < retryAfter { @@ -201,9 +201,9 @@ func (p *Pool) exhaustionError() error { } } if hasCooldown { - return &TransientExhaustionError{RetryAfter: retryAfter} + return &TransientKeyPoolError{RetryAfter: retryAfter} } - return ErrPermanentExhaustion + return ErrPermanentKeyPool } // PoolState returns a snapshot of each key's state in the pool's @@ -236,12 +236,12 @@ func (p *Pool) Walker() *Walker { // Next returns a Key handle for the next available key without // modifying the pool state. // -// Returns *TransientExhaustionError or ErrPermanentExhaustion +// Returns *TransientKeyPoolError or ErrPermanentKeyPool // when no more keys are available. func (w *Walker) Next() (*Key, error) { pool := w.pool if pool == nil { - return nil, ErrPermanentExhaustion + return nil, ErrPermanentKeyPool } for i := w.pos; i < len(pool.keys); i++ { @@ -255,5 +255,5 @@ func (w *Walker) Next() (*Key, error) { } // No keys available. - return nil, pool.exhaustionError() + return nil, pool.keyPoolError() } diff --git a/aibridge/keypool/keypool_test.go b/aibridge/keypool/keypool_test.go index bb544b7345248..0dc4cbdc240e6 100644 --- a/aibridge/keypool/keypool_test.go +++ b/aibridge/keypool/keypool_test.go @@ -51,7 +51,7 @@ func TestNewKeyPool(t *testing.T) { // No more keys available. _, err = walker.Next() - var transient *keypool.TransientExhaustionError + var transient *keypool.TransientKeyPoolError require.ErrorAs(t, err, &transient, "expected transient exhaustion: walker returned all valid keys, none marked permanent") }) } @@ -299,7 +299,7 @@ func TestWalkerNext(t *testing.T) { keys: []string{"key-0", "key-1", "key-2"}, setup: func(_ *testing.T, _ *keypool.Pool) {}, expectedValid: []string{"key-0", "key-1", "key-2"}, - expectedErr: &keypool.TransientExhaustionError{}, + expectedErr: &keypool.TransientKeyPoolError{}, }, { // Given: key-0: temporary, key-1: valid, key-2: valid. @@ -312,7 +312,7 @@ func TestWalkerNext(t *testing.T) { key.MarkTemporary(60 * time.Second) }, expectedValid: []string{"key-1", "key-2"}, - expectedErr: &keypool.TransientExhaustionError{}, + expectedErr: &keypool.TransientKeyPoolError{}, }, { // Given: key-0: permanent, key-1: permanent, key-2: valid. @@ -329,7 +329,7 @@ func TestWalkerNext(t *testing.T) { key1.MarkPermanent() }, expectedValid: []string{"key-2"}, - expectedErr: &keypool.TransientExhaustionError{}, + expectedErr: &keypool.TransientKeyPoolError{}, }, { // Given: key-0: temporary (30s), key-1: valid. @@ -344,7 +344,7 @@ func TestWalkerNext(t *testing.T) { }, advance: 35 * time.Second, expectedValid: []string{"key-0", "key-1"}, - expectedErr: &keypool.TransientExhaustionError{}, + expectedErr: &keypool.TransientKeyPoolError{}, }, { // Given: key-0: temporary (zero, default 60s), key-1: valid. @@ -359,7 +359,7 @@ func TestWalkerNext(t *testing.T) { }, advance: 50 * time.Second, expectedValid: []string{"key-1"}, - expectedErr: &keypool.TransientExhaustionError{}, + expectedErr: &keypool.TransientKeyPoolError{}, }, { // Given: key-0: temporary (zero, default 60s), key-1: valid. @@ -374,7 +374,7 @@ func TestWalkerNext(t *testing.T) { }, advance: 65 * time.Second, expectedValid: []string{"key-0", "key-1"}, - expectedErr: &keypool.TransientExhaustionError{}, + expectedErr: &keypool.TransientKeyPoolError{}, }, { // Given: key-0: temporary (negative, default 60s), key-1: valid. @@ -389,7 +389,7 @@ func TestWalkerNext(t *testing.T) { }, advance: 65 * time.Second, expectedValid: []string{"key-0", "key-1"}, - expectedErr: &keypool.TransientExhaustionError{}, + expectedErr: &keypool.TransientKeyPoolError{}, }, { // Given: key-0: temporary (60s), then marked again with shorter cooldown (10s). @@ -405,7 +405,7 @@ func TestWalkerNext(t *testing.T) { }, advance: 15 * time.Second, expectedValid: []string{}, - expectedErr: &keypool.TransientExhaustionError{RetryAfter: 45 * time.Second}, + expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 45 * time.Second}, }, { // Given: key-0: temporary (60s), then marked again with shorter cooldown (10s). @@ -421,7 +421,7 @@ func TestWalkerNext(t *testing.T) { }, advance: 65 * time.Second, expectedValid: []string{"key-0"}, - expectedErr: &keypool.TransientExhaustionError{}, + expectedErr: &keypool.TransientKeyPoolError{}, }, { // Given: key-0: temporary (60s), key-1: temporary (10s), key-2: temporary (30s). @@ -442,7 +442,7 @@ func TestWalkerNext(t *testing.T) { key2.MarkTemporary(30 * time.Second) }, expectedValid: []string{}, - expectedErr: &keypool.TransientExhaustionError{RetryAfter: 10 * time.Second}, + expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 10 * time.Second}, }, { // Given: key-0: temporary, key-1: temporary. @@ -459,7 +459,7 @@ func TestWalkerNext(t *testing.T) { key1.MarkTemporary(60 * time.Second) }, expectedValid: []string{}, - expectedErr: &keypool.TransientExhaustionError{RetryAfter: 60 * time.Second}, + expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second}, }, { // Given: key-0: permanent, key-1: permanent. @@ -476,7 +476,7 @@ func TestWalkerNext(t *testing.T) { key1.MarkPermanent() }, expectedValid: []string{}, - expectedErr: keypool.ErrPermanentExhaustion, + expectedErr: keypool.ErrPermanentKeyPool, }, { // Given: key-0: permanent, key-1: temporary, key-2: permanent. @@ -496,7 +496,7 @@ func TestWalkerNext(t *testing.T) { key2.MarkPermanent() }, expectedValid: []string{}, - expectedErr: &keypool.TransientExhaustionError{RetryAfter: 60 * time.Second}, + expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second}, }, } @@ -523,9 +523,9 @@ func TestWalkerNext(t *testing.T) { // After all expected keys, the walker should be exhausted. _, err = walker.Next() - var wantTransient *keypool.TransientExhaustionError + var wantTransient *keypool.TransientKeyPoolError if errors.As(tc.expectedErr, &wantTransient) { - var got *keypool.TransientExhaustionError + var got *keypool.TransientKeyPoolError require.ErrorAs(t, err, &got) assert.Equal(t, wantTransient.RetryAfter, got.RetryAfter) } else { From 26829421681d70549a392a5ef9a37ff37128da19 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Thu, 7 May 2026 11:17:25 +0000 Subject: [PATCH 14/14] refactor: update log level --- aibridge/keypool/keymark.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aibridge/keypool/keymark.go b/aibridge/keypool/keymark.go index 0f97543d7bd1a..9b00bb400ac49 100644 --- a/aibridge/keypool/keymark.go +++ b/aibridge/keypool/keymark.go @@ -30,7 +30,7 @@ func MarkKeyOnStatus( cooldown = defaultCooldown } if key.MarkTemporary(cooldown) { - logger.Warn(ctx, "key marked temporary", + logger.Info(ctx, "key marked temporary", slog.F("provider", providerName), slog.F("api_key_hint", utils.MaskSecret(key.Value())), slog.F("status", statusCode), @@ -39,7 +39,7 @@ func MarkKeyOnStatus( return true case http.StatusUnauthorized, http.StatusForbidden: if key.MarkPermanent() { - logger.Error(ctx, "key marked permanent", + logger.Warn(ctx, "key marked permanent", slog.F("provider", providerName), slog.F("api_key_hint", utils.MaskSecret(key.Value())), slog.F("status", statusCode))