From 637aff781988d61b68e90fee2f8745f90d4387dd Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Sat, 11 Mar 2017 15:11:40 -0800 Subject: [PATCH] Handle git ssh artifacts This PR adds handling for downloading git artifacts using ssh with the format git@github.com:hashicorp/go-getter.git Fixes https://github.com/hashicorp/nomad/issues/2430 --- client/getter/getter.go | 29 +++++++++-- client/getter/getter_test.go | 99 ++++++++++++++++++++++++++++++++++++ client/task_runner.go | 2 +- 3 files changed, 126 insertions(+), 4 deletions(-) diff --git a/client/getter/getter.go b/client/getter/getter.go index 48ca4d0b4..6bd3d53fb 100644 --- a/client/getter/getter.go +++ b/client/getter/getter.go @@ -4,6 +4,7 @@ import ( "fmt" "net/url" "path/filepath" + "strings" "sync" gg "github.com/hashicorp/go-getter" @@ -21,6 +22,11 @@ var ( supported = []string{"http", "https", "s3", "hg", "git"} ) +const ( + // gitSSHPrefix is the prefix for dowwnloading via git using ssh + gitSSHPrefix = "git@github.com:" +) + // getClient returns a client that is suitable for Nomad downloading artifacts. func getClient(src, dst string) *gg.Client { lock.Lock() @@ -47,7 +53,17 @@ func getClient(src, dst string) *gg.Client { // getGetterUrl returns the go-getter URL to download the artifact. func getGetterUrl(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact) (string, error) { taskEnv.Build() - u, err := url.Parse(taskEnv.ReplaceEnv(artifact.GetterSource)) + source := taskEnv.ReplaceEnv(artifact.GetterSource) + + // Handle an invalid URL when given a go-getter url such as + // git@github.com:hashicorp/nomad.git + gitSSH := false + if strings.HasPrefix(source, gitSSHPrefix) { + gitSSH = true + source = source[len(gitSSHPrefix):] + } + + u, err := url.Parse(source) if err != nil { return "", fmt.Errorf("failed to parse source URL %q: %v", artifact.GetterSource, err) } @@ -58,7 +74,14 @@ func getGetterUrl(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact) q.Add(k, taskEnv.ReplaceEnv(v)) } u.RawQuery = q.Encode() - return u.String(), nil + + // Add the prefix back + url := u.String() + if gitSSH { + url = fmt.Sprintf("%s%s", gitSSHPrefix, url) + } + + return url, nil } // GetArtifact downloads an artifact into the specified task directory. @@ -71,7 +94,7 @@ func GetArtifact(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact, t // Download the artifact dest := filepath.Join(taskDir, artifact.RelativeDest) if err := getClient(url, dest).Get(); err != nil { - return fmt.Errorf("GET error: %v", err) + return structs.NewRecoverableError(fmt.Errorf("GET error: %v", err), true) } return nil diff --git a/client/getter/getter_test.go b/client/getter/getter_test.go index 4fb4bdb97..becbd946c 100644 --- a/client/getter/getter_test.go +++ b/client/getter/getter_test.go @@ -204,3 +204,102 @@ func TestGetArtifact_Archive(t *testing.T) { } checkContents(taskDir, expected, t) } + +func TestGetGetterUrl_Queries(t *testing.T) { + taskEnv := env.NewTaskEnvironment(mock.Node()) + cases := []struct { + name string + artifact *structs.TaskArtifact + output string + }{ + { + name: "adds query parameters", + artifact: &structs.TaskArtifact{ + GetterSource: "https://foo.com?test=1", + GetterOptions: map[string]string{ + "foo": "bar", + "bam": "boom", + }, + }, + output: "https://foo.com?bam=boom&foo=bar&test=1", + }, + { + name: "git without http", + artifact: &structs.TaskArtifact{ + GetterSource: "github.com/hashicorp/nomad", + GetterOptions: map[string]string{ + "ref": "abcd1234", + }, + }, + output: "github.com/hashicorp/nomad?ref=abcd1234", + }, + { + name: "git using ssh", + artifact: &structs.TaskArtifact{ + GetterSource: "git@github.com:hashicorp/nomad?sshkey=1", + GetterOptions: map[string]string{ + "ref": "abcd1234", + }, + }, + output: "git@github.com:hashicorp/nomad?ref=abcd1234&sshkey=1", + }, + { + name: "s3 scheme 1", + artifact: &structs.TaskArtifact{ + GetterSource: "s3::https://s3.amazonaws.com/bucket/foo", + GetterOptions: map[string]string{ + "aws_access_key_id": "abcd1234", + }, + }, + output: "s3::https://s3.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234", + }, + { + name: "s3 scheme 2", + artifact: &structs.TaskArtifact{ + GetterSource: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo", + GetterOptions: map[string]string{ + "aws_access_key_id": "abcd1234", + }, + }, + output: "s3::https://s3-eu-west-1.amazonaws.com/bucket/foo?aws_access_key_id=abcd1234", + }, + { + name: "s3 scheme 3", + artifact: &structs.TaskArtifact{ + GetterSource: "bucket.s3.amazonaws.com/foo", + GetterOptions: map[string]string{ + "aws_access_key_id": "abcd1234", + }, + }, + output: "bucket.s3.amazonaws.com/foo?aws_access_key_id=abcd1234", + }, + { + name: "s3 scheme 4", + artifact: &structs.TaskArtifact{ + GetterSource: "bucket.s3-eu-west-1.amazonaws.com/foo/bar", + GetterOptions: map[string]string{ + "aws_access_key_id": "abcd1234", + }, + }, + output: "bucket.s3-eu-west-1.amazonaws.com/foo/bar?aws_access_key_id=abcd1234", + }, + { + name: "local file", + artifact: &structs.TaskArtifact{ + GetterSource: "/foo/bar", + }, + output: "/foo/bar", + }, + } + + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + act, err := getGetterUrl(taskEnv, c.artifact) + if err != nil { + t.Fatalf("want %q; got err %v", c.output, err) + } else if act != c.output { + t.Fatalf("want %q; got %q", c.output, act) + } + }) + } +} diff --git a/client/task_runner.go b/client/task_runner.go index 36c0d7646..f9d15767c 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -796,7 +796,7 @@ func (r *TaskRunner) prestart(resultCh chan bool) { r.logger.Printf("[DEBUG] client: %v", wrapped) r.setState(structs.TaskStatePending, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(wrapped)) - r.restartTracker.SetStartError(structs.NewRecoverableError(wrapped, true)) + r.restartTracker.SetStartError(structs.NewRecoverableError(wrapped, structs.IsRecoverable(err))) goto RESTART } }