From db8935294a4bdfb8952d241018304082e0b0c06d Mon Sep 17 00:00:00 2001 From: pyama Date: Thu, 28 May 2026 08:43:43 +0900 Subject: [PATCH 1/3] feat: add GitHub App authentication support Add native GitHub App authentication as an alternative to Personal Access Tokens. The server can now authenticate using App ID, private key, and installation ID to automatically generate and refresh installation tokens. - Add `pkg/github/appauth` package with JWT generation and installation token management using only the standard library - Auto-refresh tokens before expiry (5-minute buffer on 1-hour tokens) - Support private key via env var (GITHUB_APP_PRIVATE_KEY) or file path (GITHUB_APP_PRIVATE_KEY_PATH) - Handle literal `\n` in env var PEM keys - Add comprehensive tests (13 tests covering key parsing, JWT generation, token caching, refresh, round-trip, and error handling) Closes #1333 --- README.md | 74 ++++++++ cmd/github-mcp-server/main.go | 57 +++++- go.sum | 2 + internal/ghmcp/server.go | 75 ++++++-- pkg/github/appauth/appauth.go | 233 ++++++++++++++++++++++++ pkg/github/appauth/appauth_test.go | 282 +++++++++++++++++++++++++++++ 6 files changed, 708 insertions(+), 15 deletions(-) create mode 100644 pkg/github/appauth/appauth.go create mode 100644 pkg/github/appauth/appauth_test.go diff --git a/README.md b/README.md index bec45b5da3..da4562f50a 100644 --- a/README.md +++ b/README.md @@ -239,6 +239,80 @@ To keep your GitHub PAT secure and reusable across different MCP hosts: +### GitHub App Authentication + +As an alternative to Personal Access Tokens, the MCP server supports authenticating as a [GitHub App](https://docs.github.com/en/apps) installation. This is useful for organizations that want to grant scoped, short-lived access without relying on individual PATs. + +The server automatically generates JWTs, fetches installation tokens, and refreshes them before expiry (installation tokens are valid for 1 hour). + +#### Required Environment Variables + +| Variable | Description | +|---|---| +| `GITHUB_APP_ID` | The GitHub App ID | +| `GITHUB_APP_INSTALLATION_ID` | The installation ID of the GitHub App | +| `GITHUB_APP_PRIVATE_KEY` | The PEM-encoded private key (inline, `\n` for newlines) | +| `GITHUB_APP_PRIVATE_KEY_PATH` | Path to the private key file (alternative to inline) | + +Either `GITHUB_APP_PRIVATE_KEY` or `GITHUB_APP_PRIVATE_KEY_PATH` must be set, but not both. When all three required variables (`GITHUB_APP_ID`, `GITHUB_APP_INSTALLATION_ID`, and a private key) are set, the server uses GitHub App authentication instead of a PAT. `GITHUB_PERSONAL_ACCESS_TOKEN` is not required in this case. + +#### Example: Using a private key file + +```bash +export GITHUB_APP_ID=12345 +export GITHUB_APP_INSTALLATION_ID=67890 +export GITHUB_APP_PRIVATE_KEY_PATH=/path/to/private-key.pem +github-mcp-server stdio +``` + +#### Example: Using an inline private key + +```bash +export GITHUB_APP_ID=12345 +export GITHUB_APP_INSTALLATION_ID=67890 +export GITHUB_APP_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----" +github-mcp-server stdio +``` + +#### Example: Docker with GitHub App authentication + +```bash +docker run -i --rm \ + -e GITHUB_APP_ID=12345 \ + -e GITHUB_APP_INSTALLATION_ID=67890 \ + -e GITHUB_APP_PRIVATE_KEY_PATH=/key/private-key.pem \ + -v /path/to/private-key.pem:/key/private-key.pem:ro \ + ghcr.io/github/github-mcp-server +``` + +#### Example: VS Code configuration + +```json +{ + "mcp": { + "servers": { + "github": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", "GITHUB_APP_ID", + "-e", "GITHUB_APP_INSTALLATION_ID", + "-e", "GITHUB_APP_PRIVATE_KEY", + "ghcr.io/github/github-mcp-server" + ], + "env": { + "GITHUB_APP_ID": "12345", + "GITHUB_APP_INSTALLATION_ID": "67890", + "GITHUB_APP_PRIVATE_KEY": "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----" + } + } + } + } +} +``` + ### GitHub Enterprise Server and Enterprise Cloud with data residency (ghe.com) The flag `--gh-host` and the environment variable `GITHUB_HOST` can be used to set diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 558fdb9980..374a861fa6 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "os" + "strconv" "strings" "time" @@ -34,8 +35,16 @@ var ( Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, RunE: func(_ *cobra.Command, _ []string) error { token := viper.GetString("personal_access_token") - if token == "" { - return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set") + + // Parse GitHub App authentication config + appID, privateKey, installationID, err := parseAppAuthConfig() + if err != nil { + return err + } + useAppAuth := appID != 0 && len(privateKey) > 0 && installationID != 0 + + if token == "" && !useAppAuth { + return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set (or configure GitHub App auth with GITHUB_APP_ID, GITHUB_APP_PRIVATE_KEY/GITHUB_APP_PRIVATE_KEY_PATH, and GITHUB_APP_INSTALLATION_ID)") } // If you're wondering why we're not using viper.GetStringSlice("toolsets"), @@ -94,6 +103,9 @@ var ( InsidersMode: viper.GetBool("insiders"), ExcludeTools: excludeTools, RepoAccessCacheTTL: &ttl, + AppID: appID, + PrivateKey: privateKey, + InstallationID: installationID, } return ghmcp.RunStdioServer(stdioServerConfig) }, @@ -235,3 +247,44 @@ func wordSepNormalizeFunc(_ *pflag.FlagSet, name string) pflag.NormalizedName { } return pflag.NormalizedName(name) } + +// parseAppAuthConfig reads GitHub App authentication config from environment variables. +// Returns (0, nil, 0, nil) when no App auth is configured. +func parseAppAuthConfig() (appID int64, privateKey []byte, installationID int64, err error) { + appIDStr := viper.GetString("app_id") + installationIDStr := viper.GetString("app_installation_id") + privateKeyStr := viper.GetString("app_private_key") + privateKeyPath := viper.GetString("app_private_key_path") + + // If none are set, App auth is not configured + if appIDStr == "" && installationIDStr == "" && privateKeyStr == "" && privateKeyPath == "" { + return 0, nil, 0, nil + } + + // If some but not all are set, that's a configuration error + if appIDStr == "" || installationIDStr == "" || (privateKeyStr == "" && privateKeyPath == "") { + return 0, nil, 0, errors.New("incomplete GitHub App auth config: GITHUB_APP_ID, GITHUB_APP_INSTALLATION_ID, and GITHUB_APP_PRIVATE_KEY or GITHUB_APP_PRIVATE_KEY_PATH are all required") + } + + appID, err = strconv.ParseInt(appIDStr, 10, 64) + if err != nil { + return 0, nil, 0, fmt.Errorf("invalid GITHUB_APP_ID: %w", err) + } + + installationID, err = strconv.ParseInt(installationIDStr, 10, 64) + if err != nil { + return 0, nil, 0, fmt.Errorf("invalid GITHUB_APP_INSTALLATION_ID: %w", err) + } + + if privateKeyStr != "" { + // Environment variables often use literal "\n" instead of actual newlines + privateKey = []byte(strings.ReplaceAll(privateKeyStr, `\n`, "\n")) + } else { + privateKey, err = os.ReadFile(privateKeyPath) + if err != nil { + return 0, nil, 0, fmt.Errorf("failed to read private key from %s: %w", privateKeyPath, err) + } + } + + return appID, privateKey, installationID, nil +} diff --git a/go.sum b/go.sum index c0e9f09552..12cd565ab7 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 38106b6d9a..f4dfc368f2 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -14,6 +14,7 @@ import ( "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/github/appauth" "github.com/github/github-mcp-server/pkg/http/transport" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" @@ -40,7 +41,8 @@ type githubClients struct { } // createGitHubClients creates all the GitHub API clients needed by the server. -func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolver) (*githubClients, error) { +// If authTransport is non-nil, it is used for authentication instead of cfg.Token. +func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolver, authTransport http.RoundTripper) (*githubClients, error) { restURL, err := apiHost.BaseRESTURL(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get base REST URL: %w", err) @@ -61,30 +63,46 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv return nil, fmt.Errorf("failed to get Raw URL: %w", err) } + // Determine the base transport for REST and GraphQL clients + baseTransport := http.RoundTripper(http.DefaultTransport) + if authTransport != nil { + baseTransport = authTransport + } + // Construct REST client restUATransport := &transport.UserAgentTransport{ - Transport: http.DefaultTransport, + Transport: baseTransport, Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } - restClient, err := gogithub.NewClient( + restClientOpts := []gogithub.Option{ gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), - gogithub.WithAuthToken(cfg.Token), gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), - ) + } + if authTransport == nil { + restClientOpts = append(restClientOpts, gogithub.WithAuthToken(cfg.Token)) + } + restClient, err := gogithub.NewClient(restClientOpts...) if err != nil { return nil, fmt.Errorf("failed to create REST client: %w", err) } // Construct GraphQL client // We use NewEnterpriseClient unconditionally since we already parsed the API host - gqlHTTPClient := &http.Client{ - Transport: &transport.BearerAuthTransport{ + var gqlTransport http.RoundTripper + if authTransport != nil { + // Auth transport already sets the Authorization header + gqlTransport = &transport.GraphQLFeaturesTransport{ + Transport: authTransport, + } + } else { + gqlTransport = &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ Transport: http.DefaultTransport, }, Token: cfg.Token, - }, + } } + gqlHTTPClient := &http.Client{Transport: gqlTransport} gqlClient := githubv4.NewEnterpriseClient(graphQLURL.String(), gqlHTTPClient) @@ -116,13 +134,13 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv }, nil } -func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig) (*mcp.Server, error) { +func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig, authTransport http.RoundTripper) (*mcp.Server, error) { apiHost, err := utils.NewAPIHost(cfg.Host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) } - clients, err := createGitHubClients(cfg, apiHost) + clients, err := createGitHubClients(cfg, apiHost, authTransport) if err != nil { return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } @@ -238,6 +256,13 @@ type StdioServerConfig struct { // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration + + // GitHub App authentication (alternative to Token) + // When AppID, PrivateKey, and InstallationID are all set, the server + // authenticates as a GitHub App installation instead of using a PAT. + AppID int64 + PrivateKey []byte + InstallationID int64 } // RunStdioServer is not concurrent safe. @@ -264,11 +289,35 @@ func RunStdioServer(cfg StdioServerConfig) error { logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) + // Set up GitHub App authentication transport if configured + var appAuthTransport http.RoundTripper + if cfg.AppID != 0 && len(cfg.PrivateKey) > 0 && cfg.InstallationID != 0 { + apiHost, err := utils.NewAPIHost(cfg.Host) + if err != nil { + return fmt.Errorf("failed to parse API host for app auth: %w", err) + } + baseURL, err := apiHost.BaseRESTURL(ctx) + if err != nil { + return fmt.Errorf("failed to get base REST URL for app auth: %w", err) + } + tr, err := appauth.NewTransport(http.DefaultTransport, appauth.Config{ + AppID: cfg.AppID, + PrivateKey: cfg.PrivateKey, + InstallationID: cfg.InstallationID, + BaseURL: baseURL.String(), + }) + if err != nil { + return fmt.Errorf("failed to create GitHub App auth transport: %w", err) + } + appAuthTransport = tr + logger.Info("using GitHub App authentication", "appID", cfg.AppID, "installationID", cfg.InstallationID) + } + // Fetch token scopes for scope-based tool filtering (PAT tokens only) // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. // Fine-grained PATs and other token types don't support this, so we skip filtering. var tokenScopes []string - if strings.HasPrefix(cfg.Token, "ghp_") { + if appAuthTransport == nil && strings.HasPrefix(cfg.Token, "ghp_") { fetchedScopes, err := fetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) if err != nil { logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) @@ -276,7 +325,7 @@ func RunStdioServer(cfg StdioServerConfig) error { tokenScopes = fetchedScopes logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) } - } else { + } else if appAuthTransport == nil { logger.Debug("skipping scope filtering for non-PAT token") } @@ -296,7 +345,7 @@ func RunStdioServer(cfg StdioServerConfig) error { Logger: logger, RepoAccessTTL: cfg.RepoAccessCacheTTL, TokenScopes: tokenScopes, - }) + }, appAuthTransport) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) } diff --git a/pkg/github/appauth/appauth.go b/pkg/github/appauth/appauth.go new file mode 100644 index 0000000000..79b938a44d --- /dev/null +++ b/pkg/github/appauth/appauth.go @@ -0,0 +1,233 @@ +package appauth + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// Config holds the configuration for GitHub App authentication. +type Config struct { + // AppID is the GitHub App ID. + AppID int64 + + // PrivateKey is the PEM-encoded RSA private key for the GitHub App. + PrivateKey []byte + + // InstallationID is the installation ID of the GitHub App. + InstallationID int64 + + // BaseURL is the base URL for the GitHub API (e.g., "https://api.github.com"). + // If empty, defaults to "https://api.github.com". + BaseURL string +} + +// Transport is an http.RoundTripper that authenticates requests using +// a GitHub App installation token. It automatically generates JWTs and +// fetches/refreshes installation tokens as needed. +type Transport struct { + config Config + key *rsa.PrivateKey + base http.RoundTripper + + mu sync.Mutex + token string + exp time.Time +} + +type installationToken struct { + Token string `json:"token"` + ExpiresAt time.Time `json:"expires_at"` +} + +// NewTransport creates a new Transport that authenticates using a GitHub App +// installation token. The transport automatically handles JWT generation and +// installation token refresh. +func NewTransport(base http.RoundTripper, cfg Config) (*Transport, error) { + key, err := parsePrivateKey(cfg.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + if base == nil { + base = http.DefaultTransport + } + if cfg.BaseURL == "" { + cfg.BaseURL = "https://api.github.com" + } + return &Transport{ + config: cfg, + key: key, + base: base, + }, nil +} + +// RoundTrip implements http.RoundTripper. It adds the installation token +// to the Authorization header, refreshing it if necessary. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := t.installationToken(req.Context()) + if err != nil { + return nil, fmt.Errorf("failed to get installation token: %w", err) + } + req2 := req.Clone(req.Context()) + req2.Header.Set("Authorization", "Bearer "+token) + return t.base.RoundTrip(req2) +} + +// Token returns the current installation token, refreshing if necessary. +func (t *Transport) Token(ctx context.Context) (string, error) { + return t.installationToken(ctx) +} + +func (t *Transport) installationToken(ctx context.Context) (string, error) { + t.mu.Lock() + defer t.mu.Unlock() + + // Refresh if the token expires within 5 minutes + if t.token != "" && time.Now().Add(5*time.Minute).Before(t.exp) { + return t.token, nil + } + + jwtToken, err := t.generateJWT() + if err != nil { + return "", fmt.Errorf("failed to generate JWT: %w", err) + } + + tok, err := t.fetchInstallationToken(ctx, jwtToken) + if err != nil { + return "", err + } + + t.token = tok.Token + t.exp = tok.ExpiresAt + return t.token, nil +} + +// generateJWT creates a signed JWT for GitHub App authentication using RS256. +func (t *Transport) generateJWT() (string, error) { + now := time.Now().Add(-30 * time.Second) // allow 30s clock drift + + header := map[string]string{ + "alg": "RS256", + "typ": "JWT", + } + payload := map[string]any{ + "iat": now.Unix(), + "exp": now.Add(10 * time.Minute).Unix(), + "iss": fmt.Sprintf("%d", t.config.AppID), + } + + headerJSON, err := json.Marshal(header) + if err != nil { + return "", fmt.Errorf("failed to marshal JWT header: %w", err) + } + payloadJSON, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal JWT payload: %w", err) + } + + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON) + + signingInput := headerB64 + "." + payloadB64 + hash := sha256.Sum256([]byte(signingInput)) + sig, err := rsa.SignPKCS1v15(rand.Reader, t.key, crypto.SHA256, hash[:]) + if err != nil { + return "", fmt.Errorf("failed to sign JWT: %w", err) + } + sigB64 := base64.RawURLEncoding.EncodeToString(sig) + + return signingInput + "." + sigB64, nil +} + +func (t *Transport) fetchInstallationToken(ctx context.Context, jwtToken string) (*installationToken, error) { + url := fmt.Sprintf("%s/app/installations/%d/access_tokens", t.config.BaseURL, t.config.InstallationID) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+jwtToken) + req.Header.Set("Accept", "application/vnd.github+json") + + resp, err := t.base.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch installation token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to create installation token (status %d): %s", resp.StatusCode, body) + } + + var tok installationToken + if err := json.NewDecoder(resp.Body).Decode(&tok); err != nil { + return nil, fmt.Errorf("failed to decode installation token response: %w", err) + } + return &tok, nil +} + +func parsePrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "PRIVATE KEY": + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("expected RSA private key, got %T", key) + } + return rsaKey, nil + default: + return nil, fmt.Errorf("unsupported PEM block type: %s", block.Type) + } +} + +// VerifyJWT parses and verifies a JWT token using the given RSA public key. +// Returns the claims map. This is used only for testing. +func VerifyJWT(tokenString string, pubKey *rsa.PublicKey) (map[string]any, error) { + parts := strings.SplitN(tokenString, ".", 3) + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT: expected 3 parts, got %d", len(parts)) + } + + signingInput := parts[0] + "." + parts[1] + sig, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, fmt.Errorf("failed to decode signature: %w", err) + } + + hash := sha256.Sum256([]byte(signingInput)) + if err := rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash[:], sig); err != nil { + return nil, fmt.Errorf("invalid signature: %w", err) + } + + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode payload: %w", err) + } + var claims map[string]any + if err := json.Unmarshal(payloadJSON, &claims); err != nil { + return nil, fmt.Errorf("failed to unmarshal claims: %w", err) + } + return claims, nil +} diff --git a/pkg/github/appauth/appauth_test.go b/pkg/github/appauth/appauth_test.go new file mode 100644 index 0000000000..a5e960d20a --- /dev/null +++ b/pkg/github/appauth/appauth_test.go @@ -0,0 +1,282 @@ +package appauth + +import ( + "context" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func generateTestKey(t *testing.T) (*rsa.PrivateKey, []byte) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + return key, pemBytes +} + +func TestParsePrivateKey_PKCS1(t *testing.T) { + _, pemBytes := generateTestKey(t) + key, err := parsePrivateKey(pemBytes) + require.NoError(t, err) + assert.NotNil(t, key) +} + +func TestParsePrivateKey_PKCS8(t *testing.T) { + rsaKey, _ := generateTestKey(t) + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(rsaKey) + require.NoError(t, err) + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Bytes, + }) + + key, err := parsePrivateKey(pemBytes) + require.NoError(t, err) + assert.NotNil(t, key) +} + +func TestParsePrivateKey_InvalidPEM(t *testing.T) { + _, err := parsePrivateKey([]byte("not a pem")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode PEM block") +} + +func TestParsePrivateKey_UnsupportedType(t *testing.T) { + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: []byte("fake"), + }) + _, err := parsePrivateKey(pemBytes) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported PEM block type") +} + +func TestNewTransport_InvalidKey(t *testing.T) { + _, err := NewTransport(nil, Config{ + AppID: 123, + PrivateKey: []byte("invalid"), + InstallationID: 456, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse private key") +} + +func TestNewTransport_DefaultBaseURL(t *testing.T) { + _, pemBytes := generateTestKey(t) + tr, err := NewTransport(nil, Config{ + AppID: 123, + PrivateKey: pemBytes, + InstallationID: 456, + }) + require.NoError(t, err) + assert.Equal(t, "https://api.github.com", tr.config.BaseURL) +} + +func TestNewTransport_CustomBaseURL(t *testing.T) { + _, pemBytes := generateTestKey(t) + tr, err := NewTransport(nil, Config{ + AppID: 123, + PrivateKey: pemBytes, + InstallationID: 456, + BaseURL: "https://github.example.com/api/v3", + }) + require.NoError(t, err) + assert.Equal(t, "https://github.example.com/api/v3", tr.config.BaseURL) +} + +func TestTransport_GenerateJWT(t *testing.T) { + key, pemBytes := generateTestKey(t) + tr, err := NewTransport(nil, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + }) + require.NoError(t, err) + + jwtToken, err := tr.generateJWT() + require.NoError(t, err) + + claims, err := VerifyJWT(jwtToken, &key.PublicKey) + require.NoError(t, err) + + assert.Equal(t, "12345", claims["iss"]) + + iat := int64(claims["iat"].(float64)) + exp := int64(claims["exp"].(float64)) + assert.InDelta(t, time.Now().Unix(), iat, 60) + assert.InDelta(t, time.Now().Add(10*time.Minute).Unix(), exp, 60) +} + +func TestTransport_FetchInstallationToken(t *testing.T) { + key, pemBytes := generateTestKey(t) + installationID := int64(67890) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := fmt.Sprintf("/app/installations/%d/access_tokens", installationID) + assert.Equal(t, expectedPath, r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + + authHeader := r.Header.Get("Authorization") + assert.True(t, len(authHeader) > 7) + jwtToken := authHeader[7:] // strip "Bearer " + + _, err := VerifyJWT(jwtToken, &key.PublicKey) + assert.NoError(t, err) + + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(installationToken{ + Token: "ghs_test_token_123", + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + })) + defer server.Close() + + tr, err := NewTransport(server.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: installationID, + BaseURL: server.URL, + }) + require.NoError(t, err) + + token, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_test_token_123", token) +} + +func TestTransport_TokenCaching(t *testing.T) { + _, pemBytes := generateTestKey(t) + var callCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount.Add(1) + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(installationToken{ + Token: "ghs_cached_token", + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + })) + defer server.Close() + + tr, err := NewTransport(server.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + BaseURL: server.URL, + }) + require.NoError(t, err) + + token1, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_cached_token", token1) + + token2, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_cached_token", token2) + + assert.Equal(t, int32(1), callCount.Load()) +} + +func TestTransport_TokenRefresh(t *testing.T) { + _, pemBytes := generateTestKey(t) + var callCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + count := callCount.Add(1) + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(installationToken{ + Token: fmt.Sprintf("ghs_token_%d", count), + ExpiresAt: time.Now().Add(1 * time.Minute), // expires soon, within 5min refresh window + }) + })) + defer server.Close() + + tr, err := NewTransport(server.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + BaseURL: server.URL, + }) + require.NoError(t, err) + + token1, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_token_1", token1) + + // Token expires within 5 minutes, so next call should refresh + token2, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_token_2", token2) + assert.Equal(t, int32(2), callCount.Load()) +} + +func TestTransport_RoundTrip(t *testing.T) { + _, pemBytes := generateTestKey(t) + + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/app/installations/67890/access_tokens" { + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(installationToken{ + Token: "ghs_roundtrip_token", + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + return + } + assert.Equal(t, "Bearer ghs_roundtrip_token", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok": true}`)) + })) + defer tokenServer.Close() + + tr, err := NewTransport(tokenServer.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + BaseURL: tokenServer.URL, + }) + require.NoError(t, err) + + client := &http.Client{Transport: tr} + resp, err := client.Get(tokenServer.URL + "/repos/owner/repo") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestTransport_FetchError(t *testing.T) { + _, pemBytes := generateTestKey(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Bad credentials"}`)) + })) + defer server.Close() + + tr, err := NewTransport(server.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + BaseURL: server.URL, + }) + require.NoError(t, err) + + _, err = tr.Token(context.Background()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create installation token") + assert.Contains(t, err.Error(), "Bad credentials") +} From f3ad678ceb6510c02bf127ccc3b499ff868a1620 Mon Sep 17 00:00:00 2001 From: pyama Date: Thu, 28 May 2026 08:50:21 +0900 Subject: [PATCH 2/3] fix: use direct NewClient calls instead of gogithub.Option slice gogithub.Option is not an exported type, so we cannot store options in a slice. Use separate NewClient calls for each auth path instead. --- internal/ghmcp/server.go | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index f4dfc368f2..d7d960fe7d 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -74,14 +74,19 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv Transport: baseTransport, Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } - restClientOpts := []gogithub.Option{ - gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), - gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), - } - if authTransport == nil { - restClientOpts = append(restClientOpts, gogithub.WithAuthToken(cfg.Token)) + var restClient *gogithub.Client + if authTransport != nil { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } else { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), + gogithub.WithAuthToken(cfg.Token), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) } - restClient, err := gogithub.NewClient(restClientOpts...) if err != nil { return nil, fmt.Errorf("failed to create REST client: %w", err) } From 6799ef5bb1deb57a8c51c5a9c08c69c6b40e5eea Mon Sep 17 00:00:00 2001 From: pyama Date: Thu, 28 May 2026 13:55:38 +0900 Subject: [PATCH 3/3] fix: address Copilot review feedback - Move VerifyJWT to test file to avoid exporting test-only helpers - Use RWMutex with double-check pattern to avoid blocking reads during token refresh - Add UserAgentTransport to GraphQL path when using App auth for consistency with REST - Make GITHUB_APP_PRIVATE_KEY and GITHUB_APP_PRIVATE_KEY_PATH mutually exclusive (return error when both are set) - Only replace literal \n when the private key has no real newlines to avoid corrupting correctly-passed keys - Use safer JWT lifetime (iat=now-30s, exp=now+9m) to stay well within GitHub's 10-minute maximum - Document that base transport must not inject its own Authorization header --- README.md | 2 +- cmd/github-mcp-server/main.go | 13 ++++++-- internal/ghmcp/server.go | 10 ++++-- pkg/github/appauth/appauth.go | 53 ++++++++++-------------------- pkg/github/appauth/appauth_test.go | 39 ++++++++++++++++++++-- 5 files changed, 72 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index da4562f50a..f70bc69b66 100644 --- a/README.md +++ b/README.md @@ -254,7 +254,7 @@ The server automatically generates JWTs, fetches installation tokens, and refres | `GITHUB_APP_PRIVATE_KEY` | The PEM-encoded private key (inline, `\n` for newlines) | | `GITHUB_APP_PRIVATE_KEY_PATH` | Path to the private key file (alternative to inline) | -Either `GITHUB_APP_PRIVATE_KEY` or `GITHUB_APP_PRIVATE_KEY_PATH` must be set, but not both. When all three required variables (`GITHUB_APP_ID`, `GITHUB_APP_INSTALLATION_ID`, and a private key) are set, the server uses GitHub App authentication instead of a PAT. `GITHUB_PERSONAL_ACCESS_TOKEN` is not required in this case. +Either `GITHUB_APP_PRIVATE_KEY` or `GITHUB_APP_PRIVATE_KEY_PATH` must be set, but not both (they are mutually exclusive). When all three required variables (`GITHUB_APP_ID`, `GITHUB_APP_INSTALLATION_ID`, and a private key) are set, the server uses GitHub App authentication instead of a PAT. `GITHUB_PERSONAL_ACCESS_TOKEN` is not required in this case. #### Example: Using a private key file diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 374a861fa6..6e99c37ae2 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -266,6 +266,10 @@ func parseAppAuthConfig() (appID int64, privateKey []byte, installationID int64, return 0, nil, 0, errors.New("incomplete GitHub App auth config: GITHUB_APP_ID, GITHUB_APP_INSTALLATION_ID, and GITHUB_APP_PRIVATE_KEY or GITHUB_APP_PRIVATE_KEY_PATH are all required") } + if privateKeyStr != "" && privateKeyPath != "" { + return 0, nil, 0, errors.New("GITHUB_APP_PRIVATE_KEY and GITHUB_APP_PRIVATE_KEY_PATH are mutually exclusive") + } + appID, err = strconv.ParseInt(appIDStr, 10, 64) if err != nil { return 0, nil, 0, fmt.Errorf("invalid GITHUB_APP_ID: %w", err) @@ -277,8 +281,13 @@ func parseAppAuthConfig() (appID int64, privateKey []byte, installationID int64, } if privateKeyStr != "" { - // Environment variables often use literal "\n" instead of actual newlines - privateKey = []byte(strings.ReplaceAll(privateKeyStr, `\n`, "\n")) + // Environment variables often use literal "\n" instead of actual newlines. + // Only replace when the value has no real newlines to avoid corrupting + // keys that were correctly passed with actual newlines. + if !strings.Contains(privateKeyStr, "\n") { + privateKeyStr = strings.ReplaceAll(privateKeyStr, `\n`, "\n") + } + privateKey = []byte(privateKeyStr) } else { privateKey, err = os.ReadFile(privateKeyPath) if err != nil { diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index d7d960fe7d..9632d7de28 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -95,9 +95,13 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv // We use NewEnterpriseClient unconditionally since we already parsed the API host var gqlTransport http.RoundTripper if authTransport != nil { - // Auth transport already sets the Authorization header - gqlTransport = &transport.GraphQLFeaturesTransport{ - Transport: authTransport, + // Auth transport already sets the Authorization header. + // Wrap with UserAgentTransport for consistency with the REST path. + gqlTransport = &transport.UserAgentTransport{ + Transport: &transport.GraphQLFeaturesTransport{ + Transport: authTransport, + }, + Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } } else { gqlTransport = &transport.BearerAuthTransport{ diff --git a/pkg/github/appauth/appauth.go b/pkg/github/appauth/appauth.go index 79b938a44d..16842b4446 100644 --- a/pkg/github/appauth/appauth.go +++ b/pkg/github/appauth/appauth.go @@ -13,7 +13,6 @@ import ( "fmt" "io" "net/http" - "strings" "sync" "time" ) @@ -42,7 +41,7 @@ type Transport struct { key *rsa.PrivateKey base http.RoundTripper - mu sync.Mutex + mu sync.RWMutex token string exp time.Time } @@ -55,6 +54,8 @@ type installationToken struct { // NewTransport creates a new Transport that authenticates using a GitHub App // installation token. The transport automatically handles JWT generation and // installation token refresh. +// The base transport must not inject its own Authorization header, as this +// transport sets it for both installation token requests and API requests. func NewTransport(base http.RoundTripper, cfg Config) (*Transport, error) { key, err := parsePrivateKey(cfg.PrivateKey) if err != nil { @@ -91,10 +92,20 @@ func (t *Transport) Token(ctx context.Context) (string, error) { } func (t *Transport) installationToken(ctx context.Context) (string, error) { + // Fast path: read lock to check cached token + t.mu.RLock() + if t.token != "" && time.Now().Add(5*time.Minute).Before(t.exp) { + token := t.token + t.mu.RUnlock() + return token, nil + } + t.mu.RUnlock() + + // Slow path: write lock to refresh t.mu.Lock() defer t.mu.Unlock() - // Refresh if the token expires within 5 minutes + // Double-check after acquiring write lock if t.token != "" && time.Now().Add(5*time.Minute).Before(t.exp) { return t.token, nil } @@ -116,15 +127,15 @@ func (t *Transport) installationToken(ctx context.Context) (string, error) { // generateJWT creates a signed JWT for GitHub App authentication using RS256. func (t *Transport) generateJWT() (string, error) { - now := time.Now().Add(-30 * time.Second) // allow 30s clock drift + now := time.Now() header := map[string]string{ "alg": "RS256", "typ": "JWT", } payload := map[string]any{ - "iat": now.Unix(), - "exp": now.Add(10 * time.Minute).Unix(), + "iat": now.Add(-30 * time.Second).Unix(), // allow 30s clock drift + "exp": now.Add(9 * time.Minute).Unix(), // well within GitHub's 10-minute maximum "iss": fmt.Sprintf("%d", t.config.AppID), } @@ -201,33 +212,3 @@ func parsePrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { return nil, fmt.Errorf("unsupported PEM block type: %s", block.Type) } } - -// VerifyJWT parses and verifies a JWT token using the given RSA public key. -// Returns the claims map. This is used only for testing. -func VerifyJWT(tokenString string, pubKey *rsa.PublicKey) (map[string]any, error) { - parts := strings.SplitN(tokenString, ".", 3) - if len(parts) != 3 { - return nil, fmt.Errorf("invalid JWT: expected 3 parts, got %d", len(parts)) - } - - signingInput := parts[0] + "." + parts[1] - sig, err := base64.RawURLEncoding.DecodeString(parts[2]) - if err != nil { - return nil, fmt.Errorf("failed to decode signature: %w", err) - } - - hash := sha256.Sum256([]byte(signingInput)) - if err := rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash[:], sig); err != nil { - return nil, fmt.Errorf("invalid signature: %w", err) - } - - payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) - if err != nil { - return nil, fmt.Errorf("failed to decode payload: %w", err) - } - var claims map[string]any - if err := json.Unmarshal(payloadJSON, &claims); err != nil { - return nil, fmt.Errorf("failed to unmarshal claims: %w", err) - } - return claims, nil -} diff --git a/pkg/github/appauth/appauth_test.go b/pkg/github/appauth/appauth_test.go index a5e960d20a..150fe8da1d 100644 --- a/pkg/github/appauth/appauth_test.go +++ b/pkg/github/appauth/appauth_test.go @@ -2,14 +2,18 @@ package appauth import ( "context" + "crypto" "crypto/rand" "crypto/rsa" + "crypto/sha256" "crypto/x509" + "encoding/base64" "encoding/json" "encoding/pem" "fmt" "net/http" "net/http/httptest" + "strings" "sync/atomic" "testing" "time" @@ -18,6 +22,35 @@ import ( "github.com/stretchr/testify/require" ) +// verifyJWT parses and verifies a JWT token using the given RSA public key. +func verifyJWT(tokenString string, pubKey *rsa.PublicKey) (map[string]any, error) { + parts := strings.SplitN(tokenString, ".", 3) + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT: expected 3 parts, got %d", len(parts)) + } + + signingInput := parts[0] + "." + parts[1] + sig, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, fmt.Errorf("failed to decode signature: %w", err) + } + + hash := sha256.Sum256([]byte(signingInput)) + if err := rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash[:], sig); err != nil { + return nil, fmt.Errorf("invalid signature: %w", err) + } + + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode payload: %w", err) + } + var claims map[string]any + if err := json.Unmarshal(payloadJSON, &claims); err != nil { + return nil, fmt.Errorf("failed to unmarshal claims: %w", err) + } + return claims, nil +} + func generateTestKey(t *testing.T) (*rsa.PrivateKey, []byte) { t.Helper() key, err := rsa.GenerateKey(rand.Reader, 2048) @@ -111,7 +144,7 @@ func TestTransport_GenerateJWT(t *testing.T) { jwtToken, err := tr.generateJWT() require.NoError(t, err) - claims, err := VerifyJWT(jwtToken, &key.PublicKey) + claims, err := verifyJWT(jwtToken, &key.PublicKey) require.NoError(t, err) assert.Equal(t, "12345", claims["iss"]) @@ -119,7 +152,7 @@ func TestTransport_GenerateJWT(t *testing.T) { iat := int64(claims["iat"].(float64)) exp := int64(claims["exp"].(float64)) assert.InDelta(t, time.Now().Unix(), iat, 60) - assert.InDelta(t, time.Now().Add(10*time.Minute).Unix(), exp, 60) + assert.InDelta(t, time.Now().Add(9*time.Minute).Unix(), exp, 60) } func TestTransport_FetchInstallationToken(t *testing.T) { @@ -135,7 +168,7 @@ func TestTransport_FetchInstallationToken(t *testing.T) { assert.True(t, len(authHeader) > 7) jwtToken := authHeader[7:] // strip "Bearer " - _, err := VerifyJWT(jwtToken, &key.PublicKey) + _, err := verifyJWT(jwtToken, &key.PublicKey) assert.NoError(t, err) w.WriteHeader(http.StatusCreated)