mirror of
https://github.com/kemko/nomad.git
synced 2026-01-06 18:35:44 +03:00
artifact: fix numerous go-getter security issues
Fix numerous go-getter security issues: - Add timeouts to http, git, and hg operations to prevent DoS - Add size limit to http to prevent resource exhaustion - Disable following symlinks in both artifacts and `job run` - Stop performing initial HEAD request to avoid file corruption on retries and DoS opportunities. **Approach** Since Nomad has no ability to differentiate a DoS-via-large-artifact vs a legitimate workload, all of the new limits are configurable at the client agent level. The max size of HTTP downloads is also exposed as a node attribute so that if some workloads have large artifacts they can specify a high limit in their jobspecs. In the future all of this plumbing could be extended to enable/disable specific getters or artifact downloading entirely on a per-node basis.
This commit is contained in:
committed by
Luiz Aoqui
parent
94abe338e9
commit
3968509886
@@ -182,6 +182,9 @@ type allocRunner struct {
|
||||
// serviceRegWrapper is the handler wrapper that is used by service hooks
|
||||
// to perform service and check registration and deregistration.
|
||||
serviceRegWrapper *wrapper.HandlerWrapper
|
||||
|
||||
// getter is an interface for retrieving artifacts.
|
||||
getter cinterfaces.ArtifactGetter
|
||||
}
|
||||
|
||||
// RPCer is the interface needed by hooks to make RPC calls.
|
||||
@@ -226,6 +229,7 @@ func NewAllocRunner(config *Config) (*allocRunner, error) {
|
||||
serversContactedCh: config.ServersContactedCh,
|
||||
rpcClient: config.RPCClient,
|
||||
serviceRegWrapper: config.ServiceRegWrapper,
|
||||
getter: config.Getter,
|
||||
}
|
||||
|
||||
// Create the logger based on the allocation ID
|
||||
@@ -280,6 +284,7 @@ func (ar *allocRunner) initTaskRunners(tasks []*structs.Task) error {
|
||||
StartConditionMetCtx: ar.taskHookCoordinator.startConditionForTask(task),
|
||||
ShutdownDelayCtx: ar.shutdownDelayCtx,
|
||||
ServiceRegWrapper: ar.serviceRegWrapper,
|
||||
Getter: ar.getter,
|
||||
}
|
||||
|
||||
if ar.cpusetManager != nil {
|
||||
|
||||
@@ -86,4 +86,7 @@ type Config struct {
|
||||
// ServiceRegWrapper is the handler wrapper that is used by service hooks
|
||||
// to perform service and check registration and deregistration.
|
||||
ServiceRegWrapper *wrapper.HandlerWrapper
|
||||
|
||||
// Getter is an interface for retrieving artifacts.
|
||||
Getter interfaces.ArtifactGetter
|
||||
}
|
||||
|
||||
@@ -7,8 +7,8 @@ import (
|
||||
|
||||
log "github.com/hashicorp/go-hclog"
|
||||
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
|
||||
"github.com/hashicorp/nomad/client/allocrunner/taskrunner/getter"
|
||||
ti "github.com/hashicorp/nomad/client/allocrunner/taskrunner/interfaces"
|
||||
ci "github.com/hashicorp/nomad/client/interfaces"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
)
|
||||
|
||||
@@ -16,11 +16,13 @@ import (
|
||||
type artifactHook struct {
|
||||
eventEmitter ti.EventEmitter
|
||||
logger log.Logger
|
||||
getter ci.ArtifactGetter
|
||||
}
|
||||
|
||||
func newArtifactHook(e ti.EventEmitter, logger log.Logger) *artifactHook {
|
||||
func newArtifactHook(e ti.EventEmitter, getter ci.ArtifactGetter, logger log.Logger) *artifactHook {
|
||||
h := &artifactHook{
|
||||
eventEmitter: e,
|
||||
getter: getter,
|
||||
}
|
||||
h.logger = logger.Named(h.Name())
|
||||
return h
|
||||
@@ -40,7 +42,7 @@ func (h *artifactHook) doWork(req *interfaces.TaskPrestartRequest, resp *interfa
|
||||
|
||||
h.logger.Debug("downloading artifact", "artifact", artifact.GetterSource, "aid", aid)
|
||||
//XXX add ctx to GetArtifact to allow cancelling long downloads
|
||||
if err := getter.GetArtifact(req.TaskEnv, artifact); err != nil {
|
||||
if err := h.getter.GetArtifact(req.TaskEnv, artifact); err != nil {
|
||||
|
||||
wrapped := structs.NewRecoverableError(
|
||||
fmt.Errorf("failed to download artifact %q: %v", artifact.GetterSource, err),
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/hashicorp/nomad/ci"
|
||||
"github.com/hashicorp/nomad/client/allocdir"
|
||||
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
|
||||
"github.com/hashicorp/nomad/client/allocrunner/taskrunner/getter"
|
||||
"github.com/hashicorp/nomad/client/taskenv"
|
||||
"github.com/hashicorp/nomad/helper"
|
||||
"github.com/hashicorp/nomad/helper/testlog"
|
||||
@@ -38,7 +39,7 @@ func TestTaskRunner_ArtifactHook_Recoverable(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
me := &mockEmitter{}
|
||||
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
|
||||
artifactHook := newArtifactHook(me, getter.TestDefaultGetter(t), testlog.HCLogger(t))
|
||||
|
||||
req := &interfaces.TaskPrestartRequest{
|
||||
TaskEnv: taskenv.NewEmptyTaskEnv(),
|
||||
@@ -71,7 +72,7 @@ func TestTaskRunner_ArtifactHook_PartialDone(t *testing.T) {
|
||||
ci.Parallel(t)
|
||||
|
||||
me := &mockEmitter{}
|
||||
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
|
||||
artifactHook := newArtifactHook(me, getter.TestDefaultGetter(t), testlog.HCLogger(t))
|
||||
|
||||
// Create a source directory with 1 of the 2 artifacts
|
||||
srcdir := t.TempDir()
|
||||
@@ -159,7 +160,7 @@ func TestTaskRunner_ArtifactHook_ConcurrentDownloadSuccess(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
me := &mockEmitter{}
|
||||
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
|
||||
artifactHook := newArtifactHook(me, getter.TestDefaultGetter(t), testlog.HCLogger(t))
|
||||
|
||||
// Create a source directory all 7 artifacts
|
||||
srcdir := t.TempDir()
|
||||
@@ -246,7 +247,7 @@ func TestTaskRunner_ArtifactHook_ConcurrentDownloadFailure(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
me := &mockEmitter{}
|
||||
artifactHook := newArtifactHook(me, testlog.HCLogger(t))
|
||||
artifactHook := newArtifactHook(me, getter.TestDefaultGetter(t), testlog.HCLogger(t))
|
||||
|
||||
// Create a source directory with 3 of the 4 artifacts
|
||||
srcdir := t.TempDir()
|
||||
|
||||
@@ -10,63 +10,132 @@ import (
|
||||
"github.com/hashicorp/go-cleanhttp"
|
||||
gg "github.com/hashicorp/go-getter"
|
||||
|
||||
"github.com/hashicorp/nomad/client/config"
|
||||
"github.com/hashicorp/nomad/client/interfaces"
|
||||
"github.com/hashicorp/nomad/nomad/structs"
|
||||
)
|
||||
|
||||
// httpClient is a shared HTTP client for use across all http/https Getter
|
||||
// instantiations. The HTTP client is designed to be thread-safe, and using a pooled
|
||||
// transport will help reduce excessive connections when clients are downloading lots
|
||||
// of artifacts.
|
||||
var httpClient = &http.Client{
|
||||
Transport: cleanhttp.DefaultPooledTransport(),
|
||||
}
|
||||
|
||||
const (
|
||||
// gitSSHPrefix is the prefix for downloading via git using ssh
|
||||
gitSSHPrefix = "git@github.com:"
|
||||
)
|
||||
|
||||
// EnvReplacer is an interface which can interpolate environment variables and
|
||||
// is usually satisfied by taskenv.TaskEnv.
|
||||
type EnvReplacer interface {
|
||||
ReplaceEnv(string) string
|
||||
ClientPath(string, bool) (string, bool)
|
||||
// Getter wraps go-getter calls in an artifact configuration.
|
||||
type Getter struct {
|
||||
// httpClient is a shared HTTP client for use across all http/https
|
||||
// Getter instantiations. The HTTP client is designed to be
|
||||
// thread-safe, and using a pooled transport will help reduce excessive
|
||||
// connections when clients are downloading lots of artifacts.
|
||||
httpClient *http.Client
|
||||
config *config.ArtifactConfig
|
||||
}
|
||||
|
||||
// NewGetter returns a new Getter instance. This function is called once per
|
||||
// client and shared across alloc and task runners.
|
||||
func NewGetter(config *config.ArtifactConfig) *Getter {
|
||||
return &Getter{
|
||||
httpClient: &http.Client{
|
||||
Transport: cleanhttp.DefaultPooledTransport(),
|
||||
},
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// GetArtifact downloads an artifact into the specified task directory.
|
||||
func (g *Getter) GetArtifact(taskEnv interfaces.EnvReplacer, artifact *structs.TaskArtifact) error {
|
||||
ggURL, err := getGetterUrl(taskEnv, artifact)
|
||||
if err != nil {
|
||||
return newGetError(artifact.GetterSource, err, false)
|
||||
}
|
||||
|
||||
dest, escapes := taskEnv.ClientPath(artifact.RelativeDest, true)
|
||||
// Verify the destination is still in the task sandbox after interpolation
|
||||
if escapes {
|
||||
return newGetError(artifact.RelativeDest,
|
||||
errors.New("artifact destination path escapes the alloc directory"),
|
||||
false)
|
||||
}
|
||||
|
||||
// Convert from string getter mode to go-getter const
|
||||
mode := gg.ClientModeAny
|
||||
switch artifact.GetterMode {
|
||||
case structs.GetterModeFile:
|
||||
mode = gg.ClientModeFile
|
||||
case structs.GetterModeDir:
|
||||
mode = gg.ClientModeDir
|
||||
}
|
||||
|
||||
headers := getHeaders(taskEnv, artifact.GetterHeaders)
|
||||
if err := g.getClient(ggURL, headers, mode, dest).Get(); err != nil {
|
||||
return newGetError(ggURL, err, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// getClient returns a client that is suitable for Nomad downloading artifacts.
|
||||
func getClient(src string, headers http.Header, mode gg.ClientMode, dst string) *gg.Client {
|
||||
func (g *Getter) getClient(src string, headers http.Header, mode gg.ClientMode, dst string) *gg.Client {
|
||||
return &gg.Client{
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
Mode: mode,
|
||||
Umask: 060000000,
|
||||
Getters: createGetters(headers),
|
||||
Getters: g.createGetters(headers),
|
||||
|
||||
// This will prevent copying or writing files through symlinks
|
||||
DisableSymlinks: true,
|
||||
}
|
||||
}
|
||||
|
||||
func createGetters(header http.Header) map[string]gg.Getter {
|
||||
func (g *Getter) createGetters(header http.Header) map[string]gg.Getter {
|
||||
httpGetter := &gg.HttpGetter{
|
||||
Netrc: true,
|
||||
Client: httpClient,
|
||||
Client: g.httpClient,
|
||||
Header: header,
|
||||
|
||||
// Do not support the custom X-Terraform-Get header and
|
||||
// associated logic.
|
||||
XTerraformGetDisabled: true,
|
||||
|
||||
// Disable HEAD requests as they can produce corrupt files when
|
||||
// retrying a download of a resource that has changed.
|
||||
// hashicorp/go-getter#219
|
||||
DoNotCheckHeadFirst: true,
|
||||
|
||||
// Read timeout for HTTP operations. Must be long enough to
|
||||
// accommodate large/slow downloads.
|
||||
ReadTimeout: g.config.HTTPReadTimeout,
|
||||
|
||||
// Maximum download size. Must be large enough to accommodate
|
||||
// large downloads.
|
||||
MaxBytes: g.config.HTTPMaxBytes,
|
||||
}
|
||||
|
||||
// Explicitly create fresh set of supported Getter for each Client, because
|
||||
// go-getter is not thread-safe. Use a shared HTTP client for http/https Getter,
|
||||
// with pooled transport which is thread-safe.
|
||||
//
|
||||
// If a getter type is not listed here, it is not supported (e.g. file).
|
||||
return map[string]gg.Getter{
|
||||
"git": new(gg.GitGetter),
|
||||
"gcs": new(gg.GCSGetter),
|
||||
"hg": new(gg.HgGetter),
|
||||
"s3": new(gg.S3Getter),
|
||||
"git": &gg.GitGetter{
|
||||
Timeout: g.config.GitTimeout,
|
||||
},
|
||||
"hg": &gg.HgGetter{
|
||||
Timeout: g.config.HgTimeout,
|
||||
},
|
||||
"gcs": &gg.GCSGetter{
|
||||
Timeout: g.config.GCSTimeout,
|
||||
},
|
||||
"s3": &gg.S3Getter{
|
||||
Timeout: g.config.S3Timeout,
|
||||
},
|
||||
"http": httpGetter,
|
||||
"https": httpGetter,
|
||||
}
|
||||
}
|
||||
|
||||
// getGetterUrl returns the go-getter URL to download the artifact.
|
||||
func getGetterUrl(taskEnv EnvReplacer, artifact *structs.TaskArtifact) (string, error) {
|
||||
func getGetterUrl(taskEnv interfaces.EnvReplacer, artifact *structs.TaskArtifact) (string, error) {
|
||||
source := taskEnv.ReplaceEnv(artifact.GetterSource)
|
||||
|
||||
// Handle an invalid URL when given a go-getter url such as
|
||||
@@ -98,7 +167,7 @@ func getGetterUrl(taskEnv EnvReplacer, artifact *structs.TaskArtifact) (string,
|
||||
return ggURL, nil
|
||||
}
|
||||
|
||||
func getHeaders(env EnvReplacer, m map[string]string) http.Header {
|
||||
func getHeaders(env interfaces.EnvReplacer, m map[string]string) http.Header {
|
||||
if len(m) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -110,38 +179,6 @@ func getHeaders(env EnvReplacer, m map[string]string) http.Header {
|
||||
return headers
|
||||
}
|
||||
|
||||
// GetArtifact downloads an artifact into the specified task directory.
|
||||
func GetArtifact(taskEnv EnvReplacer, artifact *structs.TaskArtifact) error {
|
||||
ggURL, err := getGetterUrl(taskEnv, artifact)
|
||||
if err != nil {
|
||||
return newGetError(artifact.GetterSource, err, false)
|
||||
}
|
||||
|
||||
dest, escapes := taskEnv.ClientPath(artifact.RelativeDest, true)
|
||||
// Verify the destination is still in the task sandbox after interpolation
|
||||
if escapes {
|
||||
return newGetError(artifact.RelativeDest,
|
||||
errors.New("artifact destination path escapes the alloc directory"),
|
||||
false)
|
||||
}
|
||||
|
||||
// Convert from string getter mode to go-getter const
|
||||
mode := gg.ClientModeAny
|
||||
switch artifact.GetterMode {
|
||||
case structs.GetterModeFile:
|
||||
mode = gg.ClientModeFile
|
||||
case structs.GetterModeDir:
|
||||
mode = gg.ClientModeDir
|
||||
}
|
||||
|
||||
headers := getHeaders(taskEnv, artifact.GetterHeaders)
|
||||
if err := getClient(ggURL, headers, mode, dest).Get(); err != nil {
|
||||
return newGetError(ggURL, err, true)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetError wraps the underlying artifact fetching error with the URL. It
|
||||
// implements the RecoverableError interface.
|
||||
type GetError struct {
|
||||
|
||||
@@ -13,7 +13,11 @@ import (
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
gg "github.com/hashicorp/go-getter"
|
||||
clientconfig "github.com/hashicorp/nomad/client/config"
|
||||
"github.com/hashicorp/nomad/client/interfaces"
|
||||
"github.com/hashicorp/nomad/client/taskenv"
|
||||
"github.com/hashicorp/nomad/helper"
|
||||
"github.com/hashicorp/nomad/nomad/mock"
|
||||
@@ -46,7 +50,7 @@ func (r noopReplacer) ClientPath(p string, join bool) (string, bool) {
|
||||
return path, escapes
|
||||
}
|
||||
|
||||
func noopTaskEnv(taskDir string) EnvReplacer {
|
||||
func noopTaskEnv(taskDir string) interfaces.EnvReplacer {
|
||||
return noopReplacer{
|
||||
taskDir: taskDir,
|
||||
}
|
||||
@@ -67,6 +71,51 @@ func (u upperReplacer) ClientPath(p string, join bool) (string, bool) {
|
||||
return path, escapes
|
||||
}
|
||||
|
||||
func TestGetter_getClient(t *testing.T) {
|
||||
getter := NewGetter(&clientconfig.ArtifactConfig{
|
||||
HTTPReadTimeout: time.Minute,
|
||||
HTTPMaxBytes: 100_000,
|
||||
GCSTimeout: 1 * time.Minute,
|
||||
GitTimeout: 2 * time.Minute,
|
||||
HgTimeout: 3 * time.Minute,
|
||||
S3Timeout: 4 * time.Minute,
|
||||
})
|
||||
client := getter.getClient("src", nil, gg.ClientModeAny, "dst")
|
||||
|
||||
t.Run("check symlink config", func(t *testing.T) {
|
||||
require.True(t, client.DisableSymlinks)
|
||||
})
|
||||
|
||||
t.Run("check http config", func(t *testing.T) {
|
||||
require.True(t, client.Getters["http"].(*gg.HttpGetter).XTerraformGetDisabled)
|
||||
require.Equal(t, time.Minute, client.Getters["http"].(*gg.HttpGetter).ReadTimeout)
|
||||
require.Equal(t, int64(100_000), client.Getters["http"].(*gg.HttpGetter).MaxBytes)
|
||||
})
|
||||
|
||||
t.Run("check https config", func(t *testing.T) {
|
||||
require.True(t, client.Getters["https"].(*gg.HttpGetter).XTerraformGetDisabled)
|
||||
require.Equal(t, time.Minute, client.Getters["https"].(*gg.HttpGetter).ReadTimeout)
|
||||
require.Equal(t, int64(100_000), client.Getters["https"].(*gg.HttpGetter).MaxBytes)
|
||||
})
|
||||
|
||||
t.Run("check gcs config", func(t *testing.T) {
|
||||
require.Equal(t, client.Getters["gcs"].(*gg.GCSGetter).Timeout, 1*time.Minute)
|
||||
})
|
||||
|
||||
t.Run("check git config", func(t *testing.T) {
|
||||
require.Equal(t, client.Getters["git"].(*gg.GitGetter).Timeout, 2*time.Minute)
|
||||
})
|
||||
|
||||
t.Run("check hg config", func(t *testing.T) {
|
||||
require.Equal(t, client.Getters["hg"].(*gg.HgGetter).Timeout, 3*time.Minute)
|
||||
})
|
||||
|
||||
t.Run("check s3 config", func(t *testing.T) {
|
||||
require.Equal(t, client.Getters["s3"].(*gg.S3Getter).Timeout, 4*time.Minute)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func TestGetArtifact_getHeaders(t *testing.T) {
|
||||
t.Run("nil", func(t *testing.T) {
|
||||
require.Nil(t, getHeaders(noopTaskEnv(""), nil))
|
||||
@@ -118,10 +167,12 @@ func TestGetArtifact_Headers(t *testing.T) {
|
||||
}
|
||||
|
||||
// Download the artifact.
|
||||
getter := TestDefaultGetter(t)
|
||||
taskEnv := upperReplacer{
|
||||
taskDir: taskDir,
|
||||
}
|
||||
err := GetArtifact(taskEnv, artifact)
|
||||
|
||||
err := getter.GetArtifact(taskEnv, artifact)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify artifact exists.
|
||||
@@ -151,7 +202,8 @@ func TestGetArtifact_FileAndChecksum(t *testing.T) {
|
||||
}
|
||||
|
||||
// Download the artifact
|
||||
if err := GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
|
||||
getter := TestDefaultGetter(t)
|
||||
if err := getter.GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
|
||||
t.Fatalf("GetArtifact failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -181,7 +233,8 @@ func TestGetArtifact_File_RelativeDest(t *testing.T) {
|
||||
}
|
||||
|
||||
// Download the artifact
|
||||
if err := GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
|
||||
getter := TestDefaultGetter(t)
|
||||
if err := getter.GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
|
||||
t.Fatalf("GetArtifact failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -211,7 +264,8 @@ func TestGetArtifact_File_EscapeDest(t *testing.T) {
|
||||
}
|
||||
|
||||
// attempt to download the artifact
|
||||
err := GetArtifact(noopTaskEnv(taskDir), artifact)
|
||||
getter := TestDefaultGetter(t)
|
||||
err := getter.GetArtifact(noopTaskEnv(taskDir), artifact)
|
||||
if err == nil || !strings.Contains(err.Error(), "escapes") {
|
||||
t.Fatalf("expected GetArtifact to disallow sandbox escape: %v", err)
|
||||
}
|
||||
@@ -257,7 +311,8 @@ func TestGetArtifact_InvalidChecksum(t *testing.T) {
|
||||
}
|
||||
|
||||
// Download the artifact and expect an error
|
||||
if err := GetArtifact(noopTaskEnv(taskDir), artifact); err == nil {
|
||||
getter := TestDefaultGetter(t)
|
||||
if err := getter.GetArtifact(noopTaskEnv(taskDir), artifact); err == nil {
|
||||
t.Fatalf("GetArtifact should have failed")
|
||||
}
|
||||
}
|
||||
@@ -318,7 +373,8 @@ func TestGetArtifact_Archive(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
if err := GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
|
||||
getter := TestDefaultGetter(t)
|
||||
if err := getter.GetArtifact(noopTaskEnv(taskDir), artifact); err != nil {
|
||||
t.Fatalf("GetArtifact failed: %v", err)
|
||||
}
|
||||
|
||||
@@ -349,7 +405,8 @@ func TestGetArtifact_Setuid(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
require.NoError(t, GetArtifact(noopTaskEnv(taskDir), artifact))
|
||||
getter := TestDefaultGetter(t)
|
||||
require.NoError(t, getter.GetArtifact(noopTaskEnv(taskDir), artifact))
|
||||
|
||||
var expected map[string]int
|
||||
|
||||
|
||||
18
client/allocrunner/taskrunner/getter/testing.go
Normal file
18
client/allocrunner/taskrunner/getter/testing.go
Normal file
@@ -0,0 +1,18 @@
|
||||
//go:build !release
|
||||
// +build !release
|
||||
|
||||
package getter
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
clientconfig "github.com/hashicorp/nomad/client/config"
|
||||
"github.com/hashicorp/nomad/nomad/structs/config"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDefaultGetter(t *testing.T) *Getter {
|
||||
getterConf, err := clientconfig.ArtifactConfigFromAgent(config.DefaultArtifactConfig())
|
||||
require.NoError(t, err)
|
||||
return NewGetter(getterConf)
|
||||
}
|
||||
@@ -244,6 +244,9 @@ type TaskRunner struct {
|
||||
// serviceRegWrapper is the handler wrapper that is used by service hooks
|
||||
// to perform service and check registration and deregistration.
|
||||
serviceRegWrapper *wrapper.HandlerWrapper
|
||||
|
||||
// getter is an interface for retrieving artifacts.
|
||||
getter cinterfaces.ArtifactGetter
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
@@ -309,6 +312,9 @@ type Config struct {
|
||||
// ServiceRegWrapper is the handler wrapper that is used by service hooks
|
||||
// to perform service and check registration and deregistration.
|
||||
ServiceRegWrapper *wrapper.HandlerWrapper
|
||||
|
||||
// Getter is an interface for retrieving artifacts.
|
||||
Getter cinterfaces.ArtifactGetter
|
||||
}
|
||||
|
||||
func NewTaskRunner(config *Config) (*TaskRunner, error) {
|
||||
@@ -367,6 +373,7 @@ func NewTaskRunner(config *Config) (*TaskRunner, error) {
|
||||
shutdownDelayCtx: config.ShutdownDelayCtx,
|
||||
shutdownDelayCancelFn: config.ShutdownDelayCancelFn,
|
||||
serviceRegWrapper: config.ServiceRegWrapper,
|
||||
getter: config.Getter,
|
||||
}
|
||||
|
||||
// Create the logger based on the allocation ID
|
||||
|
||||
@@ -64,7 +64,7 @@ func (tr *TaskRunner) initHooks() {
|
||||
newLogMonHook(tr, hookLogger),
|
||||
newDispatchHook(alloc, hookLogger),
|
||||
newVolumeHook(tr, hookLogger),
|
||||
newArtifactHook(tr, hookLogger),
|
||||
newArtifactHook(tr, tr.getter, hookLogger),
|
||||
newStatsHook(tr, tr.clientConfig.StatsCollectionInterval, hookLogger),
|
||||
newDeviceHook(tr.devicemanager, hookLogger),
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/hashicorp/nomad/ci"
|
||||
"github.com/hashicorp/nomad/client/allocdir"
|
||||
"github.com/hashicorp/nomad/client/allocrunner/interfaces"
|
||||
"github.com/hashicorp/nomad/client/allocrunner/taskrunner/getter"
|
||||
"github.com/hashicorp/nomad/client/config"
|
||||
consulapi "github.com/hashicorp/nomad/client/consul"
|
||||
"github.com/hashicorp/nomad/client/devicemanager"
|
||||
@@ -145,6 +146,7 @@ func testTaskRunnerConfig(t *testing.T, alloc *structs.Allocation, taskName stri
|
||||
ShutdownDelayCtx: shutdownDelayCtx,
|
||||
ShutdownDelayCancelFn: shutdownDelayCancelFn,
|
||||
ServiceRegWrapper: wrapperMock,
|
||||
Getter: getter.TestDefaultGetter(t),
|
||||
}
|
||||
|
||||
// Set the cgroup path getter if we are in v2 mode
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/hashicorp/nomad/client/allocrunner/taskrunner/getter"
|
||||
"github.com/hashicorp/nomad/client/allocwatcher"
|
||||
clientconfig "github.com/hashicorp/nomad/client/config"
|
||||
"github.com/hashicorp/nomad/client/consul"
|
||||
@@ -83,7 +84,9 @@ func testAllocRunnerConfig(t *testing.T, alloc *structs.Allocation) (*Config, fu
|
||||
CpusetManager: new(cgutil.NoopCpusetManager),
|
||||
ServersContactedCh: make(chan struct{}),
|
||||
ServiceRegWrapper: wrapper.NewHandlerWrapper(clientConf.Logger, consulRegMock, nomadRegMock),
|
||||
Getter: getter.TestDefaultGetter(t),
|
||||
}
|
||||
|
||||
return conf, cleanup
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user