Handle git ssh artifacts

This PR adds handling for downloading git artifacts using ssh with the
format git@github.com:hashicorp/go-getter.git

Fixes https://github.com/hashicorp/nomad/issues/2430
This commit is contained in:
Alex Dadgar
2017-03-11 15:11:40 -08:00
parent ea18d6f309
commit 637aff7819
3 changed files with 126 additions and 4 deletions

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"net/url"
"path/filepath"
"strings"
"sync"
gg "github.com/hashicorp/go-getter"
@@ -21,6 +22,11 @@ var (
supported = []string{"http", "https", "s3", "hg", "git"}
)
const (
// gitSSHPrefix is the prefix for dowwnloading via git using ssh
gitSSHPrefix = "git@github.com:"
)
// getClient returns a client that is suitable for Nomad downloading artifacts.
func getClient(src, dst string) *gg.Client {
lock.Lock()
@@ -47,7 +53,17 @@ func getClient(src, dst string) *gg.Client {
// getGetterUrl returns the go-getter URL to download the artifact.
func getGetterUrl(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact) (string, error) {
taskEnv.Build()
u, err := url.Parse(taskEnv.ReplaceEnv(artifact.GetterSource))
source := taskEnv.ReplaceEnv(artifact.GetterSource)
// Handle an invalid URL when given a go-getter url such as
// git@github.com:hashicorp/nomad.git
gitSSH := false
if strings.HasPrefix(source, gitSSHPrefix) {
gitSSH = true
source = source[len(gitSSHPrefix):]
}
u, err := url.Parse(source)
if err != nil {
return "", fmt.Errorf("failed to parse source URL %q: %v", artifact.GetterSource, err)
}
@@ -58,7 +74,14 @@ func getGetterUrl(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact)
q.Add(k, taskEnv.ReplaceEnv(v))
}
u.RawQuery = q.Encode()
return u.String(), nil
// Add the prefix back
url := u.String()
if gitSSH {
url = fmt.Sprintf("%s%s", gitSSHPrefix, url)
}
return url, nil
}
// GetArtifact downloads an artifact into the specified task directory.
@@ -71,7 +94,7 @@ func GetArtifact(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact, t
// Download the artifact
dest := filepath.Join(taskDir, artifact.RelativeDest)
if err := getClient(url, dest).Get(); err != nil {
return fmt.Errorf("GET error: %v", err)
return structs.NewRecoverableError(fmt.Errorf("GET error: %v", err), true)
}
return nil

View File

@@ -204,3 +204,102 @@ func TestGetArtifact_Archive(t *testing.T) {
}
checkContents(taskDir, expected, t)
}
func TestGetGetterUrl_Queries(t *testing.T) {
taskEnv := env.NewTaskEnvironment(mock.Node())
cases := []struct {
name string
artifact *structs.TaskArtifact
output string
}{
{
name: "adds query parameters",
artifact: &structs.TaskArtifact{
GetterSource: "https://foo.com?test=1",
GetterOptions: map[string]string{
"foo": "bar",
"bam": "boom",
},
},
output: "https://foo.com?bam=boom&foo=bar&test=1",
},
{
name: "git without http",
artifact: &structs.TaskArtifact{
GetterSource: "github.com/hashicorp/nomad",
GetterOptions: map[string]string{
"ref": "abcd1234",
},
},
output: "github.com/hashicorp/nomad?ref=abcd1234",
},
{
name: "git using ssh",
artifact: &structs.TaskArtifact{
GetterSource: "git@github.com:hashicorp/nomad?sshkey=1",
GetterOptions: map[string]string{
"ref": "abcd1234",
},
},
output: "git@github.com:hashicorp/nomad?ref=abcd1234&sshkey=1",
},
{
name: "s3 scheme 1",
artifact: &structs.TaskArtifact{
GetterSource: "s3::https://s3.amazonaws.com/bucket/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "s3::https://s3.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 2",
artifact: &structs.TaskArtifact{
GetterSource: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 3",
artifact: &structs.TaskArtifact{
GetterSource: "bucket.s3.amazonaws.com/foo",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "bucket.s3.amazonaws.com/foo?aws_access_key_id=abcd1234",
},
{
name: "s3 scheme 4",
artifact: &structs.TaskArtifact{
GetterSource: "bucket.s3-eu-west-1.amazonaws.com/foo/bar",
GetterOptions: map[string]string{
"aws_access_key_id": "abcd1234",
},
},
output: "bucket.s3-eu-west-1.amazonaws.com/foo/bar?aws_access_key_id=abcd1234",
},
{
name: "local file",
artifact: &structs.TaskArtifact{
GetterSource: "/foo/bar",
},
output: "/foo/bar",
},
}
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
act, err := getGetterUrl(taskEnv, c.artifact)
if err != nil {
t.Fatalf("want %q; got err %v", c.output, err)
} else if act != c.output {
t.Fatalf("want %q; got %q", c.output, act)
}
})
}
}

View File

@@ -796,7 +796,7 @@ func (r *TaskRunner) prestart(resultCh chan bool) {
r.logger.Printf("[DEBUG] client: %v", wrapped)
r.setState(structs.TaskStatePending,
structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped))
r.restartTracker.SetStartError(structs.NewRecoverableError(wrapped, true))
r.restartTracker.SetStartError(structs.NewRecoverableError(wrapped, structs.IsRecoverable(err)))
goto RESTART
}
}