From e126bf39cc56bce6d26311f5a17219927bf65fe1 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Wed, 3 Jun 2026 18:52:15 +0000 Subject: [PATCH] refactor(aibridge): consolidate key failover interceptor tests --- .../chatcompletions/blocking_internal_test.go | 483 --------------- .../streaming_internal_test.go | 512 ---------------- aibridge/intercept/keyfailover_test.go | 494 +++++++++++++++ .../messages/blocking_internal_test.go | 479 --------------- .../messages/streaming_internal_test.go | 570 ------------------ aibridge/intercept/responses/base.go | 5 + .../responses/blocking_internal_test.go | 473 --------------- .../responses/streaming_internal_test.go | 520 ---------------- .../internal/testutil/mockserverproxier.go | 19 +- aibridge/internal/testutil/mockupstream.go | 18 + 10 files changed, 535 insertions(+), 3038 deletions(-) delete mode 100644 aibridge/intercept/chatcompletions/blocking_internal_test.go create mode 100644 aibridge/intercept/keyfailover_test.go delete mode 100644 aibridge/intercept/messages/blocking_internal_test.go delete mode 100644 aibridge/intercept/messages/streaming_internal_test.go delete mode 100644 aibridge/intercept/responses/blocking_internal_test.go delete mode 100644 aibridge/intercept/responses/streaming_internal_test.go diff --git a/aibridge/intercept/chatcompletions/blocking_internal_test.go b/aibridge/intercept/chatcompletions/blocking_internal_test.go deleted file mode 100644 index 088fa4177dce6..0000000000000 --- a/aibridge/intercept/chatcompletions/blocking_internal_test.go +++ /dev/null @@ -1,483 +0,0 @@ -package chatcompletions - -import ( - "io" - "net/http" - "net/http/httptest" - "sync" - "sync/atomic" - "testing" - - "github.com/google/uuid" - "github.com/openai/openai-go/v3" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "cdr.dev/slog/v3" - "github.com/coder/coder/v2/aibridge/config" - "github.com/coder/coder/v2/aibridge/intercept" - "github.com/coder/coder/v2/aibridge/internal/testutil" - "github.com/coder/coder/v2/aibridge/keypool" - "github.com/coder/coder/v2/aibridge/mcp" - "github.com/coder/coder/v2/aibridge/utils" - "github.com/coder/quartz" -) - -// OpenAI-shaped response bodies. -const ( - successBody = `{"id":"chatcmpl-01","object":"chat.completion","created":1234567890,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"Hello!"},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}` - toolUseBody = `{"id":"chatcmpl-01","object":"chat.completion","created":1234567890,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":null,"tool_calls":[{"id":"call_01","type":"function","function":{"name":"test_tool","arguments":"{}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}}` - textCompleteBody = `{"id":"chatcmpl-02","object":"chat.completion","created":1234567890,"model":"gpt-4","choices":[{"index":0,"message":{"role":"assistant","content":"done"},"finish_reason":"stop"}],"usage":{"prompt_tokens":15,"completion_tokens":3,"total_tokens":18}}` - rateLimitBody = `{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","code":"rate_limit_exceeded"}}` - authErrorBody = `{"error":{"message":"Invalid API key","type":"invalid_request_error","code":"invalid_api_key"}}` - serverErrorBody = `{"error":{"message":"Internal server error","type":"server_error","code":"internal_error"}}` -) - -type upstreamResponse struct { - statusCode int - body string - headers map[string]string -} - -// newRequestParams builds a minimal chat-completions request -// for tests. -func newRequestParams(stream bool) *ChatCompletionNewParamsWrapper { - return &ChatCompletionNewParamsWrapper{ - ChatCompletionNewParams: openai.ChatCompletionNewParams{ - Model: "gpt-4", - Messages: []openai.ChatCompletionMessageParamUnion{ - openai.UserMessage("hi"), - }, - }, - Stream: stream, - } -} - -func TestBlockingInterception_KeyFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Centralized pool keys. Empty when byokKey is set. - keys []string - // BYOK key. Empty when keys is set. - byokKey string - // Scripted upstream responses keyed by bearer token. - responses map[string]upstreamResponse - expectedRequestCount int32 - expectedStatusCode int - expectedRetryAfter string - // Expected key states after the request, by index in keys. - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: last - // attempted key for centralized, user key from initial request for BYOK. - expectedCredentialHint string - }{ - { - // Given: 1 valid key returning 200. - // Then: 1 request, 200 response, key remains valid. - name: "single_valid_key", - keys: []string{"k0-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 returns 429, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. - name: "failover_after_429", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 401, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_401", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 403, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_403", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s. - // Then: 3 requests, 429 response with smallest Retry-After, - // all keys temporary. - name: "all_keys_rate_limited", - keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - "k2-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "10"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "3", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k2-long-key"), - }, - { - // Given: 2 keys; both return 401. - // Then: 2 requests, 502 api_error response, both keys permanent. - name: "all_keys_unauthorized", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusBadGateway, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStatePermanent, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 500. - // Then: 1 request, 500 response, both keys remain valid. - name: "server_error_no_failover", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusInternalServerError, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: BYOK with a single key returning 429. - // Then: 1 request, 429 response, no failover, upstream - // Retry-After propagated to the client. - name: "byok_no_failover", - byokKey: "user-byok", - responses: map[string]upstreamResponse{ - "user-byok": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{ - "Retry-After": "5", - // BYOK doesn't set MaxRetries(0); - // suppress SDK retries to test a - // single attempt. - "x-should-retry": "false", - }, - body: rateLimitBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "5", - expectedCredentialHint: utils.MaskSecret("user-byok"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Mock upstream: counts requests and returns - // scripted responses keyed by bearer token. An - // unmapped key falls through to 500 so misconfigured - // cases surface via the status assertion. - var requestCount atomic.Int32 - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount.Add(1) - _, _ = io.Copy(io.Discard, r.Body) - resp, ok := tc.responses[utils.ExtractBearerToken(r.Header.Get("Authorization"))] - if !ok { - resp = upstreamResponse{statusCode: http.StatusInternalServerError} - } - w.Header().Set("Content-Type", "application/json") - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - cfg := config.OpenAI{BaseURL: upstream.URL + "/"} - var pool *keypool.Pool - credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") - if len(tc.keys) > 0 { - var err error - pool, err = keypool.New(tc.keys, quartz.NewMock(t)) - require.NoError(t, err) - cfg.KeyPool = pool - } else if tc.byokKey != "" { - cfg.Key = tc.byokKey - credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) - } - - interceptor := NewBlockingInterceptor( - uuid.New(), - newRequestParams(false), - config.ProviderOpenAI, - cfg, - http.Header{}, - "Authorization", - otel.Tracer("blocking_test"), - credInfo, - ) - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) - - req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) - w := httptest.NewRecorder() - err := interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") - if pool != nil { - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - } - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - }) - } -} - -// TestBlockingInterception_AgenticLoopFailover covers the -// scenarios that span an agentic-loop continuation: the initial -// client request and the subsequent tool-call continuation can -// each fail over independently. Each iteration gets its own -// walker. -func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Scripted upstream responses consumed in order of - // upstream request. - responses []upstreamResponse - expectedRequestCount int32 - expectedSeenKeys []string - expectedStatusCode int - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: hint of the - // last attempted key across all agentic-loop iterations. - expectedCredentialHint string - }{ - { - // Given: 2 keys; both upstream calls succeed on key-0. - // Then: 2 requests, 200 response, both keys remain valid. - name: "happy_path", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - {statusCode: http.StatusOK, body: textCompleteBody}, - }, - expectedRequestCount: 2, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then 429s - // during the agentic continuation, key-1 succeeds. - // Then: 3 requests, 200 response, key-0 temporary, - // key-1 valid. - name: "agentic_failover_to_k1", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - {statusCode: http.StatusOK, body: textCompleteBody}, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then both - // keys 429 during the agentic continuation. - // Then: 3 requests, 429 response with smallest - // Retry-After, both keys temporary. - name: "agentic_all_keys_fail", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedStatusCode: http.StatusTooManyRequests, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - var seenKeysMu sync.Mutex - var seenKeys []string - - // Mock upstream: returns scripted responses in order, - // records each request's bearer token for assertions. - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - idx := int(requestCount.Add(1)) - 1 - seenKeysMu.Lock() - seenKeys = append(seenKeys, utils.ExtractBearerToken(r.Header.Get("Authorization"))) - seenKeysMu.Unlock() - _, _ = io.Copy(io.Discard, r.Body) - - if idx >= len(tc.responses) { - w.WriteHeader(http.StatusInternalServerError) - return - } - resp := tc.responses[idx] - w.Header().Set("Content-Type", "application/json") - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) - require.NoError(t, err) - - cfg := config.OpenAI{ - BaseURL: upstream.URL + "/", - KeyPool: pool, - } - - interceptor := NewBlockingInterceptor( - uuid.New(), - newRequestParams(false), - config.ProviderOpenAI, - cfg, - http.Header{}, - "Authorization", - otel.Tracer("blocking_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), - ) - - // Mock proxy with a tool the upstream's tool_use - // response will reference. - proxy := &testutil.MockServerProxier{ - Tools: []*mcp.Tool{ - { - Client: testutil.StubToolCaller{}, - ID: "test_tool", - Name: "test_tool", - ServerName: "coder", - Logger: slog.Make(), - }, - }, - } - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy) - - req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - - seenKeysMu.Lock() - defer seenKeysMu.Unlock() - assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - }) - } -} diff --git a/aibridge/intercept/chatcompletions/streaming_internal_test.go b/aibridge/intercept/chatcompletions/streaming_internal_test.go index b836e7b7ccb4d..2feea1a709e96 100644 --- a/aibridge/intercept/chatcompletions/streaming_internal_test.go +++ b/aibridge/intercept/chatcompletions/streaming_internal_test.go @@ -1,12 +1,8 @@ package chatcompletions import ( - "io" "net/http" "net/http/httptest" - "strings" - "sync" - "sync/atomic" "testing" "github.com/google/uuid" @@ -20,10 +16,6 @@ import ( "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/internal/testutil" - "github.com/coder/coder/v2/aibridge/keypool" - "github.com/coder/coder/v2/aibridge/mcp" - "github.com/coder/coder/v2/aibridge/utils" - "github.com/coder/quartz" ) // Test that when the upstream provider returns an error before streaming starts, @@ -116,507 +108,3 @@ func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) { }) } } - -// OpenAI-shaped SSE body for a successful streaming response. -const streamingSuccessBody = `data: {"id":"chatcmpl-01","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"Hello"},"finish_reason":null}]} - -data: {"id":"chatcmpl-01","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"content":"!"},"finish_reason":null}]} - -data: {"id":"chatcmpl-01","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}} - -data: [DONE] - -` - -func TestStreamingInterception_KeyFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Centralized pool keys. Empty when byokKey is set. - keys []string - // BYOK key. Empty when keys is set. - byokKey string - // Scripted upstream responses keyed by bearer token. - responses map[string]upstreamResponse - expectedRequestCount int32 - expectedStatusCode int - expectedRetryAfter string - // Expected key states after the request, by index in keys. - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: last - // attempted key for centralized, user key from initial request for BYOK. - expectedCredentialHint string - }{ - { - // Given: 1 valid key returning a successful stream. - // Then: 1 request, 200 response, key remains valid. - name: "single_valid_key", - keys: []string{"k0-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 returns 429 pre-stream, key-1 - // streams successfully. - // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. - name: "failover_after_429", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 401 pre-stream, key-1 - // streams successfully. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_401", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_403", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 3 keys; all return 429 pre-stream with - // cooldowns 5s, 3s, 10s. - // Then: 3 requests, 429 response with smallest - // Retry-After, all keys temporary. - name: "all_keys_rate_limited", - keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - "k2-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "10"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "3", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k2-long-key"), - }, - { - // Given: 2 keys; both return 401 pre-stream. - // Then: 2 requests, 502 api_error response, both keys permanent. - name: "all_keys_unauthorized", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusBadGateway, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStatePermanent, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 500 pre-stream. - // Then: 1 request, 500 response, both keys remain valid. - name: "server_error_no_failover", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusInternalServerError, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: BYOK with a single key returning 429. - // Then: 1 request, 429 response, no failover, upstream - // Retry-After propagated to the client. - name: "byok_no_failover", - byokKey: "user-byok", - responses: map[string]upstreamResponse{ - "user-byok": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{ - "Retry-After": "5", - // BYOK doesn't set MaxRetries(0); - // suppress SDK retries to test a - // single attempt. - "x-should-retry": "false", - }, - body: rateLimitBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "5", - expectedCredentialHint: utils.MaskSecret("user-byok"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Mock upstream: counts requests and returns - // scripted responses keyed by bearer token. An - // unmapped key falls through to 500 so misconfigured - // cases surface via the status assertion. - var requestCount atomic.Int32 - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount.Add(1) - _, _ = io.Copy(io.Discard, r.Body) - resp, ok := tc.responses[utils.ExtractBearerToken(r.Header.Get("Authorization"))] - if !ok { - resp = upstreamResponse{statusCode: http.StatusInternalServerError} - } - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - cfg := config.OpenAI{BaseURL: upstream.URL + "/"} - var pool *keypool.Pool - credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") - if len(tc.keys) > 0 { - var err error - pool, err = keypool.New(tc.keys, quartz.NewMock(t)) - require.NoError(t, err) - cfg.KeyPool = pool - } else if tc.byokKey != "" { - cfg.Key = tc.byokKey - credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) - } - - interceptor := NewStreamingInterceptor( - uuid.New(), - newRequestParams(true), - config.ProviderOpenAI, - cfg, - http.Header{}, - "Authorization", - otel.Tracer("streaming_test"), - credInfo, - ) - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) - - req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) - w := httptest.NewRecorder() - err := interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") - if pool != nil { - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - } - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - }) - } -} - -// SSE bodies covering an agentic-continuation flow. -const ( - // First response: a tool_calls delta referencing the - // injected "test_tool". Triggers the agentic continuation - // loop. - toolUseStreamBody = `data: {"id":"chatcmpl-01","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_01","type":"function","function":{"name":"test_tool","arguments":""}}]},"finish_reason":null}]} - -data: {"id":"chatcmpl-01","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{}"}}]},"finish_reason":null}]} - -data: {"id":"chatcmpl-01","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":10,"completion_tokens":5,"total_tokens":15}} - -data: [DONE] - -` - - // Second response (after the tool result is sent back): - // a plain text completion that ends the loop. - textStreamBody = `data: {"id":"chatcmpl-02","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{"role":"assistant","content":"done"},"finish_reason":null}]} - -data: {"id":"chatcmpl-02","object":"chat.completion.chunk","created":1234567890,"model":"gpt-4","choices":[{"index":0,"delta":{},"finish_reason":"stop"}],"usage":{"prompt_tokens":15,"completion_tokens":3,"total_tokens":18}} - -data: [DONE] - -` -) - -// TestStreamingInterception_AgenticLoopFailover covers the -// scenarios that span an agentic-loop continuation: the initial -// client request and the subsequent tool-call continuation can -// each fail over independently. Each iteration gets its own -// walker. -func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { - t.Parallel() - - sseHeaders := map[string]string{"Content-Type": "text/event-stream"} - - tests := []struct { - name string - // Scripted upstream responses consumed in order of - // upstream request. - responses []upstreamResponse - expectedRequestCount int32 - expectedSeenKeys []string - // Substring expected in the response body. Either a - // success marker (e.g. "done") or an error marker - // (e.g. "rate_limit_error"). - expectedBodyContains string - // True when the error must be relayed as an SSE event. - expectErrorAsSSEEvent bool - // True when ProcessRequest is expected to return an - // error (e.g. all keys exhausted). - expectedErr bool - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: hint of the - // last attempted key across all agentic-loop iterations. - expectedCredentialHint string - }{ - { - // Given: 2 keys; both upstream calls succeed on key-0. - // Then: 2 requests, success body, both keys remain valid. - name: "happy_path", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, - }, - expectedRequestCount: 2, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, - expectedBodyContains: "done", - expectErrorAsSSEEvent: false, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then 429s - // during the agentic continuation, key-1 succeeds. - // Then: 3 requests, success body, key-0 temporary, - // key-1 valid. - name: "agentic_failover_to_k1", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedBodyContains: "done", - expectErrorAsSSEEvent: false, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then both - // keys 429 during the agentic continuation. - // Then: 3 requests, error injected as SSE event, both - // keys temporary. - name: "agentic_all_keys_fail", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedBodyContains: "all configured keys are rate-limited", - expectErrorAsSSEEvent: true, - expectedErr: true, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - var seenKeysMu sync.Mutex - var seenKeys []string - - // Mock upstream: returns scripted responses in order, - // records each request's bearer token for assertions. - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - idx := int(requestCount.Add(1)) - 1 - seenKeysMu.Lock() - seenKeys = append(seenKeys, utils.ExtractBearerToken(r.Header.Get("Authorization"))) - seenKeysMu.Unlock() - _, _ = io.Copy(io.Discard, r.Body) - - if idx >= len(tc.responses) { - w.WriteHeader(http.StatusInternalServerError) - return - } - resp := tc.responses[idx] - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) - require.NoError(t, err) - - cfg := config.OpenAI{ - BaseURL: upstream.URL + "/", - KeyPool: pool, - } - - interceptor := NewStreamingInterceptor( - uuid.New(), - newRequestParams(true), - config.ProviderOpenAI, - cfg, - http.Header{}, - "Authorization", - otel.Tracer("streaming_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), - ) - - // Mock proxy with a tool the upstream's tool_calls - // chunks will reference. The stub caller returns a - // fixed text result. - proxy := &testutil.MockServerProxier{ - Tools: []*mcp.Tool{ - { - Client: testutil.StubToolCaller{}, - ID: "test_tool", - Name: "test_tool", - ServerName: "coder", - Logger: slog.Make(), - }, - }, - } - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy) - - req := httptest.NewRequest(http.MethodPost, "/v1/chat/completions", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - body := w.Body.String() - assert.Contains(t, body, tc.expectedBodyContains, "response body") - if tc.expectErrorAsSSEEvent { - // SSE was opened before the failure, so the body - // must start with stream chunks, not a direct - // HTTP error body. - assert.True(t, strings.HasPrefix(body, "data: "), "body must start with SSE chunks") - } - - seenKeysMu.Lock() - defer seenKeysMu.Unlock() - assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - }) - } -} diff --git a/aibridge/intercept/keyfailover_test.go b/aibridge/intercept/keyfailover_test.go new file mode 100644 index 0000000000000..52ada03fb11b0 --- /dev/null +++ b/aibridge/intercept/keyfailover_test.go @@ -0,0 +1,494 @@ +package intercept_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/tidwall/sjson" + "go.opentelemetry.io/otel" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/fixtures" + "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/intercept/chatcompletions" + "github.com/coder/coder/v2/aibridge/intercept/messages" + "github.com/coder/coder/v2/aibridge/intercept/responses" + "github.com/coder/coder/v2/aibridge/internal/testutil" + "github.com/coder/coder/v2/aibridge/keypool" + "github.com/coder/coder/v2/aibridge/utils" + "github.com/coder/quartz" +) + +// interceptorCase parameterizes the failover tests over the interceptors. It +// captures the per-API differences (request shape, auth header, and route) so a +// single set of scenarios runs against every one. +type interceptorCase struct { + // name labels the subtest. + name string + // path is the route the interceptor handles. + path string + // authHeader is the header the upstream key is carried in. It is also used + // to read the key back off a recorded upstream request. + authHeader string + // fixture returns the txtar fixture for the given mode. When agentic is true + // it returns the injected-tool fixture, whose first response calls a tool and + // whose second is the final answer, otherwise the simple success fixture. + fixture func(streaming, agentic bool) []byte + // agenticStreamErrorEvent is the SSE marker a mid-loop pool exhaustion + // produces once the agentic stream has started. It is empty for responses, + // which buffers agentic events and writes the error status directly instead, + // like the blocking path. + agenticStreamErrorEvent string + // streamDoneEvent is the terminal SSE event a completed streaming response + // emits. A successful agentic continuation streams the final response, so its + // presence confirms that response reached the client. + streamDoneEvent string + // newInterceptor builds an interceptor pointed at upstreamURL. pool is the + // centralized key pool, or nil for BYOK, in which case byokKey is the + // user-supplied key. + newInterceptor func(t *testing.T, streaming bool, upstreamURL string, reqBody []byte, pool *keypool.Pool, byokKey string) intercept.Interceptor +} + +// keyFromHeader reads the API key an upstream request carried in the named auth +// header. +func keyFromHeader(name string, h http.Header) string { + if name == "Authorization" { + return utils.ExtractBearerToken(h.Get(name)) + } + return h.Get(name) +} + +// interceptorCases is the set of interceptors the failover tests run against, +// one entry per supported API. +var interceptorCases = []interceptorCase{ + { + name: "messages", + path: "/v1/messages", + authHeader: "X-Api-Key", + fixture: func(_, agentic bool) []byte { + if agentic { + return fixtures.AntSingleInjectedTool + } + return fixtures.AntSimple + }, + agenticStreamErrorEvent: "event: error", + streamDoneEvent: "event: message_stop", + newInterceptor: func(t *testing.T, streaming bool, upstreamURL string, reqBody []byte, pool *keypool.Pool, byokKey string) intercept.Interceptor { + cfg := config.Anthropic{BaseURL: upstreamURL + "/"} + cred := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") + if pool != nil { + cfg.KeyPool = pool + } else if byokKey != "" { + cfg.Key = byokKey + cred = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, byokKey) + } + + payload, err := messages.NewRequestPayload(reqBody) + require.NoError(t, err) + + id, tracer := uuid.New(), otel.Tracer("keyfailover") + if streaming { + return messages.NewStreamingInterceptor(id, payload, config.ProviderAnthropic, cfg, nil, http.Header{}, "X-Api-Key", tracer, cred) + } + return messages.NewBlockingInterceptor(id, payload, config.ProviderAnthropic, cfg, nil, http.Header{}, "X-Api-Key", tracer, cred) + }, + }, + { + name: "chatcompletions", + path: "/v1/chat/completions", + authHeader: "Authorization", + fixture: func(_, agentic bool) []byte { + if agentic { + return fixtures.OaiChatSingleInjectedTool + } + return fixtures.OaiChatSimple + }, + agenticStreamErrorEvent: `data: {"error"`, + streamDoneEvent: "data: [DONE]", + newInterceptor: func(t *testing.T, streaming bool, upstreamURL string, reqBody []byte, pool *keypool.Pool, byokKey string) intercept.Interceptor { + cfg := config.OpenAI{BaseURL: upstreamURL + "/"} + cred := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") + if pool != nil { + cfg.KeyPool = pool + } else if byokKey != "" { + cfg.Key = byokKey + cred = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, byokKey) + } + + var req chatcompletions.ChatCompletionNewParamsWrapper + require.NoError(t, json.Unmarshal(reqBody, &req)) + + id, tracer := uuid.New(), otel.Tracer("keyfailover") + if streaming { + return chatcompletions.NewStreamingInterceptor(id, &req, config.ProviderOpenAI, cfg, http.Header{}, "Authorization", tracer, cred) + } + return chatcompletions.NewBlockingInterceptor(id, &req, config.ProviderOpenAI, cfg, http.Header{}, "Authorization", tracer, cred) + }, + }, + { + name: "responses", + path: "/v1/responses", + authHeader: "Authorization", + fixture: func(streaming, agentic bool) []byte { + switch { + case streaming && agentic: + return fixtures.OaiResponsesStreamingSingleInjectedTool + case streaming: + return fixtures.OaiResponsesStreamingSimple + case agentic: + return fixtures.OaiResponsesBlockingSingleInjectedTool + default: + return fixtures.OaiResponsesBlockingSimple + } + }, + streamDoneEvent: "event: response.completed", + newInterceptor: func(t *testing.T, streaming bool, upstreamURL string, reqBody []byte, pool *keypool.Pool, byokKey string) intercept.Interceptor { + cfg := config.OpenAI{BaseURL: upstreamURL + "/"} + cred := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") + if pool != nil { + cfg.KeyPool = pool + } else if byokKey != "" { + cfg.Key = byokKey + cred = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, byokKey) + } + + payload, err := responses.NewRequestPayload(reqBody) + require.NoError(t, err) + + id, tracer := uuid.New(), otel.Tracer("keyfailover") + if streaming { + return responses.NewStreamingInterceptor(id, payload, config.ProviderOpenAI, cfg, http.Header{}, "Authorization", tracer, cred) + } + return responses.NewBlockingInterceptor(id, payload, config.ProviderOpenAI, cfg, http.Header{}, "Authorization", tracer, cred) + }, + }, +} + +// TestInterception_KeyFailover verifies that, within a single interception, the +// centralized key pool fails over across keys (temporary on 429, permanent on +// 401/403) and reports exhaustion, for every interceptor in both blocking and +// streaming mode. +func TestInterception_KeyFailover(t *testing.T) { + t.Parallel() + + const ( + k0, k1, k2 = "k0-long-key", "k1-long-key", "k2-long-key" + byokKey = "user-byok-key" + ) + errResp := testutil.NewErrorResponse + + tests := []struct { + name string + keys []string + byokKey string + // responses builds the upstream responses in call order. success is the + // interceptor's fixture success response, so each case only specifies + // the error responses that drive failover. + responses func(success testutil.UpstreamResponse) []testutil.UpstreamResponse + expectedStatus int + expectedRetryAfter string + expectedKeyStates []keypool.KeyState + expectedSeenKeys []string + expectedBodyContains string + }{ + { + // One valid key succeeds on the first attempt. + name: "single_valid_key", + keys: []string{k0}, + responses: func(s testutil.UpstreamResponse) []testutil.UpstreamResponse { return []testutil.UpstreamResponse{s} }, + expectedStatus: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, + expectedSeenKeys: []string{k0}, + }, + { + // A 429 marks the key temporary and fails over to the next one. + name: "failover_after_429", + keys: []string{k0, k1}, + responses: func(s testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{errResp(http.StatusTooManyRequests, "5"), s} + }, + expectedStatus: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateTemporary, keypool.KeyStateValid}, + expectedSeenKeys: []string{k0, k1}, + }, + { + // A 401 marks the key permanent and fails over to the next one. + name: "failover_after_401", + keys: []string{k0, k1}, + responses: func(s testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{errResp(http.StatusUnauthorized, ""), s} + }, + expectedStatus: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStatePermanent, keypool.KeyStateValid}, + expectedSeenKeys: []string{k0, k1}, + }, + { + // A 403 marks the key permanent and fails over to the next one. + name: "failover_after_403", + keys: []string{k0, k1}, + responses: func(s testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{errResp(http.StatusForbidden, ""), s} + }, + expectedStatus: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStatePermanent, keypool.KeyStateValid}, + expectedSeenKeys: []string{k0, k1}, + }, + { + // Every key is rate-limited, so the pool is exhausted and the + // smallest remaining cooldown is reported. + name: "all_keys_rate_limited", + keys: []string{k0, k1, k2}, + responses: func(testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{ + errResp(http.StatusTooManyRequests, "5"), + errResp(http.StatusTooManyRequests, "3"), + errResp(http.StatusTooManyRequests, "10"), + } + }, + expectedStatus: http.StatusTooManyRequests, + expectedRetryAfter: "3", + expectedBodyContains: "all configured keys are rate-limited", + expectedKeyStates: []keypool.KeyState{ + keypool.KeyStateTemporary, + keypool.KeyStateTemporary, + keypool.KeyStateTemporary, + }, + expectedSeenKeys: []string{k0, k1, k2}, + }, + { + // Every key is unauthorized, so the pool is permanently exhausted. + name: "all_keys_unauthorized", + keys: []string{k0, k1}, + responses: func(testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{ + errResp(http.StatusUnauthorized, ""), + errResp(http.StatusUnauthorized, ""), + } + }, + expectedStatus: http.StatusBadGateway, + expectedKeyStates: []keypool.KeyState{keypool.KeyStatePermanent, keypool.KeyStatePermanent}, + expectedSeenKeys: []string{k0, k1}, + }, + { + // A 500 is not a key-specific failure, so it does not fail over. + name: "server_error_no_failover", + keys: []string{k0, k1}, + responses: func(testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{errResp(http.StatusInternalServerError, "")} + }, + expectedStatus: http.StatusInternalServerError, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid, keypool.KeyStateValid}, + expectedSeenKeys: []string{k0}, + }, + { + // BYOK requests carry a user key and never fail over. + name: "byok_no_failover", + byokKey: byokKey, + responses: func(testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{errResp(http.StatusTooManyRequests, "5")} + }, + expectedStatus: http.StatusTooManyRequests, + expectedRetryAfter: "5", + expectedSeenKeys: []string{byokKey}, + }, + } + + for _, ic := range interceptorCases { + for _, mode := range []string{"blocking", "streaming"} { + streaming := mode == "streaming" + for _, tc := range tests { + t.Run(ic.name+"/"+mode+"/"+tc.name, func(t *testing.T) { + t.Parallel() + + var pool *keypool.Pool + if len(tc.keys) > 0 { + var err error + pool, err = keypool.New(tc.keys, quartz.NewMock(t)) + require.NoError(t, err) + } + + fixture := fixtures.Parse(t, ic.fixture(streaming, false)) + reqBody := fixture.Request() + if streaming { + var err error + reqBody, err = sjson.SetBytes(reqBody, "stream", true) + require.NoError(t, err) + } + upstream := testutil.NewMockUpstream(t.Context(), t, tc.responses(testutil.NewFixtureResponse(fixture))...) + + interceptor := ic.newInterceptor(t, streaming, upstream.URL, reqBody, pool, tc.byokKey) + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) + + req := httptest.NewRequest(http.MethodPost, ic.path, nil) + w := httptest.NewRecorder() + err := interceptor.ProcessRequest(w, req) + if tc.expectedStatus == http.StatusOK { + require.NoError(t, err) + } else { + require.Error(t, err) + } + + assert.Equal(t, tc.expectedStatus, w.Code, "response status code") + assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + if pool != nil { + assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + } + + var seenKeys []string + for _, r := range upstream.ReceivedRequests() { + seenKeys = append(seenKeys, keyFromHeader(ic.authHeader, r.Header)) + } + assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") + + if len(tc.expectedSeenKeys) > 0 { + assert.Equal(t, utils.MaskSecret(tc.expectedSeenKeys[len(tc.expectedSeenKeys)-1]), + interceptor.Credential().Hint, "credential hint") + } + if tc.expectedBodyContains != "" { + assert.Contains(t, w.Body.String(), tc.expectedBodyContains, "response body") + } + }) + } + } + } +} + +// TestInterception_AgenticLoopFailover covers the scenarios that span an +// agentic-loop continuation: the initial client request and the subsequent +// tool-call continuation can each fail over independently, in both blocking and +// streaming mode. Each iteration gets its own walker. +func TestInterception_AgenticLoopFailover(t *testing.T) { + t.Parallel() + + const k0, k1 = "k0-long-key", "k1-long-key" + errResp := testutil.NewErrorResponse + + tests := []struct { + name string + keys []string + // responses builds the upstream responses in call order. toolCall is the + // tool_use response and final is the response after the tool result. + responses func(toolCall, final testutil.UpstreamResponse) []testutil.UpstreamResponse + expectedStatus int + expectedRetryAfter string + expectedKeyStates []keypool.KeyState + expectedSeenKeys []string + expectedBodyContains string + // expectErr is true when ProcessRequest returns an error because the + // pool is exhausted. + expectErr bool + }{ + { + // Both upstream calls succeed on the first key. + name: "happy_path", + keys: []string{k0, k1}, + responses: func(toolCall, final testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{toolCall, final} + }, + expectedStatus: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid, keypool.KeyStateValid}, + expectedSeenKeys: []string{k0, k0}, + }, + { + // The continuation is rate-limited on the first key and fails over + // to the second. + name: "agentic_failover_to_k1", + keys: []string{k0, k1}, + responses: func(toolCall, final testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{toolCall, errResp(http.StatusTooManyRequests, "5"), final} + }, + expectedStatus: http.StatusOK, + expectedKeyStates: []keypool.KeyState{keypool.KeyStateTemporary, keypool.KeyStateValid}, + expectedSeenKeys: []string{k0, k0, k1}, + }, + { + // The continuation is rate-limited on every key, exhausting the pool. + name: "agentic_all_keys_fail", + keys: []string{k0, k1}, + responses: func(toolCall, _ testutil.UpstreamResponse) []testutil.UpstreamResponse { + return []testutil.UpstreamResponse{ + toolCall, + errResp(http.StatusTooManyRequests, "5"), + errResp(http.StatusTooManyRequests, "3"), + } + }, + expectedStatus: http.StatusTooManyRequests, + expectedRetryAfter: "3", + expectedBodyContains: "all configured keys are rate-limited", + expectedKeyStates: []keypool.KeyState{keypool.KeyStateTemporary, keypool.KeyStateTemporary}, + expectedSeenKeys: []string{k0, k0, k1}, + expectErr: true, + }, + } + + for _, ic := range interceptorCases { + for _, mode := range []string{"blocking", "streaming"} { + streaming := mode == "streaming" + for _, tc := range tests { + t.Run(ic.name+"/"+mode+"/"+tc.name, func(t *testing.T) { + t.Parallel() + + pool, err := keypool.New(tc.keys, quartz.NewMock(t)) + require.NoError(t, err) + + fixture := fixtures.Parse(t, ic.fixture(streaming, true)) + reqBody := fixture.Request() + if streaming { + reqBody, err = sjson.SetBytes(reqBody, "stream", true) + require.NoError(t, err) + } + toolCall, final := testutil.NewFixtureResponse(fixture), testutil.NewFixtureToolResponse(fixture) + upstream := testutil.NewMockUpstream(t.Context(), t, tc.responses(toolCall, final)...) + + interceptor := ic.newInterceptor(t, streaming, upstream.URL, reqBody, pool, "") + interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, &testutil.MockServerProxier{ResolveAnyTool: true}) + + req := httptest.NewRequest(http.MethodPost, ic.path, nil) + w := httptest.NewRecorder() + err = interceptor.ProcessRequest(w, req) + if tc.expectErr { + require.Error(t, err) + } else { + require.NoError(t, err) + } + + // Once streaming has started, exhaustion is relayed as an SSE + // error event under a 200. + wantStatus, wantRetryAfter := tc.expectedStatus, tc.expectedRetryAfter + if streaming && tc.expectErr && ic.agenticStreamErrorEvent != "" { + wantStatus, wantRetryAfter = http.StatusOK, "" + } + assert.Equal(t, wantStatus, w.Code, "response status code") + assert.Equal(t, wantRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") + if streaming && tc.expectErr && ic.agenticStreamErrorEvent != "" { + assert.Contains(t, w.Body.String(), ic.agenticStreamErrorEvent, "exhaustion relayed as SSE event") + } + if streaming && !tc.expectErr { + assert.Contains(t, w.Body.String(), ic.streamDoneEvent, "final response streamed to client") + } + assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") + + var seenKeys []string + for _, r := range upstream.ReceivedRequests() { + seenKeys = append(seenKeys, keyFromHeader(ic.authHeader, r.Header)) + } + assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") + + if len(tc.expectedSeenKeys) > 0 { + assert.Equal(t, utils.MaskSecret(tc.expectedSeenKeys[len(tc.expectedSeenKeys)-1]), + interceptor.Credential().Hint, "credential hint") + } + if tc.expectedBodyContains != "" { + assert.Contains(t, w.Body.String(), tc.expectedBodyContains, "response body") + } + }) + } + } + } +} diff --git a/aibridge/intercept/messages/blocking_internal_test.go b/aibridge/intercept/messages/blocking_internal_test.go deleted file mode 100644 index e5c3c9f6ce4a2..0000000000000 --- a/aibridge/intercept/messages/blocking_internal_test.go +++ /dev/null @@ -1,479 +0,0 @@ -package messages - -import ( - "io" - "net/http" - "net/http/httptest" - "sync" - "sync/atomic" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "cdr.dev/slog/v3" - "github.com/coder/coder/v2/aibridge/config" - "github.com/coder/coder/v2/aibridge/intercept" - "github.com/coder/coder/v2/aibridge/internal/testutil" - "github.com/coder/coder/v2/aibridge/keypool" - "github.com/coder/coder/v2/aibridge/mcp" - "github.com/coder/coder/v2/aibridge/utils" - "github.com/coder/quartz" -) - -// Common request and Anthropic-shaped response bodies. -const ( - requestBody = `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}` - successBody = `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}` - toolUseBody = `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_01","name":"test_tool","input":{}}],"model":"claude-opus-4-5","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}` - rateLimitBody = `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}` - authErrorBody = `{"type":"error","error":{"type":"authentication_error","message":"invalid key"}}` - serverErrorBody = `{"type":"error","error":{"type":"api_error","message":"server error"}}` -) - -type upstreamResponse struct { - statusCode int - body string - headers map[string]string -} - -func TestBlockingInterception_KeyFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Centralized pool keys. Empty when byokKey is set. - keys []string - // BYOK key. Empty when keys is set. - byokKey string - // Scripted upstream responses keyed by X-Api-Key. - responses map[string]upstreamResponse - expectedRequestCount int32 - expectedStatusCode int - expectedRetryAfter string - // Expected key states after the request, by index in keys. - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: last - // attempted key for centralized, user key from initial request for BYOK. - expectedCredentialHint string - }{ - { - // Given: 1 valid key returning 200. - // Then: 1 request, 200 response, key remains valid. - name: "single_valid_key", - keys: []string{"k0-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 returns 429, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. - name: "failover_after_429", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 401, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_401", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 403, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_403", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s. - // Then: 3 requests, 429 response with smallest Retry-After, - // all keys temporary. - name: "all_keys_rate_limited", - keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - "k2-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "10"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "3", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k2-long-key"), - }, - { - // Given: 2 keys; both return 401. - // Then: 2 requests, 502 api_error response, both keys permanent. - name: "all_keys_unauthorized", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusBadGateway, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStatePermanent, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 500. - // Then: 1 request, 500 response, both keys remain valid. - name: "server_error_no_failover", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusInternalServerError, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: BYOK with a single key returning 429. - // Then: 1 request, 429 response, no failover, upstream - // Retry-After propagated to the client. - name: "byok_no_failover", - byokKey: "user-byok", - responses: map[string]upstreamResponse{ - "user-byok": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{ - "Retry-After": "5", - // BYOK doesn't set MaxRetries(0); - // suppress SDK retries to test a - // single attempt. - "x-should-retry": "false", - }, - body: rateLimitBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "5", - expectedCredentialHint: utils.MaskSecret("user-byok"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Mock upstream: counts requests and returns - // scripted responses keyed by X-Api-Key. An unmapped - // key falls through to 500 so misconfigured cases - // surface via the status assertion. - var requestCount atomic.Int32 - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount.Add(1) - _, _ = io.Copy(io.Discard, r.Body) - resp, ok := tc.responses[r.Header.Get("X-Api-Key")] - if !ok { - resp = upstreamResponse{statusCode: http.StatusInternalServerError} - } - w.Header().Set("Content-Type", "application/json") - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - cfg := config.Anthropic{BaseURL: upstream.URL + "/"} - var pool *keypool.Pool - credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") - if len(tc.keys) > 0 { - var err error - pool, err = keypool.New(tc.keys, quartz.NewMock(t)) - require.NoError(t, err) - cfg.KeyPool = pool - } else if tc.byokKey != "" { - cfg.Key = tc.byokKey - credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) - } - - payload, err := NewRequestPayload([]byte(requestBody)) - require.NoError(t, err) - - interceptor := NewBlockingInterceptor( - uuid.New(), - payload, - config.ProviderAnthropic, - cfg, - nil, - http.Header{}, - "X-Api-Key", - otel.Tracer("blocking_test"), - credInfo, - ) - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - if pool != nil { - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - } - }) - } -} - -// TestBlockingInterception_AgenticLoopFailover covers the -// scenarios that span an agentic-loop continuation: the initial -// client request and the subsequent tool-call continuation can -// each fail over independently. Each iteration gets its own -// walker. -func TestBlockingInterception_AgenticLoopFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Scripted upstream responses consumed in order of - // upstream request. - responses []upstreamResponse - expectedRequestCount int32 - expectedSeenKeys []string - expectedStatusCode int - expectedRetryAfter string - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: hint of the - // last attempted key across all agentic-loop iterations. - expectedCredentialHint string - }{ - { - // Given: 2 keys; both upstream calls succeed on key-0. - // Then: 2 requests, 200 response, both keys remain valid. - name: "happy_path", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then 429s - // during the agentic continuation, key-1 succeeds. - // Then: 3 requests, 200 response, key-0 temporary, - // key-1 valid. - name: "agentic_failover_to_k1", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then both - // keys 429 during the agentic continuation. - // Then: 3 requests, 429 response with smallest - // Retry-After, both keys temporary. - name: "agentic_all_keys_fail", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "3", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - var seenKeysMu sync.Mutex - var seenKeys []string - - // Mock upstream: returns scripted responses in order, - // records each request's X-Api-Key for assertions. - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - idx := int(requestCount.Add(1)) - 1 - seenKeysMu.Lock() - seenKeys = append(seenKeys, r.Header.Get("X-Api-Key")) - seenKeysMu.Unlock() - _, _ = io.Copy(io.Discard, r.Body) - - if idx >= len(tc.responses) { - w.WriteHeader(http.StatusInternalServerError) - return - } - resp := tc.responses[idx] - w.Header().Set("Content-Type", "application/json") - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) - require.NoError(t, err) - - cfg := config.Anthropic{ - BaseURL: upstream.URL + "/", - KeyPool: pool, - } - - payload, err := NewRequestPayload([]byte(requestBody)) - require.NoError(t, err) - - interceptor := NewBlockingInterceptor( - uuid.New(), - payload, - config.ProviderAnthropic, - cfg, - nil, - http.Header{}, - "X-Api-Key", - otel.Tracer("blocking_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), - ) - - // Mock proxy with a tool the upstream's tool_use - // response will reference. - proxy := &testutil.MockServerProxier{ - Tools: []*mcp.Tool{ - { - Client: testutil.StubToolCaller{}, - ID: "test_tool", - Name: "test_tool", - ServerName: "coder", - Logger: slog.Make(), - }, - }, - } - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - - seenKeysMu.Lock() - defer seenKeysMu.Unlock() - assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - }) - } -} diff --git a/aibridge/intercept/messages/streaming_internal_test.go b/aibridge/intercept/messages/streaming_internal_test.go deleted file mode 100644 index 4445df7bb489a..0000000000000 --- a/aibridge/intercept/messages/streaming_internal_test.go +++ /dev/null @@ -1,570 +0,0 @@ -package messages - -import ( - "io" - "net/http" - "net/http/httptest" - "sync" - "sync/atomic" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "cdr.dev/slog/v3" - "github.com/coder/coder/v2/aibridge/config" - "github.com/coder/coder/v2/aibridge/intercept" - "github.com/coder/coder/v2/aibridge/internal/testutil" - "github.com/coder/coder/v2/aibridge/keypool" - "github.com/coder/coder/v2/aibridge/mcp" - "github.com/coder/coder/v2/aibridge/utils" - "github.com/coder/quartz" -) - -// Anthropic-shaped SSE body for a successful streaming response. -const streamingSuccessBody = `event: message_start -data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":0} - -event: message_delta -data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":5}} - -event: message_stop -data: {"type":"message_stop"} -` - -func TestStreamingInterception_KeyFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Centralized pool keys. Empty when byokKey is set. - keys []string - // BYOK key. Empty when keys is set. - byokKey string - // Scripted upstream responses keyed by X-Api-Key. - responses map[string]upstreamResponse - expectedRequestCount int32 - expectedStatusCode int - expectedRetryAfter string - // Expected key states after the request, by index in keys. - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: last - // attempted key for centralized, user key from initial request for BYOK. - expectedCredentialHint string - }{ - { - // Given: 1 valid key returning a successful stream. - // Then: 1 request, 200 response, key remains valid. - name: "single_valid_key", - keys: []string{"k0-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 returns 429 pre-stream, key-1 - // streams successfully. - // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. - name: "failover_after_429", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 401 pre-stream, key-1 - // streams successfully. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_401", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_403", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 3 keys; all return 429 pre-stream with - // cooldowns 5s, 3s, 10s. - // Then: 3 requests, 429 response with smallest - // Retry-After, all keys temporary. - name: "all_keys_rate_limited", - keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - "k2-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "10"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "3", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k2-long-key"), - }, - { - // Given: 2 keys; both return 401 pre-stream. - // Then: 2 requests, 502 api_error response, both keys permanent. - name: "all_keys_unauthorized", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusBadGateway, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStatePermanent, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 500 pre-stream. - // Then: 1 request, 500 response, both keys remain valid. - name: "server_error_no_failover", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusInternalServerError, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: BYOK with a single key returning 429. - // Then: 1 request, 429 response, no failover, upstream - // Retry-After propagated to the client. - name: "byok_no_failover", - byokKey: "user-byok", - responses: map[string]upstreamResponse{ - "user-byok": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{ - "Retry-After": "5", - // BYOK doesn't set MaxRetries(0); - // suppress SDK retries to test a - // single attempt. - "x-should-retry": "false", - }, - body: rateLimitBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "5", - expectedCredentialHint: utils.MaskSecret("user-byok"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Mock upstream: counts requests and returns - // scripted responses keyed by X-Api-Key. An unmapped - // key falls through to 500 so misconfigured cases - // surface via the status assertion. - var requestCount atomic.Int32 - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount.Add(1) - _, _ = io.Copy(io.Discard, r.Body) - resp, ok := tc.responses[r.Header.Get("X-Api-Key")] - if !ok { - resp = upstreamResponse{statusCode: http.StatusInternalServerError} - } - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - cfg := config.Anthropic{BaseURL: upstream.URL + "/"} - var pool *keypool.Pool - credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") - if len(tc.keys) > 0 { - var err error - pool, err = keypool.New(tc.keys, quartz.NewMock(t)) - require.NoError(t, err) - cfg.KeyPool = pool - } else if tc.byokKey != "" { - cfg.Key = tc.byokKey - credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) - } - - payload, err := NewRequestPayload([]byte(requestBody)) - require.NoError(t, err) - - interceptor := NewStreamingInterceptor( - uuid.New(), - payload, - config.ProviderAnthropic, - cfg, - nil, - http.Header{}, - "X-Api-Key", - otel.Tracer("streaming_test"), - credInfo, - ) - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") - // No prior iteration streamed, so errors must be a - // direct HTTP response, not an SSE event. - assert.NotContains(t, w.Body.String(), "event: error", "error must not be relayed as an SSE event") - if pool != nil { - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - } - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - }) - } -} - -// SSE bodies covering an agentic-continuation flow. -const ( - // First response: a tool_use block referencing the injected - // "test_tool". Triggers the agentic continuation loop. - toolUseStreamBody = `event: message_start -data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"test_tool","input":{}}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{}"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":0} - -event: message_delta -data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":5}} - -event: message_stop -data: {"type":"message_stop"} - -` - - // Second response (after the tool result is sent back): - // a plain text completion that ends the loop. - textStreamBody = `event: message_start -data: {"type":"message_start","message":{"id":"msg_02","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":15,"output_tokens":1}}} - -event: content_block_start -data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}} - -event: content_block_delta -data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"done"}} - -event: content_block_stop -data: {"type":"content_block_stop","index":0} - -event: message_delta -data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":3}} - -event: message_stop -data: {"type":"message_stop"} - -` -) - -// TestStreamingInterception_AgenticLoopFailover covers the -// scenarios that span an agentic-loop continuation: the initial -// client request and the subsequent tool-call continuation can -// each fail over independently. Each iteration gets its own -// walker. -func TestStreamingInterception_AgenticLoopFailover(t *testing.T) { - t.Parallel() - - sseHeaders := map[string]string{"Content-Type": "text/event-stream"} - - tests := []struct { - name string - // Scripted upstream responses consumed in order of - // upstream request. - responses []upstreamResponse - expectedRequestCount int32 - expectedSeenKeys []string - // Substring expected in the response body. Either a - // success marker (e.g. "done") or an error marker - // (e.g. "rate_limit_error"). - expectedBodyContains string - // True when the error must be relayed as an SSE event. - expectErrorAsSSEEvent bool - // True when ProcessRequest is expected to return an - // error (e.g. all keys exhausted). - expectedErr bool - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: hint of the - // last attempted key across all agentic-loop iterations. - expectedCredentialHint string - }{ - { - // Given: 2 keys; both upstream calls succeed on key-0. - // Then: 2 requests, success body, both keys remain valid. - name: "happy_path", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, - }, - expectedRequestCount: 2, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, - expectedBodyContains: "done", - expectErrorAsSSEEvent: false, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then 429s - // during the agentic continuation, key-1 succeeds. - // Then: 3 requests, success body, key-0 temporary, - // key-1 valid. - name: "agentic_failover_to_k1", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedBodyContains: "done", - expectErrorAsSSEEvent: false, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then both - // keys 429 during the agentic continuation. - // Then: 3 requests, error injected as SSE event, both - // keys temporary. - // - // Known flake: race in eventstream.IsStreaming() can - // produce a malformed response on the all-keys-exhausted - // path. See https://github.com/coder/internal/issues/1524. - name: "agentic_all_keys_fail", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedBodyContains: "all configured keys are rate-limited", - expectErrorAsSSEEvent: true, - expectedErr: true, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - var seenKeysMu sync.Mutex - var seenKeys []string - - // Mock upstream: returns scripted responses in order, - // records each request's X-Api-Key for assertions. - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - idx := int(requestCount.Add(1)) - 1 - seenKeysMu.Lock() - seenKeys = append(seenKeys, r.Header.Get("X-Api-Key")) - seenKeysMu.Unlock() - _, _ = io.Copy(io.Discard, r.Body) - - if idx >= len(tc.responses) { - w.WriteHeader(http.StatusInternalServerError) - return - } - resp := tc.responses[idx] - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) - require.NoError(t, err) - - cfg := config.Anthropic{ - BaseURL: upstream.URL + "/", - KeyPool: pool, - } - - payload, err := NewRequestPayload([]byte(requestBody)) - require.NoError(t, err) - - interceptor := NewStreamingInterceptor( - uuid.New(), - payload, - config.ProviderAnthropic, - cfg, - nil, - http.Header{}, - "X-Api-Key", - otel.Tracer("streaming_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), - ) - - // Mock proxy with a tool the upstream's tool_use event - // will reference. The stub caller returns a fixed - // text result. - proxy := &testutil.MockServerProxier{ - Tools: []*mcp.Tool{ - { - Client: testutil.StubToolCaller{}, - ID: "test_tool", - Name: "test_tool", - ServerName: "coder", - Logger: slog.Make(), - }, - }, - } - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy) - - req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - body := w.Body.String() - assert.Contains(t, body, tc.expectedBodyContains, "response body") - if tc.expectErrorAsSSEEvent { - assert.Contains(t, body, "event: error", "error must be relayed as an SSE event") - } - - seenKeysMu.Lock() - defer seenKeysMu.Unlock() - assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - }) - } -} diff --git a/aibridge/intercept/responses/base.go b/aibridge/intercept/responses/base.go index 4be82c64b0b29..426f4a279d2a6 100644 --- a/aibridge/intercept/responses/base.go +++ b/aibridge/intercept/responses/base.go @@ -437,6 +437,11 @@ func (r *responseCopier) forwardResp(w http.ResponseWriter) error { } w.Header().Set("Content-Type", r.responseHeaders.Get("Content-Type")) + // Preserve the upstream retry-after header so clients can honor it on + // rate-limited or unavailable responses. + if retryAfter := r.responseHeaders.Get("Retry-After"); retryAfter != "" { + w.Header().Set("Retry-After", retryAfter) + } w.WriteHeader(r.responseStatus) b, err := r.readAll() diff --git a/aibridge/intercept/responses/blocking_internal_test.go b/aibridge/intercept/responses/blocking_internal_test.go deleted file mode 100644 index 69e682b23c474..0000000000000 --- a/aibridge/intercept/responses/blocking_internal_test.go +++ /dev/null @@ -1,473 +0,0 @@ -package responses - -import ( - "io" - "net/http" - "net/http/httptest" - "sync" - "sync/atomic" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "cdr.dev/slog/v3" - "github.com/coder/coder/v2/aibridge/config" - "github.com/coder/coder/v2/aibridge/intercept" - "github.com/coder/coder/v2/aibridge/internal/testutil" - "github.com/coder/coder/v2/aibridge/keypool" - "github.com/coder/coder/v2/aibridge/mcp" - "github.com/coder/coder/v2/aibridge/utils" - "github.com/coder/quartz" -) - -// OpenAI Responses API request and response bodies. -const ( - requestBody = `{"input":"hi","model":"gpt-4o-mini"}` - successBody = `{"id":"resp_01","object":"response","status":"completed","model":"gpt-4o-mini","output":[{"type":"message","id":"msg_01","role":"assistant","content":[{"type":"output_text","text":"Hello!"}]}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}` - toolUseBody = `{"id":"resp_01","object":"response","status":"completed","model":"gpt-4o-mini","output":[{"type":"function_call","id":"fc_01","call_id":"call_01","name":"test_tool","arguments":"{}","status":"completed"}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}}` - textCompleteBody = `{"id":"resp_02","object":"response","status":"completed","model":"gpt-4o-mini","output":[{"type":"message","id":"msg_02","role":"assistant","content":[{"type":"output_text","text":"done"}]}],"usage":{"input_tokens":15,"output_tokens":3,"total_tokens":18}}` - rateLimitBody = `{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","code":"rate_limit_exceeded"}}` - authErrorBody = `{"error":{"message":"Invalid API key","type":"invalid_request_error","code":"invalid_api_key"}}` - serverErrorBody = `{"error":{"message":"Internal server error","type":"server_error","code":"internal_error"}}` -) - -type upstreamResponse struct { - statusCode int - body string - headers map[string]string -} - -func TestBlockingResponsesInterceptor_KeyFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Centralized pool keys. Empty when byokKey is set. - keys []string - // BYOK key. Empty when keys is set. - byokKey string - // Scripted upstream responses keyed by bearer token. - responses map[string]upstreamResponse - expectedRequestCount int32 - expectedStatusCode int - expectedRetryAfter string - // Expected key states after the request, by index in keys. - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: last - // attempted key for centralized, user key from initial request for BYOK. - expectedCredentialHint string - }{ - { - // Given: 1 valid key returning 200. - // Then: 1 request, 200 response, key remains valid. - name: "single_valid_key", - keys: []string{"k0-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 returns 429, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. - name: "failover_after_429", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 401, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_401", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 403, key-1 returns 200. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_403", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusOK, body: successBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s. - // Then: 3 requests, 429 response with smallest Retry-After, - // all keys temporary. - name: "all_keys_rate_limited", - keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - "k2-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "10"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "3", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k2-long-key"), - }, - { - // Given: 2 keys; both return 401. - // Then: 2 requests, 502 api_error response, both keys permanent. - name: "all_keys_unauthorized", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusBadGateway, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStatePermanent, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 500. - // Then: 1 request, 500 response, both keys remain valid. - name: "server_error_no_failover", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusInternalServerError, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: BYOK with a single key returning 429. - // Then: 1 request, 429 response, no failover. - name: "byok_no_failover", - byokKey: "user-byok", - responses: map[string]upstreamResponse{ - "user-byok": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{ - "Retry-After": "5", - // BYOK doesn't set MaxRetries(0); - // suppress SDK retries to test a - // single attempt. - "x-should-retry": "false", - }, - body: rateLimitBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedCredentialHint: utils.MaskSecret("user-byok"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Mock upstream: counts requests and returns - // scripted responses keyed by bearer token. An - // unmapped key falls through to 500 so misconfigured - // cases surface via the status assertion. - var requestCount atomic.Int32 - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount.Add(1) - _, _ = io.Copy(io.Discard, r.Body) - resp, ok := tc.responses[utils.ExtractBearerToken(r.Header.Get("Authorization"))] - if !ok { - resp = upstreamResponse{statusCode: http.StatusInternalServerError} - } - w.Header().Set("Content-Type", "application/json") - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - cfg := config.OpenAI{BaseURL: upstream.URL + "/"} - credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") - var pool *keypool.Pool - if len(tc.keys) > 0 { - var err error - pool, err = keypool.New(tc.keys, quartz.NewMock(t)) - require.NoError(t, err) - cfg.KeyPool = pool - } else if tc.byokKey != "" { - cfg.Key = tc.byokKey - credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) - } - - payload, err := NewRequestPayload([]byte(requestBody)) - require.NoError(t, err) - - interceptor := NewBlockingInterceptor( - uuid.New(), - payload, - config.ProviderOpenAI, - cfg, - http.Header{}, - "Authorization", - otel.Tracer("blocking_test"), - credInfo, - ) - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) - - req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - if pool != nil { - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - } - }) - } -} - -// TestBlockingResponsesInterceptor_AgenticLoopFailover covers -// the scenarios that span an agentic-loop continuation: the -// initial client request and the subsequent tool-call -// continuation can each fail over independently. Each iteration -// gets its own walker. -func TestBlockingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Scripted upstream responses consumed in order of - // upstream request. - responses []upstreamResponse - expectedRequestCount int32 - expectedSeenKeys []string - expectedStatusCode int - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: hint of the - // last attempted key across all agentic-loop iterations. - expectedCredentialHint string - }{ - { - // Given: 2 keys; both upstream calls succeed on key-0. - // Then: 2 requests, 200 response, both keys remain valid. - name: "happy_path", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - {statusCode: http.StatusOK, body: textCompleteBody}, - }, - expectedRequestCount: 2, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then 429s - // during the agentic continuation, key-1 succeeds. - // Then: 3 requests, 200 response, key-0 temporary, - // key-1 valid. - name: "agentic_failover_to_k1", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - {statusCode: http.StatusOK, body: textCompleteBody}, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then both - // keys 429 during the agentic continuation. - // Then: 3 requests, 429 response with smallest - // Retry-After, both keys temporary. - name: "agentic_all_keys_fail", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, body: toolUseBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedStatusCode: http.StatusTooManyRequests, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - var seenKeysMu sync.Mutex - var seenKeys []string - - // Mock upstream: returns scripted responses in order, - // records each request's bearer token for assertions. - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - idx := int(requestCount.Add(1)) - 1 - seenKeysMu.Lock() - seenKeys = append(seenKeys, utils.ExtractBearerToken(r.Header.Get("Authorization"))) - seenKeysMu.Unlock() - _, _ = io.Copy(io.Discard, r.Body) - - if idx >= len(tc.responses) { - w.WriteHeader(http.StatusInternalServerError) - return - } - resp := tc.responses[idx] - w.Header().Set("Content-Type", "application/json") - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) - require.NoError(t, err) - - cfg := config.OpenAI{ - BaseURL: upstream.URL + "/", - KeyPool: pool, - } - - payload, err := NewRequestPayload([]byte(requestBody)) - require.NoError(t, err) - - interceptor := NewBlockingInterceptor( - uuid.New(), - payload, - config.ProviderOpenAI, - cfg, - http.Header{}, - "Authorization", - otel.Tracer("blocking_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), - ) - - // Mock proxy with a tool the upstream's function_call - // response will reference. - proxy := &testutil.MockServerProxier{ - Tools: []*mcp.Tool{ - { - Client: testutil.StubToolCaller{}, - ID: "test_tool", - Name: "test_tool", - ServerName: "coder", - Logger: slog.Make(), - }, - }, - } - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy) - - req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - - seenKeysMu.Lock() - defer seenKeysMu.Unlock() - assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - }) - } -} diff --git a/aibridge/intercept/responses/streaming_internal_test.go b/aibridge/intercept/responses/streaming_internal_test.go deleted file mode 100644 index 7c49140bfa2e6..0000000000000 --- a/aibridge/intercept/responses/streaming_internal_test.go +++ /dev/null @@ -1,520 +0,0 @@ -package responses - -import ( - "io" - "net/http" - "net/http/httptest" - "sync" - "sync/atomic" - "testing" - - "github.com/google/uuid" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel" - - "cdr.dev/slog/v3" - "github.com/coder/coder/v2/aibridge/config" - "github.com/coder/coder/v2/aibridge/intercept" - "github.com/coder/coder/v2/aibridge/internal/testutil" - "github.com/coder/coder/v2/aibridge/keypool" - "github.com/coder/coder/v2/aibridge/mcp" - "github.com/coder/coder/v2/aibridge/utils" - "github.com/coder/quartz" -) - -// Streaming request body for the OpenAI Responses API. -const streamingRequestBody = `{"input":"hi","model":"gpt-4o-mini","stream":true}` - -// OpenAI Responses API SSE body for a successful streaming response. -const streamingSuccessBody = `event: response.created -data: {"type":"response.created","response":{"id":"resp_01","object":"response","status":"in_progress"},"sequence_number":0} - -event: response.completed -data: {"type":"response.completed","response":{"id":"resp_01","object":"response","status":"completed","model":"gpt-4o-mini","output":[{"type":"message","id":"msg_01","role":"assistant","content":[{"type":"output_text","text":"Hello!"}]}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}},"sequence_number":1} - -` - -func TestStreamingResponsesInterceptor_KeyFailover(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - // Centralized pool keys. Empty when byokKey is set. - keys []string - // BYOK key. Empty when keys is set. - byokKey string - // Scripted upstream responses keyed by bearer token. - responses map[string]upstreamResponse - expectedRequestCount int32 - expectedStatusCode int - expectedRetryAfter string - // Expected key states after the request, by index in keys. - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: last - // attempted key for centralized, user key from initial request for BYOK. - expectedCredentialHint string - }{ - { - // Given: 1 valid key returning a successful stream. - // Then: 1 request, 200 response, key remains valid. - name: "single_valid_key", - keys: []string{"k0-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid}, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 returns 429 pre-stream, key-1 - // streams successfully. - // Then: 2 requests, 200 response, key-0 temporary, key-1 valid. - name: "failover_after_429", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 401 pre-stream, key-1 - // streams successfully. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_401", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams. - // Then: 2 requests, 200 response, key-0 permanent, key-1 valid. - name: "failover_after_403", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusForbidden, body: authErrorBody}, - "k1-long-key": { - statusCode: http.StatusOK, - headers: map[string]string{"Content-Type": "text/event-stream"}, - body: streamingSuccessBody, - }, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusOK, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 3 keys; all return 429 pre-stream with - // cooldowns 5s, 3s, 10s. - // Then: 3 requests, 429 response with smallest - // Retry-After, all keys temporary. - name: "all_keys_rate_limited", - keys: []string{"k0-long-key", "k1-long-key", "k2-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - "k1-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - "k2-long-key": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "10"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedStatusCode: http.StatusTooManyRequests, - expectedRetryAfter: "3", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k2-long-key"), - }, - { - // Given: 2 keys; both return 401 pre-stream. - // Then: 2 requests, 502 api_error response, both keys permanent. - name: "all_keys_unauthorized", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - "k1-long-key": {statusCode: http.StatusUnauthorized, body: authErrorBody}, - }, - expectedRequestCount: 2, - expectedStatusCode: http.StatusBadGateway, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStatePermanent, - keypool.KeyStatePermanent, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 returns 500 pre-stream. - // Then: 1 request, 500 response, both keys remain valid. - name: "server_error_no_failover", - keys: []string{"k0-long-key", "k1-long-key"}, - responses: map[string]upstreamResponse{ - "k0-long-key": {statusCode: http.StatusInternalServerError, body: serverErrorBody}, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusInternalServerError, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: BYOK with a single key returning 429. - // Then: 1 request, 429 response, no failover. - name: "byok_no_failover", - byokKey: "user-byok", - responses: map[string]upstreamResponse{ - "user-byok": { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{ - "Retry-After": "5", - // BYOK doesn't set MaxRetries(0); - // suppress SDK retries to test a - // single attempt. - "x-should-retry": "false", - }, - body: rateLimitBody, - }, - }, - expectedRequestCount: 1, - expectedStatusCode: http.StatusTooManyRequests, - expectedCredentialHint: utils.MaskSecret("user-byok"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - // Mock upstream: counts requests and returns - // scripted responses keyed by bearer token. An - // unmapped key falls through to 500 so misconfigured - // cases surface via the status assertion. - var requestCount atomic.Int32 - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - requestCount.Add(1) - _, _ = io.Copy(io.Discard, r.Body) - resp, ok := tc.responses[utils.ExtractBearerToken(r.Header.Get("Authorization"))] - if !ok { - resp = upstreamResponse{statusCode: http.StatusInternalServerError} - } - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - cfg := config.OpenAI{BaseURL: upstream.URL + "/"} - credInfo := intercept.NewCredentialInfo(intercept.CredentialKindCentralized, "") - var pool *keypool.Pool - if len(tc.keys) > 0 { - var err error - pool, err = keypool.New(tc.keys, quartz.NewMock(t)) - require.NoError(t, err) - cfg.KeyPool = pool - } else if tc.byokKey != "" { - cfg.Key = tc.byokKey - credInfo = intercept.NewCredentialInfo(intercept.CredentialKindBYOK, tc.byokKey) - } - - payload, err := NewRequestPayload([]byte(streamingRequestBody)) - require.NoError(t, err) - - interceptor := NewStreamingInterceptor( - uuid.New(), - payload, - config.ProviderOpenAI, - cfg, - http.Header{}, - "Authorization", - otel.Tracer("streaming_test"), - credInfo, - ) - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil) - - req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedStatusCode == http.StatusOK { - require.NoError(t, err) - } else { - require.Error(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code") - assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - if pool != nil { - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - } - }) - } -} - -// SSE bodies covering an agentic-continuation flow. -const ( - // First response: a function_call output referencing the - // injected "test_tool". Triggers the agentic continuation - // loop. - toolUseStreamBody = `event: response.created -data: {"type":"response.created","response":{"id":"resp_01","object":"response","status":"in_progress"},"sequence_number":0} - -event: response.completed -data: {"type":"response.completed","response":{"id":"resp_01","object":"response","status":"completed","model":"gpt-4o-mini","output":[{"type":"function_call","id":"fc_01","call_id":"call_01","name":"test_tool","arguments":"{}","status":"completed"}],"usage":{"input_tokens":10,"output_tokens":5,"total_tokens":15}},"sequence_number":1} - -` - - // Second response (after the tool result is sent back): - // a plain text message that ends the loop. - textStreamBody = `event: response.created -data: {"type":"response.created","response":{"id":"resp_02","object":"response","status":"in_progress"},"sequence_number":0} - -event: response.completed -data: {"type":"response.completed","response":{"id":"resp_02","object":"response","status":"completed","model":"gpt-4o-mini","output":[{"type":"message","id":"msg_02","role":"assistant","content":[{"type":"output_text","text":"done"}]}],"usage":{"input_tokens":15,"output_tokens":3,"total_tokens":18}},"sequence_number":1} - -` -) - -// TestStreamingResponsesInterceptor_AgenticLoopFailover covers -// the scenarios that span an agentic-loop continuation: the -// initial client request and the subsequent tool-call -// continuation can each fail over independently. Each iteration -// gets its own walker. -func TestStreamingResponsesInterceptor_AgenticLoopFailover(t *testing.T) { - t.Parallel() - - sseHeaders := map[string]string{"Content-Type": "text/event-stream"} - - tests := []struct { - name string - // Scripted upstream responses consumed in order of - // upstream request. - responses []upstreamResponse - expectedRequestCount int32 - expectedSeenKeys []string - // Substring expected in the response body. Either a - // success marker (e.g. "done") or an error marker - // (e.g. "rate_limit_error"). - expectedBodyContains string - // True when ProcessRequest is expected to return an - // error (e.g. all keys exhausted). - expectedErr bool - expectedKeyStates []keypool.KeyState - // Expected credential hint after ProcessRequest: hint of the - // last attempted key across all agentic-loop iterations. - expectedCredentialHint string - }{ - { - // Given: 2 keys; both upstream calls succeed on key-0. - // Then: 2 requests, success body, both keys remain valid. - name: "happy_path", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, - }, - expectedRequestCount: 2, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key"}, - expectedBodyContains: "done", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateValid, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k0-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then 429s - // during the agentic continuation, key-1 succeeds. - // Then: 3 requests, success body, key-0 temporary, - // key-1 valid. - name: "agentic_failover_to_k1", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - {statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody}, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedBodyContains: "done", - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateValid, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - { - // Given: 2 keys; key-0 succeeds initially, then both - // keys 429 during the agentic continuation. - // Then: 3 requests, error injected as SSE event, both - // keys temporary. - name: "agentic_all_keys_fail", - responses: []upstreamResponse{ - {statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody}, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "5"}, - body: rateLimitBody, - }, - { - statusCode: http.StatusTooManyRequests, - headers: map[string]string{"Retry-After": "3"}, - body: rateLimitBody, - }, - }, - expectedRequestCount: 3, - expectedSeenKeys: []string{"k0-long-key", "k0-long-key", "k1-long-key"}, - expectedBodyContains: "all configured keys are rate-limited", - expectedErr: true, - expectedKeyStates: []keypool.KeyState{ - keypool.KeyStateTemporary, - keypool.KeyStateTemporary, - }, - expectedCredentialHint: utils.MaskSecret("k1-long-key"), - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - var requestCount atomic.Int32 - var seenKeysMu sync.Mutex - var seenKeys []string - - // Mock upstream: returns scripted responses in order, - // records each request's bearer token for assertions. - upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - idx := int(requestCount.Add(1)) - 1 - seenKeysMu.Lock() - seenKeys = append(seenKeys, utils.ExtractBearerToken(r.Header.Get("Authorization"))) - seenKeysMu.Unlock() - _, _ = io.Copy(io.Discard, r.Body) - - if idx >= len(tc.responses) { - w.WriteHeader(http.StatusInternalServerError) - return - } - resp := tc.responses[idx] - for hk, hv := range resp.headers { - w.Header().Set(hk, hv) - } - w.WriteHeader(resp.statusCode) - _, _ = w.Write([]byte(resp.body)) - })) - t.Cleanup(upstream.Close) - - pool, err := keypool.New([]string{"k0-long-key", "k1-long-key"}, quartz.NewMock(t)) - require.NoError(t, err) - - cfg := config.OpenAI{ - BaseURL: upstream.URL + "/", - KeyPool: pool, - } - - payload, err := NewRequestPayload([]byte(streamingRequestBody)) - require.NoError(t, err) - - interceptor := NewStreamingInterceptor( - uuid.New(), - payload, - config.ProviderOpenAI, - cfg, - http.Header{}, - "Authorization", - otel.Tracer("streaming_test"), - intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""), - ) - - // Mock proxy with a tool the upstream's function_call - // response will reference. The stub caller returns a - // fixed text result. - proxy := &testutil.MockServerProxier{ - Tools: []*mcp.Tool{ - { - Client: testutil.StubToolCaller{}, - ID: "test_tool", - Name: "test_tool", - ServerName: "coder", - Logger: slog.Make(), - }, - }, - } - interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy) - - req := httptest.NewRequest(http.MethodPost, "/v1/responses", nil) - w := httptest.NewRecorder() - err = interceptor.ProcessRequest(w, req) - if tc.expectedErr { - require.Error(t, err) - } else { - require.NoError(t, err) - } - - assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count") - body := w.Body.String() - assert.Contains(t, body, tc.expectedBodyContains, "response body") - assert.Equal(t, tc.expectedCredentialHint, interceptor.Credential().Hint, "credential hint") - - seenKeysMu.Lock() - defer seenKeysMu.Unlock() - assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys") - assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states") - }) - } -} diff --git a/aibridge/internal/testutil/mockserverproxier.go b/aibridge/internal/testutil/mockserverproxier.go index 04f78330b3a46..b962e825e7459 100644 --- a/aibridge/internal/testutil/mockserverproxier.go +++ b/aibridge/internal/testutil/mockserverproxier.go @@ -5,13 +5,21 @@ import ( mcpgo "github.com/mark3labs/mcp-go/mcp" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge/mcp" ) // MockServerProxier is a test [mcp.ServerProxier] that injects a fixed set of -// tools. +// tools. When ResolveAnyTool is set, GetTool resolves any unregistered tool to a +// stub, so callers that only need the tool loop to proceed need not register +// each tool the fixture might call. type MockServerProxier struct { Tools []*mcp.Tool + // ResolveAnyTool makes GetTool return a stub tool, backed by a + // StubToolCaller, for any id not present in Tools. Use it to exercise + // injected-tool agentic loops where the test does not need to validate which + // tool was called. + ResolveAnyTool bool } func (*MockServerProxier) Init(context.Context) error { @@ -32,6 +40,15 @@ func (m *MockServerProxier) GetTool(id string) *mcp.Tool { return t } } + if m.ResolveAnyTool { + return &mcp.Tool{ + Client: StubToolCaller{}, + ID: id, + Name: id, + ServerName: "coder", + Logger: slog.Make(), + } + } return nil } diff --git a/aibridge/internal/testutil/mockupstream.go b/aibridge/internal/testutil/mockupstream.go index 6ecf5b0df1a24..242bfde34574d 100644 --- a/aibridge/internal/testutil/mockupstream.go +++ b/aibridge/internal/testutil/mockupstream.go @@ -65,6 +65,24 @@ func NewFixtureToolResponse(fix fixtures.Fixture) UpstreamResponse { return resp } +// NewErrorResponse returns an UpstreamResponse that replays a raw HTTP error +// response with the given status code and optional Retry-After header. SDK +// auto-retries are disabled via x-should-retry. +func NewErrorResponse(status int, retryAfter string) UpstreamResponse { + body := fmt.Sprintf(`{"error":{"message":%q}}`, http.StatusText(status)) + + raw := fmt.Sprintf("HTTP/1.1 %d %s\r\n", status, http.StatusText(status)) + if retryAfter != "" { + raw += fmt.Sprintf("Retry-After: %s\r\n", retryAfter) + } + raw += "x-should-retry: false\r\n" + raw += "Content-Type: application/json\r\n" + raw += fmt.Sprintf("Content-Length: %d\r\n\r\n%s", len(body), body) + + rawBytes := []byte(raw) + return UpstreamResponse{Streaming: rawBytes, Blocking: rawBytes} +} + // ReceivedRequest captures the details of a single request handled by MockUpstream. type ReceivedRequest struct { Method string