Skip to content

Commit 4315c09

Browse files
committed
Ignore non-pre-authenticated hosts when parsing git remotes
This ensures that while having git remotes to point to either `github.com` or authenticated GHE instances, adding another git remote pointing to an unrelated host won't change the remote resolution in any way, even if the unrelated remote is called `upstream` or `github` (and thus normally took precedence).
1 parent c095a4b commit 4315c09

File tree

4 files changed

+137
-44
lines changed

4 files changed

+137
-44
lines changed

git/remote.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,15 @@ var remoteRE = regexp.MustCompile(`(.+)\s+(.+)\s+\((push|fetch)\)`)
1414
// RemoteSet is a slice of git remotes
1515
type RemoteSet []*Remote
1616

17+
func NewRemote(name string, u string) *Remote {
18+
pu, _ := url.Parse(u)
19+
return &Remote{
20+
Name: name,
21+
FetchURL: pu,
22+
PushURL: pu,
23+
}
24+
}
25+
1726
// Remote is a parsed git remote
1827
type Remote struct {
1928
Name string

pkg/cmd/factory/default.go

Lines changed: 6 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ import (
66
"net/http"
77
"os"
88

9-
"github.com/cli/cli/context"
109
"github.com/cli/cli/git"
1110
"github.com/cli/cli/internal/config"
1211
"github.com/cli/cli/internal/ghrepo"
@@ -31,7 +30,11 @@ func New(appVersion string) *cmdutil.Factory {
3130
return cachedConfig, configError
3231
}
3332

34-
remotesFunc := remotesResolver()
33+
rr := &remoteResolver{
34+
readRemotes: git.Remotes,
35+
getConfig: configFunc,
36+
}
37+
remotesFunc := rr.Resolver()
3538

3639
return &cmdutil.Factory{
3740
IOStreams: io,
@@ -51,7 +54,7 @@ func New(appVersion string) *cmdutil.Factory {
5154
if err != nil {
5255
return nil, err
5356
}
54-
return remotes.FindByName("upstream", "github", "origin", "*")
57+
return remotes[0], nil
5558
},
5659
Branch: func() (string, error) {
5760
currentBranch, err := git.CurrentBranch()
@@ -62,44 +65,3 @@ func New(appVersion string) *cmdutil.Factory {
6265
},
6366
}
6467
}
65-
66-
// TODO: pass in a Config instance to parse remotes based on pre-authenticated hostnames
67-
func remotesResolver() func() (context.Remotes, error) {
68-
var cachedRemotes context.Remotes
69-
var remotesError error
70-
71-
return func() (context.Remotes, error) {
72-
if cachedRemotes != nil || remotesError != nil {
73-
return cachedRemotes, remotesError
74-
}
75-
76-
gitRemotes, err := git.Remotes()
77-
if err != nil {
78-
remotesError = err
79-
return nil, err
80-
}
81-
if len(gitRemotes) == 0 {
82-
remotesError = errors.New("no git remotes found")
83-
return nil, remotesError
84-
}
85-
86-
sshTranslate := git.ParseSSHConfig().Translator()
87-
resolvedRemotes := context.TranslateRemotes(gitRemotes, sshTranslate)
88-
89-
// determine hostname by looking at the primary remotes
90-
var hostname string
91-
if mainRemote, err := resolvedRemotes.FindByName("upstream", "github", "origin", "*"); err == nil {
92-
hostname = mainRemote.RepoHost()
93-
}
94-
95-
// filter the rest of the remotes to just that hostname
96-
cachedRemotes = context.Remotes{}
97-
for _, r := range resolvedRemotes {
98-
if r.RepoHost() != hostname {
99-
continue
100-
}
101-
cachedRemotes = append(cachedRemotes, r)
102-
}
103-
return cachedRemotes, nil
104-
}
105-
}

pkg/cmd/factory/remote_resolver.go

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
package factory
2+
3+
import (
4+
"errors"
5+
"net/url"
6+
"sort"
7+
8+
"github.com/cli/cli/context"
9+
"github.com/cli/cli/git"
10+
"github.com/cli/cli/internal/config"
11+
"github.com/cli/cli/internal/ghinstance"
12+
)
13+
14+
type remoteResolver struct {
15+
readRemotes func() (git.RemoteSet, error)
16+
getConfig func() (config.Config, error)
17+
urlTranslator func(*url.URL) *url.URL
18+
}
19+
20+
func (rr *remoteResolver) Resolver() func() (context.Remotes, error) {
21+
var cachedRemotes context.Remotes
22+
var remotesError error
23+
24+
return func() (context.Remotes, error) {
25+
if cachedRemotes != nil || remotesError != nil {
26+
return cachedRemotes, remotesError
27+
}
28+
29+
gitRemotes, err := rr.readRemotes()
30+
if err != nil {
31+
remotesError = err
32+
return nil, err
33+
}
34+
if len(gitRemotes) == 0 {
35+
remotesError = errors.New("no git remotes found")
36+
return nil, remotesError
37+
}
38+
39+
sshTranslate := rr.urlTranslator
40+
if sshTranslate == nil {
41+
sshTranslate = git.ParseSSHConfig().Translator()
42+
}
43+
resolvedRemotes := context.TranslateRemotes(gitRemotes, sshTranslate)
44+
45+
cfg, err := rr.getConfig()
46+
if err != nil {
47+
return nil, err
48+
}
49+
50+
knownHosts := map[string]bool{}
51+
knownHosts[ghinstance.Default()] = true
52+
if authenticatedHosts, err := cfg.Hosts(); err == nil {
53+
for _, h := range authenticatedHosts {
54+
knownHosts[h] = true
55+
}
56+
}
57+
58+
// filter remotes to only those sharing a single, known hostname
59+
var hostname string
60+
cachedRemotes = context.Remotes{}
61+
sort.Sort(resolvedRemotes)
62+
for _, r := range resolvedRemotes {
63+
if hostname == "" {
64+
if !knownHosts[r.RepoHost()] {
65+
continue
66+
}
67+
hostname = r.RepoHost()
68+
} else if r.RepoHost() != hostname {
69+
continue
70+
}
71+
cachedRemotes = append(cachedRemotes, r)
72+
}
73+
74+
if len(cachedRemotes) == 0 {
75+
remotesError = errors.New("none of the git remotes point to a known GitHub host")
76+
return nil, remotesError
77+
}
78+
return cachedRemotes, nil
79+
}
80+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
package factory
2+
3+
import (
4+
"net/url"
5+
"testing"
6+
7+
"github.com/MakeNowJust/heredoc"
8+
"github.com/cli/cli/git"
9+
"github.com/cli/cli/internal/config"
10+
"github.com/stretchr/testify/assert"
11+
"github.com/stretchr/testify/require"
12+
)
13+
14+
func Test_remoteResolver(t *testing.T) {
15+
rr := &remoteResolver{
16+
readRemotes: func() (git.RemoteSet, error) {
17+
return git.RemoteSet{
18+
git.NewRemote("fork", "https://example.org/ghe-owner/ghe-fork.git"),
19+
git.NewRemote("origin", "https://github.com/owner/repo.git"),
20+
git.NewRemote("upstream", "https://example.org/ghe-owner/ghe-repo.git"),
21+
}, nil
22+
},
23+
getConfig: func() (config.Config, error) {
24+
return config.NewFromString(heredoc.Doc(`
25+
hosts:
26+
example.org:
27+
oauth_token: GHETOKEN
28+
`)), nil
29+
},
30+
urlTranslator: func(u *url.URL) *url.URL {
31+
return u
32+
},
33+
}
34+
35+
resolver := rr.Resolver()
36+
remotes, err := resolver()
37+
require.NoError(t, err)
38+
require.Equal(t, 2, len(remotes))
39+
40+
assert.Equal(t, "upstream", remotes[0].Name)
41+
assert.Equal(t, "fork", remotes[1].Name)
42+
}

0 commit comments

Comments
 (0)