diff --git a/README.md b/README.md index dc063f22c..4f4074ea2 100644 --- a/README.md +++ b/README.md @@ -1092,6 +1092,12 @@ The following sets of tools are available: - `repo`: Repository name (string, required) - `title`: PR title (string, required) +- **get_pull_request_metadata_batch** - Get batch pull request metadata + - **Required OAuth Scopes**: `repo` + - `owner`: Repository owner (string, required) + - `pullNumbers`: Explicit pull request numbers to hydrate. Accepts up to 25 items. (integer[], required) + - `repo`: Repository name (string, required) + - **list_pull_requests** - List pull requests - **Required OAuth Scopes**: `repo` - `base`: Filter by base branch (string, optional) diff --git a/pkg/github/__toolsnaps__/get_pull_request_metadata_batch.snap b/pkg/github/__toolsnaps__/get_pull_request_metadata_batch.snap new file mode 100644 index 000000000..965593b0d --- /dev/null +++ b/pkg/github/__toolsnaps__/get_pull_request_metadata_batch.snap @@ -0,0 +1,36 @@ +{ + "annotations": { + "readOnlyHint": true, + "title": "Get batch pull request metadata" + }, + "description": "Get metadata for an explicit list of pull requests in a GitHub repository. Returns partial success with per-PR errors when some requested pull requests cannot be hydrated.", + "inputSchema": { + "properties": { + "owner": { + "description": "Repository owner", + "type": "string" + }, + "pullNumbers": { + "description": "Explicit pull request numbers to hydrate. Accepts up to 25 items.", + "items": { + "minimum": 1, + "type": "integer" + }, + "maxItems": 25, + "minItems": 1, + "type": "array" + }, + "repo": { + "description": "Repository name", + "type": "string" + } + }, + "required": [ + "owner", + "repo", + "pullNumbers" + ], + "type": "object" + }, + "name": "get_pull_request_metadata_batch" +} \ No newline at end of file diff --git a/pkg/github/pullrequests.go b/pkg/github/pullrequests.go index ae7d04331..184bdfbc6 100644 --- a/pkg/github/pullrequests.go +++ b/pkg/github/pullrequests.go @@ -161,28 +161,40 @@ Possible options: } func GetPullRequest(ctx context.Context, client *github.Client, deps ToolDependencies, owner, repo string, pullNumber int) (*mcp.CallToolResult, error) { + minimalPR, toolErr, err := getMinimalPullRequest(ctx, client, deps, owner, repo, pullNumber) + if toolErr != nil || err != nil { + return toolErr, err + } + + return MarshalledTextResult(minimalPR), nil +} + +func getMinimalPullRequest(ctx context.Context, client *github.Client, deps ToolDependencies, owner, repo string, pullNumber int) (MinimalPullRequest, *mcp.CallToolResult, error) { cache, err := deps.GetRepoAccessCache(ctx) if err != nil { - return nil, fmt.Errorf("failed to get repo access cache: %w", err) + return MinimalPullRequest{}, nil, fmt.Errorf("failed to get repo access cache: %w", err) } ff := deps.GetFlags(ctx) pr, resp, err := client.PullRequests.Get(ctx, owner, repo, pullNumber) if err != nil { - return ghErrors.NewGitHubAPIErrorResponse(ctx, + return MinimalPullRequest{}, ghErrors.NewGitHubAPIErrorResponse(ctx, "failed to get pull request", resp, err, ), nil } + if resp == nil { + return MinimalPullRequest{}, nil, fmt.Errorf("missing GitHub response") + } defer func() { _ = resp.Body.Close() }() if resp.StatusCode != http.StatusOK { body, err := io.ReadAll(resp.Body) if err != nil { - return nil, fmt.Errorf("failed to read response body: %w", err) + return MinimalPullRequest{}, nil, fmt.Errorf("failed to read response body: %w", err) } - return ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request", resp, body), nil + return MinimalPullRequest{}, ghErrors.NewGitHubAPIStatusErrorResponse(ctx, "failed to get pull request", resp, body), nil } // sanitize title/body on response @@ -197,24 +209,22 @@ func GetPullRequest(ctx context.Context, client *github.Client, deps ToolDepende if ff.LockdownMode { if cache == nil { - return nil, fmt.Errorf("lockdown cache is not configured") + return MinimalPullRequest{}, nil, fmt.Errorf("lockdown cache is not configured") } login := pr.GetUser().GetLogin() if login != "" { isSafeContent, err := cache.IsSafeContent(ctx, login, owner, repo) if err != nil { - return nil, fmt.Errorf("failed to check content removal: %w", err) + return MinimalPullRequest{}, nil, fmt.Errorf("failed to check content removal: %w", err) } if !isSafeContent { - return utils.NewToolResultError("access to pull request is restricted by lockdown mode"), nil + return MinimalPullRequest{}, utils.NewToolResultError("access to pull request is restricted by lockdown mode"), nil } } } - minimalPR := convertToMinimalPullRequest(pr) - - return MarshalledTextResult(minimalPR), nil + return convertToMinimalPullRequest(pr), nil, nil } func GetPullRequestDiff(ctx context.Context, client *github.Client, owner, repo string, pullNumber int) (*mcp.CallToolResult, error) { diff --git a/pkg/github/pullrequests_batch_metadata.go b/pkg/github/pullrequests_batch_metadata.go new file mode 100644 index 000000000..4841b0e1a --- /dev/null +++ b/pkg/github/pullrequests_batch_metadata.go @@ -0,0 +1,172 @@ +package github + +import ( + "context" + "fmt" + + "github.com/github/github-mcp-server/pkg/ifc" + "github.com/github/github-mcp-server/pkg/inventory" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/github/github-mcp-server/pkg/utils" + "github.com/google/go-github/v87/github" + "github.com/google/jsonschema-go/jsonschema" + "github.com/modelcontextprotocol/go-sdk/mcp" + + "github.com/github/github-mcp-server/pkg/scopes" +) + +const maxPullRequestMetadataBatchSize = 25 + +type batchPullRequestMetadataError struct { + PullNumber int `json:"pull_number"` + Message string `json:"message"` +} + +type batchPullRequestMetadataResponse struct { + PullRequests []MinimalPullRequest `json:"pull_requests"` + Errors []batchPullRequestMetadataError `json:"errors,omitempty"` +} + +func GetPullRequestMetadataBatch(t translations.TranslationHelperFunc) inventory.ServerTool { + schema := &jsonschema.Schema{ + Type: "object", + Properties: map[string]*jsonschema.Schema{ + "owner": { + Type: "string", + Description: "Repository owner", + }, + "repo": { + Type: "string", + Description: "Repository name", + }, + "pullNumbers": { + Type: "array", + Description: fmt.Sprintf("Explicit pull request numbers to hydrate. Accepts up to %d items.", maxPullRequestMetadataBatchSize), + MinItems: jsonschema.Ptr(1), + MaxItems: jsonschema.Ptr(maxPullRequestMetadataBatchSize), + Items: &jsonschema.Schema{ + Type: "integer", + Minimum: jsonschema.Ptr(1.0), + }, + }, + }, + Required: []string{"owner", "repo", "pullNumbers"}, + } + + return NewTool( + ToolsetMetadataPullRequests, + mcp.Tool{ + Name: "get_pull_request_metadata_batch", + Description: t("TOOL_GET_PULL_REQUEST_METADATA_BATCH_DESCRIPTION", "Get metadata for an explicit list of pull requests in a GitHub repository. Returns partial success with per-PR errors when some requested pull requests cannot be hydrated."), + Annotations: &mcp.ToolAnnotations{ + Title: t("TOOL_GET_PULL_REQUEST_METADATA_BATCH_USER_TITLE", "Get batch pull request metadata"), + ReadOnlyHint: true, + }, + InputSchema: schema, + }, + []scopes.Scope{scopes.Repo}, + func(ctx context.Context, deps ToolDependencies, _ *mcp.CallToolRequest, args map[string]any) (*mcp.CallToolResult, any, error) { + owner, err := RequiredParam[string](args, "owner") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + repo, err := RequiredParam[string](args, "repo") + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + pullNumbers, err := requiredPullNumberBatchParam(args, "pullNumbers", maxPullRequestMetadataBatchSize) + if err != nil { + return utils.NewToolResultError(err.Error()), nil, nil + } + + client, err := deps.GetClient(ctx) + if err != nil { + return utils.NewToolResultErrorFromErr("failed to get GitHub client", err), nil, nil + } + + attachIFC := func(r *mcp.CallToolResult) *mcp.CallToolResult { + return attachRepoVisibilityIFCLabel(ctx, deps, client, owner, repo, r, ifc.LabelListIssues) + } + + result := batchPullRequestMetadataResponse{ + PullRequests: make([]MinimalPullRequest, 0, len(pullNumbers)), + Errors: make([]batchPullRequestMetadataError, 0), + } + + for _, pullNumber := range pullNumbers { + pr, err := fetchMinimalPullRequest(ctx, client, deps, owner, repo, pullNumber) + if err != nil { + result.Errors = append(result.Errors, batchPullRequestMetadataError{ + PullNumber: pullNumber, + Message: err.Error(), + }) + continue + } + + result.PullRequests = append(result.PullRequests, pr) + } + + return attachIFC(MarshalledTextResult(result)), nil, nil + }, + ) +} + +func requiredPullNumberBatchParam(args map[string]any, key string, maxItems int) ([]int, error) { + raw, ok := args[key] + if !ok { + return nil, fmt.Errorf("missing required parameter: %s", key) + } + + values, ok := raw.([]any) + if !ok { + return nil, fmt.Errorf("parameter %s could not be coerced to []int, is %T", key, raw) + } + if len(values) == 0 { + return nil, fmt.Errorf("parameter %s must contain at least one pull request number", key) + } + if len(values) > maxItems { + return nil, fmt.Errorf("parameter %s exceeds the maximum batch size of %d", key, maxItems) + } + + pullNumbers := make([]int, 0, len(values)) + seen := make(map[int]struct{}, len(values)) + for i, value := range values { + number, ok := value.(float64) + if !ok { + return nil, fmt.Errorf("parameter %s element %d is not a number, is %T", key, i, value) + } + if number < 1 || number != float64(int(number)) { + return nil, fmt.Errorf("parameter %s element %d must be a positive integer", key, i) + } + intNumber := int(number) + if _, ok := seen[intNumber]; ok { + continue + } + seen[intNumber] = struct{}{} + pullNumbers = append(pullNumbers, intNumber) + } + + return pullNumbers, nil +} + +func fetchMinimalPullRequest(ctx context.Context, client *github.Client, deps ToolDependencies, owner, repo string, pullNumber int) (MinimalPullRequest, error) { + minimalPR, toolErr, err := getMinimalPullRequest(ctx, client, deps, owner, repo, pullNumber) + if toolErr != nil { + return MinimalPullRequest{}, fmt.Errorf("%s", getErrorResultText(toolErr)) + } + if err != nil { + return MinimalPullRequest{}, err + } + return minimalPR, nil +} + +func getErrorResultText(result *mcp.CallToolResult) string { + if result == nil || len(result.Content) == 0 { + return "failed to get pull request" + } + text, ok := result.Content[0].(*mcp.TextContent) + if !ok { + return "failed to get pull request" + } + return text.Text +} diff --git a/pkg/github/pullrequests_batch_metadata_test.go b/pkg/github/pullrequests_batch_metadata_test.go new file mode 100644 index 000000000..81cf95fb5 --- /dev/null +++ b/pkg/github/pullrequests_batch_metadata_test.go @@ -0,0 +1,265 @@ +package github + +import ( + "context" + "encoding/json" + "net/http" + "testing" + "time" + + "github.com/github/github-mcp-server/internal/githubv4mock" + "github.com/github/github-mcp-server/internal/toolsnaps" + "github.com/github/github-mcp-server/pkg/translations" + "github.com/google/go-github/v87/github" + "github.com/google/jsonschema-go/jsonschema" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_GetPullRequestMetadataBatch(t *testing.T) { + serverTool := GetPullRequestMetadataBatch(translations.NullTranslationHelper) + tool := serverTool.Tool + require.NoError(t, toolsnaps.Test(tool.Name, tool)) + + assert.Equal(t, "get_pull_request_metadata_batch", tool.Name) + schema := tool.InputSchema.(*jsonschema.Schema) + assert.Contains(t, schema.Properties, "owner") + assert.Contains(t, schema.Properties, "repo") + assert.Contains(t, schema.Properties, "pullNumbers") + assert.ElementsMatch(t, schema.Required, []string{"owner", "repo", "pullNumbers"}) + + pr42 := &github.PullRequest{ + Number: github.Ptr(42), + Title: github.Ptr("Release prep"), + State: github.Ptr("closed"), + Merged: github.Ptr(true), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/42"), + MergedAt: &github.Timestamp{Time: time.Date(2026, time.June, 12, 12, 0, 0, 0, time.UTC)}, + User: &github.User{ + Login: github.Ptr("octocat"), + }, + Labels: []*github.Label{{Name: github.Ptr("release")}}, + } + pr18 := &github.PullRequest{ + Number: github.Ptr(18), + Title: github.Ptr("Changelog fix"), + State: github.Ptr("closed"), + Merged: github.Ptr(true), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/18"), + MergedAt: &github.Timestamp{Time: time.Date(2026, time.June, 10, 12, 0, 0, 0, time.UTC)}, + User: &github.User{ + Login: github.Ptr("hubot"), + }, + Labels: []*github.Label{{Name: github.Ptr("docs")}}, + } + + tests := []struct { + name string + mockedClient *http.Client + requestArgs map[string]any + expectError bool + expectedErrMsg string + lockdownEnabled bool + validateResult func(t *testing.T, textContent string) + }{ + { + name: "successful metadata batch preserves input order", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/repos/owner/repo/pulls/42": + mockResponse(t, http.StatusOK, pr42).ServeHTTP(w, r) + case "/repos/owner/repo/pulls/18": + mockResponse(t, http.StatusOK, pr18).ServeHTTP(w, r) + default: + http.NotFound(w, r) + } + }, + }), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{float64(42), float64(18)}, + }, + validateResult: func(t *testing.T, textContent string) { + var result batchPullRequestMetadataResponse + require.NoError(t, json.Unmarshal([]byte(textContent), &result)) + assert.Len(t, result.PullRequests, 2) + assert.Empty(t, result.Errors) + assert.Equal(t, 42, result.PullRequests[0].Number) + assert.Equal(t, "Release prep", result.PullRequests[0].Title) + assert.Equal(t, 18, result.PullRequests[1].Number) + assert.Equal(t, "Changelog fix", result.PullRequests[1].Title) + }, + }, + { + name: "partial failures are returned without failing the batch", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/repos/owner/repo/pulls/42": + mockResponse(t, http.StatusOK, pr42).ServeHTTP(w, r) + case "/repos/owner/repo/pulls/999": + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message":"Not Found"}`)) + default: + http.NotFound(w, r) + } + }, + }), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{float64(42), float64(999)}, + }, + validateResult: func(t *testing.T, textContent string) { + var result batchPullRequestMetadataResponse + require.NoError(t, json.Unmarshal([]byte(textContent), &result)) + assert.Len(t, result.PullRequests, 1) + assert.Equal(t, 42, result.PullRequests[0].Number) + assert.Len(t, result.Errors, 1) + assert.Equal(t, 999, result.Errors[0].PullNumber) + assert.Contains(t, result.Errors[0].Message, "failed to get pull request") + }, + }, + { + name: "duplicate pull numbers are deduplicated before hydration", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: func(w http.ResponseWriter, r *http.Request) { + mockResponse(t, http.StatusOK, pr42).ServeHTTP(w, r) + }, + }), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{float64(42), float64(42), float64(42)}, + }, + validateResult: func(t *testing.T, textContent string) { + var result batchPullRequestMetadataResponse + require.NoError(t, json.Unmarshal([]byte(textContent), &result)) + assert.Len(t, result.PullRequests, 1) + assert.Empty(t, result.Errors) + assert.Equal(t, 42, result.PullRequests[0].Number) + }, + }, + { + name: "lockdown enabled still allows collaborator-authored pull requests", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: func(w http.ResponseWriter, r *http.Request) { + mockResponse(t, http.StatusOK, &github.PullRequest{ + Number: github.Ptr(7), + Title: github.Ptr("Trusted PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/7"), + User: &github.User{Login: github.Ptr("maintainer")}, + }).ServeHTTP(w, r) + }, + }), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{float64(7)}, + }, + lockdownEnabled: true, + validateResult: func(t *testing.T, textContent string) { + var result batchPullRequestMetadataResponse + require.NoError(t, json.Unmarshal([]byte(textContent), &result)) + assert.Len(t, result.PullRequests, 1) + assert.Empty(t, result.Errors) + assert.Equal(t, 7, result.PullRequests[0].Number) + }, + }, + { + name: "lockdown enabled reports restricted pull requests as per-item errors", + mockedClient: MockHTTPClientWithHandlers(map[string]http.HandlerFunc{ + GetReposPullsByOwnerByRepoByPullNumber: func(w http.ResponseWriter, r *http.Request) { + mockResponse(t, http.StatusOK, &github.PullRequest{ + Number: github.Ptr(8), + Title: github.Ptr("Untrusted PR"), + State: github.Ptr("open"), + HTMLURL: github.Ptr("https://github.com/owner/repo/pull/8"), + User: &github.User{Login: github.Ptr("external-user")}, + }).ServeHTTP(w, r) + }, + }), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{float64(8)}, + }, + lockdownEnabled: true, + validateResult: func(t *testing.T, textContent string) { + var result batchPullRequestMetadataResponse + require.NoError(t, json.Unmarshal([]byte(textContent), &result)) + assert.Empty(t, result.PullRequests) + assert.Len(t, result.Errors, 1) + assert.Equal(t, 8, result.Errors[0].PullNumber) + assert.Contains(t, result.Errors[0].Message, "restricted by lockdown mode") + }, + }, + { + name: "empty pullNumbers fails validation", + mockedClient: githubv4mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": []any{}, + }, + expectError: true, + expectedErrMsg: "must contain at least one pull request number", + }, + { + name: "oversized pullNumbers fails validation", + mockedClient: githubv4mock.NewMockedHTTPClient(), + requestArgs: map[string]any{ + "owner": "owner", + "repo": "repo", + "pullNumbers": oversizedPullRequestArgs(maxPullRequestMetadataBatchSize + 1), + }, + expectError: true, + expectedErrMsg: "exceeds the maximum batch size", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + client := mustNewGHClient(t, tc.mockedClient) + var repoAccessClient *github.Client + if tc.lockdownEnabled { + repoAccessClient = mockRESTPermissionServer(t, "read", map[string]string{ + "maintainer": "write", + "external-user": "read", + }) + } + deps := BaseDeps{ + Client: client, + RepoAccessCache: stubRepoAccessCache(repoAccessClient, 5*time.Minute), + Flags: stubFeatureFlags(map[string]bool{"lockdown-mode": tc.lockdownEnabled}), + } + handler := serverTool.Handler(deps) + + request := createMCPRequest(tc.requestArgs) + result, err := handler(ContextWithDeps(context.Background(), deps), &request) + require.NoError(t, err) + + if tc.expectError { + require.True(t, result.IsError) + text := getErrorResult(t, result) + assert.Contains(t, text.Text, tc.expectedErrMsg) + return + } + + require.False(t, result.IsError) + text := getTextResult(t, result) + tc.validateResult(t, text.Text) + }) + } +} + +func oversizedPullRequestArgs(count int) []any { + values := make([]any, 0, count) + for i := range count { + values = append(values, float64(i+1)) + } + return values +} diff --git a/pkg/github/tools.go b/pkg/github/tools.go index 906fa777d..6a3ea6cbf 100644 --- a/pkg/github/tools.go +++ b/pkg/github/tools.go @@ -221,6 +221,7 @@ func AllTools(t translations.TranslationHelperFunc) []inventory.ServerTool { // Pull request tools PullRequestRead(t), + GetPullRequestMetadataBatch(t), ListPullRequests(t), SearchPullRequests(t), MergePullRequest(t),