diff --git a/aibridge/internal/testutil/mockprovider.go b/aibridge/internal/testutil/mockprovider.go index 0c56cf2c9eee9..0fd85d2863637 100644 --- a/aibridge/internal/testutil/mockprovider.go +++ b/aibridge/internal/testutil/mockprovider.go @@ -13,12 +13,11 @@ import ( ) type MockProvider struct { - NameStr string - URL string - Bridged []string - Passthrough []string - InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) - InjectAuthHeaderFunc func(h *http.Header) + NameStr string + URL string + Bridged []string + Passthrough []string + InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) } func (m *MockProvider) Type() string { return m.NameStr } @@ -28,11 +27,6 @@ func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough } func (*MockProvider) AuthHeader() string { return "Authorization" } -func (m *MockProvider) InjectAuthHeader(h *http.Header) { - if m.InjectAuthHeaderFunc != nil { - m.InjectAuthHeaderFunc(h) - } -} func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig { return keypool.KeyFailoverConfig{} diff --git a/aibridge/passthrough_test.go b/aibridge/passthrough_test.go index 6679feb932528..246becc0d903e 100644 --- a/aibridge/passthrough_test.go +++ b/aibridge/passthrough_test.go @@ -321,6 +321,7 @@ func TestPassthrough_KeyFailover(t *testing.T) { extractKey func(*http.Request) string setBYOK func(*http.Request, string) newProvider func(baseURL string, pool *keypool.Pool) provider.Provider + byokOnly bool }{ { name: "anthropic", @@ -353,6 +354,22 @@ func TestPassthrough_KeyFailover(t *testing.T) { return provider.NewOpenAI(cfg) }, }, + { + // Copilot is always BYOK and returns an empty KeyFailoverConfig, + // which makes the KeyFailoverTransport short-circuit. Only the + // BYOK scenario applies, centralized-pool cases are skipped below. + name: "copilot", + extractKey: func(r *http.Request) string { + return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + }, + setBYOK: func(r *http.Request, key string) { + r.Header.Set("Authorization", "Bearer "+key) + }, + newProvider: func(baseURL string, _ *keypool.Pool) provider.Provider { + return provider.NewCopilot(config.Copilot{BaseURL: baseURL}) + }, + byokOnly: true, + }, } tests := []struct { @@ -516,6 +533,10 @@ func TestPassthrough_KeyFailover(t *testing.T) { for _, prov := range providers { for _, tc := range tests { + // BYOK-only providers do not exercise centralized-pool scenarios. + if prov.byokOnly && tc.byokKey == "" { + continue + } t.Run(prov.name+"/"+tc.name, func(t *testing.T) { t.Parallel() diff --git a/aibridge/provider/anthropic.go b/aibridge/provider/anthropic.go index 64128cee5fab9..c03ba19762f75 100644 --- a/aibridge/provider/anthropic.go +++ b/aibridge/provider/anthropic.go @@ -198,29 +198,6 @@ func (*Anthropic) AuthHeader() string { return "X-Api-Key" } -func (p *Anthropic) InjectAuthHeader(headers *http.Header) { - if headers == nil { - headers = &http.Header{} - } - - // BYOK: if the request already carries user-supplied credentials, - // do not overwrite them with the centralized key. - if headers.Get("X-Api-Key") != "" || headers.Get("Authorization") != "" { - return - } - - // Centralized: pull a single key from the pool. No failover - // or exhaustion handling here. - // TODO(ssncferreira): replace with RoundTripper-based auth - // in the upstack passthrough PR. - if p.cfg.KeyPool == nil { - return - } - if key, err := p.cfg.KeyPool.Walker().Next(); err == nil { - headers.Set(p.AuthHeader(), key.Value()) - } -} - func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig { name := p.Name() return keypool.KeyFailoverConfig{ diff --git a/aibridge/provider/anthropic_test.go b/aibridge/provider/anthropic_test.go index 7ebf4495d9b3d..3c7378ae71081 100644 --- a/aibridge/provider/anthropic_test.go +++ b/aibridge/provider/anthropic_test.go @@ -317,60 +317,6 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) { } } -func TestAnthropic_InjectAuthHeader(t *testing.T) { - t.Parallel() - - provider := NewAnthropic(config.Anthropic{Key: "centralized-key"}, nil) - - tests := []struct { - name string - presetHeaders map[string]string - wantXApiKey string - wantAuthorization string - }{ - { - name: "when no auth headers are provided, inject centralized key", - presetHeaders: map[string]string{}, - wantXApiKey: "centralized-key", - }, - { - name: "when X-Api-Key header is provided, use it", - presetHeaders: map[string]string{"X-Api-Key": "user-api-key"}, - wantXApiKey: "user-api-key", - }, - { - name: "when Authorization header is provided, use it", - presetHeaders: map[string]string{"Authorization": "Bearer user-access-token"}, - wantAuthorization: "Bearer user-access-token", - }, - { - name: "when both headers are provided, keep both", - presetHeaders: map[string]string{ - "Authorization": "Bearer user-access-token", - "X-Api-Key": "user-api-key", - }, - wantXApiKey: "user-api-key", - wantAuthorization: "Bearer user-access-token", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - headers := http.Header{} - for k, v := range tc.presetHeaders { - headers.Set(k, v) - } - - provider.InjectAuthHeader(&headers) - - assert.Equal(t, tc.wantXApiKey, headers.Get("X-Api-Key")) - assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization")) - }) - } -} - func TestExtractAnthropicHeaders(t *testing.T) { t.Parallel() diff --git a/aibridge/provider/copilot.go b/aibridge/provider/copilot.go index b68513ecec84a..1186e8b253f6d 100644 --- a/aibridge/provider/copilot.go +++ b/aibridge/provider/copilot.go @@ -107,12 +107,6 @@ func (*Copilot) AuthHeader() string { return "Authorization" } -// InjectAuthHeader is a no-op for Copilot. -// Copilot uses per-user tokens passed in the original Authorization header, -// rather than a global key configured at the provider level. -// The original Authorization header flows through untouched from the client. -func (*Copilot) InjectAuthHeader(_ *http.Header) {} - // KeyFailoverConfig returns a config with a nil Pool, which makes // the KeyFailoverTransport short-circuit. Copilot is always BYOK. func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig { diff --git a/aibridge/provider/copilot_test.go b/aibridge/provider/copilot_test.go index cd30a833500d8..3d68881379dc3 100644 --- a/aibridge/provider/copilot_test.go +++ b/aibridge/provider/copilot_test.go @@ -51,39 +51,6 @@ func TestCopilot_TypeAndName(t *testing.T) { } } -func TestCopilot_InjectAuthHeader(t *testing.T) { - t.Parallel() - - // Copilot uses per-user key passed in the Authorization header, - // so InjectAuthHeader should not modify any headers. - provider := NewCopilot(config.Copilot{}) - - t.Run("ExistingHeaders_Unchanged", func(t *testing.T) { - t.Parallel() - - headers := http.Header{} - headers.Set("Authorization", "Bearer user-token") - headers.Set("X-Custom-Header", "custom-value") - - provider.InjectAuthHeader(&headers) - - assert.Equal(t, "Bearer user-token", headers.Get("Authorization"), - "Authorization header should remain unchanged") - assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"), - "other headers should remain unchanged") - }) - - t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) { - t.Parallel() - - headers := http.Header{} - - provider.InjectAuthHeader(&headers) - - assert.Empty(t, headers, "no headers should be added") - }) -} - func TestCopilot_CreateInterceptor(t *testing.T) { t.Parallel() diff --git a/aibridge/provider/openai.go b/aibridge/provider/openai.go index 894dda7194106..3f6f131de6a37 100644 --- a/aibridge/provider/openai.go +++ b/aibridge/provider/openai.go @@ -197,29 +197,6 @@ func (*OpenAI) AuthHeader() string { return "Authorization" } -func (p *OpenAI) InjectAuthHeader(headers *http.Header) { - if headers == nil { - headers = &http.Header{} - } - - // BYOK: if the request already carries user-supplied credentials, - // do not overwrite them with the centralized key. - if headers.Get("Authorization") != "" { - return - } - - // Centralized: pull a single key from the pool. No failover - // or exhaustion handling here. - // TODO(ssncferreira): replace with RoundTripper-based auth - // in the upstack passthrough PR. - if p.cfg.KeyPool == nil { - return - } - if key, err := p.cfg.KeyPool.Walker().Next(); err == nil { - headers.Set(p.AuthHeader(), "Bearer "+key.Value()) - } -} - func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig { name := p.Name() return keypool.KeyFailoverConfig{ diff --git a/aibridge/provider/openai_test.go b/aibridge/provider/openai_test.go index d739a2dc20082..6aaa0b56441a5 100644 --- a/aibridge/provider/openai_test.go +++ b/aibridge/provider/openai_test.go @@ -325,44 +325,6 @@ func TestOpenAI_CreateInterceptor(t *testing.T) { } } -func TestOpenAI_InjectAuthHeader(t *testing.T) { - t.Parallel() - - provider := NewOpenAI(config.OpenAI{Key: "centralized-key"}) - - tests := []struct { - name string - presetHeaders map[string]string - wantAuthorization string - }{ - { - name: "when no Authorization header is provided, inject centralized key", - presetHeaders: map[string]string{}, - wantAuthorization: "Bearer centralized-key", - }, - { - name: "when Authorization header is provided, do not overwrite it", - presetHeaders: map[string]string{"Authorization": "Bearer user-token"}, - wantAuthorization: "Bearer user-token", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - headers := http.Header{} - for k, v := range tc.presetHeaders { - headers.Set(k, v) - } - - provider.InjectAuthHeader(&headers) - - assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization")) - }) - } -} - func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) { provider := NewOpenAI(config.OpenAI{ BaseURL: "https://api.openai.com/v1/", diff --git a/aibridge/provider/provider.go b/aibridge/provider/provider.go index 587dfd85ce014..7520333b53b61 100644 --- a/aibridge/provider/provider.go +++ b/aibridge/provider/provider.go @@ -77,10 +77,6 @@ type Provider interface { // AuthHeader returns the name of the header which the provider expects to find its authentication // token in. AuthHeader() string - // InjectAuthHeader allows [Provider]s to set its authentication header. - // TODO(ssncferreira): remove. Auth is now applied per-attempt by - // KeyFailoverTransport (see [Provider.KeyFailoverConfig]). - InjectAuthHeader(*http.Header) // KeyFailoverConfig returns the per-provider configuration for // automatic key failover on passthrough routes. KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig