@@ -17,7 +17,6 @@ import (
1717 "go.opentelemetry.io/otel/trace"
1818
1919 "cdr.dev/slog/v3"
20- "github.com/coder/coder/v2/aibridge/config"
2120 aibcontext "github.com/coder/coder/v2/aibridge/context"
2221 "github.com/coder/coder/v2/aibridge/intercept"
2322 "github.com/coder/coder/v2/aibridge/intercept/apidump"
@@ -29,56 +28,46 @@ import (
2928)
3029
3130type interceptionBase struct {
32- id uuid.UUID
33- providerName string
34- req * ChatCompletionNewParamsWrapper
35- cfg config.OpenAI
31+ id uuid.UUID
32+ req * ChatCompletionNewParamsWrapper
33+
34+ cfg intercept.Config
35+ cred intercept.Credential
3636
3737 // clientHeaders are the original HTTP headers from the client request.
38- clientHeaders http.Header
39- authHeaderName string
38+ clientHeaders http.Header
4039
4140 logger slog.Logger
4241 tracer trace.Tracer
4342
44- recorder recorder.Recorder
45- mcpProxy mcp.ServerProxier
46- credential intercept.CredentialInfo
43+ recorder recorder.Recorder
44+ mcpProxy mcp.ServerProxier
4745}
4846
4947// newCompletionsService builds the SDK service used for upstream
5048// calls. BYOK auth is set here. Centralized auth is set
5149// per-attempt by the failover loop.
5250func (i * interceptionBase ) newCompletionsService () openai.ChatCompletionService {
53- // TODO(ssncferreira): validate auth is configured per
54- // https://github.com/coder/aibridge/issues/266.
55-
5651 var opts []option.RequestOption
57- // BYOK auth.
58- if i .cfg .KeyPool == nil {
59- opts = append (opts , option .WithAPIKey (i .cfg .Key ))
52+ // BYOK sets its key here; centralized injects per-attempt in the failover
53+ // loop. The OpenAI SDK presents the key as an Authorization bearer.
54+ if byok , ok := intercept .AsBYOK (i .cred ); ok {
55+ opts = append (opts , option .WithAPIKey (byok .Secret ))
6056 }
6157 opts = append (opts , option .WithBaseURL (i .cfg .BaseURL ))
6258
63- // Add extra headers if configured.
64- // Some providers require additional headers that are not added by the SDK.
65- // TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192
66- for key , value := range i .cfg .ExtraHeaders {
67- opts = append (opts , option .WithHeader (key , value ))
68- }
69-
7059 // Forward client headers to upstream. This middleware runs after the SDK
7160 // has built the request, and replaces the outgoing headers with the sanitized
7261 // client headers plus provider auth.
7362 if i .clientHeaders != nil {
7463 opts = append (opts , option .WithMiddleware (func (req * http.Request , next option.MiddlewareNext ) (* http.Response , error ) {
75- req .Header = intercept .BuildUpstreamHeaders (req .Header , i .clientHeaders , i .authHeaderName )
64+ req .Header = intercept .BuildUpstreamHeaders (req .Header , i .clientHeaders , i .cred . AuthHeader () )
7665 return next (req )
7766 }))
7867 }
7968
8069 // Add API dump middleware if configured
81- if mw := apidump .NewBridgeMiddleware (i .cfg .APIDumpDir , i .providerName , i .Model (), i .id , i .logger , quartz .NewReal ()); mw != nil {
70+ if mw := apidump .NewBridgeMiddleware (i .cfg .APIDumpDir , i .cfg . ProviderName , i .Model (), i .id , i .logger , quartz .NewReal ()); mw != nil {
8271 opts = append (opts , option .WithMiddleware (mw ))
8372 }
8473
@@ -89,8 +78,8 @@ func (i *interceptionBase) ID() uuid.UUID {
8978 return i .id
9079}
9180
92- func (i * interceptionBase ) Credential () intercept.CredentialInfo {
93- return i .credential
81+ func (i * interceptionBase ) Credential () intercept.Credential {
82+ return i .cred
9483}
9584
9685func (i * interceptionBase ) Setup (logger slog.Logger , rec recorder.Recorder , mcpProxy mcp.ServerProxier ) {
@@ -117,7 +106,7 @@ func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool)
117106 attribute .String (tracing .RequestPath , r .URL .Path ),
118107 attribute .String (tracing .InterceptionID , i .id .String ()),
119108 attribute .String (tracing .InitiatorID , aibcontext .ActorIDFromContext (r .Context ())),
120- attribute .String (tracing .Provider , i .providerName ),
109+ attribute .String (tracing .Provider , i .cfg . ProviderName ),
121110 attribute .String (tracing .Model , i .Model ()),
122111 attribute .Bool (tracing .Streaming , streaming ),
123112 }
@@ -219,14 +208,15 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *int
219208// code. Returns true if the status was a key-specific failover
220209// trigger so callers can retry with the next key.
221210func (i * interceptionBase ) markKeyOnError (ctx context.Context , key * keypool.Key , err error ) bool {
222- if i .cfg .KeyPool == nil {
211+ centralized , ok := intercept .AsCentralized (i .cred )
212+ if ! ok {
223213 return false
224214 }
225215 var apiErr * openai.Error
226216 if ! errors .As (err , & apiErr ) {
227217 return false
228218 }
229- return i . cfg . KeyPool .MarkKeyOnStatus (
219+ return centralized . Pool .MarkKeyOnStatus (
230220 ctx , key , apiErr .Response , i .logger ,
231221 )
232222}
0 commit comments