diff --git a/client/getter/getter.go b/client/getter/getter.go index 22a3f369e..83ee0d695 100644 --- a/client/getter/getter.go +++ b/client/getter/getter.go @@ -8,6 +8,7 @@ import ( "sync" gg "github.com/hashicorp/go-getter" + "github.com/hashicorp/nomad/client/driver/env" "github.com/hashicorp/nomad/nomad/structs" ) @@ -45,8 +46,9 @@ func getClient(src, dst string) *gg.Client { } // getGetterUrl returns the go-getter URL to download the artifact. -func getGetterUrl(artifact *structs.TaskArtifact) (string, error) { - u, err := url.Parse(artifact.GetterSource) +func getGetterUrl(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact) (string, error) { + taskEnv.Build() + u, err := url.Parse(taskEnv.ReplaceEnv(artifact.GetterSource)) if err != nil { return "", fmt.Errorf("failed to parse source URL %q: %v", artifact.GetterSource, err) } @@ -54,15 +56,17 @@ func getGetterUrl(artifact *structs.TaskArtifact) (string, error) { // Build the url q := u.Query() for k, v := range artifact.GetterOptions { - q.Add(k, v) + q.Add(k, taskEnv.ReplaceEnv(v)) } u.RawQuery = q.Encode() return u.String(), nil } // GetArtifact downloads an artifact into the specified task directory. -func GetArtifact(artifact *structs.TaskArtifact, taskDir string, logger *log.Logger) error { - url, err := getGetterUrl(artifact) +func GetArtifact(taskEnv *env.TaskEnvironment, artifact *structs.TaskArtifact, + taskDir string, logger *log.Logger) error { + + url, err := getGetterUrl(taskEnv, artifact) if err != nil { return err } diff --git a/client/getter/getter_test.go b/client/getter/getter_test.go index 208f9efe5..24db5f1c2 100644 --- a/client/getter/getter_test.go +++ b/client/getter/getter_test.go @@ -12,6 +12,8 @@ import ( "strings" "testing" + "github.com/hashicorp/nomad/client/driver/env" + "github.com/hashicorp/nomad/nomad/mock" "github.com/hashicorp/nomad/nomad/structs" ) @@ -37,8 +39,9 @@ func TestGetArtifact_FileAndChecksum(t *testing.T) { } // Download the artifact + taskEnv := env.NewTaskEnvironment(mock.Node()) logger := log.New(os.Stderr, "", log.LstdFlags) - if err := GetArtifact(artifact, taskDir, logger); err != nil { + if err := GetArtifact(taskEnv, artifact, taskDir, logger); err != nil { t.Fatalf("GetArtifact failed: %v", err) } @@ -72,8 +75,9 @@ func TestGetArtifact_File_RelativeDest(t *testing.T) { } // Download the artifact + taskEnv := env.NewTaskEnvironment(mock.Node()) logger := log.New(os.Stderr, "", log.LstdFlags) - if err := GetArtifact(artifact, taskDir, logger); err != nil { + if err := GetArtifact(taskEnv, artifact, taskDir, logger); err != nil { t.Fatalf("GetArtifact failed: %v", err) } @@ -83,6 +87,24 @@ func TestGetArtifact_File_RelativeDest(t *testing.T) { } } +func TestGetGetterUrl_Interprolation(t *testing.T) { + // Create the artifact + artifact := &structs.TaskArtifact{ + GetterSource: "${NOMAD_META_ARTIFACT}", + } + + url := "foo.com" + taskEnv := env.NewTaskEnvironment(mock.Node()).SetTaskMeta(map[string]string{"artifact": url}) + act, err := getGetterUrl(taskEnv, artifact) + if err != nil { + t.Fatalf("getGetterUrl() failed: %v", err) + } + + if act != url { + t.Fatalf("getGetterUrl() returned %q; want %q", act, url) + } +} + func TestGetArtifact_InvalidChecksum(t *testing.T) { // Create the test server hosting the file to download ts := httptest.NewServer(http.FileServer(http.Dir(filepath.Dir("./test-fixtures/")))) @@ -105,8 +127,9 @@ func TestGetArtifact_InvalidChecksum(t *testing.T) { } // Download the artifact and expect an error + taskEnv := env.NewTaskEnvironment(mock.Node()) logger := log.New(os.Stderr, "", log.LstdFlags) - if err := GetArtifact(artifact, taskDir, logger); err == nil { + if err := GetArtifact(taskEnv, artifact, taskDir, logger); err == nil { t.Fatalf("GetArtifact should have failed") } } @@ -171,8 +194,9 @@ func TestGetArtifact_Archive(t *testing.T) { }, } + taskEnv := env.NewTaskEnvironment(mock.Node()) logger := log.New(os.Stderr, "", log.LstdFlags) - if err := GetArtifact(artifact, taskDir, logger); err != nil { + if err := GetArtifact(taskEnv, artifact, taskDir, logger); err != nil { t.Fatalf("GetArtifact failed: %v", err) } diff --git a/client/task_runner.go b/client/task_runner.go index 99400f6be..701d071fe 100644 --- a/client/task_runner.go +++ b/client/task_runner.go @@ -16,6 +16,7 @@ import ( "github.com/hashicorp/nomad/client/getter" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/client/driver/env" cstructs "github.com/hashicorp/nomad/client/driver/structs" ) @@ -43,6 +44,7 @@ type TaskRunner struct { restartTracker *RestartTracker task *structs.Task + taskEnv *env.TaskEnvironment updateCh chan *structs.Allocation handle driver.DriverHandle handleLock sync.Mutex @@ -188,18 +190,29 @@ func (r *TaskRunner) setState(state string, event *structs.TaskEvent) { r.updater(r.task.Name, state, event) } -// createDriver makes a driver for the task -func (r *TaskRunner) createDriver() (driver.Driver, error) { +// setTaskEnv sets the task environment. It returns an error if it could not be +// created. +func (r *TaskRunner) setTaskEnv() error { taskEnv, err := driver.GetTaskEnv(r.ctx.AllocDir, r.config.Node, r.task, r.alloc) if err != nil { - err = fmt.Errorf("failed to create driver '%s' for alloc %s: %v", - r.task.Driver, r.alloc.ID, err) - r.logger.Printf("[ERR] client: %s", err) - return nil, err + return err + } + r.taskEnv = taskEnv + return nil +} +// createDriver makes a driver for the task +func (r *TaskRunner) createDriver() (driver.Driver, error) { + if r.taskEnv == nil { + if err := r.setTaskEnv(); err != nil { + err := fmt.Errorf("failed to create driver '%s' for alloc %s: %v", + r.task.Driver, r.alloc.ID, err) + r.logger.Printf("[ERR] client: %s", err) + return nil, err + } } - driverCtx := driver.NewDriverContext(r.task.Name, r.config, r.config.Node, r.logger, taskEnv) + driverCtx := driver.NewDriverContext(r.task.Name, r.config, r.config.Node, r.logger, r.taskEnv) driver, err := driver.NewDriver(r.task.Driver, driverCtx) if err != nil { err = fmt.Errorf("failed to create driver '%s' for alloc %s: %v", @@ -223,6 +236,13 @@ func (r *TaskRunner) Run() { return } + if err := r.setTaskEnv(); err != nil { + r.setState( + structs.TaskStateDead, + structs.NewTaskEvent(structs.TaskDriverFailure).SetDriverError(err)) + return + } + r.run() return } @@ -277,7 +297,7 @@ func (r *TaskRunner) run() { } for _, artifact := range r.task.Artifacts { - if err := getter.GetArtifact(artifact, taskDir, r.logger); err != nil { + if err := getter.GetArtifact(r.taskEnv, artifact, taskDir, r.logger); err != nil { r.setState(structs.TaskStateDead, structs.NewTaskEvent(structs.TaskArtifactDownloadFailed).SetDownloadError(err)) r.restartTracker.SetStartError(cstructs.NewRecoverableError(err, true)) diff --git a/nomad/structs/structs.go b/nomad/structs/structs.go index 844e0a077..380bdeaef 100644 --- a/nomad/structs/structs.go +++ b/nomad/structs/structs.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "net/url" "path/filepath" "reflect" "regexp" @@ -1980,11 +1979,6 @@ func (ta *TaskArtifact) Validate() error { var mErr multierror.Error if ta.GetterSource == "" { mErr.Errors = append(mErr.Errors, fmt.Errorf("source must be specified")) - } else { - _, err := url.Parse(ta.GetterSource) - if err != nil { - mErr.Errors = append(mErr.Errors, fmt.Errorf("invalid source URL %q: %v", ta.GetterSource, err)) - } } // Verify the destination doesn't escape the tasks directory