Skip to content

Commit b853cc6

Browse files
committed
Support multiple SSH keys for the same host
1 parent 5160e9f commit b853cc6

File tree

3 files changed

+44
-65
lines changed

3 files changed

+44
-65
lines changed

pkg/credentials/gitcreds/creds.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func flags(fs *flag.FlagSet) {
4040
basicConfig = basicGitConfig{entries: make(map[string]basicEntry)}
4141
fs.Var(&basicConfig, basicAuthFlag, "List of secret=url pairs.")
4242

43-
sshConfig = sshGitConfig{entries: make(map[string]sshEntry)}
43+
sshConfig = sshGitConfig{entries: make(map[string][]sshEntry)}
4444
fs.Var(&sshConfig, sshFlag, "List of secret=url pairs.")
4545
}
4646

pkg/credentials/gitcreds/creds_test.go

Lines changed: 16 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -220,11 +220,11 @@ func TestSSHFlagHandling(t *testing.T) {
220220

221221
expectedSSHConfig := fmt.Sprintf(`Host github.com
222222
HostName github.com
223-
IdentityFile %s
224223
Port 22
225-
`, filepath.Join(os.Getenv("HOME"), ".ssh", "id_foo"))
226-
if string(b) != expectedSSHConfig {
227-
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHConfig)
224+
IdentityFile %s/.ssh/id_foo
225+
`, credentials.VolumePath)
226+
if d := cmp.Diff(expectedSSHConfig, string(b)); d != "" {
227+
t.Errorf("ssh_config diff: %s", d)
228228
}
229229

230230
b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "known_hosts"))
@@ -283,8 +283,10 @@ func TestSSHFlagHandlingThrice(t *testing.T) {
283283
fs := flag.NewFlagSet("test", flag.ContinueOnError)
284284
flags(fs)
285285
err := fs.Parse([]string{
286+
// Two secrets target github.com, and both will end up in the
287+
// ssh config.
286288
"-ssh-git=foo=github.com",
287-
"-ssh-git=bar=gitlab.com",
289+
"-ssh-git=bar=github.com",
288290
"-ssh-git=baz=gitlab.example.com:2222",
289291
})
290292
if err != nil {
@@ -303,21 +305,16 @@ func TestSSHFlagHandlingThrice(t *testing.T) {
303305

304306
expectedSSHConfig := fmt.Sprintf(`Host github.com
305307
HostName github.com
306-
IdentityFile %s
307-
Port 22
308-
Host gitlab.com
309-
HostName gitlab.com
310-
IdentityFile %s
311308
Port 22
309+
IdentityFile %s/.ssh/id_foo
310+
IdentityFile %s/.ssh/id_bar
312311
Host gitlab.example.com
313312
HostName gitlab.example.com
314-
IdentityFile %s
315313
Port 2222
316-
`, filepath.Join(os.Getenv("HOME"), ".ssh", "id_foo"),
317-
filepath.Join(os.Getenv("HOME"), ".ssh", "id_bar"),
318-
filepath.Join(os.Getenv("HOME"), ".ssh", "id_baz"))
319-
if string(b) != expectedSSHConfig {
320-
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHConfig)
314+
IdentityFile %s/.ssh/id_baz
315+
`, credentials.VolumePath, credentials.VolumePath, credentials.VolumePath)
316+
if d := cmp.Diff(expectedSSHConfig, string(b)); d != "" {
317+
t.Errorf("ssh_config diff: %s", d)
321318
}
322319

323320
b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "known_hosts"))
@@ -327,8 +324,8 @@ Host gitlab.example.com
327324
expectedSSHKnownHosts := `ssh-rsa aaaa
328325
ssh-rsa bbbb
329326
ssh-rsa cccc`
330-
if string(b) != expectedSSHKnownHosts {
331-
t.Errorf("got: %v, wanted: %v", string(b), expectedSSHKnownHosts)
327+
if d := cmp.Diff(expectedSSHKnownHosts, string(b)); d != "" {
328+
t.Errorf("known_hosts diff: %s", d)
332329
}
333330

334331
b, err = ioutil.ReadFile(filepath.Join(credentials.VolumePath, ".ssh", "id_foo"))
@@ -370,31 +367,12 @@ func TestSSHFlagHandlingMissingFiles(t *testing.T) {
370367
}
371368
// No ssh-privatekey files yields an error.
372369

373-
cfg := sshGitConfig{entries: make(map[string]sshEntry)}
370+
cfg := sshGitConfig{entries: make(map[string][]sshEntry)}
374371
if err := cfg.Set("not-found=github.com"); err == nil {
375372
t.Error("Set(); got success, wanted error.")
376373
}
377374
}
378375

379-
func TestSSHFlagHandlingURLCollision(t *testing.T) {
380-
credentials.VolumePath, _ = ioutil.TempDir("", "")
381-
dir := credentials.VolumeName("foo")
382-
if err := os.MkdirAll(dir, os.ModePerm); err != nil {
383-
t.Fatalf("os.MkdirAll(%s) = %v", dir, err)
384-
}
385-
if err := ioutil.WriteFile(filepath.Join(dir, corev1.SSHAuthPrivateKey), []byte("bar"), 0777); err != nil {
386-
t.Fatalf("ioutil.WriteFile(ssh-privatekey) = %v", err)
387-
}
388-
389-
cfg := sshGitConfig{entries: make(map[string]sshEntry)}
390-
if err := cfg.Set("foo=github.com"); err != nil {
391-
t.Fatalf("First Set() = %v", err)
392-
}
393-
if err := cfg.Set("bar=github.com"); err == nil {
394-
t.Error("Second Set(); got success, wanted error.")
395-
}
396-
}
397-
398376
func TestBasicMalformedValues(t *testing.T) {
399377
tests := []string{
400378
"bar=baz=blah",

pkg/credentials/gitcreds/ssh.go

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ const sshKnownHosts = "known_hosts"
3636
// As the flag is read, this status is populated.
3737
// sshGitConfig implements flag.Value
3838
type sshGitConfig struct {
39-
entries map[string]sshEntry
39+
entries map[string][]sshEntry
4040
// The order we see things, for iterating over the above.
4141
order []string
4242
}
@@ -48,8 +48,9 @@ func (dc *sshGitConfig) String() string {
4848
}
4949
var urls []string
5050
for _, k := range dc.order {
51-
v := dc.entries[k]
52-
urls = append(urls, fmt.Sprintf("%s=%s", v.secret, k))
51+
for _, e := range dc.entries[k] {
52+
urls = append(urls, fmt.Sprintf("%s=%s", e.secretName, k))
53+
}
5354
}
5455
return strings.Join(urls, ",")
5556
}
@@ -59,19 +60,17 @@ func (dc *sshGitConfig) Set(value string) error {
5960
if len(parts) != 2 {
6061
return xerrors.Errorf("Expect entries of the form secret=url, got: %v", value)
6162
}
62-
secret := parts[0]
63+
secretName := parts[0]
6364
url := parts[1]
6465

65-
if _, ok := dc.entries[url]; ok {
66-
return xerrors.Errorf("Multiple entries for url: %v", url)
67-
}
68-
69-
e, err := newSshEntry(url, secret)
66+
e, err := newSshEntry(url, secretName)
7067
if err != nil {
7168
return err
7269
}
73-
dc.entries[url] = *e
74-
dc.order = append(dc.order, url)
70+
if _, exists := dc.entries[url]; !exists {
71+
dc.order = append(dc.order, url)
72+
}
73+
dc.entries[url] = append(dc.entries[url], *e)
7574
return nil
7675
}
7776

@@ -82,7 +81,7 @@ func (dc *sshGitConfig) Write() error {
8281
}
8382

8483
// Walk each of the entries and for each do three things:
85-
// 1. Write out: ~/.ssh/id_{secret} with the secret key
84+
// 1. Write out: ~/.ssh/id_{secretName} with the secret key
8685
// 2. Compute its part of "~/.ssh/config"
8786
// 3. Compute its part of "~/.ssh/known_hosts"
8887
var configEntries []string
@@ -95,17 +94,19 @@ func (dc *sshGitConfig) Write() error {
9594
host = k
9695
port = defaultPort
9796
}
98-
v := dc.entries[k]
99-
if err := v.Write(sshDir); err != nil {
100-
return err
101-
}
102-
configEntries = append(configEntries, fmt.Sprintf(`Host %s
97+
configEntry := fmt.Sprintf(`Host %s
10398
HostName %s
104-
IdentityFile %s
10599
Port %s
106-
`, host, host, v.path(sshDir), port))
107-
108-
knownHosts = append(knownHosts, v.knownHosts)
100+
`, host, host, port)
101+
for _, e := range dc.entries[k] {
102+
if err := e.Write(sshDir); err != nil {
103+
return err
104+
}
105+
configEntry += fmt.Sprintf(` IdentityFile %s
106+
`, e.path(sshDir))
107+
knownHosts = append(knownHosts, e.knownHosts)
108+
}
109+
configEntries = append(configEntries, configEntry)
109110
}
110111
configPath := filepath.Join(sshDir, "config")
111112
configContent := strings.Join(configEntries, "")
@@ -118,13 +119,13 @@ func (dc *sshGitConfig) Write() error {
118119
}
119120

120121
type sshEntry struct {
121-
secret string
122+
secretName string
122123
privateKey string
123124
knownHosts string
124125
}
125126

126127
func (be *sshEntry) path(sshDir string) string {
127-
return filepath.Join(sshDir, "id_"+be.secret)
128+
return filepath.Join(sshDir, "id_"+be.secretName)
128129
}
129130

130131
func sshKeyScan(domain string) ([]byte, error) {
@@ -142,8 +143,8 @@ func (be *sshEntry) Write(sshDir string) error {
142143
return ioutil.WriteFile(be.path(sshDir), []byte(be.privateKey), 0600)
143144
}
144145

145-
func newSshEntry(u, secret string) (*sshEntry, error) {
146-
secretPath := credentials.VolumeName(secret)
146+
func newSshEntry(u, secretName string) (*sshEntry, error) {
147+
secretPath := credentials.VolumeName(secretName)
147148

148149
pk, err := ioutil.ReadFile(filepath.Join(secretPath, corev1.SSHAuthPrivateKey))
149150
if err != nil {
@@ -161,7 +162,7 @@ func newSshEntry(u, secret string) (*sshEntry, error) {
161162
knownHosts := string(kh)
162163

163164
return &sshEntry{
164-
secret: secret,
165+
secretName: secretName,
165166
privateKey: privateKey,
166167
knownHosts: knownHosts,
167168
}, nil

0 commit comments

Comments
 (0)