Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 1 addition & 76 deletions aibridge/intercept/chatcompletions/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,10 @@ import (
"net/http"
"strconv"
"strings"
"time"

"github.com/google/uuid"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/shared"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"

Expand All @@ -27,7 +25,6 @@ import (
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
)

Expand Down Expand Up @@ -189,7 +186,7 @@ func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) {
}

// writeUpstreamError marshals and writes a given error.
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *intercept.ResponseError) {
if oaiErr == nil {
return
}
Expand Down Expand Up @@ -235,33 +232,6 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
)
}

// ProcessKeyPoolError translates a keypool exhaustion error
// into a developer-facing responseError shaped for the OpenAI
// API. Returns nil if err is not an exhaustion error.
func ProcessKeyPoolError(err error) *ResponseError {
var transient *keypool.TransientKeyPoolError
switch {
case errors.As(err, &transient):
return newErrorResponse(
"all configured keys are rate-limited",
intercept.OpenAIErrTypeRateLimit,
intercept.OpenAIErrCodeRateLimit,
http.StatusTooManyRequests,
transient.RetryAfter,
)
case errors.Is(err, keypool.ErrPermanentKeyPool):
return newErrorResponse(
"all configured keys failed authentication",
intercept.OpenAIErrTypeAPI,
intercept.OpenAIErrCodeServer,
http.StatusBadGateway,
0,
)
default:
return nil
}
}

func (i *interceptionBase) hasInjectableTools() bool {
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
}
Expand Down Expand Up @@ -292,48 +262,3 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ -
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
}

func getErrorResponse(err error) *ResponseError {
var apiErr *openai.Error
if !errors.As(err, &apiErr) {
return nil
}
return newErrorResponse(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response))
}

var _ error = &ResponseError{}

type ResponseError struct {
ErrorObject *shared.ErrorObject `json:"error"`
StatusCode int `json:"-"`
RetryAfter time.Duration `json:"-"`
}

func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
return &ResponseError{
ErrorObject: &shared.ErrorObject{
Code: code,
Message: msg,
Type: errType,
},
StatusCode: status,
RetryAfter: retryAfter,
}
}

func (e *ResponseError) Error() string {
if e.ErrorObject == nil {
return ""
}
return e.ErrorObject.Message
}

// ToResponse marshals e into an *http.Response shaped for the
// OpenAI API.
func (e *ResponseError) ToResponse() *http.Response {
body, err := json.Marshal(e)
if err != nil {
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
}
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
}
72 changes: 10 additions & 62 deletions aibridge/intercept/chatcompletions/base_internal_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
Expand Down Expand Up @@ -86,59 +87,6 @@ func TestScanForCorrelatingToolCallID(t *testing.T) {
}
}

func TestProcessKeyPoolError(t *testing.T) {
t.Parallel()

tests := []struct {
name string
err error
expectedNil bool
expectedStatus int
expectedRetryAfter time.Duration
}{
{
// Transient with valid keys present: 429, no Retry-After.
name: "transient_zero_retry_after",
err: &keypool.TransientKeyPoolError{},
expectedStatus: http.StatusTooManyRequests,
expectedRetryAfter: 0,
},
{
// Transient with cooldown: 429, Retry-After set.
name: "transient_with_retry_after",
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
expectedStatus: http.StatusTooManyRequests,
expectedRetryAfter: 5 * time.Second,
},
{
// Permanent: 502 api_error.
name: "permanent_returns_502",
err: keypool.ErrPermanentKeyPool,
expectedStatus: http.StatusBadGateway,
},
{
// Anything else: not a pool-exhaustion error.
name: "non_pool_exhaustion_error_returns_nil",
err: xerrors.New("some other error"),
expectedNil: true,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := ProcessKeyPoolError(tc.err)
if tc.expectedNil {
require.Nil(t, got)
return
}
require.NotNil(t, got)
assert.Equal(t, tc.expectedStatus, got.StatusCode)
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
})
}
}

func TestMarkKeyOnError(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -190,8 +138,8 @@ func TestMarkKeyOnError(t *testing.T) {
t.Parallel()
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
require.NoError(t, err)
key, err := pool.Walker().Next()
require.NoError(t, err)
key, keyPoolErr := pool.Walker().Next()
require.Nil(t, keyPoolErr)

base := &interceptionBase{cfg: config.OpenAI{KeyPool: pool}, logger: slog.Make()}

Expand All @@ -207,7 +155,7 @@ func TestWriteUpstreamError(t *testing.T) {

tests := []struct {
name string
respErr *ResponseError
respErr *intercept.ResponseError
expectStatus int
// Empty string means the header should be absent.
expectRetryAfter string
Expand All @@ -217,42 +165,42 @@ func TestWriteUpstreamError(t *testing.T) {
{
// Standard error: status, code, and JSON body written.
name: "writes_status_and_body",
respErr: newErrorResponse("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0),
respErr: intercept.NewResponseError("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0),
expectStatus: http.StatusBadGateway,
expectBodyContains: `"upstream failed"`,
},
{
// OpenAI envelope: the code field round-trips into the body.
name: "writes_code_field",
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0),
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0),
expectStatus: http.StatusTooManyRequests,
expectBodyContains: `"rate_limit_exceeded"`,
},
{
// Whole-second retryAfter: emitted as integer seconds.
name: "retry_after_in_seconds",
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second),
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "60",
},
{
// 500ms rounds up to Retry-After: 1.
name: "retry_after_500ms_rounds_up_to_one",
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond),
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "1",
},
{
// 200ms rounds up to Retry-After: 1.
name: "retry_after_200ms_rounds_up_to_one",
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond),
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "1",
},
{
// Negative retryAfter: header omitted.
name: "negative_retry_after_omits_header",
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second),
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "",
},
Expand Down
15 changes: 9 additions & 6 deletions aibridge/intercept/chatcompletions/blocking.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package chatcompletions
import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
Expand All @@ -19,6 +20,7 @@ import (
aibcontext "github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/eventstream"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
Expand Down Expand Up @@ -224,12 +226,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req

// The failover loop may return a keypool exhaustion
// error. Check before the SDK-error path.
if keyErr := ProcessKeyPoolError(err); keyErr != nil {
i.writeUpstreamError(w, keyErr)
var keyPoolErr *keypool.Error
if errors.As(err, &keyPoolErr) {
i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr))
return xerrors.Errorf("key pool exhausted: %w", err)
}

if apiErr := getErrorResponse(err); apiErr != nil {
if apiErr := intercept.ResponseErrorFromAPIError(err); apiErr != nil {
i.writeUpstreamError(w, apiErr)
return xerrors.Errorf("openai API error: %w", err)
}
Expand Down Expand Up @@ -293,9 +296,9 @@ func (i *BlockingInterception) newChatCompletionWithKeyFailover(ctx context.Cont
// success, the last tried key on failure) in the upstack PR.
walker := i.cfg.KeyPool.Walker()
for {
key, err := walker.Next()
if err != nil {
return nil, err
key, keyPoolErr := walker.Next()
if keyPoolErr != nil {
return nil, keyPoolErr
}

requestOpts := append([]option.RequestOption{}, opts...)
Expand Down
15 changes: 8 additions & 7 deletions aibridge/intercept/chatcompletions/streaming.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,9 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
var opts []option.RequestOption
var currentKey *keypool.Key
if walker != nil {
key, err := walker.Next()
if respErr := ProcessKeyPoolError(err); respErr != nil {
key, keyPoolErr := walker.Next()
if keyPoolErr != nil {
respErr := intercept.ResponseErrorFromKeyPool(keyPoolErr)
// Pool exhausted in this iteration. Relay the
// error to the client: as an SSE event if events
// have already been sent, or by direct write
Expand Down Expand Up @@ -470,17 +471,17 @@ func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCo
}

// mapStreamError converts a mid-stream upstream error or
// processing error into a relayable responseError. Returns nil
// processing error into a relayable ResponseError. Returns nil
// when the error is unrecoverable, in which case nothing can be
// relayed back.
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *ResponseError {
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *intercept.ResponseError {
if streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
return nil
}
if oaiErr := getErrorResponse(streamErr); oaiErr != nil {
if oaiErr := intercept.ResponseErrorFromAPIError(streamErr); oaiErr != nil {
logger.Warn(ctx, "openai stream error", slog.Error(streamErr))
return oaiErr
}
Expand All @@ -489,11 +490,11 @@ func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Lo
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171
// All it does is wrap the payload in an error - which is all we can return, currently.
return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0)
return intercept.NewResponseError(fmt.Sprintf("unknown stream error: %s", streamErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0)
}
if lastErr != nil {
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0)
return intercept.NewResponseError(fmt.Sprintf("processing error: %s", lastErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0)
}
return nil
}
Expand Down
Loading
Loading