Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions aibridge/internal/testutil/mockprovider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@ import (
)

type MockProvider struct {
NameStr string
URL string
Bridged []string
Passthrough []string
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
InjectAuthHeaderFunc func(h *http.Header)
NameStr string
URL string
Bridged []string
Passthrough []string
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
}

func (m *MockProvider) Type() string { return m.NameStr }
Expand All @@ -28,11 +27,6 @@ func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s",
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (*MockProvider) AuthHeader() string { return "Authorization" }
func (m *MockProvider) InjectAuthHeader(h *http.Header) {
if m.InjectAuthHeaderFunc != nil {
m.InjectAuthHeaderFunc(h)
}
}

func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
return keypool.KeyFailoverConfig{}
Expand Down
21 changes: 21 additions & 0 deletions aibridge/passthrough_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ func TestPassthrough_KeyFailover(t *testing.T) {
extractKey func(*http.Request) string
setBYOK func(*http.Request, string)
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
byokOnly bool
}{
{
name: "anthropic",
Expand Down Expand Up @@ -353,6 +354,22 @@ func TestPassthrough_KeyFailover(t *testing.T) {
return provider.NewOpenAI(cfg)
},
},
{
// Copilot is always BYOK and returns an empty KeyFailoverConfig,
// which makes the KeyFailoverTransport short-circuit. Only the
// BYOK scenario applies, centralized-pool cases are skipped below.
name: "copilot",
extractKey: func(r *http.Request) string {
return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
},
setBYOK: func(r *http.Request, key string) {
r.Header.Set("Authorization", "Bearer "+key)
},
newProvider: func(baseURL string, _ *keypool.Pool) provider.Provider {
return provider.NewCopilot(config.Copilot{BaseURL: baseURL})
},
byokOnly: true,
},
}

tests := []struct {
Expand Down Expand Up @@ -516,6 +533,10 @@ func TestPassthrough_KeyFailover(t *testing.T) {

for _, prov := range providers {
for _, tc := range tests {
// BYOK-only providers do not exercise centralized-pool scenarios.
if prov.byokOnly && tc.byokKey == "" {
continue
}
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
t.Parallel()

Expand Down
23 changes: 0 additions & 23 deletions aibridge/provider/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,29 +198,6 @@ func (*Anthropic) AuthHeader() string {
return "X-Api-Key"
}

func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
if headers == nil {
headers = &http.Header{}
}

// BYOK: if the request already carries user-supplied credentials,
// do not overwrite them with the centralized key.
if headers.Get("X-Api-Key") != "" || headers.Get("Authorization") != "" {
return
}

// Centralized: pull a single key from the pool. No failover
// or exhaustion handling here.
// TODO(ssncferreira): replace with RoundTripper-based auth
// in the upstack passthrough PR.
if p.cfg.KeyPool == nil {
return
}
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
headers.Set(p.AuthHeader(), key.Value())
}
}

func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
name := p.Name()
return keypool.KeyFailoverConfig{
Expand Down
54 changes: 0 additions & 54 deletions aibridge/provider/anthropic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -317,60 +317,6 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) {
}
}

func TestAnthropic_InjectAuthHeader(t *testing.T) {
t.Parallel()

provider := NewAnthropic(config.Anthropic{Key: "centralized-key"}, nil)

tests := []struct {
name string
presetHeaders map[string]string
wantXApiKey string
wantAuthorization string
}{
{
name: "when no auth headers are provided, inject centralized key",
presetHeaders: map[string]string{},
wantXApiKey: "centralized-key",
},
{
name: "when X-Api-Key header is provided, use it",
presetHeaders: map[string]string{"X-Api-Key": "user-api-key"},
wantXApiKey: "user-api-key",
},
{
name: "when Authorization header is provided, use it",
presetHeaders: map[string]string{"Authorization": "Bearer user-access-token"},
wantAuthorization: "Bearer user-access-token",
},
{
name: "when both headers are provided, keep both",
presetHeaders: map[string]string{
"Authorization": "Bearer user-access-token",
"X-Api-Key": "user-api-key",
},
wantXApiKey: "user-api-key",
wantAuthorization: "Bearer user-access-token",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

headers := http.Header{}
for k, v := range tc.presetHeaders {
headers.Set(k, v)
}

provider.InjectAuthHeader(&headers)

assert.Equal(t, tc.wantXApiKey, headers.Get("X-Api-Key"))
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
})
}
}

func TestExtractAnthropicHeaders(t *testing.T) {
t.Parallel()

Expand Down
6 changes: 0 additions & 6 deletions aibridge/provider/copilot.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,6 @@ func (*Copilot) AuthHeader() string {
return "Authorization"
}

// InjectAuthHeader is a no-op for Copilot.
// Copilot uses per-user tokens passed in the original Authorization header,
// rather than a global key configured at the provider level.
// The original Authorization header flows through untouched from the client.
func (*Copilot) InjectAuthHeader(_ *http.Header) {}

// KeyFailoverConfig returns a config with a nil Pool, which makes
// the KeyFailoverTransport short-circuit. Copilot is always BYOK.
func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
Expand Down
33 changes: 0 additions & 33 deletions aibridge/provider/copilot_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,39 +51,6 @@ func TestCopilot_TypeAndName(t *testing.T) {
}
}

func TestCopilot_InjectAuthHeader(t *testing.T) {
t.Parallel()

// Copilot uses per-user key passed in the Authorization header,
// so InjectAuthHeader should not modify any headers.
provider := NewCopilot(config.Copilot{})

t.Run("ExistingHeaders_Unchanged", func(t *testing.T) {
t.Parallel()

headers := http.Header{}
headers.Set("Authorization", "Bearer user-token")
headers.Set("X-Custom-Header", "custom-value")

provider.InjectAuthHeader(&headers)

assert.Equal(t, "Bearer user-token", headers.Get("Authorization"),
"Authorization header should remain unchanged")
assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"),
"other headers should remain unchanged")
})

t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) {
t.Parallel()

headers := http.Header{}

provider.InjectAuthHeader(&headers)

assert.Empty(t, headers, "no headers should be added")
})
}

func TestCopilot_CreateInterceptor(t *testing.T) {
t.Parallel()

Expand Down
23 changes: 0 additions & 23 deletions aibridge/provider/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,29 +197,6 @@ func (*OpenAI) AuthHeader() string {
return "Authorization"
}

func (p *OpenAI) InjectAuthHeader(headers *http.Header) {
if headers == nil {
headers = &http.Header{}
}

// BYOK: if the request already carries user-supplied credentials,
// do not overwrite them with the centralized key.
if headers.Get("Authorization") != "" {
return
}

// Centralized: pull a single key from the pool. No failover
// or exhaustion handling here.
// TODO(ssncferreira): replace with RoundTripper-based auth
// in the upstack passthrough PR.
if p.cfg.KeyPool == nil {
return
}
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
headers.Set(p.AuthHeader(), "Bearer "+key.Value())
}
}

func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
name := p.Name()
return keypool.KeyFailoverConfig{
Expand Down
38 changes: 0 additions & 38 deletions aibridge/provider/openai_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -325,44 +325,6 @@ func TestOpenAI_CreateInterceptor(t *testing.T) {
}
}

func TestOpenAI_InjectAuthHeader(t *testing.T) {
t.Parallel()

provider := NewOpenAI(config.OpenAI{Key: "centralized-key"})

tests := []struct {
name string
presetHeaders map[string]string
wantAuthorization string
}{
{
name: "when no Authorization header is provided, inject centralized key",
presetHeaders: map[string]string{},
wantAuthorization: "Bearer centralized-key",
},
{
name: "when Authorization header is provided, do not overwrite it",
presetHeaders: map[string]string{"Authorization": "Bearer user-token"},
wantAuthorization: "Bearer user-token",
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

headers := http.Header{}
for k, v := range tc.presetHeaders {
headers.Set(k, v)
}

provider.InjectAuthHeader(&headers)

assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
})
}
}

func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) {
provider := NewOpenAI(config.OpenAI{
BaseURL: "https://api.openai.com/v1/",
Expand Down
4 changes: 0 additions & 4 deletions aibridge/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,6 @@ type Provider interface {
// AuthHeader returns the name of the header which the provider expects to find its authentication
// token in.
AuthHeader() string
// InjectAuthHeader allows [Provider]s to set its authentication header.
// TODO(ssncferreira): remove. Auth is now applied per-attempt by
// KeyFailoverTransport (see [Provider.KeyFailoverConfig]).
InjectAuthHeader(*http.Header)
// KeyFailoverConfig returns the per-provider configuration for
// automatic key failover on passthrough routes.
KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig
Expand Down
Loading