Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions coderd/aibridge/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,27 @@ func SourceFromContext(ctx context.Context) Source {
return src
}

type delegatedAPIKeyIDCtxKey struct{}

// WithDelegatedAPIKeyID returns a copy of ctx carrying an API key ID on whose
// behalf the request is being made. The in-process aibridge transport requires
// this on every RoundTrip and rejects calls whose context lacks it.
//
// The caller is responsible for having established that the user owning this
// key authorized the request: aibridged validates only that the key exists,
// has not expired, and belongs to a non-deleted, non-system user. It does not
// verify the key secret, because the caller never has it.
func WithDelegatedAPIKeyID(ctx context.Context, id string) context.Context {
return context.WithValue(ctx, delegatedAPIKeyIDCtxKey{}, id)
}

// DelegatedAPIKeyIDFromContext returns the API key ID attached by
// [WithDelegatedAPIKeyID] and whether a non-empty value was set.
func DelegatedAPIKeyIDFromContext(ctx context.Context) (string, bool) {
id, ok := ctx.Value(delegatedAPIKeyIDCtxKey{}).(string)
return id, ok && id != ""
}

// TransportFactory returns an [http.RoundTripper] that dispatches an aibridge
// request in-process for a given ai_providers row.
//
Expand Down
214 changes: 214 additions & 0 deletions coderd/aibridged/aibridged_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,220 @@ func TestServeHTTP_FailureModes(t *testing.T) {
}
}

// When the request context carries a delegated API key ID (set by the
// in-process transport on behalf of a trusted caller like chatd), the handler
// must authenticate via the key_id field, skipping the header-based key
// extraction entirely. Validation succeeds or fails exactly as it would for a
// real API key. Delegation is orthogonal to BYOK: in BYOK mode the user's own
// LLM credentials must still be forwarded upstream while the Coder governance
// token is stripped.
func TestServeHTTP_DelegatedAPIKey(t *testing.T) {
t.Parallel()

const testKeyID = "abcdef1234"

tests := []struct {
name string
reqHeaders map[string]string
applyMocks func(t *testing.T, client *mock.MockDRPCClient, pool *mock.MockPooler, mockH *mockHandler)
expectStatus int
expectHandled bool
expectPresent map[string]string
expectAbsent []string
}{
{
name: "valid centralized",
applyMocks: func(t *testing.T, client *mock.MockDRPCClient, pool *mock.MockPooler, mockH *mockHandler) {
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, in *proto.IsAuthorizedRequest) (*proto.IsAuthorizedResponse, error) {
assert.Equal(t, testKeyID, in.GetKeyId(), "handler must use KeyId for delegated requests")
assert.Empty(t, in.GetKey(), "handler must not set Key for delegated requests")
return &proto.IsAuthorizedResponse{
OwnerId: uuid.NewString(),
ApiKeyId: testKeyID,
Username: "u",
}, nil
})
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)
},
expectStatus: http.StatusOK,
expectHandled: true,
expectAbsent: []string{
"Authorization",
"X-Api-Key",
agplaibridge.HeaderCoderToken,
},
},
{
name: "valid BYOK preserves user credentials",
reqHeaders: map[string]string{
// Marks BYOK; this header must be stripped before
// forwarding upstream.
agplaibridge.HeaderCoderToken: "should-not-be-present",
// The user's own LLM credential; must be preserved.
"Authorization": "Bearer sk-ant-oat01-user-token",
},
applyMocks: func(_ *testing.T, client *mock.MockDRPCClient, pool *mock.MockPooler, mockH *mockHandler) {
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).Return(&proto.IsAuthorizedResponse{
OwnerId: uuid.NewString(),
ApiKeyId: testKeyID,
Username: "u",
}, nil)
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)
},
expectStatus: http.StatusOK,
expectHandled: true,
expectPresent: map[string]string{
"Authorization": "Bearer sk-ant-oat01-user-token",
},
expectAbsent: []string{
agplaibridge.HeaderCoderToken,
},
},
{
name: "invalid",
applyMocks: func(_ *testing.T, client *mock.MockDRPCClient, _ *mock.MockPooler, _ *mockHandler) {
client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).Return(nil, xerrors.New("unknown key"))
},
expectStatus: http.StatusForbidden,
},
}

for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()

srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
mockH := &mockHandler{}
tc.applyMocks(t, client, pool, mockH)

ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/openai/v1/chat/completions", nil)
require.NoError(t, err)
for k, v := range tc.reqHeaders {
req.Header.Set(k, v)
}

rw := httptest.NewRecorder()
srv.ServeHTTP(rw, req)

require.Equal(t, tc.expectStatus, rw.Code)
if tc.expectHandled {
require.NotNil(t, mockH.headersReceived, "downstream handler must be invoked")
for h, v := range tc.expectPresent {
require.Equal(t, v, mockH.headersReceived.Get(h), "header %q must be preserved", h)
}
for _, h := range tc.expectAbsent {
require.Empty(t, mockH.headersReceived.Get(h), "header %q must be stripped", h)
}
} else {
require.Nil(t, mockH.headersReceived, "downstream handler must not be invoked on auth failure")
}
})
}
}

// End-to-end: a real transport factory wired to a real server, with BYOK in
// effect. The delegated key ID identifies the user (no Coder token over the
// wire) while the user's own LLM credentials in Authorization must flow
// through to the downstream handler. The Coder governance token, if set by
// the caller, must be stripped.
func TestServeHTTP_DelegatedAPIKey_BYOK_Integration(t *testing.T) {
t.Parallel()

const (
testKeyID = "abcdef1234"
// nolint:gosec // Fake LLM credential for assertion comparison.
userLLMToken = "Bearer sk-ant-oat01-user-byok-token"
)

srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
mockH := &mockHandler{}

client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, in *proto.IsAuthorizedRequest) (*proto.IsAuthorizedResponse, error) {
assert.Equal(t, testKeyID, in.GetKeyId(), "delegated identity must be carried in KeyId")
assert.Empty(t, in.GetKey(), "Key must not be set on delegated requests")
return &proto.IsAuthorizedResponse{
OwnerId: uuid.NewString(),
ApiKeyId: testKeyID,
Username: "u",
}, nil
})
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)

factory := aibridged.NewTransportFactory(srv)
rt, err := factory.TransportFor(uuid.New(), agplaibridge.SourceAgents)
require.NoError(t, err)

ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/anthropic/v1/messages", nil)
require.NoError(t, err)
// HeaderCoderToken marks the request as BYOK. Its value is irrelevant on
// the delegated path (identity comes from context) and it must be
// stripped before forwarding upstream.
req.Header.Set(agplaibridge.HeaderCoderToken, "ignored-on-delegated-path")
// The user's own LLM credential; must reach the downstream handler.
req.Header.Set("Authorization", userLLMToken)

resp, err := rt.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)

require.NotNil(t, mockH.headersReceived, "downstream handler must be invoked")
require.Equal(t, userLLMToken, mockH.headersReceived.Get("Authorization"),
"user's BYOK credential must be preserved end-to-end")
require.Empty(t, mockH.headersReceived.Get(agplaibridge.HeaderCoderToken),
"Coder governance token must be stripped before forwarding upstream")
}

// End-to-end: a real transport factory wired to a real server. Verifies the
// delegated key ID survives the in-memory round-trip and is treated as the
// authoritative caller identity by the handler, without any HTTP-layer header
// extraction.
func TestServeHTTP_DelegatedAPIKey_Integration(t *testing.T) {
t.Parallel()

const testKeyID = "abcdef1234"

srv, client, pool := newTestServer(t)
conn := &mockDRPCConn{}
client.EXPECT().DRPCConn().AnyTimes().Return(conn)
mockH := &mockHandler{}

client.EXPECT().IsAuthorized(gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, in *proto.IsAuthorizedRequest) (*proto.IsAuthorizedResponse, error) {
assert.Equal(t, testKeyID, in.GetKeyId())
assert.Empty(t, in.GetKey())
return &proto.IsAuthorizedResponse{
OwnerId: uuid.NewString(),
ApiKeyId: testKeyID,
Username: "u",
}, nil
})
pool.EXPECT().Acquire(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Return(mockH, nil)

factory := aibridged.NewTransportFactory(srv)
rt, err := factory.TransportFor(uuid.New(), agplaibridge.SourceAgents)
require.NoError(t, err)

ctx := agplaibridge.WithDelegatedAPIKeyID(testutil.Context(t, testutil.WaitShort), testKeyID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, "http://aibridge/openai/v1/chat/completions", nil)
require.NoError(t, err)

resp, err := rt.RoundTrip(req)
require.NoError(t, err)
defer resp.Body.Close()

require.Equal(t, http.StatusOK, resp.StatusCode)
require.NotNil(t, mockH.headersReceived, "downstream handler must observe the delegated request")
}

func TestServeHTTP_StripCoderToken(t *testing.T) {
t.Parallel()

Expand Down
81 changes: 57 additions & 24 deletions coderd/aibridged/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,56 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
authMode = "byok"
}

key := strings.TrimSpace(agplaibridge.ExtractAuthToken(r.Header))
if key == "" {
// Some clients (e.g. Claude) send a HEAD request
// without credentials to check connectivity.
if r.Method == http.MethodHead {
logger.Info(ctx, "unauthenticated HEAD request")
} else {
logger.Warn(ctx, "no auth key provided")
// When the request arrived via the in-process transport, the caller
// has placed a delegated API key ID on the context. We trust that the
// caller already established the user's identity and only validate
// liveness; the caller does not have (and cannot send) the key secret.
// Delegation is orthogonal to BYOK: a delegated request still carries
// the user's own LLM credentials in Authorization/X-Api-Key when BYOK
// is in effect.
var (
authReq *proto.IsAuthorizedRequest
sessionKey string
delegated bool
)
if delegatedID, ok := agplaibridge.DelegatedAPIKeyIDFromContext(ctx); ok {
Comment thread
dannykopping marked this conversation as resolved.
authReq = &proto.IsAuthorizedRequest{KeyId: delegatedID}
delegated = true
// SessionKey is consumed only by the injected MCP path, which is
// not available to delegated callers (they have no secret).
Comment thread
dannykopping marked this conversation as resolved.
} else {
key := strings.TrimSpace(agplaibridge.ExtractAuthToken(r.Header))
if key == "" {
// Some clients (e.g. Claude) send a HEAD request
// without credentials to check connectivity.
if r.Method == http.MethodHead {
logger.Info(ctx, "unauthenticated HEAD request")
} else {
logger.Warn(ctx, "no auth key provided")
}
http.Error(rw, ErrNoAuthKey.Error(), http.StatusBadRequest)
return
}
http.Error(rw, ErrNoAuthKey.Error(), http.StatusBadRequest)
return
authReq = &proto.IsAuthorizedRequest{Key: key}
sessionKey = key
}

// Strip every header that may carry the Coder token so it is
// never forwarded to upstream providers. After stripping, the
// aibridge library can treat the request as a normal LLM API call
// with no Coder-specific information.
// Strip every header that may carry the Coder token so it is never
// forwarded to upstream providers. Runs for both header-auth and
// delegated requests: a delegated caller may forward the user's BYOK
// headers, and we still want to scrub any Coder-specific credentials
// that may have leaked through. After stripping, the aibridge library
// can treat the request as a normal LLM API call with no
// Coder-specific information.
if byok {
// In BYOK mode the token is in X-Coder-AI-Governance-Token;
// Authorization and X-Api-Key carry the user's own LLM credentials
// and must be preserved.
// In BYOK mode the Coder token is in X-Coder-AI-Governance-Token;
// Authorization and X-Api-Key carry the user's own LLM
// credentials and must be preserved.
r.Header.Del(agplaibridge.HeaderCoderToken)
} else {
// In centralized mode the token may be in Authorization (the
// documented path) or X-Api-Key (legacy clients that set
// ANTHROPIC_API_KEY to their Coder token). Both are
// stripped.
// In centralized mode the Coder token may be in Authorization
// (the documented path) or X-Api-Key (legacy clients that set
// ANTHROPIC_API_KEY to their Coder token). Both are stripped.
r.Header.Del("Authorization")
r.Header.Del("X-Api-Key")
}
Expand All @@ -94,9 +117,19 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
return
}

resp, err := client.IsAuthorized(ctx, &proto.IsAuthorizedRequest{Key: key})
// Attach auth attributes used by all log lines below. "source" is the
// transport origin (e.g., "agents" for in-process callers, empty for
// network callers); "auth_delegated" distinguishes header-based from
// context-delegated authentication.
logger = logger.With(
slog.F("source", string(agplaibridge.SourceFromContext(ctx))),
slog.F("auth_mode", authMode),
slog.F("auth_delegated", delegated),
)

resp, err := client.IsAuthorized(ctx, authReq)
if err != nil {
logger.Warn(ctx, "key authorization check failed", slog.Error(err), slog.F("auth_mode", authMode))
logger.Warn(ctx, "key authorization check failed", slog.Error(err))
http.Error(rw, ErrUnauthorized.Error(), http.StatusForbidden)
return
}
Expand All @@ -118,7 +151,7 @@ func (s *Server) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
}

handler, err := s.GetRequestHandler(ctx, Request{
SessionKey: key,
SessionKey: sessionKey,
APIKeyID: resp.ApiKeyId,
InitiatorID: id,
})
Expand Down
18 changes: 13 additions & 5 deletions coderd/aibridged/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,20 @@ func (m *MCPProxyFactory) retrieveMCPServerConfigs(ctx context.Context, req Requ
proxiers := make(map[string]mcp.ServerProxier, len(mcpSrvCfgs.GetExternalAuthMcpConfigs())+1) // Extra one for Coder MCP server.

if mcpSrvCfgs.GetCoderMcpConfig() != nil {
// Setup the Coder MCP server proxy.
coderMCPProxy, err := m.newStreamableHTTPServerProxy(mcpSrvCfgs.GetCoderMcpConfig(), req.SessionKey) // The session key is used to auth against our internal MCP server.
if err != nil {
m.logger.Warn(ctx, "failed to create MCP server proxy", slog.F("mcp_server_id", mcpSrvCfgs.GetCoderMcpConfig().GetId()), slog.Error(err))
// Delegated callers (e.g., chatd) do not hold the user's API key
// secret and so cannot authenticate against the Coder MCP server.
// Skip the proxy in that case rather than attempting a connection
// with an empty bearer token, which will fail upstream.
if req.SessionKey == "" {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P3 [CRF-12] The SessionKey == "" guard was added to fix CRF-4, but no test exercises this branch. The integration tests use pool mocks that bypass MCP; mcp_internal_test.go only tests regex compilation; pool_test.go always passes SessionKey: "key".

If someone reverts this guard, delegated requests would silently attempt MCP proxy creation with empty bearer tokens. The upstream server would reject (401), and the proxy factory would log a warning and proceed without the Coder MCP server. Not catastrophic, but a regression in fix-commit behavior that would go undetected.

(Bisky P3)

🤖

m.logger.Debug(ctx, "skipping Coder MCP server proxy: no session key (delegated request)", slog.F("mcp_server_id", mcpSrvCfgs.GetCoderMcpConfig().GetId()))
} else {
proxiers[InternalMCPServerID] = coderMCPProxy
// Setup the Coder MCP server proxy.
coderMCPProxy, err := m.newStreamableHTTPServerProxy(mcpSrvCfgs.GetCoderMcpConfig(), req.SessionKey) // The session key is used to auth against our internal MCP server.
if err != nil {
m.logger.Warn(ctx, "failed to create MCP server proxy", slog.F("mcp_server_id", mcpSrvCfgs.GetCoderMcpConfig().GetId()), slog.Error(err))
} else {
proxiers[InternalMCPServerID] = coderMCPProxy
}
}
}

Expand Down
Loading
Loading