From 20cc4cefb8d090f720512ac6ce5eaece13765fef Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 16 Jun 2026 19:05:00 +0200 Subject: [PATCH 1/3] feat(oauth): wire stdio OAuth 2.1 login into the server Connect the internal/oauth core library to the stdio MCP server so users can authenticate with an OAuth App or GitHub App client ID instead of a static personal access token. - BearerAuthTransport gains a TokenProvider that is consulted per request, letting the lazily-acquired, auto-refreshing OAuth token take effect without rebuilding the client. - createGitHubClients uses BearerAuthTransport (and skips go-github's WithAuthToken, which would pin a static token) when a TokenProvider is set. - RunStdioServer starts without a token and installs receiving middleware that runs the authorization flow on the first tool call, surfacing the auth URL or device code via elicitation (or a tool result as a fallback). - Tool filtering uses the requested OAuth scopes; the default supported set hides nothing, while a narrower --oauth-scopes both narrows the grant and filters tools accordingly. - A sessionPrompter adapts the MCP server session to oauth.Prompter, keeping the authorization URL off the model's context. - New stdio flags: --oauth-client-id/-client-secret/-scopes/-callback-port. This is stdio-only and deliberately does not touch MCP-HTTP auth. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- cmd/github-mcp-server/main.go | 42 +++- internal/ghmcp/oauth.go | 128 ++++++++++++ internal/ghmcp/oauth_test.go | 329 ++++++++++++++++++++++++++++++ internal/ghmcp/server.go | 71 +++++-- pkg/github/server.go | 5 + pkg/http/transport/bearer.go | 12 +- pkg/http/transport/bearer_test.go | 164 +++++++++++++++ 7 files changed, 736 insertions(+), 15 deletions(-) create mode 100644 internal/ghmcp/oauth.go create mode 100644 internal/ghmcp/oauth_test.go create mode 100644 pkg/http/transport/bearer_test.go diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 604556692c..b329b5012d 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -8,8 +8,10 @@ import ( "time" "github.com/github/github-mcp-server/internal/ghmcp" + "github.com/github/github-mcp-server/internal/oauth" "github.com/github/github-mcp-server/pkg/github" ghhttp "github.com/github/github-mcp-server/pkg/http" + ghoauth "github.com/github/github-mcp-server/pkg/http/oauth" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" @@ -34,8 +36,9 @@ var ( Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, RunE: func(_ *cobra.Command, _ []string) error { token := viper.GetString("personal_access_token") - if token == "" { - return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set") + oauthClientID := viper.GetString("oauth-client-id") + if token == "" && oauthClientID == "" { + return errors.New("authentication required: set GITHUB_PERSONAL_ACCESS_TOKEN, or pass --oauth-client-id to log in via OAuth") } // If you're wondering why we're not using viper.GetStringSlice("toolsets"), @@ -95,6 +98,29 @@ var ( ExcludeTools: excludeTools, RepoAccessCacheTTL: &ttl, } + + // When no static token is provided, log in via OAuth using the given + // client. The requested scopes default to the full supported set + // (which filters out no tools); an explicit, narrower --oauth-scopes + // both narrows the grant and hides tools needing other scopes. + if token == "" { + scopes := ghoauth.SupportedScopes + if viper.IsSet("oauth-scopes") { + if err := viper.UnmarshalKey("oauth-scopes", &scopes); err != nil { + return fmt.Errorf("failed to unmarshal oauth-scopes: %w", err) + } + } + oauthConfig := oauth.NewGitHubConfig( + oauthClientID, + viper.GetString("oauth-client-secret"), + scopes, + viper.GetString("host"), + viper.GetInt("oauth-callback-port"), + ) + stdioServerConfig.OAuthManager = oauth.NewManager(oauthConfig, nil) + stdioServerConfig.OAuthScopes = scopes + } + return ghmcp.RunStdioServer(stdioServerConfig) }, } @@ -183,6 +209,14 @@ func init() { rootCmd.PersistentFlags().Bool("insiders", false, "Enable insiders features") rootCmd.PersistentFlags().Duration("repo-access-cache-ttl", 5*time.Minute, "Override the repo access cache TTL (e.g. 1m, 0s to disable)") + // stdio-specific OAuth flags. Provide --oauth-client-id (instead of a token) + // to log in via the browser-based OAuth flow on first use. Works for both + // OAuth Apps and GitHub Apps. + stdioCmd.Flags().String("oauth-client-id", "", "OAuth App or GitHub App client ID, enabling interactive OAuth login when no token is set") + stdioCmd.Flags().String("oauth-client-secret", "", "OAuth client secret, if the app requires one (it is a public, non-confidential credential for distributed clients)") + stdioCmd.Flags().StringSlice("oauth-scopes", nil, "Comma-separated OAuth scopes to request; also filters tools to those scopes. Defaults to the full supported set") + stdioCmd.Flags().Int("oauth-callback-port", 0, "Fixed local port for the OAuth callback server. Defaults to a random port; set a fixed port when mapping it through Docker") + // HTTP-specific flags httpCmd.Flags().Int("port", 8082, "HTTP server port") httpCmd.Flags().String("listen-host", "", "Host the HTTP server binds to (e.g. 127.0.0.1). Empty binds to all interfaces.") @@ -205,6 +239,10 @@ func init() { _ = viper.BindPFlag("lockdown-mode", rootCmd.PersistentFlags().Lookup("lockdown-mode")) _ = viper.BindPFlag("insiders", rootCmd.PersistentFlags().Lookup("insiders")) _ = viper.BindPFlag("repo-access-cache-ttl", rootCmd.PersistentFlags().Lookup("repo-access-cache-ttl")) + _ = viper.BindPFlag("oauth-client-id", stdioCmd.Flags().Lookup("oauth-client-id")) + _ = viper.BindPFlag("oauth-client-secret", stdioCmd.Flags().Lookup("oauth-client-secret")) + _ = viper.BindPFlag("oauth-scopes", stdioCmd.Flags().Lookup("oauth-scopes")) + _ = viper.BindPFlag("oauth-callback-port", stdioCmd.Flags().Lookup("oauth-callback-port")) _ = viper.BindPFlag("port", httpCmd.Flags().Lookup("port")) _ = viper.BindPFlag("listen-host", httpCmd.Flags().Lookup("listen-host")) _ = viper.BindPFlag("base-url", httpCmd.Flags().Lookup("base-url")) diff --git a/internal/ghmcp/oauth.go b/internal/ghmcp/oauth.go new file mode 100644 index 0000000000..6a1d388956 --- /dev/null +++ b/internal/ghmcp/oauth.go @@ -0,0 +1,128 @@ +package ghmcp + +import ( + "context" + "crypto/rand" + "fmt" + "log/slog" + + "github.com/github/github-mcp-server/internal/oauth" + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +// sessionPrompter adapts an MCP server session to oauth.Prompter, presenting +// authorization prompts to the user via elicitation. Keeping the prompt on the +// MCP control channel (rather than a tool result) keeps the authorization URL +// and any session-bound state out of the model's context. +type sessionPrompter struct { + session *mcp.ServerSession +} + +// elicitationCaps returns the client's declared elicitation capabilities, or nil +// if the client did not advertise any. +func (p *sessionPrompter) elicitationCaps() *mcp.ElicitationCapabilities { + params := p.session.InitializeParams() + if params == nil || params.Capabilities == nil { + return nil + } + return params.Capabilities.Elicitation +} + +// CanPromptURL reports whether the client supports URL-mode elicitation. +func (p *sessionPrompter) CanPromptURL() bool { + caps := p.elicitationCaps() + return caps != nil && caps.URL != nil +} + +// PromptURL presents the authorization URL via URL-mode elicitation and blocks +// until the user acknowledges, declines, or ctx is done. +func (p *sessionPrompter) PromptURL(ctx context.Context, prompt oauth.Prompt) error { + res, err := p.session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "url", + Message: prompt.Message, + URL: prompt.URL, + ElicitationID: rand.Text(), + }) + if err != nil { + return err + } + if res.Action != "accept" { + return oauth.ErrPromptDeclined + } + return nil +} + +// CanPromptForm reports whether the client supports form-mode elicitation. The +// SDK treats a client that advertises neither form nor URL capabilities as +// supporting forms, for backward compatibility, so we mirror that here. +func (p *sessionPrompter) CanPromptForm() bool { + caps := p.elicitationCaps() + if caps == nil { + return false + } + return caps.Form != nil || caps.URL == nil +} + +// PromptForm presents a textual acknowledgement (used to display a device code +// when URL elicitation is unavailable) and blocks until the user responds. +func (p *sessionPrompter) PromptForm(ctx context.Context, prompt oauth.Prompt) error { + res, err := p.session.Elicit(ctx, &mcp.ElicitParams{ + Mode: "form", + Message: prompt.Message, + }) + if err != nil { + return err + } + if res.Action != "accept" { + return oauth.ErrPromptDeclined + } + return nil +} + +// oauthAuthenticator is the subset of *oauth.Manager that the middleware needs. +// Depending on the interface (rather than the concrete manager) lets the +// middleware be exercised with a deterministic fake, since driving the real +// manager to its branches would require standing up live GitHub flows. +type oauthAuthenticator interface { + HasToken() bool + Authenticate(ctx context.Context, prompter oauth.Prompter) (*oauth.Outcome, error) +} + +// createOAuthMiddleware returns receiving middleware that authorizes the session +// lazily, on the first tool call. Authorization is deferred until here (rather +// than at startup) because the prompts depend on an initialized session whose +// elicitation capabilities are known. +// +// When a token is already available the call proceeds untouched. Otherwise the +// flow runs: secure channels (browser, URL elicitation) block until the token +// arrives and then the call proceeds; the last-resort channel returns the +// instruction to the user as a tool result and asks them to retry. +func createOAuthMiddleware(mgr oauthAuthenticator, logger *slog.Logger) func(next mcp.MethodHandler) mcp.MethodHandler { + return func(next mcp.MethodHandler) mcp.MethodHandler { + return func(ctx context.Context, method string, request mcp.Request) (mcp.Result, error) { + if method != "tools/call" || mgr.HasToken() { + return next(ctx, method, request) + } + + callReq, ok := request.(*mcp.CallToolRequest) + if !ok { + return next(ctx, method, request) + } + + outcome, err := mgr.Authenticate(ctx, &sessionPrompter{session: callReq.Session}) + if err != nil { + return nil, fmt.Errorf("github authorization failed: %w", err) + } + if outcome != nil && outcome.UserAction != nil { + logger.Info("surfacing github authorization instructions to user") + return &mcp.CallToolResult{ + Content: []mcp.Content{&mcp.TextContent{Text: outcome.UserAction.Message}}, + }, nil + } + return next(ctx, method, request) + } + } +} + +// ensure sessionPrompter satisfies the Prompter contract. +var _ oauth.Prompter = (*sessionPrompter)(nil) diff --git a/internal/ghmcp/oauth_test.go b/internal/ghmcp/oauth_test.go new file mode 100644 index 0000000000..4f370cf7bc --- /dev/null +++ b/internal/ghmcp/oauth_test.go @@ -0,0 +1,329 @@ +package ghmcp + +import ( + "context" + "io" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + + "github.com/github/github-mcp-server/internal/oauth" + "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/http/headers" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/modelcontextprotocol/go-sdk/mcp" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func discardLogger() *slog.Logger { + return slog.New(slog.NewTextHandler(io.Discard, nil)) +} + +// probeToolName is the name of the throwaway tool the harness registers; its +// handler runs a probe closure against a sessionPrompter so the adapter can be +// exercised against a real, fully-negotiated server session from the client side. +const probeToolName = "probe" + +// runProbe stands up an in-memory MCP client/server pair, registers a tool whose +// handler runs probe against a sessionPrompter wrapping the live server session, +// and returns the text the probe produced. The client is configured with the +// given capabilities and elicitation handler so the adapter sees a real, +// fully-negotiated session rather than a hand-built fake. +func runProbe( + t *testing.T, + clientCaps *mcp.ClientCapabilities, + elicitationHandler func(context.Context, *mcp.ElicitRequest) (*mcp.ElicitResult, error), + probe func(context.Context, *sessionPrompter) string, +) string { + t.Helper() + + server := mcp.NewServer(&mcp.Implementation{Name: "test-server", Version: "v0.0.1"}, nil) + mcp.AddTool(server, &mcp.Tool{Name: probeToolName}, func(ctx context.Context, req *mcp.CallToolRequest, _ struct{}) (*mcp.CallToolResult, any, error) { + text := probe(ctx, &sessionPrompter{session: req.Session}) + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: text}}}, nil, nil + }) + + st, ct := mcp.NewInMemoryTransports() + + ss, err := server.Connect(context.Background(), st, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = ss.Close() }) + + client := mcp.NewClient(&mcp.Implementation{Name: "test-client", Version: "v0.0.1"}, &mcp.ClientOptions{ + Capabilities: clientCaps, + ElicitationHandler: elicitationHandler, + }) + cs, err := client.Connect(context.Background(), ct, nil) + require.NoError(t, err) + t.Cleanup(func() { _ = cs.Close() }) + + res, err := cs.CallTool(context.Background(), &mcp.CallToolParams{Name: probeToolName}) + require.NoError(t, err) + require.Len(t, res.Content, 1) + text, ok := res.Content[0].(*mcp.TextContent) + require.True(t, ok, "probe result should be text content") + return text.Text +} + +func TestSessionPrompterCapabilities(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + caps *mcp.ClientCapabilities + wantURL bool + wantForm bool + }{ + { + name: "no elicitation advertised", + caps: &mcp.ClientCapabilities{}, + wantURL: false, + wantForm: false, + }, + { + name: "url only", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{URL: &mcp.URLElicitationCapabilities{}}}, + wantURL: true, + wantForm: false, + }, + { + name: "form only", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{Form: &mcp.FormElicitationCapabilities{}}}, + wantURL: false, + wantForm: true, + }, + { + name: "url and form", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{URL: &mcp.URLElicitationCapabilities{}, Form: &mcp.FormElicitationCapabilities{}}}, + wantURL: true, + wantForm: true, + }, + { + name: "empty elicitation capability implies form for backward compatibility", + caps: &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{}}, + wantURL: false, + wantForm: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + got := runProbe(t, tc.caps, nil, func(_ context.Context, p *sessionPrompter) string { + if p.CanPromptURL() { + if p.CanPromptForm() { + return "url+form" + } + return "url" + } + if p.CanPromptForm() { + return "form" + } + return "none" + }) + + want := "none" + switch { + case tc.wantURL && tc.wantForm: + want = "url+form" + case tc.wantURL: + want = "url" + case tc.wantForm: + want = "form" + } + assert.Equal(t, want, got) + }) + } +} + +func TestSessionPrompterPromptActions(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + action string + wantDecline bool + }{ + {name: "accept", action: "accept", wantDecline: false}, + {name: "decline", action: "decline", wantDecline: true}, + {name: "cancel", action: "cancel", wantDecline: true}, + } + + caps := &mcp.ClientCapabilities{Elicitation: &mcp.ElicitationCapabilities{ + URL: &mcp.URLElicitationCapabilities{}, + Form: &mcp.FormElicitationCapabilities{}, + }} + + for _, tc := range tests { + // URL and form modes share the accept/decline mapping; cover both. + for _, mode := range []string{"url", "form"} { + t.Run(tc.name+"/"+mode, func(t *testing.T) { + t.Parallel() + + handler := func(_ context.Context, _ *mcp.ElicitRequest) (*mcp.ElicitResult, error) { + return &mcp.ElicitResult{Action: tc.action}, nil + } + + got := runProbe(t, caps, handler, func(ctx context.Context, p *sessionPrompter) string { + var err error + if mode == "url" { + err = p.PromptURL(ctx, oauth.Prompt{Message: "msg", URL: "https://example.com/auth"}) + } else { + err = p.PromptForm(ctx, oauth.Prompt{Message: "msg"}) + } + if err == nil { + return "ok" + } + if err == oauth.ErrPromptDeclined { + return "declined" + } + return "error: " + err.Error() + }) + + if tc.wantDecline { + assert.Equal(t, "declined", got) + } else { + assert.Equal(t, "ok", got) + } + }) + } + } +} + +// fakeAuthenticator is a deterministic stand-in for *oauth.Manager that lets the +// middleware be tested at each branch without standing up live GitHub flows. +type fakeAuthenticator struct { + hasToken bool + outcome *oauth.Outcome + err error + authCalls int + lastPrompter oauth.Prompter +} + +func (f *fakeAuthenticator) HasToken() bool { return f.hasToken } + +func (f *fakeAuthenticator) Authenticate(_ context.Context, prompter oauth.Prompter) (*oauth.Outcome, error) { + f.authCalls++ + f.lastPrompter = prompter + return f.outcome, f.err +} + +func TestCreateOAuthMiddleware(t *testing.T) { + t.Parallel() + + const nextText = "handler-ran" + newNext := func(called *bool) mcp.MethodHandler { + return func(_ context.Context, _ string, _ mcp.Request) (mcp.Result, error) { + *called = true + return &mcp.CallToolResult{Content: []mcp.Content{&mcp.TextContent{Text: nextText}}}, nil + } + } + + t.Run("non tool call passes through without authenticating", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "initialize", &mcp.InitializeRequest{}) + require.NoError(t, err) + assert.True(t, called, "next should run") + assert.Zero(t, fake.authCalls, "authentication must not run for non tool calls") + }) + + t.Run("existing token short circuits authentication", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: true} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.True(t, called, "next should run") + assert.Zero(t, fake.authCalls, "authentication must be skipped when a token already exists") + }) + + t.Run("successful authentication proceeds to handler", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false, outcome: nil, err: nil} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + res, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.Equal(t, 1, fake.authCalls) + assert.True(t, called, "next should run once authorized") + callRes, ok := res.(*mcp.CallToolResult) + require.True(t, ok) + require.Len(t, callRes.Content, 1) + assert.Equal(t, nextText, callRes.Content[0].(*mcp.TextContent).Text) + }) + + t.Run("pending user action is surfaced as a tool result", func(t *testing.T) { + t.Parallel() + const message = "Open https://example.com/auth to authorize, then retry." + fake := &fakeAuthenticator{hasToken: false, outcome: &oauth.Outcome{UserAction: &oauth.UserAction{Message: message}}} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + res, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.NoError(t, err) + assert.False(t, called, "next must not run while the user still needs to authorize") + callRes, ok := res.(*mcp.CallToolResult) + require.True(t, ok) + require.Len(t, callRes.Content, 1) + assert.Equal(t, message, callRes.Content[0].(*mcp.TextContent).Text) + }) + + t.Run("authentication error is returned", func(t *testing.T) { + t.Parallel() + fake := &fakeAuthenticator{hasToken: false, err: assert.AnError} + var called bool + mw := createOAuthMiddleware(fake, discardLogger()) + _, err := mw(newNext(&called))(context.Background(), "tools/call", &mcp.CallToolRequest{}) + require.Error(t, err) + assert.ErrorIs(t, err, assert.AnError) + assert.False(t, called, "next must not run when authentication fails") + }) +} + +// TestCreateGitHubClientsTokenProvider proves the OAuth wiring: when a +// TokenProvider is configured the REST client authenticates with the provider's +// current token on every request (and never pins a stale one), which is what the +// lazy, refreshing OAuth token depends on. +func TestCreateGitHubClientsTokenProvider(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + current := "" + apiHost, err := utils.NewAPIHost(server.URL) + require.NoError(t, err) + + clients, err := createGitHubClients(github.MCPServerConfig{ + Version: "test", + TokenProvider: func() string { return current }, + }, apiHost) + require.NoError(t, err) + + do := func() { + resp, err := clients.rest.Client().Get(server.URL) + require.NoError(t, err) + defer resp.Body.Close() + } + + do() + assert.Equal(t, "Bearer", gotAuth, "no token before authorization") + + current = "oauth-token" + do() + assert.Equal(t, "Bearer oauth-token", gotAuth, "provider token used once available") + + current = "refreshed-token" + do() + assert.Equal(t, "Bearer refreshed-token", gotAuth, "refreshed provider token used") +} diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index a37c4d940d..2364b02688 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -12,6 +12,7 @@ import ( "syscall" "time" + "github.com/github/github-mcp-server/internal/oauth" "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" "github.com/github/github-mcp-server/pkg/http/transport" @@ -61,16 +62,30 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv return nil, fmt.Errorf("failed to get Raw URL: %w", err) } - // Construct REST client + // Construct REST client. When a TokenProvider is configured (OAuth), we + // authenticate via BearerAuthTransport and skip go-github's WithAuthToken: + // the latter installs its own round tripper that would pin the static token + // and shadow the dynamic one. restUATransport := &transport.UserAgentTransport{ Transport: http.DefaultTransport, Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } - restClient, err := gogithub.NewClient( - gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), - gogithub.WithAuthToken(cfg.Token), - gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), - ) + var restClient *gogithub.Client + if cfg.TokenProvider != nil { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: &transport.BearerAuthTransport{ + Transport: restUATransport, + TokenProvider: cfg.TokenProvider, + }}), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } else { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), + gogithub.WithAuthToken(cfg.Token), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } if err != nil { return nil, fmt.Errorf("failed to create REST client: %w", err) } @@ -82,7 +97,8 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv Transport: &transport.GraphQLFeaturesTransport{ Transport: http.DefaultTransport, }, - Token: cfg.Token, + Token: cfg.Token, + TokenProvider: cfg.TokenProvider, }, } @@ -229,6 +245,18 @@ type StdioServerConfig struct { // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration + + // OAuthManager, when non-nil, enables OAuth 2.1 login for stdio mode. The + // server starts without a token and runs the authorization flow on the + // first tool call (see createOAuthMiddleware). It is mutually exclusive with + // a static Token. + OAuthManager *oauth.Manager + + // OAuthScopes are the scopes requested during OAuth login. They double as + // the scope set for tool filtering: tools requiring a scope outside this set + // are hidden. The default set is the full supported list, which hides + // nothing; an explicit, narrower list filters accordingly. + OAuthScopes []string } // RunStdioServer is not concurrent safe. @@ -255,11 +283,13 @@ func RunStdioServer(cfg StdioServerConfig) error { logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) - // Fetch token scopes for scope-based tool filtering (PAT tokens only) - // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. - // Fine-grained PATs and other token types don't support this, so we skip filtering. + // Determine the scope set used to filter tools. Classic PATs expose their + // granted scopes via the API; OAuth uses the requested scopes (the default + // set hides nothing, a narrower explicit set filters accordingly). Other + // token types don't advertise scopes, so filtering is skipped. var tokenScopes []string - if strings.HasPrefix(cfg.Token, "ghp_") { + switch { + case strings.HasPrefix(cfg.Token, "ghp_"): fetchedScopes, err := fetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) if err != nil { logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) @@ -267,10 +297,20 @@ func RunStdioServer(cfg StdioServerConfig) error { tokenScopes = fetchedScopes logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) } - } else { + case cfg.OAuthManager != nil: + tokenScopes = cfg.OAuthScopes + logger.Info("using requested OAuth scopes for tool filtering", "scopes", tokenScopes) + default: logger.Debug("skipping scope filtering for non-PAT token") } + // For OAuth, the token is resolved lazily: empty until the user authorizes + // on the first tool call, then refreshed for the rest of the session. + var tokenProvider func() string + if cfg.OAuthManager != nil { + tokenProvider = cfg.OAuthManager.AccessToken + } + ghServer, err := NewStdioMCPServer(ctx, github.MCPServerConfig{ Version: cfg.Version, Host: cfg.Host, @@ -287,11 +327,18 @@ func RunStdioServer(cfg StdioServerConfig) error { Logger: logger, RepoAccessTTL: cfg.RepoAccessCacheTTL, TokenScopes: tokenScopes, + TokenProvider: tokenProvider, }) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) } + // With OAuth, intercept tool calls to run the authorization flow on first + // use, before the handler tries to call GitHub with an empty token. + if cfg.OAuthManager != nil { + ghServer.AddReceivingMiddleware(createOAuthMiddleware(cfg.OAuthManager, logger)) + } + if cfg.ExportTranslations { // Once server is initialized, all translations are loaded dumpTranslations() diff --git a/pkg/github/server.go b/pkg/github/server.go index 7ec5837c3a..627cc678b2 100644 --- a/pkg/github/server.go +++ b/pkg/github/server.go @@ -68,6 +68,11 @@ type MCPServerConfig struct { // This is used for PAT scope filtering where we can't issue scope challenges. TokenScopes []string + // TokenProvider, when non-nil, supplies the GitHub token for each API + // request instead of the static Token. It backs OAuth login, where the + // token is obtained lazily on first use and refreshed thereafter. + TokenProvider func() string + // Additional server options to apply ServerOptions []MCPServerOption } diff --git a/pkg/http/transport/bearer.go b/pkg/http/transport/bearer.go index 66922bbdaa..9be3fd5342 100644 --- a/pkg/http/transport/bearer.go +++ b/pkg/http/transport/bearer.go @@ -11,11 +11,21 @@ import ( type BearerAuthTransport struct { Transport http.RoundTripper Token string + + // TokenProvider, when non-nil, supplies the bearer token for each request + // and takes precedence over Token. It backs OAuth, where the token is + // obtained after the client is built and is refreshed over the session's + // lifetime. It may return an empty string before authorization completes. + TokenProvider func() string } func (t *BearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, error) { req = req.Clone(req.Context()) - req.Header.Set(headers.AuthorizationHeader, "Bearer "+t.Token) + token := t.Token + if t.TokenProvider != nil { + token = t.TokenProvider() + } + req.Header.Set(headers.AuthorizationHeader, "Bearer "+token) // Check for GraphQL-Features in context and add header if present if features := ghcontext.GetGraphQLFeatures(req.Context()); len(features) > 0 { diff --git a/pkg/http/transport/bearer_test.go b/pkg/http/transport/bearer_test.go new file mode 100644 index 0000000000..550144b866 --- /dev/null +++ b/pkg/http/transport/bearer_test.go @@ -0,0 +1,164 @@ +package transport + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + ghcontext "github.com/github/github-mcp-server/pkg/context" + "github.com/github/github-mcp-server/pkg/http/headers" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestBearerAuthTransport(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + token string + tokenProvider func() string + wantAuth string + }{ + { + name: "static token", + token: "static-token", + wantAuth: "Bearer static-token", + }, + { + name: "token provider takes precedence over static token", + token: "static-token", + tokenProvider: func() string { return "provided-token" }, + wantAuth: "Bearer provided-token", + }, + { + name: "token provider with empty static token", + tokenProvider: func() string { return "provided-token" }, + wantAuth: "Bearer provided-token", + }, + { + name: "token provider may return empty before authorization", + tokenProvider: func() string { return "" }, + wantAuth: "Bearer", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: tc.token, + TokenProvider: tc.tokenProvider, + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, tc.wantAuth, gotAuth) + }) + } +} + +// TestBearerAuthTransport_TokenProviderResolvedPerRequest verifies that the +// token provider is consulted on every request, so a token that arrives (or is +// refreshed) after the transport is constructed takes effect without rebuilding +// the client. This is the property OAuth relies on. +func TestBearerAuthTransport_TokenProviderResolvedPerRequest(t *testing.T) { + t.Parallel() + + var gotAuth string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotAuth = r.Header.Get(headers.AuthorizationHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + current := "" + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + TokenProvider: func() string { return current }, + } + + do := func() { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + } + + do() + assert.Equal(t, "Bearer", gotAuth, "no token yet before authorization") + + current = "first-token" + do() + assert.Equal(t, "Bearer first-token", gotAuth, "token picked up once available") + + current = "refreshed-token" + do() + assert.Equal(t, "Bearer refreshed-token", gotAuth, "refreshed token picked up") +} + +func TestBearerAuthTransport_PassesGraphQLFeaturesHeader(t *testing.T) { + t.Parallel() + + var gotFeatures string + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + gotFeatures = r.Header.Get(headers.GraphQLFeaturesHeader) + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: "token", + } + + ctx := ghcontext.WithGraphQLFeatures(context.Background(), "feature1", "feature2") + req, err := http.NewRequestWithContext(ctx, http.MethodPost, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, "feature1, feature2", gotFeatures) +} + +func TestBearerAuthTransport_DoesNotMutateOriginalRequest(t *testing.T) { + t.Parallel() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer server.Close() + + rt := &BearerAuthTransport{ + Transport: http.DefaultTransport, + Token: "token", + } + + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Empty(t, req.Header.Get(headers.AuthorizationHeader), "original request must not be mutated") +} From 622d429004373b3f508ed1ccbef164a1c5973e84 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Tue, 16 Jun 2026 19:14:33 +0200 Subject: [PATCH 2/3] =?UTF-8?q?refactor(oauth):=20address=20review=20?= =?UTF-8?q?=E2=80=94=20omit=20empty=20bearer=20header,=20guard=20token/oau?= =?UTF-8?q?th?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - BearerAuthTransport omits the Authorization header entirely when the token is empty (pre-authorization) rather than sending an empty "Bearer " value. - RunStdioServer rejects the ambiguous combination of a static Token and an OAuthManager up front, enforcing the documented mutual exclusivity. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- internal/ghmcp/oauth_test.go | 18 +++++++++++++++++- internal/ghmcp/server.go | 7 +++++++ pkg/http/transport/bearer.go | 6 +++++- pkg/http/transport/bearer_test.go | 4 ++-- 4 files changed, 31 insertions(+), 4 deletions(-) diff --git a/internal/ghmcp/oauth_test.go b/internal/ghmcp/oauth_test.go index 4f370cf7bc..826b000185 100644 --- a/internal/ghmcp/oauth_test.go +++ b/internal/ghmcp/oauth_test.go @@ -286,6 +286,22 @@ func TestCreateOAuthMiddleware(t *testing.T) { }) } +// TestRunStdioServerRejectsTokenAndOAuth verifies the mutually-exclusive guard: +// supplying both a static token and an OAuth manager is rejected before the +// server starts, rather than silently preferring one for auth and the other for +// scope filtering. +func TestRunStdioServerRejectsTokenAndOAuth(t *testing.T) { + t.Parallel() + + mgr := oauth.NewManager(oauth.NewGitHubConfig("client-id", "", nil, "", 0), discardLogger()) + err := RunStdioServer(StdioServerConfig{ + Token: "ghp_static", + OAuthManager: mgr, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "mutually exclusive") +} + // TestCreateGitHubClientsTokenProvider proves the OAuth wiring: when a // TokenProvider is configured the REST client authenticates with the provider's // current token on every request (and never pins a stale one), which is what the @@ -317,7 +333,7 @@ func TestCreateGitHubClientsTokenProvider(t *testing.T) { } do() - assert.Equal(t, "Bearer", gotAuth, "no token before authorization") + assert.Equal(t, "", gotAuth, "no auth header before authorization") current = "oauth-token" do() diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 2364b02688..1bf84453c8 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -261,6 +261,13 @@ type StdioServerConfig struct { // RunStdioServer is not concurrent safe. func RunStdioServer(cfg StdioServerConfig) error { + // OAuth login and a static token are mutually exclusive: they would + // disagree on how the token is sourced (lazy provider vs. static) and on + // scope filtering, so reject the ambiguous combination up front. + if cfg.OAuthManager != nil && cfg.Token != "" { + return fmt.Errorf("OAuthManager and a static Token are mutually exclusive: provide one or the other") + } + // Create app context ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) defer stop() diff --git a/pkg/http/transport/bearer.go b/pkg/http/transport/bearer.go index 9be3fd5342..0c12ddfc91 100644 --- a/pkg/http/transport/bearer.go +++ b/pkg/http/transport/bearer.go @@ -25,7 +25,11 @@ func (t *BearerAuthTransport) RoundTrip(req *http.Request) (*http.Response, erro if t.TokenProvider != nil { token = t.TokenProvider() } - req.Header.Set(headers.AuthorizationHeader, "Bearer "+token) + // Before OAuth authorization completes the token is empty; send an + // unauthenticated request rather than an empty "Bearer " header. + if token != "" { + req.Header.Set(headers.AuthorizationHeader, "Bearer "+token) + } // Check for GraphQL-Features in context and add header if present if features := ghcontext.GetGraphQLFeatures(req.Context()); len(features) > 0 { diff --git a/pkg/http/transport/bearer_test.go b/pkg/http/transport/bearer_test.go index 550144b866..76ef8686cd 100644 --- a/pkg/http/transport/bearer_test.go +++ b/pkg/http/transport/bearer_test.go @@ -41,7 +41,7 @@ func TestBearerAuthTransport(t *testing.T) { { name: "token provider may return empty before authorization", tokenProvider: func() string { return "" }, - wantAuth: "Bearer", + wantAuth: "", }, } @@ -103,7 +103,7 @@ func TestBearerAuthTransport_TokenProviderResolvedPerRequest(t *testing.T) { } do() - assert.Equal(t, "Bearer", gotAuth, "no token yet before authorization") + assert.Equal(t, "", gotAuth, "no auth header before authorization") current = "first-token" do() From 2b4d5e60a67c7bc54821102e80fe8204fa314bb0 Mon Sep 17 00:00:00 2001 From: Sam Morrow Date: Thu, 18 Jun 2026 10:57:54 +0200 Subject: [PATCH 3/3] docs(oauth): clarify SupportedScopes is the stdio default and tool filter Document that stdio OAuth login requests these scopes by default and then filters the exposed tools to the scopes actually granted, so a tool whose required scope is absent from this list is hidden under default OAuth even though a PAT carrying that scope would expose it. Keep the list in sync with tool scope requirements when scopes change. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- pkg/http/oauth/oauth.go | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pkg/http/oauth/oauth.go b/pkg/http/oauth/oauth.go index ffa7669a9d..f7ffe67e6b 100644 --- a/pkg/http/oauth/oauth.go +++ b/pkg/http/oauth/oauth.go @@ -19,7 +19,13 @@ const ( OAuthProtectedResourcePrefix = "/.well-known/oauth-protected-resource" ) -// SupportedScopes lists all OAuth scopes that may be required by MCP tools. +// SupportedScopes lists every OAuth scope that an MCP tool may require. It is the +// source of truth in two places: HTTP mode advertises it as scopes_supported in +// the protected-resource metadata, and stdio OAuth login requests it by default +// and then filters the exposed tools to the granted scopes. A tool whose required +// scope is absent here is therefore hidden under default OAuth even though a PAT +// carrying that scope would expose it, so keep this list in sync with tool scope +// requirements when scopes change. var SupportedScopes = []string{ "repo", "read:org",