diff --git a/client/allocrunner/taskrunner/stats_hook.go b/client/allocrunner/taskrunner/stats_hook.go index 01733e495..7343df5bc 100644 --- a/client/allocrunner/taskrunner/stats_hook.go +++ b/client/allocrunner/taskrunner/stats_hook.go @@ -2,13 +2,13 @@ package taskrunner import ( "context" - "strings" "sync" "time" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/client/allocrunner/interfaces" cstructs "github.com/hashicorp/nomad/client/structs" + bstructs "github.com/hashicorp/nomad/plugins/base/structs" ) // StatsUpdater is the interface required by the StatsHook to update stats. @@ -99,12 +99,13 @@ func (h *statsHook) collectResourceUsageStats(handle interfaces.DriverStats, sto return } - //XXX This is a net/rpc specific error // We do not log when the plugin is shutdown as this is simply a // race between the stopCollection channel being closed and calling // Stats on the handle. - if !strings.Contains(err.Error(), "connection is shut down") { + if err != bstructs.ErrPluginShutdown { h.logger.Debug("error fetching stats of task", "error", err) + } else { + // TODO(alex) this breaks if the handle dies } continue diff --git a/client/allocrunner/taskrunner/task_runner.go b/client/allocrunner/taskrunner/task_runner.go index 88eedf3c8..69e434f02 100644 --- a/client/allocrunner/taskrunner/task_runner.go +++ b/client/allocrunner/taskrunner/task_runner.go @@ -28,9 +28,10 @@ import ( "github.com/hashicorp/nomad/client/vaultclient" "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/structs" + bstructs "github.com/hashicorp/nomad/plugins/base/structs" "github.com/hashicorp/nomad/plugins/drivers" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" ) const ( @@ -408,8 +409,10 @@ MAIN: } // Grab the result proxy and wait for task to exit + WAIT: { handle := tr.getDriverHandle() + result = nil // Do *not* use tr.killCtx here as it would cause // Wait() to unblock before the task exits when Kill() @@ -419,12 +422,15 @@ MAIN: } else { select { case result = <-resultCh: - // WaitCh returned a result - tr.handleTaskExitResult(result) case <-tr.ctx.Done(): // TaskRunner was told to exit immediately return } + + // WaitCh returned a result + if retryWait := tr.handleTaskExitResult(result); retryWait { + goto WAIT + } } } @@ -467,9 +473,33 @@ MAIN: tr.logger.Debug("task run loop exiting") } -func (tr *TaskRunner) handleTaskExitResult(result *drivers.ExitResult) { +// TODO(alex) is this a good return type? Should these be separate methods? +func (tr *TaskRunner) handleTaskExitResult(result *drivers.ExitResult) (retryWait bool) { if result == nil { - return + return false + } + + if result.Err == bstructs.ErrPluginShutdown { + tr.logger.Warn("driver plugin has shutdown; attempting to recover task") + + // Initialize a new driver handle + if err := tr.initDriver(); err != nil { + tr.logger.Error("failed to initialize driver after it exited unexpectedly", "error", err) + return false + } + + // Try to restore the handle + tr.stateLock.RLock() + h := tr.localState.TaskHandle + net := tr.localState.DriverNetwork + tr.stateLock.RUnlock() + if !tr.restoreHandle(h, net) { + tr.logger.Error("failed to restore handle on driver after it exited unexpectedly") + return false + } + + tr.logger.Info("task successfully recovered on driver") + return true } event := structs.NewTaskEvent(structs.TaskTerminated). @@ -483,6 +513,8 @@ func (tr *TaskRunner) handleTaskExitResult(result *drivers.ExitResult) { if result.OOMKilled && !tr.clientConfig.DisableTaggedMetrics { metrics.IncrCounterWithLabels([]string{"client", "allocs", "oom_killed"}, 1, tr.baseLabels) } + + return false } // handleUpdates runs update hooks when triggerUpdateCh is ticked and exits @@ -530,7 +562,6 @@ func (tr *TaskRunner) shouldRestart() (bool, time.Duration) { // runDriver runs the driver and waits for it to exit func (tr *TaskRunner) runDriver() error { - // TODO(nickethier): make sure this uses alloc.AllocatedResources once #4750 is rebased taskConfig := tr.buildTaskConfig() // Build hcl context variables @@ -556,10 +587,10 @@ func (tr *TaskRunner) runDriver() error { evalCtx := &hcl.EvalContext{ Variables: vars, - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } - val, diag := shared.ParseHclInterface(tr.task.Config, tr.taskSchema, evalCtx) + val, diag := hclutils.ParseHclInterface(tr.task.Config, tr.taskSchema, evalCtx) if diag.HasErrors() { return multierror.Append(errors.New("failed to parse config"), diag.Errs()...) } @@ -568,8 +599,6 @@ func (tr *TaskRunner) runDriver() error { return fmt.Errorf("failed to encode driver config: %v", err) } - //XXX Evaluate and encode driver config - // If there's already a task handle (eg from a Restore) there's nothing // to do except update state. if tr.getDriverHandle() != nil { @@ -586,7 +615,22 @@ func (tr *TaskRunner) runDriver() error { // Start the job if there's no existing handle (or if RecoverTask failed) handle, net, err := tr.driver.StartTask(taskConfig) if err != nil { - return fmt.Errorf("driver start failed: %v", err) + // The plugin has died, try relaunching it + if err == bstructs.ErrPluginShutdown { + tr.logger.Info("failed to start task because plugin shutdown unexpectedly; attempting to recover") + if err := tr.initDriver(); err != nil { + tr.logger.Error("failed to initialize driver after it exited unexpectedly", "error", err) + return fmt.Errorf("driver exited and couldn't be started again: %v", err) + } + + handle, net, err = tr.driver.StartTask(taskConfig) + if err != nil { + tr.logger.Error("failed to start task after driver exited unexpectedly", "error", err) + return fmt.Errorf("driver start failed: %v", err) + } + } else { + return fmt.Errorf("driver start failed: %v", err) + } } tr.stateLock.Lock() @@ -732,19 +776,20 @@ func (tr *TaskRunner) Restore() error { return nil } +// TODO(alex) Is the return optimal? // restoreHandle ensures a TaskHandle is valid by calling Driver.RecoverTask // and sets the driver handle. If the TaskHandle is not valid, DestroyTask is // called. -func (tr *TaskRunner) restoreHandle(taskHandle *drivers.TaskHandle, net *cstructs.DriverNetwork) { +func (tr *TaskRunner) restoreHandle(taskHandle *drivers.TaskHandle, net *cstructs.DriverNetwork) (success bool) { // Ensure handle is well-formed if taskHandle.Config == nil { - return + return true } if err := tr.driver.RecoverTask(taskHandle); err != nil { if tr.TaskState().State != structs.TaskStateRunning { // RecoverTask should fail if the Task wasn't running - return + return true } tr.logger.Error("error recovering task; cleaning up", @@ -760,14 +805,15 @@ func (tr *TaskRunner) restoreHandle(taskHandle *drivers.TaskHandle, net *cstruct "error", err, "task_id", taskHandle.Config.ID) } + return false } - return + return true } // Update driver handle on task runner tr.setDriverHandle(NewDriverHandle(tr.driver, taskHandle.Config.ID, tr.Task(), net)) - return + return true } // UpdateState sets the task runners allocation state and triggers a server diff --git a/client/client.go b/client/client.go index 7abadd425..df1b75eae 100644 --- a/client/client.go +++ b/client/client.go @@ -622,9 +622,6 @@ func (c *Client) Shutdown() error { } c.logger.Info("shutting down") - // Shutdown the plugin managers - c.pluginManagers.Shutdown() - // Stop renewing tokens and secrets if c.vaultClient != nil { c.vaultClient.Stop() @@ -649,6 +646,9 @@ func (c *Client) Shutdown() error { } arGroup.Wait() + // Shutdown the plugin managers + c.pluginManagers.Shutdown() + c.shutdown = true close(c.shutdownCh) diff --git a/client/devicemanager/instance.go b/client/devicemanager/instance.go index f834062e1..0837abf10 100644 --- a/client/devicemanager/instance.go +++ b/client/devicemanager/instance.go @@ -10,6 +10,7 @@ import ( multierror "github.com/hashicorp/go-multierror" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/base" + bstructs "github.com/hashicorp/nomad/plugins/base/structs" "github.com/hashicorp/nomad/plugins/device" "github.com/hashicorp/nomad/plugins/shared/loader" "github.com/hashicorp/nomad/plugins/shared/singleton" @@ -363,7 +364,7 @@ START: // Handle any errors if fresp.Error != nil { - if fresp.Error == base.ErrPluginShutdown { + if fresp.Error == bstructs.ErrPluginShutdown { i.logger.Error("plugin exited unexpectedly") goto START } @@ -488,7 +489,7 @@ START: // Handle any errors if sresp.Error != nil { - if sresp.Error == base.ErrPluginShutdown { + if sresp.Error == bstructs.ErrPluginShutdown { i.logger.Error("plugin exited unexpectedly") goto START } diff --git a/client/pluginmanager/drivermanager/instance.go b/client/pluginmanager/drivermanager/instance.go index 47866b122..fe42c196e 100644 --- a/client/pluginmanager/drivermanager/instance.go +++ b/client/pluginmanager/drivermanager/instance.go @@ -9,6 +9,7 @@ import ( log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/base" + bstructs "github.com/hashicorp/nomad/plugins/base/structs" "github.com/hashicorp/nomad/plugins/drivers" "github.com/hashicorp/nomad/plugins/shared/loader" "github.com/hashicorp/nomad/plugins/shared/singleton" @@ -448,7 +449,7 @@ func (i *instanceManager) handleEvents() { // handleEvent looks up the event handler(s) for the event and runs them func (i *instanceManager) handleEvent(ev *drivers.TaskEvent) { // Do not emit that the plugin is shutdown - if ev.Err != nil && ev.Err == base.ErrPluginShutdown { + if ev.Err != nil && ev.Err == bstructs.ErrPluginShutdown { return } diff --git a/drivers/exec/driver_test.go b/drivers/exec/driver_test.go index eee5c426e..dec127da4 100644 --- a/drivers/exec/driver_test.go +++ b/drivers/exec/driver_test.go @@ -21,8 +21,8 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/drivers" dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" @@ -606,11 +606,11 @@ touch: cannot touch '/tmp/task-path-ro/testfile-from-ro': Read-only file system` func encodeDriverHelper(require *require.Assertions, task *drivers.TaskConfig, taskConfig map[string]interface{}) { evalCtx := &hcl.EvalContext{ - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } spec, diag := hclspec.Convert(taskConfigSpec) require.False(diag.HasErrors()) - taskConfigCtyVal, diag := shared.ParseHclInterface(taskConfig, spec, evalCtx) + taskConfigCtyVal, diag := hclutils.ParseHclInterface(taskConfig, spec, evalCtx) require.False(diag.HasErrors()) err := task.EncodeDriverConfig(taskConfigCtyVal) require.Nil(err) diff --git a/drivers/java/driver_test.go b/drivers/java/driver_test.go index 89966c571..ce0ab2221 100644 --- a/drivers/java/driver_test.go +++ b/drivers/java/driver_test.go @@ -19,8 +19,8 @@ import ( "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/drivers" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" ) @@ -272,11 +272,11 @@ func encodeDriverHelper(t *testing.T, task *drivers.TaskConfig, taskConfig map[s t.Helper() evalCtx := &hcl.EvalContext{ - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } spec, diag := hclspec.Convert(taskConfigSpec) require.False(t, diag.HasErrors()) - taskConfigCtyVal, diag := shared.ParseHclInterface(taskConfig, spec, evalCtx) + taskConfigCtyVal, diag := hclutils.ParseHclInterface(taskConfig, spec, evalCtx) require.Empty(t, diag.Errs()) err := task.EncodeDriverConfig(taskConfigCtyVal) require.Nil(t, err) diff --git a/drivers/lxc/driver_test.go b/drivers/lxc/driver_test.go index 8a79510a1..313c9e480 100644 --- a/drivers/lxc/driver_test.go +++ b/drivers/lxc/driver_test.go @@ -19,8 +19,8 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/drivers" dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" lxc "gopkg.in/lxc/go-lxc.v2" @@ -269,11 +269,11 @@ func requireLXC(t *testing.T) { func encodeDriverHelper(require *require.Assertions, task *drivers.TaskConfig, taskConfig map[string]interface{}) { evalCtx := &hcl.EvalContext{ - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } spec, diag := hclspec.Convert(taskConfigSpec) require.False(diag.HasErrors()) - taskConfigCtyVal, diag := shared.ParseHclInterface(taskConfig, spec, evalCtx) + taskConfigCtyVal, diag := hclutils.ParseHclInterface(taskConfig, spec, evalCtx) require.False(diag.HasErrors()) err := task.EncodeDriverConfig(taskConfigCtyVal) require.Nil(err) diff --git a/drivers/qemu/driver_test.go b/drivers/qemu/driver_test.go index 53901f0b7..d0e204c52 100644 --- a/drivers/qemu/driver_test.go +++ b/drivers/qemu/driver_test.go @@ -17,8 +17,8 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/drivers" dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" pstructs "github.com/hashicorp/nomad/plugins/shared/structs" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" @@ -203,11 +203,11 @@ func TestQemuDriver_GetMonitorPathNewQemu(t *testing.T) { //encodeDriverhelper sets up the task config spec and encodes qemu specific driver configuration func encodeDriverHelper(require *require.Assertions, task *drivers.TaskConfig, taskConfig map[string]interface{}) { evalCtx := &hcl.EvalContext{ - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } spec, diag := hclspec.Convert(taskConfigSpec) require.False(diag.HasErrors(), diag.Error()) - taskConfigCtyVal, diag := shared.ParseHclInterface(taskConfig, spec, evalCtx) + taskConfigCtyVal, diag := hclutils.ParseHclInterface(taskConfig, spec, evalCtx) require.False(diag.HasErrors(), diag.Error()) err := task.EncodeDriverConfig(taskConfigCtyVal) require.Nil(err) diff --git a/drivers/rawexec/driver_test.go b/drivers/rawexec/driver_test.go index b65becafa..186475747 100644 --- a/drivers/rawexec/driver_test.go +++ b/drivers/rawexec/driver_test.go @@ -20,8 +20,8 @@ import ( basePlug "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers" dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" pstructs "github.com/hashicorp/nomad/plugins/shared/structs" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" @@ -500,11 +500,11 @@ func TestRawExecDriver_Exec(t *testing.T) { func encodeDriverHelper(require *require.Assertions, task *drivers.TaskConfig, taskConfig map[string]interface{}) { evalCtx := &hcl.EvalContext{ - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } spec, diag := hclspec.Convert(taskConfigSpec) require.False(diag.HasErrors()) - taskConfigCtyVal, diag := shared.ParseHclInterface(taskConfig, spec, evalCtx) + taskConfigCtyVal, diag := hclutils.ParseHclInterface(taskConfig, spec, evalCtx) require.False(diag.HasErrors()) err := task.EncodeDriverConfig(taskConfigCtyVal) require.Nil(err) diff --git a/drivers/rkt/driver_test.go b/drivers/rkt/driver_test.go index 95a75e371..2df70aa97 100644 --- a/drivers/rkt/driver_test.go +++ b/drivers/rkt/driver_test.go @@ -22,8 +22,8 @@ import ( basePlug "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers" dtestutil "github.com/hashicorp/nomad/plugins/drivers/testutils" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" @@ -874,11 +874,11 @@ func TestRktDriver_Stats(t *testing.T) { func encodeDriverHelper(require *require.Assertions, task *drivers.TaskConfig, taskConfig map[string]interface{}) { evalCtx := &hcl.EvalContext{ - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } spec, diag := hclspec.Convert(taskConfigSpec) require.False(diag.HasErrors()) - taskConfigCtyVal, diag := shared.ParseHclInterface(taskConfig, spec, evalCtx) + taskConfigCtyVal, diag := hclutils.ParseHclInterface(taskConfig, spec, evalCtx) if diag.HasErrors() { fmt.Println("conversion error", diag.Error()) } diff --git a/plugins/base/client.go b/plugins/base/client.go index 80bc7ef4b..9dadaeed4 100644 --- a/plugins/base/client.go +++ b/plugins/base/client.go @@ -5,6 +5,7 @@ import ( "fmt" "github.com/hashicorp/nomad/plugins/base/proto" + "github.com/hashicorp/nomad/plugins/shared/grpcutils" "github.com/hashicorp/nomad/plugins/shared/hclspec" ) @@ -20,7 +21,7 @@ type BasePluginClient struct { func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) { presp, err := b.Client.PluginInfo(b.DoneCtx, &proto.PluginInfoRequest{}) if err != nil { - return nil, err + return nil, grpcutils.HandleGrpcErr(err, b.DoneCtx) } var ptype string @@ -46,7 +47,7 @@ func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) { func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) { presp, err := b.Client.ConfigSchema(b.DoneCtx, &proto.ConfigSchemaRequest{}) if err != nil { - return nil, err + return nil, grpcutils.HandleGrpcErr(err, b.DoneCtx) } return presp.GetSpec(), nil @@ -60,5 +61,5 @@ func (b *BasePluginClient) SetConfig(c *Config) error { PluginApiVersion: c.ApiVersion, }) - return err + return grpcutils.HandleGrpcErr(err, b.DoneCtx) } diff --git a/plugins/base/plugin.go b/plugins/base/plugin.go index 411c79662..f511a3d45 100644 --- a/plugins/base/plugin.go +++ b/plugins/base/plugin.go @@ -3,7 +3,6 @@ package base import ( "bytes" "context" - "errors" "reflect" plugin "github.com/hashicorp/go-plugin" @@ -30,9 +29,6 @@ var ( MagicCookieKey: "NOMAD_PLUGIN_MAGIC_COOKIE", MagicCookieValue: "e4327c2e01eabfd75a8a67adb114fb34a757d57eee7728d857a8cec6e91a7255", } - - // ErrPluginShutdown is returned when the plugin has shutdown. - ErrPluginShutdown = errors.New("plugin is shut down") ) // PluginBase is wraps a BasePlugin and implements go-plugins GRPCPlugin diff --git a/plugins/base/structs/errors.go b/plugins/base/structs/errors.go new file mode 100644 index 000000000..0a5a7a6d6 --- /dev/null +++ b/plugins/base/structs/errors.go @@ -0,0 +1,12 @@ +package structs + +import "errors" + +const ( + errPluginShutdown = "plugin is shut down" +) + +var ( + // ErrPluginShutdown is returned when the plugin has shutdown. + ErrPluginShutdown = errors.New(errPluginShutdown) +) diff --git a/plugins/device/client.go b/plugins/device/client.go index ffbb80166..4dc187453 100644 --- a/plugins/device/client.go +++ b/plugins/device/client.go @@ -9,7 +9,7 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/device/proto" - "github.com/hashicorp/nomad/plugins/shared" + "github.com/hashicorp/nomad/plugins/shared/grpcutils" ) // devicePluginClient implements the client side of a remote device plugin, using @@ -30,12 +30,12 @@ type devicePluginClient struct { // cancelled, the error will be propogated. func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *FingerprintResponse, error) { // Join the passed context and the shutdown context - ctx, _ = joincontext.Join(ctx, d.doneCtx) + joinedCtx, _ := joincontext.Join(ctx, d.doneCtx) var req proto.FingerprintRequest - stream, err := d.client.Fingerprint(ctx, &req) + stream, err := d.client.Fingerprint(joinedCtx, &req) if err != nil { - return nil, err + return nil, grpcutils.HandleReqCtxGrpcErr(err, ctx, d.doneCtx) } out := make(chan *FingerprintResponse, 1) @@ -47,7 +47,7 @@ func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri // the gRPC stream to a channel. Exits either when context is cancelled or the // stream has an error. func (d *devicePluginClient) handleFingerprint( - ctx context.Context, + reqCtx context.Context, stream proto.DevicePlugin_FingerprintClient, out chan *FingerprintResponse) { @@ -57,7 +57,7 @@ func (d *devicePluginClient) handleFingerprint( if err != nil { if err != io.EOF { out <- &FingerprintResponse{ - Error: shared.HandleStreamErr(err, ctx, d.doneCtx), + Error: grpcutils.HandleReqCtxGrpcErr(err, reqCtx, d.doneCtx), } } @@ -70,7 +70,7 @@ func (d *devicePluginClient) handleFingerprint( Devices: convertProtoDeviceGroups(resp.GetDeviceGroup()), } select { - case <-ctx.Done(): + case <-reqCtx.Done(): return case out <- f: } @@ -86,7 +86,7 @@ func (d *devicePluginClient) Reserve(deviceIDs []string) (*ContainerReservation, // Make the request resp, err := d.client.Reserve(d.doneCtx, req) if err != nil { - return nil, err + return nil, grpcutils.HandleGrpcErr(err, d.doneCtx) } // Convert the response @@ -100,14 +100,14 @@ func (d *devicePluginClient) Reserve(deviceIDs []string) (*ContainerReservation, // propogated. func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration) (<-chan *StatsResponse, error) { // Join the passed context and the shutdown context - ctx, _ = joincontext.Join(ctx, d.doneCtx) + joinedCtx, _ := joincontext.Join(ctx, d.doneCtx) req := proto.StatsRequest{ CollectionInterval: ptypes.DurationProto(interval), } - stream, err := d.client.Stats(ctx, &req) + stream, err := d.client.Stats(joinedCtx, &req) if err != nil { - return nil, err + return nil, grpcutils.HandleReqCtxGrpcErr(err, ctx, d.doneCtx) } out := make(chan *StatsResponse, 1) @@ -119,7 +119,7 @@ func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration) // the gRPC stream to a channel. Exits either when context is cancelled or the // stream has an error. func (d *devicePluginClient) handleStats( - ctx context.Context, + reqCtx context.Context, stream proto.DevicePlugin_StatsClient, out chan *StatsResponse) { @@ -129,7 +129,7 @@ func (d *devicePluginClient) handleStats( if err != nil { if err != io.EOF { out <- &StatsResponse{ - Error: shared.HandleStreamErr(err, ctx, d.doneCtx), + Error: grpcutils.HandleReqCtxGrpcErr(err, reqCtx, d.doneCtx), } } @@ -142,7 +142,7 @@ func (d *devicePluginClient) handleStats( Groups: convertProtoDeviceGroupsStats(resp.GetGroups()), } select { - case <-ctx.Done(): + case <-reqCtx.Done(): return case out <- s: } diff --git a/plugins/drivers/client.go b/plugins/drivers/client.go index 2b01ffa93..98adc8632 100644 --- a/plugins/drivers/client.go +++ b/plugins/drivers/client.go @@ -12,7 +12,7 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers/proto" - "github.com/hashicorp/nomad/plugins/shared" + "github.com/hashicorp/nomad/plugins/shared/grpcutils" "github.com/hashicorp/nomad/plugins/shared/hclspec" pstructs "github.com/hashicorp/nomad/plugins/shared/structs" sproto "github.com/hashicorp/nomad/plugins/shared/structs/proto" @@ -35,7 +35,7 @@ func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) { resp, err := d.client.TaskConfigSchema(d.doneCtx, req) if err != nil { - return nil, err + return nil, grpcutils.HandleGrpcErr(err, d.doneCtx) } return resp.Spec, nil @@ -46,7 +46,7 @@ func (d *driverPluginClient) Capabilities() (*Capabilities, error) { resp, err := d.client.Capabilities(d.doneCtx, req) if err != nil { - return nil, err + return nil, grpcutils.HandleGrpcErr(err, d.doneCtx) } caps := &Capabilities{} @@ -74,11 +74,11 @@ func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri req := &proto.FingerprintRequest{} // Join the passed context and the shutdown context - ctx, _ = joincontext.Join(ctx, d.doneCtx) + joinedCtx, _ := joincontext.Join(ctx, d.doneCtx) - stream, err := d.client.Fingerprint(ctx, req) + stream, err := d.client.Fingerprint(joinedCtx, req) if err != nil { - return nil, err + return nil, grpcutils.HandleReqCtxGrpcErr(err, ctx, d.doneCtx) } ch := make(chan *Fingerprint, 1) @@ -87,14 +87,14 @@ func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri return ch, nil } -func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fingerprint, stream proto.Driver_FingerprintClient) { +func (d *driverPluginClient) handleFingerprint(reqCtx context.Context, ch chan *Fingerprint, stream proto.Driver_FingerprintClient) { defer close(ch) for { pb, err := stream.Recv() if err != nil { if err != io.EOF { ch <- &Fingerprint{ - Err: shared.HandleStreamErr(err, ctx, d.doneCtx), + Err: grpcutils.HandleReqCtxGrpcErr(err, reqCtx, d.doneCtx), } } @@ -109,7 +109,7 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin } select { - case <-ctx.Done(): + case <-reqCtx.Done(): return case ch <- f: } @@ -122,7 +122,7 @@ func (d *driverPluginClient) RecoverTask(h *TaskHandle) error { req := &proto.RecoverTaskRequest{Handle: taskHandleToProto(h)} _, err := d.client.RecoverTask(d.doneCtx, req) - return err + return grpcutils.HandleGrpcErr(err, d.doneCtx) } // StartTask starts execution of a task with the given TaskConfig. A TaskHandle @@ -141,7 +141,7 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr return nil, nil, structs.NewRecoverableError(err, rec.Recoverable) } } - return nil, nil, err + return nil, nil, grpcutils.HandleGrpcErr(err, d.doneCtx) } var net *cstructs.DriverNetwork @@ -165,10 +165,6 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr // the same task without issue. func (d *driverPluginClient) WaitTask(ctx context.Context, id string) (<-chan *ExitResult, error) { ch := make(chan *ExitResult) - - // Join the passed context and the shutdown context - ctx, _ = joincontext.Join(ctx, d.doneCtx) - go d.handleWaitTask(ctx, id, ch) return ch, nil } @@ -180,9 +176,12 @@ func (d *driverPluginClient) handleWaitTask(ctx context.Context, id string, ch c TaskId: id, } - resp, err := d.client.WaitTask(ctx, req) + // Join the passed context and the shutdown context + joinedCtx, _ := joincontext.Join(ctx, d.doneCtx) + + resp, err := d.client.WaitTask(joinedCtx, req) if err != nil { - result.Err = err + result.Err = grpcutils.HandleReqCtxGrpcErr(err, ctx, d.doneCtx) } else { result.ExitCode = int(resp.Result.ExitCode) result.Signal = int(resp.Result.Signal) @@ -206,7 +205,7 @@ func (d *driverPluginClient) StopTask(taskID string, timeout time.Duration, sign } _, err := d.client.StopTask(d.doneCtx, req) - return err + return grpcutils.HandleGrpcErr(err, d.doneCtx) } // DestroyTask removes the task from the driver's in memory state. The task @@ -219,7 +218,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error { } _, err := d.client.DestroyTask(d.doneCtx, req) - return err + return grpcutils.HandleGrpcErr(err, d.doneCtx) } // InspectTask returns status information for a task @@ -228,7 +227,7 @@ func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) { resp, err := d.client.InspectTask(d.doneCtx, req) if err != nil { - return nil, err + return nil, grpcutils.HandleGrpcErr(err, d.doneCtx) } status, err := taskStatusFromProto(resp.Task) @@ -259,7 +258,7 @@ func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsa resp, err := d.client.TaskStats(d.doneCtx, req) if err != nil { - return nil, err + return nil, grpcutils.HandleGrpcErr(err, d.doneCtx) } stats, err := TaskStatsFromProto(resp.Stats) @@ -276,11 +275,11 @@ func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, req := &proto.TaskEventsRequest{} // Join the passed context and the shutdown context - ctx, _ = joincontext.Join(ctx, d.doneCtx) + joinedCtx, _ := joincontext.Join(ctx, d.doneCtx) - stream, err := d.client.TaskEvents(ctx, req) + stream, err := d.client.TaskEvents(joinedCtx, req) if err != nil { - return nil, err + return nil, grpcutils.HandleReqCtxGrpcErr(err, ctx, d.doneCtx) } ch := make(chan *TaskEvent, 1) @@ -288,14 +287,14 @@ func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, return ch, nil } -func (d *driverPluginClient) handleTaskEvents(ctx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) { +func (d *driverPluginClient) handleTaskEvents(reqCtx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) { defer close(ch) for { ev, err := stream.Recv() if err != nil { if err != io.EOF { ch <- &TaskEvent{ - Err: shared.HandleStreamErr(err, ctx, d.doneCtx), + Err: grpcutils.HandleReqCtxGrpcErr(err, reqCtx, d.doneCtx), } } @@ -313,7 +312,7 @@ func (d *driverPluginClient) handleTaskEvents(ctx context.Context, ch chan *Task Timestamp: timestamp, } select { - case <-ctx.Done(): + case <-reqCtx.Done(): return case ch <- event: } @@ -327,7 +326,7 @@ func (d *driverPluginClient) SignalTask(taskID string, signal string) error { Signal: signal, } _, err := d.client.SignalTask(d.doneCtx, req) - return err + return grpcutils.HandleGrpcErr(err, d.doneCtx) } // ExecTask will run the given command within the execution context of the task. @@ -343,7 +342,7 @@ func (d *driverPluginClient) ExecTask(taskID string, cmd []string, timeout time. resp, err := d.client.ExecTask(d.doneCtx, req) if err != nil { - return nil, err + return nil, grpcutils.HandleGrpcErr(err, d.doneCtx) } result := &ExecTaskResult{ diff --git a/plugins/shared/cmd/launcher/command/device.go b/plugins/shared/cmd/launcher/command/device.go index 01855da7b..fc8b6b7af 100644 --- a/plugins/shared/cmd/launcher/command/device.go +++ b/plugins/shared/cmd/launcher/command/device.go @@ -18,8 +18,8 @@ import ( "github.com/hashicorp/hcl2/hcldec" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/device" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" "github.com/kr/pretty" "github.com/mitchellh/cli" "github.com/zclconf/go-cty/cty/msgpack" @@ -198,10 +198,10 @@ func (c *Device) setConfig(spec hcldec.Spec, apiVersion string, config []byte, n c.logger.Trace("raw hcl config", "config", hclog.Fmt("% #v", pretty.Formatter(configVal))) ctx := &hcl2.EvalContext{ - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } - val, diag := shared.ParseHclInterface(configVal, spec, ctx) + val, diag := hclutils.ParseHclInterface(configVal, spec, ctx) if diag.HasErrors() { errStr := "failed to parse config" for _, err := range diag.Errs() { diff --git a/plugins/shared/grpc_utils.go b/plugins/shared/grpc_utils.go deleted file mode 100644 index 34fb33a87..000000000 --- a/plugins/shared/grpc_utils.go +++ /dev/null @@ -1,61 +0,0 @@ -package shared - -import ( - "context" - "time" - - "github.com/hashicorp/nomad/plugins/base" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -// HandleStreamErr is used to handle a non io.EOF error in a stream. It handles -// detecting if the plugin has shutdown via the passeed pluginCtx. The -// parameters are: -// - err: the error returned from the streaming RPC -// - reqCtx: the context passed to the streaming request -// - pluginCtx: the plugins done ctx used to detect the plugin dying -// -// The return values are: -// - base.ErrPluginShutdown if the error is because the plugin shutdown -// - context.Canceled if the reqCtx is canceled -// - The original error -func HandleStreamErr(err error, reqCtx, pluginCtx context.Context) error { - if err == nil { - return nil - } - - // Determine if the error is because the plugin shutdown - if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable { - // Potentially wait a little before returning an error so we can detect - // the exit - select { - case <-pluginCtx.Done(): - err = base.ErrPluginShutdown - case <-reqCtx.Done(): - err = reqCtx.Err() - - // There is no guarantee that the select will choose the - // doneCtx first so we have to double check - select { - case <-pluginCtx.Done(): - err = base.ErrPluginShutdown - default: - } - case <-time.After(3 * time.Second): - // Its okay to wait a while since the connection isn't available and - // on local host it is likely shutting down. It is not expected for - // this to ever reach even close to 3 seconds. - } - - // It is an error we don't know how to handle, so return it - return err - } - - // Context was cancelled - if errStatus := status.FromContextError(reqCtx.Err()); errStatus.Code() == codes.Canceled { - return context.Canceled - } - - return err -} diff --git a/plugins/shared/grpcutils/utils.go b/plugins/shared/grpcutils/utils.go new file mode 100644 index 000000000..001cf4ad3 --- /dev/null +++ b/plugins/shared/grpcutils/utils.go @@ -0,0 +1,105 @@ +package grpcutils + +import ( + "context" + "time" + + "github.com/hashicorp/nomad/plugins/base/structs" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// HandleReqCtxGrpcErr is used to handle a non io.EOF error in a GRPC request +// where a user supplied context is used. It handles detecting if the plugin has +// shutdown via the passeed pluginCtx. The parameters are: +// - err: the error returned from the streaming RPC +// - reqCtx: the user context passed to the request +// - pluginCtx: the plugins done ctx used to detect the plugin dying +// +// The return values are: +// - ErrPluginShutdown if the error is because the plugin shutdown +// - context.Canceled if the reqCtx is canceled +// - The original error +func HandleReqCtxGrpcErr(err error, reqCtx, pluginCtx context.Context) error { + if err == nil { + return nil + } + + // Determine if the error is because the plugin shutdown + if errStatus, ok := status.FromError(err); ok && + (errStatus.Code() == codes.Unavailable || errStatus.Code() == codes.Canceled) { + // Potentially wait a little before returning an error so we can detect + // the exit + select { + case <-pluginCtx.Done(): + err = structs.ErrPluginShutdown + case <-reqCtx.Done(): + err = reqCtx.Err() + + // There is no guarantee that the select will choose the + // doneCtx first so we have to double check + select { + case <-pluginCtx.Done(): + err = structs.ErrPluginShutdown + default: + } + case <-time.After(3 * time.Second): + // Its okay to wait a while since the connection isn't available and + // on local host it is likely shutting down. It is not expected for + // this to ever reach even close to 3 seconds. + } + + // It is an error we don't know how to handle, so return it + return err + } + + // Context was cancelled + if errStatus := status.FromContextError(reqCtx.Err()); errStatus.Code() == codes.Canceled { + return context.Canceled + } + + return err +} + +// HandleGrpcErr is used to handle errors made to a remote gRPC plugin. It +// handles detecting if the plugin has shutdown via the passeed pluginCtx. The +// parameters are: +// - err: the error returned from the streaming RPC +// - pluginCtx: the plugins done ctx used to detect the plugin dying +// +// The return values are: +// - ErrPluginShutdown if the error is because the plugin shutdown +// - The original error +func HandleGrpcErr(err error, pluginCtx context.Context) error { + if err == nil { + return nil + } + + if errStatus := status.FromContextError(pluginCtx.Err()); errStatus.Code() == codes.Canceled { + // See if the plugin shutdown + select { + case <-pluginCtx.Done(): + err = structs.ErrPluginShutdown + default: + } + } + + // Determine if the error is because the plugin shutdown + if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable { + // Potentially wait a little before returning an error so we can detect + // the exit + select { + case <-pluginCtx.Done(): + err = structs.ErrPluginShutdown + case <-time.After(3 * time.Second): + // Its okay to wait a while since the connection isn't available and + // on local host it is likely shutting down. It is not expected for + // this to ever reach even close to 3 seconds. + } + + // It is an error we don't know how to handle, so return it + return err + } + + return err +} diff --git a/plugins/shared/util.go b/plugins/shared/hclutils/util.go similarity index 99% rename from plugins/shared/util.go rename to plugins/shared/hclutils/util.go index 9152915b4..86a8d2e6c 100644 --- a/plugins/shared/util.go +++ b/plugins/shared/hclutils/util.go @@ -1,4 +1,4 @@ -package shared +package hclutils import ( "bytes" diff --git a/plugins/shared/util_test.go b/plugins/shared/hclutils/util_test.go similarity index 99% rename from plugins/shared/util_test.go rename to plugins/shared/hclutils/util_test.go index 3dc2488ce..bfbb7c0a6 100644 --- a/plugins/shared/util_test.go +++ b/plugins/shared/hclutils/util_test.go @@ -1,4 +1,4 @@ -package shared +package hclutils import ( "testing" diff --git a/plugins/shared/loader/init.go b/plugins/shared/loader/init.go index 89f09198c..7af5c8e53 100644 --- a/plugins/shared/loader/init.go +++ b/plugins/shared/loader/init.go @@ -13,8 +13,8 @@ import ( hcl2 "github.com/hashicorp/hcl2/hcl" "github.com/hashicorp/nomad/nomad/structs/config" "github.com/hashicorp/nomad/plugins/base" - "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/hashicorp/nomad/plugins/shared/hclutils" "github.com/zclconf/go-cty/cty/msgpack" ) @@ -22,7 +22,7 @@ var ( // configParseCtx is the context used to parse a plugin's configuration // stanza configParseCtx = &hcl2.EvalContext{ - Functions: shared.GetStdlibFuncs(), + Functions: hclutils.GetStdlibFuncs(), } ) @@ -467,7 +467,7 @@ func (l *PluginLoader) validePluginConfig(id PluginID, info *pluginInfo) error { } // Parse the config using the spec - val, diag := shared.ParseHclInterface(info.config, spec, configParseCtx) + val, diag := hclutils.ParseHclInterface(info.config, spec, configParseCtx) if diag.HasErrors() { multierror.Append(&mErr, diag.Errs()...) return multierror.Prefix(&mErr, "failed parsing config:")