diff --git a/README.md b/README.md index bec45b5da..f70bc69b6 100644 --- a/README.md +++ b/README.md @@ -239,6 +239,80 @@ To keep your GitHub PAT secure and reusable across different MCP hosts: +### GitHub App Authentication + +As an alternative to Personal Access Tokens, the MCP server supports authenticating as a [GitHub App](https://docs.github.com/en/apps) installation. This is useful for organizations that want to grant scoped, short-lived access without relying on individual PATs. + +The server automatically generates JWTs, fetches installation tokens, and refreshes them before expiry (installation tokens are valid for 1 hour). + +#### Required Environment Variables + +| Variable | Description | +|---|---| +| `GITHUB_APP_ID` | The GitHub App ID | +| `GITHUB_APP_INSTALLATION_ID` | The installation ID of the GitHub App | +| `GITHUB_APP_PRIVATE_KEY` | The PEM-encoded private key (inline, `\n` for newlines) | +| `GITHUB_APP_PRIVATE_KEY_PATH` | Path to the private key file (alternative to inline) | + +Either `GITHUB_APP_PRIVATE_KEY` or `GITHUB_APP_PRIVATE_KEY_PATH` must be set, but not both (they are mutually exclusive). When all three required variables (`GITHUB_APP_ID`, `GITHUB_APP_INSTALLATION_ID`, and a private key) are set, the server uses GitHub App authentication instead of a PAT. `GITHUB_PERSONAL_ACCESS_TOKEN` is not required in this case. + +#### Example: Using a private key file + +```bash +export GITHUB_APP_ID=12345 +export GITHUB_APP_INSTALLATION_ID=67890 +export GITHUB_APP_PRIVATE_KEY_PATH=/path/to/private-key.pem +github-mcp-server stdio +``` + +#### Example: Using an inline private key + +```bash +export GITHUB_APP_ID=12345 +export GITHUB_APP_INSTALLATION_ID=67890 +export GITHUB_APP_PRIVATE_KEY="-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----" +github-mcp-server stdio +``` + +#### Example: Docker with GitHub App authentication + +```bash +docker run -i --rm \ + -e GITHUB_APP_ID=12345 \ + -e GITHUB_APP_INSTALLATION_ID=67890 \ + -e GITHUB_APP_PRIVATE_KEY_PATH=/key/private-key.pem \ + -v /path/to/private-key.pem:/key/private-key.pem:ro \ + ghcr.io/github/github-mcp-server +``` + +#### Example: VS Code configuration + +```json +{ + "mcp": { + "servers": { + "github": { + "command": "docker", + "args": [ + "run", + "-i", + "--rm", + "-e", "GITHUB_APP_ID", + "-e", "GITHUB_APP_INSTALLATION_ID", + "-e", "GITHUB_APP_PRIVATE_KEY", + "ghcr.io/github/github-mcp-server" + ], + "env": { + "GITHUB_APP_ID": "12345", + "GITHUB_APP_INSTALLATION_ID": "67890", + "GITHUB_APP_PRIVATE_KEY": "-----BEGIN RSA PRIVATE KEY-----\nMIIE...\n-----END RSA PRIVATE KEY-----" + } + } + } + } +} +``` + ### GitHub Enterprise Server and Enterprise Cloud with data residency (ghe.com) The flag `--gh-host` and the environment variable `GITHUB_HOST` can be used to set diff --git a/cmd/github-mcp-server/main.go b/cmd/github-mcp-server/main.go index 558fdb998..6e99c37ae 100644 --- a/cmd/github-mcp-server/main.go +++ b/cmd/github-mcp-server/main.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "os" + "strconv" "strings" "time" @@ -34,8 +35,16 @@ var ( Long: `Start a server that communicates via standard input/output streams using JSON-RPC messages.`, RunE: func(_ *cobra.Command, _ []string) error { token := viper.GetString("personal_access_token") - if token == "" { - return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set") + + // Parse GitHub App authentication config + appID, privateKey, installationID, err := parseAppAuthConfig() + if err != nil { + return err + } + useAppAuth := appID != 0 && len(privateKey) > 0 && installationID != 0 + + if token == "" && !useAppAuth { + return errors.New("GITHUB_PERSONAL_ACCESS_TOKEN not set (or configure GitHub App auth with GITHUB_APP_ID, GITHUB_APP_PRIVATE_KEY/GITHUB_APP_PRIVATE_KEY_PATH, and GITHUB_APP_INSTALLATION_ID)") } // If you're wondering why we're not using viper.GetStringSlice("toolsets"), @@ -94,6 +103,9 @@ var ( InsidersMode: viper.GetBool("insiders"), ExcludeTools: excludeTools, RepoAccessCacheTTL: &ttl, + AppID: appID, + PrivateKey: privateKey, + InstallationID: installationID, } return ghmcp.RunStdioServer(stdioServerConfig) }, @@ -235,3 +247,53 @@ func wordSepNormalizeFunc(_ *pflag.FlagSet, name string) pflag.NormalizedName { } return pflag.NormalizedName(name) } + +// parseAppAuthConfig reads GitHub App authentication config from environment variables. +// Returns (0, nil, 0, nil) when no App auth is configured. +func parseAppAuthConfig() (appID int64, privateKey []byte, installationID int64, err error) { + appIDStr := viper.GetString("app_id") + installationIDStr := viper.GetString("app_installation_id") + privateKeyStr := viper.GetString("app_private_key") + privateKeyPath := viper.GetString("app_private_key_path") + + // If none are set, App auth is not configured + if appIDStr == "" && installationIDStr == "" && privateKeyStr == "" && privateKeyPath == "" { + return 0, nil, 0, nil + } + + // If some but not all are set, that's a configuration error + if appIDStr == "" || installationIDStr == "" || (privateKeyStr == "" && privateKeyPath == "") { + return 0, nil, 0, errors.New("incomplete GitHub App auth config: GITHUB_APP_ID, GITHUB_APP_INSTALLATION_ID, and GITHUB_APP_PRIVATE_KEY or GITHUB_APP_PRIVATE_KEY_PATH are all required") + } + + if privateKeyStr != "" && privateKeyPath != "" { + return 0, nil, 0, errors.New("GITHUB_APP_PRIVATE_KEY and GITHUB_APP_PRIVATE_KEY_PATH are mutually exclusive") + } + + appID, err = strconv.ParseInt(appIDStr, 10, 64) + if err != nil { + return 0, nil, 0, fmt.Errorf("invalid GITHUB_APP_ID: %w", err) + } + + installationID, err = strconv.ParseInt(installationIDStr, 10, 64) + if err != nil { + return 0, nil, 0, fmt.Errorf("invalid GITHUB_APP_INSTALLATION_ID: %w", err) + } + + if privateKeyStr != "" { + // Environment variables often use literal "\n" instead of actual newlines. + // Only replace when the value has no real newlines to avoid corrupting + // keys that were correctly passed with actual newlines. + if !strings.Contains(privateKeyStr, "\n") { + privateKeyStr = strings.ReplaceAll(privateKeyStr, `\n`, "\n") + } + privateKey = []byte(privateKeyStr) + } else { + privateKey, err = os.ReadFile(privateKeyPath) + if err != nil { + return 0, nil, 0, fmt.Errorf("failed to read private key from %s: %w", privateKeyPath, err) + } + } + + return appID, privateKey, installationID, nil +} diff --git a/go.sum b/go.sum index c0e9f0955..12cd565ab 100644 --- a/go.sum +++ b/go.sum @@ -11,6 +11,8 @@ github.com/go-chi/chi/v5 v5.2.5 h1:Eg4myHZBjyvJmAFjFvWgrqDTXFyOzjj7YIm3L3mu6Ug= github.com/go-chi/chi/v5 v5.2.5/go.mod h1:X7Gx4mteadT3eDOMTsXzmI4/rwUpOwBHLpAfupzFJP0= github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPEgAXnvj1Ro= github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXeUI= +github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.3.1 h1:kYf81DTWFe7t+1VvL7eS+jKFVWaUnK9cB1qbwn63YCY= github.com/golang-jwt/jwt/v5 v5.3.1/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= diff --git a/internal/ghmcp/server.go b/internal/ghmcp/server.go index 38106b6d9..9632d7de2 100644 --- a/internal/ghmcp/server.go +++ b/internal/ghmcp/server.go @@ -14,6 +14,7 @@ import ( "github.com/github/github-mcp-server/pkg/errors" "github.com/github/github-mcp-server/pkg/github" + "github.com/github/github-mcp-server/pkg/github/appauth" "github.com/github/github-mcp-server/pkg/http/transport" "github.com/github/github-mcp-server/pkg/inventory" "github.com/github/github-mcp-server/pkg/lockdown" @@ -40,7 +41,8 @@ type githubClients struct { } // createGitHubClients creates all the GitHub API clients needed by the server. -func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolver) (*githubClients, error) { +// If authTransport is non-nil, it is used for authentication instead of cfg.Token. +func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolver, authTransport http.RoundTripper) (*githubClients, error) { restURL, err := apiHost.BaseRESTURL(context.Background()) if err != nil { return nil, fmt.Errorf("failed to get base REST URL: %w", err) @@ -61,30 +63,55 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv return nil, fmt.Errorf("failed to get Raw URL: %w", err) } + // Determine the base transport for REST and GraphQL clients + baseTransport := http.RoundTripper(http.DefaultTransport) + if authTransport != nil { + baseTransport = authTransport + } + // Construct REST client restUATransport := &transport.UserAgentTransport{ - Transport: http.DefaultTransport, + Transport: baseTransport, Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), } - restClient, err := gogithub.NewClient( - gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), - gogithub.WithAuthToken(cfg.Token), - gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), - ) + var restClient *gogithub.Client + if authTransport != nil { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } else { + restClient, err = gogithub.NewClient( + gogithub.WithHTTPClient(&http.Client{Transport: restUATransport}), + gogithub.WithAuthToken(cfg.Token), + gogithub.WithEnterpriseURLs(restURL.String(), uploadURL.String()), + ) + } if err != nil { return nil, fmt.Errorf("failed to create REST client: %w", err) } // Construct GraphQL client // We use NewEnterpriseClient unconditionally since we already parsed the API host - gqlHTTPClient := &http.Client{ - Transport: &transport.BearerAuthTransport{ + var gqlTransport http.RoundTripper + if authTransport != nil { + // Auth transport already sets the Authorization header. + // Wrap with UserAgentTransport for consistency with the REST path. + gqlTransport = &transport.UserAgentTransport{ + Transport: &transport.GraphQLFeaturesTransport{ + Transport: authTransport, + }, + Agent: fmt.Sprintf("github-mcp-server/%s", cfg.Version), + } + } else { + gqlTransport = &transport.BearerAuthTransport{ Transport: &transport.GraphQLFeaturesTransport{ Transport: http.DefaultTransport, }, Token: cfg.Token, - }, + } } + gqlHTTPClient := &http.Client{Transport: gqlTransport} gqlClient := githubv4.NewEnterpriseClient(graphQLURL.String(), gqlHTTPClient) @@ -116,13 +143,13 @@ func createGitHubClients(cfg github.MCPServerConfig, apiHost utils.APIHostResolv }, nil } -func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig) (*mcp.Server, error) { +func NewStdioMCPServer(ctx context.Context, cfg github.MCPServerConfig, authTransport http.RoundTripper) (*mcp.Server, error) { apiHost, err := utils.NewAPIHost(cfg.Host) if err != nil { return nil, fmt.Errorf("failed to parse API host: %w", err) } - clients, err := createGitHubClients(cfg, apiHost) + clients, err := createGitHubClients(cfg, apiHost, authTransport) if err != nil { return nil, fmt.Errorf("failed to create GitHub clients: %w", err) } @@ -238,6 +265,13 @@ type StdioServerConfig struct { // RepoAccessCacheTTL overrides the default TTL for repository access cache entries. RepoAccessCacheTTL *time.Duration + + // GitHub App authentication (alternative to Token) + // When AppID, PrivateKey, and InstallationID are all set, the server + // authenticates as a GitHub App installation instead of using a PAT. + AppID int64 + PrivateKey []byte + InstallationID int64 } // RunStdioServer is not concurrent safe. @@ -264,11 +298,35 @@ func RunStdioServer(cfg StdioServerConfig) error { logger := slog.New(slogHandler) logger.Info("starting server", "version", cfg.Version, "host", cfg.Host, "readOnly", cfg.ReadOnly, "lockdownEnabled", cfg.LockdownMode) + // Set up GitHub App authentication transport if configured + var appAuthTransport http.RoundTripper + if cfg.AppID != 0 && len(cfg.PrivateKey) > 0 && cfg.InstallationID != 0 { + apiHost, err := utils.NewAPIHost(cfg.Host) + if err != nil { + return fmt.Errorf("failed to parse API host for app auth: %w", err) + } + baseURL, err := apiHost.BaseRESTURL(ctx) + if err != nil { + return fmt.Errorf("failed to get base REST URL for app auth: %w", err) + } + tr, err := appauth.NewTransport(http.DefaultTransport, appauth.Config{ + AppID: cfg.AppID, + PrivateKey: cfg.PrivateKey, + InstallationID: cfg.InstallationID, + BaseURL: baseURL.String(), + }) + if err != nil { + return fmt.Errorf("failed to create GitHub App auth transport: %w", err) + } + appAuthTransport = tr + logger.Info("using GitHub App authentication", "appID", cfg.AppID, "installationID", cfg.InstallationID) + } + // Fetch token scopes for scope-based tool filtering (PAT tokens only) // Only classic PATs (ghp_ prefix) return OAuth scopes via X-OAuth-Scopes header. // Fine-grained PATs and other token types don't support this, so we skip filtering. var tokenScopes []string - if strings.HasPrefix(cfg.Token, "ghp_") { + if appAuthTransport == nil && strings.HasPrefix(cfg.Token, "ghp_") { fetchedScopes, err := fetchTokenScopesForHost(ctx, cfg.Token, cfg.Host) if err != nil { logger.Warn("failed to fetch token scopes, continuing without scope filtering", "error", err) @@ -276,7 +334,7 @@ func RunStdioServer(cfg StdioServerConfig) error { tokenScopes = fetchedScopes logger.Info("token scopes fetched for filtering", "scopes", tokenScopes) } - } else { + } else if appAuthTransport == nil { logger.Debug("skipping scope filtering for non-PAT token") } @@ -296,7 +354,7 @@ func RunStdioServer(cfg StdioServerConfig) error { Logger: logger, RepoAccessTTL: cfg.RepoAccessCacheTTL, TokenScopes: tokenScopes, - }) + }, appAuthTransport) if err != nil { return fmt.Errorf("failed to create MCP server: %w", err) } diff --git a/pkg/github/appauth/appauth.go b/pkg/github/appauth/appauth.go new file mode 100644 index 000000000..16842b444 --- /dev/null +++ b/pkg/github/appauth/appauth.go @@ -0,0 +1,214 @@ +package appauth + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "io" + "net/http" + "sync" + "time" +) + +// Config holds the configuration for GitHub App authentication. +type Config struct { + // AppID is the GitHub App ID. + AppID int64 + + // PrivateKey is the PEM-encoded RSA private key for the GitHub App. + PrivateKey []byte + + // InstallationID is the installation ID of the GitHub App. + InstallationID int64 + + // BaseURL is the base URL for the GitHub API (e.g., "https://api.github.com"). + // If empty, defaults to "https://api.github.com". + BaseURL string +} + +// Transport is an http.RoundTripper that authenticates requests using +// a GitHub App installation token. It automatically generates JWTs and +// fetches/refreshes installation tokens as needed. +type Transport struct { + config Config + key *rsa.PrivateKey + base http.RoundTripper + + mu sync.RWMutex + token string + exp time.Time +} + +type installationToken struct { + Token string `json:"token"` + ExpiresAt time.Time `json:"expires_at"` +} + +// NewTransport creates a new Transport that authenticates using a GitHub App +// installation token. The transport automatically handles JWT generation and +// installation token refresh. +// The base transport must not inject its own Authorization header, as this +// transport sets it for both installation token requests and API requests. +func NewTransport(base http.RoundTripper, cfg Config) (*Transport, error) { + key, err := parsePrivateKey(cfg.PrivateKey) + if err != nil { + return nil, fmt.Errorf("failed to parse private key: %w", err) + } + if base == nil { + base = http.DefaultTransport + } + if cfg.BaseURL == "" { + cfg.BaseURL = "https://api.github.com" + } + return &Transport{ + config: cfg, + key: key, + base: base, + }, nil +} + +// RoundTrip implements http.RoundTripper. It adds the installation token +// to the Authorization header, refreshing it if necessary. +func (t *Transport) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := t.installationToken(req.Context()) + if err != nil { + return nil, fmt.Errorf("failed to get installation token: %w", err) + } + req2 := req.Clone(req.Context()) + req2.Header.Set("Authorization", "Bearer "+token) + return t.base.RoundTrip(req2) +} + +// Token returns the current installation token, refreshing if necessary. +func (t *Transport) Token(ctx context.Context) (string, error) { + return t.installationToken(ctx) +} + +func (t *Transport) installationToken(ctx context.Context) (string, error) { + // Fast path: read lock to check cached token + t.mu.RLock() + if t.token != "" && time.Now().Add(5*time.Minute).Before(t.exp) { + token := t.token + t.mu.RUnlock() + return token, nil + } + t.mu.RUnlock() + + // Slow path: write lock to refresh + t.mu.Lock() + defer t.mu.Unlock() + + // Double-check after acquiring write lock + if t.token != "" && time.Now().Add(5*time.Minute).Before(t.exp) { + return t.token, nil + } + + jwtToken, err := t.generateJWT() + if err != nil { + return "", fmt.Errorf("failed to generate JWT: %w", err) + } + + tok, err := t.fetchInstallationToken(ctx, jwtToken) + if err != nil { + return "", err + } + + t.token = tok.Token + t.exp = tok.ExpiresAt + return t.token, nil +} + +// generateJWT creates a signed JWT for GitHub App authentication using RS256. +func (t *Transport) generateJWT() (string, error) { + now := time.Now() + + header := map[string]string{ + "alg": "RS256", + "typ": "JWT", + } + payload := map[string]any{ + "iat": now.Add(-30 * time.Second).Unix(), // allow 30s clock drift + "exp": now.Add(9 * time.Minute).Unix(), // well within GitHub's 10-minute maximum + "iss": fmt.Sprintf("%d", t.config.AppID), + } + + headerJSON, err := json.Marshal(header) + if err != nil { + return "", fmt.Errorf("failed to marshal JWT header: %w", err) + } + payloadJSON, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("failed to marshal JWT payload: %w", err) + } + + headerB64 := base64.RawURLEncoding.EncodeToString(headerJSON) + payloadB64 := base64.RawURLEncoding.EncodeToString(payloadJSON) + + signingInput := headerB64 + "." + payloadB64 + hash := sha256.Sum256([]byte(signingInput)) + sig, err := rsa.SignPKCS1v15(rand.Reader, t.key, crypto.SHA256, hash[:]) + if err != nil { + return "", fmt.Errorf("failed to sign JWT: %w", err) + } + sigB64 := base64.RawURLEncoding.EncodeToString(sig) + + return signingInput + "." + sigB64, nil +} + +func (t *Transport) fetchInstallationToken(ctx context.Context, jwtToken string) (*installationToken, error) { + url := fmt.Sprintf("%s/app/installations/%d/access_tokens", t.config.BaseURL, t.config.InstallationID) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+jwtToken) + req.Header.Set("Accept", "application/vnd.github+json") + + resp, err := t.base.RoundTrip(req) + if err != nil { + return nil, fmt.Errorf("failed to fetch installation token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusCreated { + body, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("failed to create installation token (status %d): %s", resp.StatusCode, body) + } + + var tok installationToken + if err := json.NewDecoder(resp.Body).Decode(&tok); err != nil { + return nil, fmt.Errorf("failed to decode installation token response: %w", err) + } + return &tok, nil +} + +func parsePrivateKey(pemBytes []byte) (*rsa.PrivateKey, error) { + block, _ := pem.Decode(pemBytes) + if block == nil { + return nil, fmt.Errorf("failed to decode PEM block") + } + + switch block.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(block.Bytes) + case "PRIVATE KEY": + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, err + } + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("expected RSA private key, got %T", key) + } + return rsaKey, nil + default: + return nil, fmt.Errorf("unsupported PEM block type: %s", block.Type) + } +} diff --git a/pkg/github/appauth/appauth_test.go b/pkg/github/appauth/appauth_test.go new file mode 100644 index 000000000..150fe8da1 --- /dev/null +++ b/pkg/github/appauth/appauth_test.go @@ -0,0 +1,315 @@ +package appauth + +import ( + "context" + "crypto" + "crypto/rand" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// verifyJWT parses and verifies a JWT token using the given RSA public key. +func verifyJWT(tokenString string, pubKey *rsa.PublicKey) (map[string]any, error) { + parts := strings.SplitN(tokenString, ".", 3) + if len(parts) != 3 { + return nil, fmt.Errorf("invalid JWT: expected 3 parts, got %d", len(parts)) + } + + signingInput := parts[0] + "." + parts[1] + sig, err := base64.RawURLEncoding.DecodeString(parts[2]) + if err != nil { + return nil, fmt.Errorf("failed to decode signature: %w", err) + } + + hash := sha256.Sum256([]byte(signingInput)) + if err := rsa.VerifyPKCS1v15(pubKey, crypto.SHA256, hash[:], sig); err != nil { + return nil, fmt.Errorf("invalid signature: %w", err) + } + + payloadJSON, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + return nil, fmt.Errorf("failed to decode payload: %w", err) + } + var claims map[string]any + if err := json.Unmarshal(payloadJSON, &claims); err != nil { + return nil, fmt.Errorf("failed to unmarshal claims: %w", err) + } + return claims, nil +} + +func generateTestKey(t *testing.T) (*rsa.PrivateKey, []byte) { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }) + return key, pemBytes +} + +func TestParsePrivateKey_PKCS1(t *testing.T) { + _, pemBytes := generateTestKey(t) + key, err := parsePrivateKey(pemBytes) + require.NoError(t, err) + assert.NotNil(t, key) +} + +func TestParsePrivateKey_PKCS8(t *testing.T) { + rsaKey, _ := generateTestKey(t) + pkcs8Bytes, err := x509.MarshalPKCS8PrivateKey(rsaKey) + require.NoError(t, err) + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "PRIVATE KEY", + Bytes: pkcs8Bytes, + }) + + key, err := parsePrivateKey(pemBytes) + require.NoError(t, err) + assert.NotNil(t, key) +} + +func TestParsePrivateKey_InvalidPEM(t *testing.T) { + _, err := parsePrivateKey([]byte("not a pem")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to decode PEM block") +} + +func TestParsePrivateKey_UnsupportedType(t *testing.T) { + pemBytes := pem.EncodeToMemory(&pem.Block{ + Type: "EC PRIVATE KEY", + Bytes: []byte("fake"), + }) + _, err := parsePrivateKey(pemBytes) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unsupported PEM block type") +} + +func TestNewTransport_InvalidKey(t *testing.T) { + _, err := NewTransport(nil, Config{ + AppID: 123, + PrivateKey: []byte("invalid"), + InstallationID: 456, + }) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to parse private key") +} + +func TestNewTransport_DefaultBaseURL(t *testing.T) { + _, pemBytes := generateTestKey(t) + tr, err := NewTransport(nil, Config{ + AppID: 123, + PrivateKey: pemBytes, + InstallationID: 456, + }) + require.NoError(t, err) + assert.Equal(t, "https://api.github.com", tr.config.BaseURL) +} + +func TestNewTransport_CustomBaseURL(t *testing.T) { + _, pemBytes := generateTestKey(t) + tr, err := NewTransport(nil, Config{ + AppID: 123, + PrivateKey: pemBytes, + InstallationID: 456, + BaseURL: "https://github.example.com/api/v3", + }) + require.NoError(t, err) + assert.Equal(t, "https://github.example.com/api/v3", tr.config.BaseURL) +} + +func TestTransport_GenerateJWT(t *testing.T) { + key, pemBytes := generateTestKey(t) + tr, err := NewTransport(nil, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + }) + require.NoError(t, err) + + jwtToken, err := tr.generateJWT() + require.NoError(t, err) + + claims, err := verifyJWT(jwtToken, &key.PublicKey) + require.NoError(t, err) + + assert.Equal(t, "12345", claims["iss"]) + + iat := int64(claims["iat"].(float64)) + exp := int64(claims["exp"].(float64)) + assert.InDelta(t, time.Now().Unix(), iat, 60) + assert.InDelta(t, time.Now().Add(9*time.Minute).Unix(), exp, 60) +} + +func TestTransport_FetchInstallationToken(t *testing.T) { + key, pemBytes := generateTestKey(t) + installationID := int64(67890) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expectedPath := fmt.Sprintf("/app/installations/%d/access_tokens", installationID) + assert.Equal(t, expectedPath, r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) + + authHeader := r.Header.Get("Authorization") + assert.True(t, len(authHeader) > 7) + jwtToken := authHeader[7:] // strip "Bearer " + + _, err := verifyJWT(jwtToken, &key.PublicKey) + assert.NoError(t, err) + + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(installationToken{ + Token: "ghs_test_token_123", + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + })) + defer server.Close() + + tr, err := NewTransport(server.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: installationID, + BaseURL: server.URL, + }) + require.NoError(t, err) + + token, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_test_token_123", token) +} + +func TestTransport_TokenCaching(t *testing.T) { + _, pemBytes := generateTestKey(t) + var callCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + callCount.Add(1) + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(installationToken{ + Token: "ghs_cached_token", + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + })) + defer server.Close() + + tr, err := NewTransport(server.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + BaseURL: server.URL, + }) + require.NoError(t, err) + + token1, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_cached_token", token1) + + token2, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_cached_token", token2) + + assert.Equal(t, int32(1), callCount.Load()) +} + +func TestTransport_TokenRefresh(t *testing.T) { + _, pemBytes := generateTestKey(t) + var callCount atomic.Int32 + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + count := callCount.Add(1) + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(installationToken{ + Token: fmt.Sprintf("ghs_token_%d", count), + ExpiresAt: time.Now().Add(1 * time.Minute), // expires soon, within 5min refresh window + }) + })) + defer server.Close() + + tr, err := NewTransport(server.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + BaseURL: server.URL, + }) + require.NoError(t, err) + + token1, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_token_1", token1) + + // Token expires within 5 minutes, so next call should refresh + token2, err := tr.Token(context.Background()) + require.NoError(t, err) + assert.Equal(t, "ghs_token_2", token2) + assert.Equal(t, int32(2), callCount.Load()) +} + +func TestTransport_RoundTrip(t *testing.T) { + _, pemBytes := generateTestKey(t) + + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/app/installations/67890/access_tokens" { + w.WriteHeader(http.StatusCreated) + _ = json.NewEncoder(w).Encode(installationToken{ + Token: "ghs_roundtrip_token", + ExpiresAt: time.Now().Add(1 * time.Hour), + }) + return + } + assert.Equal(t, "Bearer ghs_roundtrip_token", r.Header.Get("Authorization")) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"ok": true}`)) + })) + defer tokenServer.Close() + + tr, err := NewTransport(tokenServer.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + BaseURL: tokenServer.URL, + }) + require.NoError(t, err) + + client := &http.Client{Transport: tr} + resp, err := client.Get(tokenServer.URL + "/repos/owner/repo") + require.NoError(t, err) + defer resp.Body.Close() + assert.Equal(t, http.StatusOK, resp.StatusCode) +} + +func TestTransport_FetchError(t *testing.T) { + _, pemBytes := generateTestKey(t) + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message": "Bad credentials"}`)) + })) + defer server.Close() + + tr, err := NewTransport(server.Client().Transport, Config{ + AppID: 12345, + PrivateKey: pemBytes, + InstallationID: 67890, + BaseURL: server.URL, + }) + require.NoError(t, err) + + _, err = tr.Token(context.Background()) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create installation token") + assert.Contains(t, err.Error(), "Bad credentials") +}