diff --git a/aibridge/intercept/chatcompletions/base.go b/aibridge/intercept/chatcompletions/base.go index 2c42c9c9dfa73..87e72d8e81362 100644 --- a/aibridge/intercept/chatcompletions/base.go +++ b/aibridge/intercept/chatcompletions/base.go @@ -9,12 +9,10 @@ import ( "net/http" "strconv" "strings" - "time" "github.com/google/uuid" "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" - "github.com/openai/openai-go/v3/shared" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" @@ -27,7 +25,6 @@ import ( "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/recorder" "github.com/coder/coder/v2/aibridge/tracing" - "github.com/coder/coder/v2/aibridge/utils" "github.com/coder/quartz" ) @@ -189,7 +186,7 @@ func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) { } // writeUpstreamError marshals and writes a given error. -func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) { +func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *intercept.ResponseError) { if oaiErr == nil { return } @@ -235,33 +232,6 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, ) } -// ProcessKeyPoolError translates a keypool exhaustion error -// into a developer-facing responseError shaped for the OpenAI -// API. Returns nil if err is not an exhaustion error. -func ProcessKeyPoolError(err error) *ResponseError { - var transient *keypool.TransientKeyPoolError - switch { - case errors.As(err, &transient): - return newErrorResponse( - "all configured keys are rate-limited", - intercept.OpenAIErrTypeRateLimit, - intercept.OpenAIErrCodeRateLimit, - http.StatusTooManyRequests, - transient.RetryAfter, - ) - case errors.Is(err, keypool.ErrPermanentKeyPool): - return newErrorResponse( - "all configured keys failed authentication", - intercept.OpenAIErrTypeAPI, - intercept.OpenAIErrCodeServer, - http.StatusBadGateway, - 0, - ) - default: - return nil - } -} - func (i *interceptionBase) hasInjectableTools() bool { return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0 } @@ -292,48 +262,3 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 { return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ - in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */ } - -func getErrorResponse(err error) *ResponseError { - var apiErr *openai.Error - if !errors.As(err, &apiErr) { - return nil - } - return newErrorResponse(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response)) -} - -var _ error = &ResponseError{} - -type ResponseError struct { - ErrorObject *shared.ErrorObject `json:"error"` - StatusCode int `json:"-"` - RetryAfter time.Duration `json:"-"` -} - -func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError { - return &ResponseError{ - ErrorObject: &shared.ErrorObject{ - Code: code, - Message: msg, - Type: errType, - }, - StatusCode: status, - RetryAfter: retryAfter, - } -} - -func (e *ResponseError) Error() string { - if e.ErrorObject == nil { - return "" - } - return e.ErrorObject.Message -} - -// ToResponse marshals e into an *http.Response shaped for the -// OpenAI API. -func (e *ResponseError) ToResponse() *http.Response { - body, err := json.Marshal(e) - if err != nil { - body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`) - } - return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body) -} diff --git a/aibridge/intercept/chatcompletions/base_internal_test.go b/aibridge/intercept/chatcompletions/base_internal_test.go index d673ec1fe4532..1af6054cfa1aa 100644 --- a/aibridge/intercept/chatcompletions/base_internal_test.go +++ b/aibridge/intercept/chatcompletions/base_internal_test.go @@ -14,6 +14,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/utils" "github.com/coder/quartz" @@ -86,59 +87,6 @@ func TestScanForCorrelatingToolCallID(t *testing.T) { } } -func TestProcessKeyPoolError(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - err error - expectedNil bool - expectedStatus int - expectedRetryAfter time.Duration - }{ - { - // Transient with valid keys present: 429, no Retry-After. - name: "transient_zero_retry_after", - err: &keypool.TransientKeyPoolError{}, - expectedStatus: http.StatusTooManyRequests, - expectedRetryAfter: 0, - }, - { - // Transient with cooldown: 429, Retry-After set. - name: "transient_with_retry_after", - err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second}, - expectedStatus: http.StatusTooManyRequests, - expectedRetryAfter: 5 * time.Second, - }, - { - // Permanent: 502 api_error. - name: "permanent_returns_502", - err: keypool.ErrPermanentKeyPool, - expectedStatus: http.StatusBadGateway, - }, - { - // Anything else: not a pool-exhaustion error. - name: "non_pool_exhaustion_error_returns_nil", - err: xerrors.New("some other error"), - expectedNil: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got := ProcessKeyPoolError(tc.err) - if tc.expectedNil { - require.Nil(t, got) - return - } - require.NotNil(t, got) - assert.Equal(t, tc.expectedStatus, got.StatusCode) - assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter) - }) - } -} - func TestMarkKeyOnError(t *testing.T) { t.Parallel() @@ -190,8 +138,8 @@ func TestMarkKeyOnError(t *testing.T) { t.Parallel() pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t)) require.NoError(t, err) - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) base := &interceptionBase{cfg: config.OpenAI{KeyPool: pool}, logger: slog.Make()} @@ -207,7 +155,7 @@ func TestWriteUpstreamError(t *testing.T) { tests := []struct { name string - respErr *ResponseError + respErr *intercept.ResponseError expectStatus int // Empty string means the header should be absent. expectRetryAfter string @@ -217,42 +165,42 @@ func TestWriteUpstreamError(t *testing.T) { { // Standard error: status, code, and JSON body written. name: "writes_status_and_body", - respErr: newErrorResponse("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0), + respErr: intercept.NewResponseError("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0), expectStatus: http.StatusBadGateway, expectBodyContains: `"upstream failed"`, }, { // OpenAI envelope: the code field round-trips into the body. name: "writes_code_field", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0), expectStatus: http.StatusTooManyRequests, expectBodyContains: `"rate_limit_exceeded"`, }, { // Whole-second retryAfter: emitted as integer seconds. name: "retry_after_in_seconds", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "60", }, { // 500ms rounds up to Retry-After: 1. name: "retry_after_500ms_rounds_up_to_one", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "1", }, { // 200ms rounds up to Retry-After: 1. name: "retry_after_200ms_rounds_up_to_one", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "1", }, { // Negative retryAfter: header omitted. name: "negative_retry_after_omits_header", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "", }, diff --git a/aibridge/intercept/chatcompletions/blocking.go b/aibridge/intercept/chatcompletions/blocking.go index ce18d3d654cec..95d065ce5b3ec 100644 --- a/aibridge/intercept/chatcompletions/blocking.go +++ b/aibridge/intercept/chatcompletions/blocking.go @@ -3,6 +3,7 @@ package chatcompletions import ( "context" "encoding/json" + "errors" "net/http" "strings" "time" @@ -19,6 +20,7 @@ import ( aibcontext "github.com/coder/coder/v2/aibridge/context" "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/intercept/eventstream" + "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/recorder" "github.com/coder/coder/v2/aibridge/tracing" @@ -224,12 +226,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req // The failover loop may return a keypool exhaustion // error. Check before the SDK-error path. - if keyErr := ProcessKeyPoolError(err); keyErr != nil { - i.writeUpstreamError(w, keyErr) + var keyPoolErr *keypool.Error + if errors.As(err, &keyPoolErr) { + i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr)) return xerrors.Errorf("key pool exhausted: %w", err) } - if apiErr := getErrorResponse(err); apiErr != nil { + if apiErr := intercept.ResponseErrorFromAPIError(err); apiErr != nil { i.writeUpstreamError(w, apiErr) return xerrors.Errorf("openai API error: %w", err) } @@ -293,9 +296,9 @@ func (i *BlockingInterception) newChatCompletionWithKeyFailover(ctx context.Cont // success, the last tried key on failure) in the upstack PR. walker := i.cfg.KeyPool.Walker() for { - key, err := walker.Next() - if err != nil { - return nil, err + key, keyPoolErr := walker.Next() + if keyPoolErr != nil { + return nil, keyPoolErr } requestOpts := append([]option.RequestOption{}, opts...) diff --git a/aibridge/intercept/chatcompletions/streaming.go b/aibridge/intercept/chatcompletions/streaming.go index dea799853aa2b..581ab49d034c1 100644 --- a/aibridge/intercept/chatcompletions/streaming.go +++ b/aibridge/intercept/chatcompletions/streaming.go @@ -143,8 +143,9 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re var opts []option.RequestOption var currentKey *keypool.Key if walker != nil { - key, err := walker.Next() - if respErr := ProcessKeyPoolError(err); respErr != nil { + key, keyPoolErr := walker.Next() + if keyPoolErr != nil { + respErr := intercept.ResponseErrorFromKeyPool(keyPoolErr) // Pool exhausted in this iteration. Relay the // error to the client: as an SSE event if events // have already been sent, or by direct write @@ -470,17 +471,17 @@ func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCo } // mapStreamError converts a mid-stream upstream error or -// processing error into a relayable responseError. Returns nil +// processing error into a relayable ResponseError. Returns nil // when the error is unrecoverable, in which case nothing can be // relayed back. -func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *ResponseError { +func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *intercept.ResponseError { if streamErr != nil { if eventstream.IsUnrecoverableError(streamErr) { logger.Debug(ctx, "stream terminated", slog.Error(streamErr)) // We can't reflect an error back if there's a connection error or the request context was canceled. return nil } - if oaiErr := getErrorResponse(streamErr); oaiErr != nil { + if oaiErr := intercept.ResponseErrorFromAPIError(streamErr); oaiErr != nil { logger.Warn(ctx, "openai stream error", slog.Error(streamErr)) return oaiErr } @@ -489,11 +490,11 @@ func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Lo // into known types (i.e. [shared.OverloadedError]). // See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171 // All it does is wrap the payload in an error - which is all we can return, currently. - return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0) + return intercept.NewResponseError(fmt.Sprintf("unknown stream error: %s", streamErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0) } if lastErr != nil { logger.Warn(ctx, "stream processing failed", slog.Error(lastErr)) - return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0) + return intercept.NewResponseError(fmt.Sprintf("processing error: %s", lastErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0) } return nil } diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index 3ccb422566996..e35e2a9726175 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -583,32 +583,36 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, ) } -// ProcessKeyPoolError translates a keypool exhaustion error -// into a developer-facing responseError shaped for the Anthropic -// API. Returns nil if err is not an exhaustion error. -func ProcessKeyPoolError(err error) *ResponseError { - var transient *keypool.TransientKeyPoolError - switch { - case errors.As(err, &transient): - return newErrorResponse( - "all configured keys are rate-limited", +// ResponseErrorFromKeyPool translates a *keypool.Error into +// a developer-facing ResponseError shaped for the Anthropic API. +func ResponseErrorFromKeyPool(keyPoolErr *keypool.Error) *ResponseError { + switch keyPoolErr.Kind { + case keypool.ErrorKindPermanent: + return newResponseError( + keyPoolErr.Error(), + string(constant.ValueOf[constant.APIError]()), + http.StatusBadGateway, + keyPoolErr.RetryAfter, + ) + case keypool.ErrorKindRateLimited: + return newResponseError( + keyPoolErr.Error(), string(constant.ValueOf[constant.RateLimitError]()), http.StatusTooManyRequests, - transient.RetryAfter, + keyPoolErr.RetryAfter, ) - case errors.Is(err, keypool.ErrPermanentKeyPool): - return newErrorResponse( - "all configured keys failed authentication", + default: + // Fall back to a generic 502. + return newResponseError( + keyPoolErr.Error(), string(constant.ValueOf[constant.APIError]()), http.StatusBadGateway, - 0, + keyPoolErr.RetryAfter, ) - default: - return nil } } -func getErrorResponse(err error) *ResponseError { +func responseErrorFromAPIError(err error) *ResponseError { var apierr *anthropic.Error if !errors.As(err, &apierr) { return nil @@ -626,7 +630,7 @@ func getErrorResponse(err error) *ResponseError { errType = string(detail.Type) } - return newErrorResponse(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response)) + return newResponseError(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response)) } var _ error = &ResponseError{} @@ -638,7 +642,7 @@ type ResponseError struct { RetryAfter time.Duration `json:"-"` } -func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *ResponseError { +func newResponseError(msg, errType string, status int, retryAfter time.Duration) *ResponseError { return &ResponseError{ ErrorResponse: &shared.ErrorResponse{ Error: shared.ErrorObjectUnion{ diff --git a/aibridge/intercept/messages/base_internal_test.go b/aibridge/intercept/messages/base_internal_test.go index 9605d97966501..ef130deca1cce 100644 --- a/aibridge/intercept/messages/base_internal_test.go +++ b/aibridge/intercept/messages/base_internal_test.go @@ -1061,52 +1061,41 @@ func TestFilterBedrockBetaFlags(t *testing.T) { } } -func TestProcessKeyPoolError(t *testing.T) { +func TestResponseErrorFromKeyPool(t *testing.T) { t.Parallel() tests := []struct { name string - err error - expectedNil bool + keyPoolErr *keypool.Error expectedStatus int expectedRetryAfter time.Duration }{ { - // Transient with valid keys present: 429, no Retry-After. - name: "transient_zero_retry_after", - err: &keypool.TransientKeyPoolError{}, + // Rate-limited with no cooldown: 429, no Retry-After. + name: "rate_limited_zero_retry_after", + keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, expectedStatus: http.StatusTooManyRequests, expectedRetryAfter: 0, }, { - // Transient with cooldown: 429, Retry-After set. - name: "transient_with_retry_after", - err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second}, + // Rate-limited with cooldown: 429, Retry-After set. + name: "rate_limited_with_retry_after", + keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second}, expectedStatus: http.StatusTooManyRequests, expectedRetryAfter: 5 * time.Second, }, { // Permanent: 502 api_error. name: "permanent_returns_502", - err: keypool.ErrPermanentKeyPool, + keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindPermanent}, expectedStatus: http.StatusBadGateway, }, - { - // Anything else: not a pool-exhaustion error. - name: "non_pool_exhaustion_error_returns_nil", - err: xerrors.New("some other error"), - expectedNil: true, - }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { t.Parallel() - got := ProcessKeyPoolError(tc.err) - if tc.expectedNil { - require.Nil(t, got) - return - } + got := ResponseErrorFromKeyPool(tc.keyPoolErr) require.NotNil(t, got) assert.Equal(t, tc.expectedStatus, got.StatusCode) assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter) @@ -1165,8 +1154,8 @@ func TestMarkKeyOnError(t *testing.T) { t.Parallel() pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t)) require.NoError(t, err) - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) base := &interceptionBase{cfg: config.Anthropic{KeyPool: pool}, logger: slog.Make()} @@ -1192,35 +1181,35 @@ func TestWriteUpstreamError(t *testing.T) { { // Standard error: status and JSON body written. name: "writes_status_and_body", - respErr: newErrorResponse("upstream failed", "api_error", http.StatusBadGateway, 0), + respErr: newResponseError("upstream failed", "api_error", http.StatusBadGateway, 0), expectStatus: http.StatusBadGateway, expectBodyContains: `"upstream failed"`, }, { // Whole-second retryAfter: emitted as integer seconds. name: "retry_after_in_seconds", - respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second), + respErr: newResponseError("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "60", }, { // 500ms rounds up to Retry-After: 1. name: "retry_after_500ms_rounds_up_to_one", - respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond), + respErr: newResponseError("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "1", }, { // 200ms rounds up to Retry-After: 1. name: "retry_after_200ms_rounds_up_to_one", - respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond), + respErr: newResponseError("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "1", }, { // Negative retryAfter: header omitted. name: "negative_retry_after_omits_header", - respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second), + respErr: newResponseError("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "", }, diff --git a/aibridge/intercept/messages/blocking.go b/aibridge/intercept/messages/blocking.go index 670231790e7fb..e91f80feb9e6e 100644 --- a/aibridge/intercept/messages/blocking.go +++ b/aibridge/intercept/messages/blocking.go @@ -2,6 +2,7 @@ package messages import ( "context" + "errors" "fmt" "net/http" "time" @@ -20,6 +21,7 @@ import ( aibcontext "github.com/coder/coder/v2/aibridge/context" "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/intercept/eventstream" + "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/recorder" "github.com/coder/coder/v2/aibridge/tracing" @@ -114,12 +116,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req // The failover loop may return a keypool exhaustion // error. Check before the SDK-error path. - if keyErr := ProcessKeyPoolError(err); keyErr != nil { - i.writeUpstreamError(w, keyErr) + var keyPoolErr *keypool.Error + if errors.As(err, &keyPoolErr) { + i.writeUpstreamError(w, ResponseErrorFromKeyPool(keyPoolErr)) return xerrors.Errorf("key pool exhausted: %w", err) } - if antErr := getErrorResponse(err); antErr != nil { + if antErr := responseErrorFromAPIError(err); antErr != nil { i.writeUpstreamError(w, antErr) return xerrors.Errorf("anthropic API error: %w", err) } @@ -369,9 +372,9 @@ func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, sv // success, the last tried key on failure) in the upstack PR. walker := i.cfg.KeyPool.Walker() for { - key, err := walker.Next() - if err != nil { - return nil, err + key, keyPoolErr := walker.Next() + if keyPoolErr != nil { + return nil, keyPoolErr } msg, err := i.newMessageWithKey(ctx, svc, diff --git a/aibridge/intercept/messages/streaming.go b/aibridge/intercept/messages/streaming.go index 8186a6d010c5e..475f32c99c459 100644 --- a/aibridge/intercept/messages/streaming.go +++ b/aibridge/intercept/messages/streaming.go @@ -174,12 +174,13 @@ newStream: var streamOpts []option.RequestOption var currentKey *keypool.Key if walker != nil { - key, err := walker.Next() - if respErr := ProcessKeyPoolError(err); respErr != nil { + key, keyPoolErr := walker.Next() + if keyPoolErr != nil { // Pool exhausted in this iteration. Relay the // error to the client: as an SSE event if events // have already been sent, or by direct write // otherwise. + respErr := ResponseErrorFromKeyPool(keyPoolErr) interceptionErr = respErr if events.IsStreaming() { payload, mErr := i.marshal(respErr) @@ -607,7 +608,7 @@ func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Lo // We can't reflect an error back if there's a connection error or the request context was canceled. return nil } - if antErr := getErrorResponse(streamErr); antErr != nil { + if antErr := responseErrorFromAPIError(streamErr); antErr != nil { logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr)) return antErr } @@ -616,11 +617,11 @@ func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Lo // into known types (i.e. [shared.OverloadedError]). // See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174 // All it does is wrap the payload in an error - which is all we can return, currently. - return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0) + return newResponseError(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0) } if lastErr != nil { logger.Warn(ctx, "stream processing failed", slog.Error(lastErr)) - return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0) + return newResponseError(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0) } return nil } diff --git a/aibridge/intercept/openai_errors.go b/aibridge/intercept/openai_errors.go index 92e13fd02f9c4..faf2e19e3e023 100644 --- a/aibridge/intercept/openai_errors.go +++ b/aibridge/intercept/openai_errors.go @@ -1,5 +1,18 @@ package intercept +import ( + "encoding/json" + "errors" + "net/http" + "time" + + "github.com/openai/openai-go/v3" + "github.com/openai/openai-go/v3/shared" + + "github.com/coder/coder/v2/aibridge/keypool" + "github.com/coder/coder/v2/aibridge/utils" +) + // OpenAI error type and code constants used by the chatcompletions // and responses interceptors. The OpenAI Go SDK does not expose // these as typed constants, so we define our own. @@ -12,3 +25,89 @@ const ( OpenAIErrCodeServer = "server_error" OpenAIErrCodeRateLimit = "rate_limit_exceeded" ) + +var _ error = &ResponseError{} + +// ResponseError is the OpenAI-shaped error envelope returned to +// clients. StatusCode and RetryAfter map to HTTP headers, not JSON +// fields. The chatcompletions and responses interceptors both +// use the same response error format. +type ResponseError struct { + ErrorObject *shared.ErrorObject `json:"error"` + StatusCode int `json:"-"` + RetryAfter time.Duration `json:"-"` +} + +// NewResponseError builds a ResponseError with the OpenAI-shaped +// envelope. errType and code should be one of the OpenAIErrType* +// and OpenAIErrCode* constants defined above. +func NewResponseError(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError { + return &ResponseError{ + ErrorObject: &shared.ErrorObject{ + Code: code, + Message: msg, + Type: errType, + }, + StatusCode: status, + RetryAfter: retryAfter, + } +} + +func (e *ResponseError) Error() string { + if e.ErrorObject == nil { + return "" + } + return e.ErrorObject.Message +} + +// ToResponse marshals e into an *http.Response shaped for the +// OpenAI API. +func (e *ResponseError) ToResponse() *http.Response { + body, err := json.Marshal(e) + if err != nil { + body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`) + } + return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body) +} + +// ResponseErrorFromKeyPool translates a *keypool.Error into +// a developer-facing ResponseError shaped for the OpenAI API. +func ResponseErrorFromKeyPool(keyPoolErr *keypool.Error) *ResponseError { + switch keyPoolErr.Kind { + case keypool.ErrorKindPermanent: + return NewResponseError( + keyPoolErr.Error(), + OpenAIErrTypeAPI, + OpenAIErrCodeServer, + http.StatusBadGateway, + keyPoolErr.RetryAfter, + ) + case keypool.ErrorKindRateLimited: + return NewResponseError( + keyPoolErr.Error(), + OpenAIErrTypeRateLimit, + OpenAIErrCodeRateLimit, + http.StatusTooManyRequests, + keyPoolErr.RetryAfter, + ) + default: + // Fall back to a generic 502. + return NewResponseError( + keyPoolErr.Error(), + OpenAIErrTypeAPI, + OpenAIErrCodeServer, + http.StatusBadGateway, + keyPoolErr.RetryAfter, + ) + } +} + +// ResponseErrorFromAPIError converts an OpenAI SDK error into a +// ResponseError. Returns nil if err is not an *openai.Error. +func ResponseErrorFromAPIError(err error) *ResponseError { + var apiErr *openai.Error + if !errors.As(err, &apiErr) { + return nil + } + return NewResponseError(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response)) +} diff --git a/aibridge/intercept/openai_errors_test.go b/aibridge/intercept/openai_errors_test.go new file mode 100644 index 0000000000000..9b49c1e43ab80 --- /dev/null +++ b/aibridge/intercept/openai_errors_test.go @@ -0,0 +1,55 @@ +package intercept_test + +import ( + "net/http" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/keypool" +) + +func TestResponseErrorFromKeyPool(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + keyPoolErr *keypool.Error + expectedStatus int + expectedRetryAfter time.Duration + }{ + { + // Rate-limited with no cooldown: 429, no Retry-After. + name: "rate_limited_zero_retry_after", + keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, + expectedStatus: http.StatusTooManyRequests, + expectedRetryAfter: 0, + }, + { + // Rate-limited with cooldown: 429, Retry-After set. + name: "rate_limited_with_retry_after", + keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second}, + expectedStatus: http.StatusTooManyRequests, + expectedRetryAfter: 5 * time.Second, + }, + { + // Permanent: 502 api_error. + name: "permanent_returns_502", + keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindPermanent}, + expectedStatus: http.StatusBadGateway, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := intercept.ResponseErrorFromKeyPool(tc.keyPoolErr) + require.NotNil(t, got) + assert.Equal(t, tc.expectedStatus, got.StatusCode) + assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter) + }) + } +} diff --git a/aibridge/intercept/responses/base.go b/aibridge/intercept/responses/base.go index 33986e1acf374..4be82c64b0b29 100644 --- a/aibridge/intercept/responses/base.go +++ b/aibridge/intercept/responses/base.go @@ -19,7 +19,6 @@ import ( "github.com/openai/openai-go/v3" "github.com/openai/openai-go/v3/option" "github.com/openai/openai-go/v3/responses" - "github.com/openai/openai-go/v3/shared" "github.com/openai/openai-go/v3/shared/constant" "github.com/tidwall/gjson" "go.opentelemetry.io/otel/attribute" @@ -35,7 +34,6 @@ import ( "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/recorder" "github.com/coder/coder/v2/aibridge/tracing" - "github.com/coder/coder/v2/aibridge/utils" "github.com/coder/quartz" ) @@ -143,7 +141,7 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http. } // writeUpstreamError marshals and writes a given error. -func (i *responsesInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) { +func (i *responsesInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *intercept.ResponseError) { if oaiErr == nil { return } @@ -189,70 +187,6 @@ func (i *responsesInterceptionBase) markKeyOnError(ctx context.Context, key *key ) } -// ProcessKeyPoolError translates a keypool exhaustion error -// into a developer-facing ResponseError shaped for the OpenAI -// API. Returns nil if err is not an exhaustion error. -func ProcessKeyPoolError(err error) *ResponseError { - var transient *keypool.TransientKeyPoolError - switch { - case errors.As(err, &transient): - return newErrorResponse( - "all configured keys are rate-limited", - intercept.OpenAIErrTypeRateLimit, - intercept.OpenAIErrCodeRateLimit, - http.StatusTooManyRequests, - transient.RetryAfter, - ) - case errors.Is(err, keypool.ErrPermanentKeyPool): - return newErrorResponse( - "all configured keys failed authentication", - intercept.OpenAIErrTypeAPI, - intercept.OpenAIErrCodeServer, - http.StatusBadGateway, - 0, - ) - default: - return nil - } -} - -func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError { - return &ResponseError{ - ErrorObject: &shared.ErrorObject{ - Code: code, - Message: msg, - Type: errType, - }, - StatusCode: status, - RetryAfter: retryAfter, - } -} - -var _ error = &ResponseError{} - -type ResponseError struct { - ErrorObject *shared.ErrorObject `json:"error"` - StatusCode int `json:"-"` - RetryAfter time.Duration `json:"-"` -} - -func (e *ResponseError) Error() string { - if e.ErrorObject == nil { - return "" - } - return e.ErrorObject.Message -} - -// ToResponse marshals e into an *http.Response shaped for the -// OpenAI API. -func (e *ResponseError) ToResponse() *http.Response { - body, err := json.Marshal(e) - if err != nil { - body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`) - } - return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body) -} - // sendCustomErr sends custom responses.Error error to the client // it should only be called before any data is sent back to the client func (i *responsesInterceptionBase) sendCustomErr(ctx context.Context, w http.ResponseWriter, code int, err error) { diff --git a/aibridge/intercept/responses/base_internal_test.go b/aibridge/intercept/responses/base_internal_test.go index d836646708753..f2b92ea029f01 100644 --- a/aibridge/intercept/responses/base_internal_test.go +++ b/aibridge/intercept/responses/base_internal_test.go @@ -16,6 +16,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/internal/testutil" "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/recorder" @@ -390,59 +391,6 @@ func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) { require.True(t, mrw.writeHeaderCalled) } -func TestProcessKeyPoolError(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - err error - expectedNil bool - expectedStatus int - expectedRetryAfter time.Duration - }{ - { - // Transient with valid keys present: 429, no Retry-After. - name: "transient_zero_retry_after", - err: &keypool.TransientKeyPoolError{}, - expectedStatus: http.StatusTooManyRequests, - expectedRetryAfter: 0, - }, - { - // Transient with cooldown: 429, Retry-After set. - name: "transient_with_retry_after", - err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second}, - expectedStatus: http.StatusTooManyRequests, - expectedRetryAfter: 5 * time.Second, - }, - { - // Permanent: 502 api_error. - name: "permanent_returns_502", - err: keypool.ErrPermanentKeyPool, - expectedStatus: http.StatusBadGateway, - }, - { - // Anything else: not a pool-exhaustion error. - name: "non_pool_exhaustion_error_returns_nil", - err: xerrors.New("some other error"), - expectedNil: true, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got := ProcessKeyPoolError(tc.err) - if tc.expectedNil { - require.Nil(t, got) - return - } - require.NotNil(t, got) - assert.Equal(t, tc.expectedStatus, got.StatusCode) - assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter) - }) - } -} - func TestMarkKeyOnError(t *testing.T) { t.Parallel() @@ -494,8 +442,8 @@ func TestMarkKeyOnError(t *testing.T) { t.Parallel() pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t)) require.NoError(t, err) - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) base := &responsesInterceptionBase{cfg: config.OpenAI{KeyPool: pool}, logger: slog.Make()} @@ -511,7 +459,7 @@ func TestWriteUpstreamError(t *testing.T) { tests := []struct { name string - respErr *ResponseError + respErr *intercept.ResponseError expectStatus int // Empty string means the header should be absent. expectRetryAfter string @@ -521,42 +469,42 @@ func TestWriteUpstreamError(t *testing.T) { { // Standard error: status, code, and JSON body written. name: "writes_status_and_body", - respErr: newErrorResponse("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0), + respErr: intercept.NewResponseError("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0), expectStatus: http.StatusBadGateway, expectBodyContains: `"upstream failed"`, }, { // OpenAI envelope: the code field round-trips into the body. name: "writes_code_field", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0), expectStatus: http.StatusTooManyRequests, expectBodyContains: `"rate_limit_exceeded"`, }, { // Whole-second retryAfter: emitted as integer seconds. name: "retry_after_in_seconds", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "60", }, { // 500ms rounds up to Retry-After: 1. name: "retry_after_500ms_rounds_up_to_one", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "1", }, { // 200ms rounds up to Retry-After: 1. name: "retry_after_200ms_rounds_up_to_one", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "1", }, { // Negative retryAfter: header omitted. name: "negative_retry_after_omits_header", - respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second), + respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second), expectStatus: http.StatusTooManyRequests, expectRetryAfter: "", }, diff --git a/aibridge/intercept/responses/blocking.go b/aibridge/intercept/responses/blocking.go index d2dd7a3de8f4e..9726b6f750efc 100644 --- a/aibridge/intercept/responses/blocking.go +++ b/aibridge/intercept/responses/blocking.go @@ -17,6 +17,7 @@ import ( "github.com/coder/coder/v2/aibridge/config" aibcontext "github.com/coder/coder/v2/aibridge/context" "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/keypool" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/recorder" "github.com/coder/coder/v2/aibridge/tracing" @@ -103,8 +104,9 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r * // The failover loop may return a keypool exhaustion // error. Render it here. if upstreamErr != nil { - if keyErr := ProcessKeyPoolError(upstreamErr); keyErr != nil { - i.writeUpstreamError(w, keyErr) + var keyPoolErr *keypool.Error + if errors.As(upstreamErr, &keyPoolErr) { + i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr)) return xerrors.Errorf("key pool exhausted: %w", upstreamErr) } } @@ -174,9 +176,9 @@ func (i *BlockingResponsesInterceptor) newResponseWithKeyFailover(ctx context.Co // success, the last tried key on failure) in the upstack PR. walker := i.cfg.KeyPool.Walker() for { - key, err := walker.Next() - if err != nil { - return nil, err + key, keyPoolErr := walker.Next() + if keyPoolErr != nil { + return nil, keyPoolErr } requestOpts := append([]option.RequestOption{}, opts...) diff --git a/aibridge/intercept/responses/streaming.go b/aibridge/intercept/responses/streaming.go index 6d1e70c9ca9b9..2140c5e6c8670 100644 --- a/aibridge/intercept/responses/streaming.go +++ b/aibridge/intercept/responses/streaming.go @@ -134,14 +134,14 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r var currentKey *keypool.Key if walker != nil { - key, err := walker.Next() - if respErr := ProcessKeyPoolError(err); respErr != nil { + key, keyPoolErr := walker.Next() + if keyPoolErr != nil { // Pool exhausted: write the error directly. In // agentic mode the inner loop buffers events // instead of streaming them downstream, so the // SSE connection has not been opened yet. - i.writeUpstreamError(w, respErr) - return xerrors.Errorf("key pool exhausted: %w", err) + i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr)) + return xerrors.Errorf("key pool exhausted: %w", keyPoolErr) } currentKey = key opts = append(opts, diff --git a/aibridge/keypool/failover.go b/aibridge/keypool/failover.go index f650cab258a09..38dcd3b972e94 100644 --- a/aibridge/keypool/failover.go +++ b/aibridge/keypool/failover.go @@ -2,11 +2,10 @@ package keypool import ( "bytes" - "context" - "fmt" "io" "net/http" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge/utils" ) @@ -16,6 +15,9 @@ type KeyFailoverConfig struct { // Pool is the key pool to walk. Nil disables key failover. Pool *Pool + ProviderName string + Logger slog.Logger + // IsBYOK returns true when the request already carries // user-supplied auth. BYOK requests skip key failover. IsBYOK func(*http.Request) bool @@ -24,14 +26,9 @@ type KeyFailoverConfig struct { // in the format the provider expects. InjectAuthKey func(*http.Header, string) - // MarkKey marks the key based on the upstream response. - // Returns true when the response is a key-specific error, - // causing the walker to advance and retry with the next key. - MarkKey func(ctx context.Context, key *Key, resp *http.Response) bool - - // BuildExhaustedResponse returns the response sent to the - // client when the walker has no more keys to try. - BuildExhaustedResponse func(err error) *http.Response + // BuildKeyPoolResponse renders the response sent to the client + // when the walker has no more keys to try. + BuildKeyPoolResponse func(*Error) *http.Response } // keyFailoverTransport retries inner across the key pool on @@ -74,12 +71,12 @@ func (t *keyFailoverTransport) RoundTrip(req *http.Request) (*http.Response, err // Fresh walker per request, independent of other inflight requests. walker := t.config.Pool.Walker() for { - key, err := walker.Next() - if err != nil { - resp := t.config.BuildExhaustedResponse(err) + key, keyPoolErr := walker.Next() + if keyPoolErr != nil { + resp := t.config.BuildKeyPoolResponse(keyPoolErr) if resp == nil { - // Fallback if BuildExhaustedResponse returns nil. - body := []byte(fmt.Sprintf(`{"error":"key pool exhausted: %s"}`, err)) + // Fallback if BuildKeyPoolResponse returns nil. + body := []byte(`{"error":"key pool unavailable"}`) resp = utils.NewJSONErrorResponse(http.StatusBadGateway, 0, body) } return resp, nil @@ -97,8 +94,8 @@ func (t *keyFailoverTransport) RoundTrip(req *http.Request) (*http.Response, err // Transport-level error, not a key issue. return resp, rtErr } - // MarkKey returns true on key-specific failures (e.g. 401/403/429). - if t.config.MarkKey(req.Context(), key, resp) { + // MarkKeyOnStatus returns true on key-specific failures (e.g. 401/403/429). + if MarkKeyOnStatus(req.Context(), key, resp, t.config.Logger, t.config.ProviderName) { // Drain and retry with the next key. _, _ = io.Copy(io.Discard, resp.Body) _ = resp.Body.Close() diff --git a/aibridge/keypool/keymark_test.go b/aibridge/keypool/keymark_test.go index 9e4631409d2c7..228e576aa0d2c 100644 --- a/aibridge/keypool/keymark_test.go +++ b/aibridge/keypool/keymark_test.go @@ -91,8 +91,8 @@ func TestMarkKeyOnStatus(t *testing.T) { clk := quartz.NewMock(t) pool, err := keypool.New([]string{"key-0"}, clk) require.NoError(t, err) - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) resp := &http.Response{ StatusCode: tc.statusCode, diff --git a/aibridge/keypool/keypool.go b/aibridge/keypool/keypool.go index a2791f031deee..55d1712a935c8 100644 --- a/aibridge/keypool/keypool.go +++ b/aibridge/keypool/keypool.go @@ -10,6 +10,8 @@ import ( "github.com/coder/quartz" ) +// Configuration validation type errors. These surface when the +// pool is built from invalid input. var ( // ErrNoKeys is returned when the input is empty. ErrNoKeys = xerrors.New("no keys provided") @@ -18,20 +20,35 @@ var ( ErrDuplicateKey = xerrors.New("duplicate key") ) -// ErrPermanentKeyPool is returned when every key in the -// pool has been permanently marked unavailable. -var ErrPermanentKeyPool = xerrors.New("all keys permanently unavailable") +// ErrorKind classifies a runtime key-pool failure. +type ErrorKind int -// TransientKeyPoolError is returned when no key is currently -// available but at least one will recover. RetryAfter is the -// soonest remaining cooldown across the pool, or 0 if a key -// just became valid mid-walk. -type TransientKeyPoolError struct { +const ( + // ErrorKindRateLimited means no key is currently available + // but at least one key will recover after a cooldown. + ErrorKindRateLimited ErrorKind = iota + // ErrorKindPermanent means every key is permanently marked + // and no key can satisfy the request. + ErrorKindPermanent +) + +// Error is returned when no key is available for the +// current attempt. RetryAfter is the soonest remaining +// cooldown across the pool. +type Error struct { + Kind ErrorKind RetryAfter time.Duration } -func (e *TransientKeyPoolError) Error() string { - return fmt.Sprintf("all keys exhausted (retry after %s)", e.RetryAfter) +func (e *Error) Error() string { + switch e.Kind { + case ErrorKindPermanent: + return "all configured keys failed authentication" + case ErrorKindRateLimited: + return fmt.Sprintf("all configured keys are rate-limited (retry after %s)", e.RetryAfter) + default: + return "key pool error" + } } // KeyState represents the current state of a key in the pool. @@ -176,20 +193,21 @@ func (k *Key) MarkPermanent() bool { return true } -// keyPoolError returns ErrPermanentKeyPool if every key -// is permanently unavailable, or *TransientKeyPoolError if -// at least one key is temporarily unavailable. When multiple -// keys are temporary, the smallest remaining cooldown is used -// as the retry-after. -func (p *Pool) keyPoolError() error { +// keyPoolError returns an Error summarizing why no +// key is currently available. When at least one key is +// temporary, the smallest remaining cooldown is used as the +// retry-after. +func (p *Pool) keyPoolError() *Error { var retryAfter time.Duration var hasCooldown bool for i := range p.keys { state, cooldown := p.keys[i].stateAndCooldown() switch state { - // Recoverable now: signal transient with zero retry-after. + // Recoverable now: a key's cooldown expired between the walker's + // check and this scan. Return Retry-After: 0 to indicate that + // an immediate retry will succeed. case KeyStateValid: - return &TransientKeyPoolError{} + return &Error{Kind: ErrorKindRateLimited} // Recoverable later: track soonest remaining cooldown. case KeyStateTemporary: if !hasCooldown || cooldown < retryAfter { @@ -201,9 +219,9 @@ func (p *Pool) keyPoolError() error { } } if hasCooldown { - return &TransientKeyPoolError{RetryAfter: retryAfter} + return &Error{Kind: ErrorKindRateLimited, RetryAfter: retryAfter} } - return ErrPermanentKeyPool + return &Error{Kind: ErrorKindPermanent} } // PoolState returns a snapshot of each key's state in the pool's @@ -236,16 +254,10 @@ func (p *Pool) Walker() *Walker { // Next returns a Key handle for the next available key without // modifying the pool state. // -// Returns *TransientKeyPoolError or ErrPermanentKeyPool -// when no more keys are available. -func (w *Walker) Next() (*Key, error) { - pool := w.pool - if pool == nil { - return nil, ErrPermanentKeyPool - } - - for i := w.pos; i < len(pool.keys); i++ { - key := &pool.keys[i] +// Returns *Error when no more keys are available. +func (w *Walker) Next() (*Key, *Error) { + for i := w.pos; i < len(w.pool.keys); i++ { + key := &w.pool.keys[i] if key.State() != KeyStateValid { continue } @@ -255,5 +267,5 @@ func (w *Walker) Next() (*Key, error) { } // No keys available. - return nil, pool.keyPoolError() + return nil, w.pool.keyPoolError() } diff --git a/aibridge/keypool/keypool_test.go b/aibridge/keypool/keypool_test.go index 0dc4cbdc240e6..2029dafd688c2 100644 --- a/aibridge/keypool/keypool_test.go +++ b/aibridge/keypool/keypool_test.go @@ -1,7 +1,6 @@ package keypool_test import ( - "errors" "sync" "testing" "time" @@ -43,16 +42,15 @@ func TestNewKeyPool(t *testing.T) { // Verify all keys are returned in order and valid. walker := pool.Walker() for _, expected := range tc.expectedKeys { - key, err := walker.Next() - require.NoError(t, err) + key, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) assert.Equal(t, expected, key.Value()) assert.Equal(t, keypool.KeyStateValid, key.State()) } // No more keys available. - _, err = walker.Next() - var transient *keypool.TransientKeyPoolError - require.ErrorAs(t, err, &transient, "expected transient exhaustion: walker returned all valid keys, none marked permanent") + _, keyPoolErr := walker.Next() + require.Equal(t, &keypool.Error{Kind: keypool.ErrorKindRateLimited}, keyPoolErr, "expected rate-limited exhaustion: walker returned all valid keys, none marked permanent") }) } } @@ -69,8 +67,8 @@ func TestState(t *testing.T) { // Fresh key is valid. name: "fresh_key_is_valid", setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) return key }, expectedState: keypool.KeyStateValid, @@ -79,8 +77,8 @@ func TestState(t *testing.T) { // Active cooldown makes the key temporary. name: "active_cooldown_is_temporary", setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(60 * time.Second) return key }, @@ -90,8 +88,8 @@ func TestState(t *testing.T) { // Expired cooldown returns the key to valid. name: "expired_cooldown_is_valid", setup: func(t *testing.T, pool *keypool.Pool, clk *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(30 * time.Second) clk.Advance(35 * time.Second) return key @@ -102,8 +100,8 @@ func TestState(t *testing.T) { // Permanent key is permanent. name: "permanent_key", setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkPermanent() return key }, @@ -113,8 +111,8 @@ func TestState(t *testing.T) { // Permanent takes precedence over active cooldown. name: "permanent_with_cooldown_is_permanent", setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(60 * time.Second) key.MarkPermanent() return key @@ -152,8 +150,8 @@ func TestMarkTemporary(t *testing.T) { name: "valid_to_temporary", cooldown: 60 * time.Second, setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) return key }, expectedState: keypool.KeyStateTemporary, @@ -165,8 +163,8 @@ func TestMarkTemporary(t *testing.T) { name: "temporary_to_temporary_extends_cooldown", cooldown: 60 * time.Second, setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(10 * time.Second) return key }, @@ -179,8 +177,8 @@ func TestMarkTemporary(t *testing.T) { name: "temporary_to_temporary_keeps_longer_cooldown", cooldown: 10 * time.Second, setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(60 * time.Second) return key }, @@ -192,8 +190,8 @@ func TestMarkTemporary(t *testing.T) { name: "permanent_to_temporary_is_no_op", cooldown: 60 * time.Second, setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkPermanent() return key }, @@ -231,8 +229,8 @@ func TestMarkPermanent(t *testing.T) { // valid -> permanent: key becomes permanently unavailable. name: "valid_to_permanent", setup: func(t *testing.T, pool *keypool.Pool) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) return key }, expectedState: keypool.KeyStatePermanent, @@ -243,8 +241,8 @@ func TestMarkPermanent(t *testing.T) { // to auth failure. name: "temporary_to_permanent", setup: func(t *testing.T, pool *keypool.Pool) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(60 * time.Second) return key }, @@ -255,8 +253,8 @@ func TestMarkPermanent(t *testing.T) { // permanent -> permanent: no-op, already permanent. name: "permanent_to_permanent", setup: func(t *testing.T, pool *keypool.Pool) *keypool.Key { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkPermanent() return key }, @@ -290,7 +288,7 @@ func TestWalkerNext(t *testing.T) { setup func(t *testing.T, pool *keypool.Pool) advance time.Duration expectedValid []string - expectedErr error + expectedErr *keypool.Error }{ { // Given: key-0: valid, key-1: valid, key-2: valid. @@ -299,7 +297,7 @@ func TestWalkerNext(t *testing.T) { keys: []string{"key-0", "key-1", "key-2"}, setup: func(_ *testing.T, _ *keypool.Pool) {}, expectedValid: []string{"key-0", "key-1", "key-2"}, - expectedErr: &keypool.TransientKeyPoolError{}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, }, { // Given: key-0: temporary, key-1: valid, key-2: valid. @@ -307,12 +305,12 @@ func TestWalkerNext(t *testing.T) { name: "skips_temporary_keys", keys: []string{"key-0", "key-1", "key-2"}, setup: func(t *testing.T, pool *keypool.Pool) { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(60 * time.Second) }, expectedValid: []string{"key-1", "key-2"}, - expectedErr: &keypool.TransientKeyPoolError{}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, }, { // Given: key-0: permanent, key-1: permanent, key-2: valid. @@ -321,15 +319,15 @@ func TestWalkerNext(t *testing.T) { keys: []string{"key-0", "key-1", "key-2"}, setup: func(t *testing.T, pool *keypool.Pool) { walker := pool.Walker() - key0, err := walker.Next() - require.NoError(t, err) + key0, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key0.MarkPermanent() - key1, err := walker.Next() - require.NoError(t, err) + key1, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key1.MarkPermanent() }, expectedValid: []string{"key-2"}, - expectedErr: &keypool.TransientKeyPoolError{}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, }, { // Given: key-0: temporary (30s), key-1: valid. @@ -338,13 +336,13 @@ func TestWalkerNext(t *testing.T) { name: "expired_temporary_is_available", keys: []string{"key-0", "key-1"}, setup: func(t *testing.T, pool *keypool.Pool) { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(30 * time.Second) }, advance: 35 * time.Second, expectedValid: []string{"key-0", "key-1"}, - expectedErr: &keypool.TransientKeyPoolError{}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, }, { // Given: key-0: temporary (zero, default 60s), key-1: valid. @@ -353,13 +351,13 @@ func TestWalkerNext(t *testing.T) { name: "default_cooldown_not_expired", keys: []string{"key-0", "key-1"}, setup: func(t *testing.T, pool *keypool.Pool) { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(0) }, advance: 50 * time.Second, expectedValid: []string{"key-1"}, - expectedErr: &keypool.TransientKeyPoolError{}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, }, { // Given: key-0: temporary (zero, default 60s), key-1: valid. @@ -368,13 +366,13 @@ func TestWalkerNext(t *testing.T) { name: "default_cooldown_expired", keys: []string{"key-0", "key-1"}, setup: func(t *testing.T, pool *keypool.Pool) { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(0) }, advance: 65 * time.Second, expectedValid: []string{"key-0", "key-1"}, - expectedErr: &keypool.TransientKeyPoolError{}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, }, { // Given: key-0: temporary (negative, default 60s), key-1: valid. @@ -383,13 +381,13 @@ func TestWalkerNext(t *testing.T) { name: "negative_cooldown_uses_default", keys: []string{"key-0", "key-1"}, setup: func(t *testing.T, pool *keypool.Pool) { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(-10 * time.Second) }, advance: 65 * time.Second, expectedValid: []string{"key-0", "key-1"}, - expectedErr: &keypool.TransientKeyPoolError{}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, }, { // Given: key-0: temporary (60s), then marked again with shorter cooldown (10s). @@ -398,14 +396,14 @@ func TestWalkerNext(t *testing.T) { name: "shorter_cooldown_preserves_longer_not_expired", keys: []string{"key-0"}, setup: func(t *testing.T, pool *keypool.Pool) { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(60 * time.Second) key.MarkTemporary(10 * time.Second) }, advance: 15 * time.Second, expectedValid: []string{}, - expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 45 * time.Second}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 45 * time.Second}, }, { // Given: key-0: temporary (60s), then marked again with shorter cooldown (10s). @@ -414,14 +412,14 @@ func TestWalkerNext(t *testing.T) { name: "shorter_cooldown_preserves_longer_expired", keys: []string{"key-0"}, setup: func(t *testing.T, pool *keypool.Pool) { - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) key.MarkTemporary(60 * time.Second) key.MarkTemporary(10 * time.Second) }, advance: 65 * time.Second, expectedValid: []string{"key-0"}, - expectedErr: &keypool.TransientKeyPoolError{}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited}, }, { // Given: key-0: temporary (60s), key-1: temporary (10s), key-2: temporary (30s). @@ -431,18 +429,18 @@ func TestWalkerNext(t *testing.T) { keys: []string{"key-0", "key-1", "key-2"}, setup: func(t *testing.T, pool *keypool.Pool) { walker := pool.Walker() - key0, err := walker.Next() - require.NoError(t, err) + key0, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key0.MarkTemporary(60 * time.Second) - key1, err := walker.Next() - require.NoError(t, err) + key1, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key1.MarkTemporary(10 * time.Second) - key2, err := walker.Next() - require.NoError(t, err) + key2, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key2.MarkTemporary(30 * time.Second) }, expectedValid: []string{}, - expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 10 * time.Second}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 10 * time.Second}, }, { // Given: key-0: temporary, key-1: temporary. @@ -451,15 +449,15 @@ func TestWalkerNext(t *testing.T) { keys: []string{"key-0", "key-1"}, setup: func(t *testing.T, pool *keypool.Pool) { walker := pool.Walker() - key0, err := walker.Next() - require.NoError(t, err) + key0, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key0.MarkTemporary(60 * time.Second) - key1, err := walker.Next() - require.NoError(t, err) + key1, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key1.MarkTemporary(60 * time.Second) }, expectedValid: []string{}, - expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 60 * time.Second}, }, { // Given: key-0: permanent, key-1: permanent. @@ -468,15 +466,15 @@ func TestWalkerNext(t *testing.T) { keys: []string{"key-0", "key-1"}, setup: func(t *testing.T, pool *keypool.Pool) { walker := pool.Walker() - key0, err := walker.Next() - require.NoError(t, err) + key0, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key0.MarkPermanent() - key1, err := walker.Next() - require.NoError(t, err) + key1, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key1.MarkPermanent() }, expectedValid: []string{}, - expectedErr: keypool.ErrPermanentKeyPool, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindPermanent}, }, { // Given: key-0: permanent, key-1: temporary, key-2: permanent. @@ -485,18 +483,18 @@ func TestWalkerNext(t *testing.T) { keys: []string{"key-0", "key-1", "key-2"}, setup: func(t *testing.T, pool *keypool.Pool) { walker := pool.Walker() - key0, err := walker.Next() - require.NoError(t, err) + key0, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key0.MarkPermanent() - key1, err := walker.Next() - require.NoError(t, err) + key1, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key1.MarkTemporary(60 * time.Second) - key2, err := walker.Next() - require.NoError(t, err) + key2, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) key2.MarkPermanent() }, expectedValid: []string{}, - expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second}, + expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 60 * time.Second}, }, } @@ -516,21 +514,14 @@ func TestWalkerNext(t *testing.T) { walker := pool.Walker() for _, expectedKey := range tc.expectedValid { - key, err := walker.Next() - require.NoError(t, err) + key, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) assert.Equal(t, expectedKey, key.Value()) } // After all expected keys, the walker should be exhausted. - _, err = walker.Next() - var wantTransient *keypool.TransientKeyPoolError - if errors.As(tc.expectedErr, &wantTransient) { - var got *keypool.TransientKeyPoolError - require.ErrorAs(t, err, &got) - assert.Equal(t, wantTransient.RetryAfter, got.RetryAfter) - } else { - require.ErrorIs(t, err, tc.expectedErr) - } + _, keyPoolErr := walker.Next() + require.Equal(t, tc.expectedErr, keyPoolErr) }) } } @@ -595,8 +586,8 @@ func TestKeyConcurrent(t *testing.T) { clk := quartz.NewMock(t) pool, err := keypool.New([]string{"key-0"}, clk) require.NoError(t, err) - key, err := pool.Walker().Next() - require.NoError(t, err) + key, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) const numGoroutines = 10 var wg sync.WaitGroup @@ -628,29 +619,29 @@ func TestWalkerIndependence(t *testing.T) { walker := pool.Walker() // First attempt: get key-0. - key, err := walker.Next() - require.NoError(t, err) + key, keyPoolErr := walker.Next() + require.Nil(t, keyPoolErr) assert.Equal(t, "key-0", key.Value()) // Simulate 429: mark key-0 temporary. key.MarkTemporary(60 * time.Second) // Second attempt: walker advances to key-1. - key, err = walker.Next() - require.NoError(t, err) + key, keyPoolErr = walker.Next() + require.Nil(t, keyPoolErr) assert.Equal(t, "key-1", key.Value()) // Simulate 401: mark key-1 permanent. key.MarkPermanent() // Third attempt: walker advances to key-2. - key, err = walker.Next() - require.NoError(t, err) + key, keyPoolErr = walker.Next() + require.Nil(t, keyPoolErr) assert.Equal(t, "key-2", key.Value()) // A new walker should skip key-0 (temporary) and key-1 // (permanent), and return key-2. - key2, err := pool.Walker().Next() - require.NoError(t, err) + key2, keyPoolErr := pool.Walker().Next() + require.Nil(t, keyPoolErr) assert.Equal(t, "key-2", key2.Value()) } diff --git a/aibridge/provider/anthropic.go b/aibridge/provider/anthropic.go index 64128cee5fab9..5fdbebae163dc 100644 --- a/aibridge/provider/anthropic.go +++ b/aibridge/provider/anthropic.go @@ -1,7 +1,6 @@ package provider import ( - "context" "fmt" "io" "net/http" @@ -173,8 +172,8 @@ func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tr // Centralized: use the first key as a placeholder hint. // TODO(ssncferreira): record the actually-used key in // the interception record to reflect failover. - if k, err := cfg.KeyPool.Walker().Next(); err == nil { - credSecret = k.Value() + if key, keyPoolErr := cfg.KeyPool.Walker().Next(); keyPoolErr == nil { + credSecret = key.Value() } } @@ -222,20 +221,18 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) { } func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig { - name := p.Name() return keypool.KeyFailoverConfig{ - Pool: p.cfg.KeyPool, + Pool: p.cfg.KeyPool, + ProviderName: p.Name(), + Logger: logger, IsBYOK: func(r *http.Request) bool { return r.Header.Get("X-Api-Key") != "" || r.Header.Get("Authorization") != "" }, InjectAuthKey: func(h *http.Header, key string) { h.Set("X-Api-Key", key) }, - MarkKey: func(ctx context.Context, key *keypool.Key, resp *http.Response) bool { - return keypool.MarkKeyOnStatus(ctx, key, resp, logger, name) - }, - BuildExhaustedResponse: func(err error) *http.Response { - return messages.ProcessKeyPoolError(err).ToResponse() + BuildKeyPoolResponse: func(keyPoolErr *keypool.Error) *http.Response { + return messages.ResponseErrorFromKeyPool(keyPoolErr).ToResponse() }, } } diff --git a/aibridge/provider/openai.go b/aibridge/provider/openai.go index 894dda7194106..be53e612d1c3a 100644 --- a/aibridge/provider/openai.go +++ b/aibridge/provider/openai.go @@ -1,7 +1,6 @@ package provider import ( - "context" "encoding/json" "fmt" "io" @@ -146,8 +145,8 @@ func (p *OpenAI) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trace // Centralized: use the first key as a placeholder hint. // TODO(ssncferreira): record the actually-used key in // the interception record to reflect failover. - if k, err := cfg.KeyPool.Walker().Next(); err == nil { - credSecret = k.Value() + if key, keyPoolErr := cfg.KeyPool.Walker().Next(); keyPoolErr == nil { + credSecret = key.Value() } } cred := intercept.NewCredentialInfo(credKind, credSecret) @@ -221,20 +220,18 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) { } func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig { - name := p.Name() return keypool.KeyFailoverConfig{ - Pool: p.cfg.KeyPool, + Pool: p.cfg.KeyPool, + ProviderName: p.Name(), + Logger: logger, IsBYOK: func(r *http.Request) bool { return r.Header.Get("Authorization") != "" }, InjectAuthKey: func(h *http.Header, key string) { h.Set("Authorization", "Bearer "+key) }, - MarkKey: func(ctx context.Context, key *keypool.Key, resp *http.Response) bool { - return keypool.MarkKeyOnStatus(ctx, key, resp, logger, name) - }, - BuildExhaustedResponse: func(err error) *http.Response { - return chatcompletions.ProcessKeyPoolError(err).ToResponse() + BuildKeyPoolResponse: func(keyPoolErr *keypool.Error) *http.Response { + return intercept.ResponseErrorFromKeyPool(keyPoolErr).ToResponse() }, } }