From 9d42f4d0398229db2907506c38e251dd5da48e1f Mon Sep 17 00:00:00 2001 From: Alex Dadgar Date: Mon, 12 Nov 2018 17:09:27 -0800 Subject: [PATCH] Plugin client's handle plugin dying This PR plumbs the plugins done ctx through the base and driver plugin clients (device already had it). Further, it adds generic handling of gRPC stream errors. --- drivers/exec/driver.go | 2 +- drivers/java/driver.go | 2 +- drivers/mock/driver.go | 7 ++- drivers/qemu/driver.go | 2 +- drivers/rawexec/driver.go | 2 +- drivers/rkt/driver.go | 2 +- drivers/rkt/driver_test.go | 8 ++- drivers/shared/eventer/eventer.go | 2 +- plugins/base/client.go | 9 ++-- plugins/base/plugin.go | 5 +- plugins/device/client.go | 66 ++++++------------------- plugins/device/plugin.go | 3 +- plugins/drivers/client.go | 82 ++++++++++++++++++++----------- plugins/drivers/driver.go | 2 +- plugins/drivers/plugin.go | 8 +-- plugins/drivers/plugin_test.go | 2 +- plugins/drivers/server.go | 3 +- plugins/drivers/testing.go | 7 ++- plugins/shared/grpc_utils.go | 61 +++++++++++++++++++++++ 19 files changed, 165 insertions(+), 110 deletions(-) create mode 100644 plugins/shared/grpc_utils.go diff --git a/drivers/exec/driver.go b/drivers/exec/driver.go index d0dde400b..a5ccbc94e 100644 --- a/drivers/exec/driver.go +++ b/drivers/exec/driver.go @@ -1,6 +1,7 @@ package exec import ( + "context" "fmt" "os" "path/filepath" @@ -20,7 +21,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - "golang.org/x/net/context" ) const ( diff --git a/drivers/java/driver.go b/drivers/java/driver.go index dce0467f8..5a68b5ef9 100644 --- a/drivers/java/driver.go +++ b/drivers/java/driver.go @@ -1,6 +1,7 @@ package java import ( + "context" "fmt" "os" "os/exec" @@ -23,7 +24,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - "golang.org/x/net/context" ) const ( diff --git a/drivers/mock/driver.go b/drivers/mock/driver.go index 734ae61f8..6fb2d1b52 100644 --- a/drivers/mock/driver.go +++ b/drivers/mock/driver.go @@ -16,7 +16,6 @@ import ( "github.com/hashicorp/nomad/plugins/drivers" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - netctx "golang.org/x/net/context" ) const ( @@ -232,7 +231,7 @@ func (d *Driver) Capabilities() (*drivers.Capabilities, error) { return capabilities, nil } -func (d *Driver) Fingerprint(ctx netctx.Context) (<-chan *drivers.Fingerprint, error) { +func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) { ch := make(chan *drivers.Fingerprint) go d.handleFingerprint(ctx, ch) return ch, nil @@ -365,7 +364,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru } -func (d *Driver) WaitTask(ctx netctx.Context, taskID string) (<-chan *drivers.ExitResult, error) { +func (d *Driver) WaitTask(ctx context.Context, taskID string) (<-chan *drivers.ExitResult, error) { handle, ok := d.tasks.Get(taskID) if !ok { return nil, drivers.ErrTaskNotFound @@ -430,7 +429,7 @@ func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) { return nil, nil } -func (d *Driver) TaskEvents(ctx netctx.Context) (<-chan *drivers.TaskEvent, error) { +func (d *Driver) TaskEvents(ctx context.Context) (<-chan *drivers.TaskEvent, error) { return d.eventer.TaskEvents(ctx) } diff --git a/drivers/qemu/driver.go b/drivers/qemu/driver.go index 363771515..c8a7ac50c 100644 --- a/drivers/qemu/driver.go +++ b/drivers/qemu/driver.go @@ -1,6 +1,7 @@ package qemu import ( + "context" "errors" "fmt" "net" @@ -25,7 +26,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - "golang.org/x/net/context" ) const ( diff --git a/drivers/rawexec/driver.go b/drivers/rawexec/driver.go index 018c9adfa..da5eb17ab 100644 --- a/drivers/rawexec/driver.go +++ b/drivers/rawexec/driver.go @@ -1,6 +1,7 @@ package rawexec import ( + "context" "fmt" "os" "path/filepath" @@ -22,7 +23,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" - "golang.org/x/net/context" ) const ( diff --git a/drivers/rkt/driver.go b/drivers/rkt/driver.go index 4ffa69f8d..74cac57b3 100644 --- a/drivers/rkt/driver.go +++ b/drivers/rkt/driver.go @@ -4,6 +4,7 @@ package rkt import ( "bytes" + "context" "encoding/json" "fmt" "io/ioutil" @@ -36,7 +37,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/plugins/shared/loader" rktv1 "github.com/rkt/rkt/api/v1" - "golang.org/x/net/context" ) const ( diff --git a/drivers/rkt/driver_test.go b/drivers/rkt/driver_test.go index 7847cb72a..0edfec535 100644 --- a/drivers/rkt/driver_test.go +++ b/drivers/rkt/driver_test.go @@ -3,17 +3,16 @@ package rkt import ( + "bytes" + "context" "fmt" "io/ioutil" + "os" "path/filepath" "sync" "testing" "time" - "os" - - "bytes" - "github.com/hashicorp/hcl2/hcl" ctestutil "github.com/hashicorp/nomad/client/testutil" "github.com/hashicorp/nomad/helper/testlog" @@ -26,7 +25,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/hashicorp/nomad/testutil" "github.com/stretchr/testify/require" - "golang.org/x/net/context" ) var _ drivers.DriverPlugin = (*Driver)(nil) diff --git a/drivers/shared/eventer/eventer.go b/drivers/shared/eventer/eventer.go index a68a20162..1e7674ee4 100644 --- a/drivers/shared/eventer/eventer.go +++ b/drivers/shared/eventer/eventer.go @@ -1,12 +1,12 @@ package eventer import ( + "context" "sync" "time" hclog "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/plugins/drivers" - "golang.org/x/net/context" ) var ( diff --git a/plugins/base/client.go b/plugins/base/client.go index f5476cef7..6baf9a07d 100644 --- a/plugins/base/client.go +++ b/plugins/base/client.go @@ -12,10 +12,13 @@ import ( // gRPC to communicate to the remote plugin. type BasePluginClient struct { Client proto.BasePluginClient + + // DoneCtx is closed when the plugin exits + DoneCtx context.Context } func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) { - presp, err := b.Client.PluginInfo(context.Background(), &proto.PluginInfoRequest{}) + presp, err := b.Client.PluginInfo(b.DoneCtx, &proto.PluginInfoRequest{}) if err != nil { return nil, err } @@ -41,7 +44,7 @@ func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) { } func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) { - presp, err := b.Client.ConfigSchema(context.Background(), &proto.ConfigSchemaRequest{}) + presp, err := b.Client.ConfigSchema(b.DoneCtx, &proto.ConfigSchemaRequest{}) if err != nil { return nil, err } @@ -51,7 +54,7 @@ func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) { func (b *BasePluginClient) SetConfig(data []byte, config *ClientAgentConfig) error { // Send the config - _, err := b.Client.SetConfig(context.Background(), &proto.SetConfigRequest{ + _, err := b.Client.SetConfig(b.DoneCtx, &proto.SetConfigRequest{ MsgpackConfig: data, NomadConfig: config.toProto(), }) diff --git a/plugins/base/plugin.go b/plugins/base/plugin.go index a386d2c45..411c79662 100644 --- a/plugins/base/plugin.go +++ b/plugins/base/plugin.go @@ -51,7 +51,10 @@ func (p *PluginBase) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error } func (p *PluginBase) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { - return &BasePluginClient{Client: proto.NewBasePluginClient(c)}, nil + return &BasePluginClient{ + Client: proto.NewBasePluginClient(c), + DoneCtx: ctx, + }, nil } // MsgpackHandle is a shared handle for encoding/decoding of structs diff --git a/plugins/device/client.go b/plugins/device/client.go index d20146e75..ffbb80166 100644 --- a/plugins/device/client.go +++ b/plugins/device/client.go @@ -9,9 +9,7 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/device/proto" - netctx "golang.org/x/net/context" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" + "github.com/hashicorp/nomad/plugins/shared" ) // devicePluginClient implements the client side of a remote device plugin, using @@ -49,28 +47,33 @@ func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri // the gRPC stream to a channel. Exits either when context is cancelled or the // stream has an error. func (d *devicePluginClient) handleFingerprint( - ctx netctx.Context, + ctx context.Context, stream proto.DevicePlugin_FingerprintClient, out chan *FingerprintResponse) { + defer close(out) for { resp, err := stream.Recv() if err != nil { if err != io.EOF { out <- &FingerprintResponse{ - Error: d.handleStreamErr(err, ctx), + Error: shared.HandleStreamErr(err, ctx, d.doneCtx), } } // End the stream - close(out) return } // Send the response - out <- &FingerprintResponse{ + f := &FingerprintResponse{ Devices: convertProtoDeviceGroups(resp.GetDeviceGroup()), } + select { + case <-ctx.Done(): + return + case out <- f: + } } } @@ -116,69 +119,32 @@ func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration) // the gRPC stream to a channel. Exits either when context is cancelled or the // stream has an error. func (d *devicePluginClient) handleStats( - ctx netctx.Context, + ctx context.Context, stream proto.DevicePlugin_StatsClient, out chan *StatsResponse) { + defer close(out) for { resp, err := stream.Recv() if err != nil { if err != io.EOF { out <- &StatsResponse{ - Error: d.handleStreamErr(err, ctx), + Error: shared.HandleStreamErr(err, ctx, d.doneCtx), } } // End the stream - close(out) return } // Send the response - out <- &StatsResponse{ + s := &StatsResponse{ Groups: convertProtoDeviceGroupsStats(resp.GetGroups()), } - } -} - -// handleStreamErr is used to handle a non io.EOF error in a stream. It handles -// detecting if the plugin has shutdown -func (d *devicePluginClient) handleStreamErr(err error, ctx context.Context) error { - if err == nil { - return nil - } - - // Determine if the error is because the plugin shutdown - if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable { - // Potentially wait a little before returning an error so we can detect - // the exit select { - case <-d.doneCtx.Done(): - err = base.ErrPluginShutdown case <-ctx.Done(): - err = ctx.Err() - - // There is no guarantee that the select will choose the - // doneCtx first so we have to double check - select { - case <-d.doneCtx.Done(): - err = base.ErrPluginShutdown - default: - } - case <-time.After(3 * time.Second): - // Its okay to wait a while since the connection isn't available and - // on local host it is likely shutting down. It is not expected for - // this to ever reach even close to 3 seconds. + return + case out <- s: } - - // It is an error we don't know how to handle, so return it - return err } - - // Context was cancelled - if errStatus := status.FromContextError(ctx.Err()); errStatus.Code() == codes.Canceled { - return context.Canceled - } - - return err } diff --git a/plugins/device/plugin.go b/plugins/device/plugin.go index f03733857..65ec19540 100644 --- a/plugins/device/plugin.go +++ b/plugins/device/plugin.go @@ -31,7 +31,8 @@ func (p *PluginDevice) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker doneCtx: ctx, client: proto.NewDevicePluginClient(c), BasePluginClient: &base.BasePluginClient{ - Client: bproto.NewBasePluginClient(c), + Client: bproto.NewBasePluginClient(c), + DoneCtx: ctx, }, }, nil } diff --git a/plugins/drivers/client.go b/plugins/drivers/client.go index 1fb60ccde..6974f8dcf 100644 --- a/plugins/drivers/client.go +++ b/plugins/drivers/client.go @@ -1,18 +1,19 @@ package drivers import ( + "context" "errors" - "fmt" "io" "time" + "github.com/LK4D4/joincontext" "github.com/golang/protobuf/ptypes" hclog "github.com/hashicorp/go-hclog" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/drivers/proto" + "github.com/hashicorp/nomad/plugins/shared" "github.com/hashicorp/nomad/plugins/shared/hclspec" - "golang.org/x/net/context" ) var _ DriverPlugin = &driverPluginClient{} @@ -22,12 +23,15 @@ type driverPluginClient struct { client proto.DriverClient logger hclog.Logger + + // doneCtx is closed when the plugin exits + doneCtx context.Context } func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) { req := &proto.TaskConfigSchemaRequest{} - resp, err := d.client.TaskConfigSchema(context.Background(), req) + resp, err := d.client.TaskConfigSchema(d.doneCtx, req) if err != nil { return nil, err } @@ -38,7 +42,7 @@ func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) { func (d *driverPluginClient) Capabilities() (*Capabilities, error) { req := &proto.CapabilitiesRequest{} - resp, err := d.client.Capabilities(context.Background(), req) + resp, err := d.client.Capabilities(d.doneCtx, req) if err != nil { return nil, err } @@ -67,12 +71,15 @@ func (d *driverPluginClient) Capabilities() (*Capabilities, error) { func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerprint, error) { req := &proto.FingerprintRequest{} + // Join the passed context and the shutdown context + ctx, _ = joincontext.Join(ctx, d.doneCtx) + stream, err := d.client.Fingerprint(ctx, req) if err != nil { return nil, err } - ch := make(chan *Fingerprint) + ch := make(chan *Fingerprint, 1) go d.handleFingerprint(ctx, ch, stream) return ch, nil @@ -82,17 +89,18 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin defer close(ch) for { pb, err := stream.Recv() - if err == io.EOF { - return - } if err != nil { - select { - case <-ctx.Done(): - case ch <- &Fingerprint{Err: fmt.Errorf("error from RPC stream: %v", err)}: + if err != io.EOF { d.logger.Error("error receiving stream from Fingerprint driver RPC", "error", err) + ch <- &Fingerprint{ + Err: shared.HandleStreamErr(err, ctx, d.doneCtx), + } } + + // End the stream return } + f := &Fingerprint{ Attributes: pb.Attributes, Health: healthStateFromProto(pb.Health), @@ -112,7 +120,7 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin func (d *driverPluginClient) RecoverTask(h *TaskHandle) error { req := &proto.RecoverTaskRequest{Handle: taskHandleToProto(h)} - _, err := d.client.RecoverTask(context.Background(), req) + _, err := d.client.RecoverTask(d.doneCtx, req) return err } @@ -124,7 +132,7 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr Task: taskConfigToProto(c), } - resp, err := d.client.StartTask(context.Background(), req) + resp, err := d.client.StartTask(d.doneCtx, req) if err != nil { return nil, nil, err } @@ -150,6 +158,10 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr // the same task without issue. func (d *driverPluginClient) WaitTask(ctx context.Context, id string) (<-chan *ExitResult, error) { ch := make(chan *ExitResult) + + // Join the passed context and the shutdown context + ctx, _ = joincontext.Join(ctx, d.doneCtx) + go d.handleWaitTask(ctx, id, ch) return ch, nil } @@ -186,7 +198,7 @@ func (d *driverPluginClient) StopTask(taskID string, timeout time.Duration, sign Signal: signal, } - _, err := d.client.StopTask(context.Background(), req) + _, err := d.client.StopTask(d.doneCtx, req) return err } @@ -199,7 +211,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error { Force: force, } - _, err := d.client.DestroyTask(context.Background(), req) + _, err := d.client.DestroyTask(d.doneCtx, req) return err } @@ -207,7 +219,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error { func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) { req := &proto.InspectTaskRequest{TaskId: taskID} - resp, err := d.client.InspectTask(context.Background(), req) + resp, err := d.client.InspectTask(d.doneCtx, req) if err != nil { return nil, err } @@ -238,7 +250,7 @@ func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) { func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) { req := &proto.TaskStatsRequest{TaskId: taskID} - resp, err := d.client.TaskStats(context.Background(), req) + resp, err := d.client.TaskStats(d.doneCtx, req) if err != nil { return nil, err } @@ -255,28 +267,36 @@ func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsa // tasks such as lifecycle events, terminal errors, etc. func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, error) { req := &proto.TaskEventsRequest{} + + // Join the passed context and the shutdown context + ctx, _ = joincontext.Join(ctx, d.doneCtx) + stream, err := d.client.TaskEvents(ctx, req) if err != nil { return nil, err } - ch := make(chan *TaskEvent) - go d.handleTaskEvents(ch, stream) + ch := make(chan *TaskEvent, 1) + go d.handleTaskEvents(ctx, ch, stream) return ch, nil } -func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) { +func (d *driverPluginClient) handleTaskEvents(ctx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) { defer close(ch) for { ev, err := stream.Recv() - if err == io.EOF { - break - } if err != nil { - d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err) - ch <- &TaskEvent{Err: err} - break + if err != io.EOF { + d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err) + ch <- &TaskEvent{ + Err: shared.HandleStreamErr(err, ctx, d.doneCtx), + } + } + + // End the stream + return } + timestamp, _ := ptypes.Timestamp(ev.Timestamp) event := &TaskEvent{ TaskID: ev.TaskId, @@ -284,7 +304,11 @@ func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.D Message: ev.Message, Timestamp: timestamp, } - ch <- event + select { + case <-ctx.Done(): + return + case ch <- event: + } } } @@ -294,7 +318,7 @@ func (d *driverPluginClient) SignalTask(taskID string, signal string) error { TaskId: taskID, Signal: signal, } - _, err := d.client.SignalTask(context.Background(), req) + _, err := d.client.SignalTask(d.doneCtx, req) return err } @@ -309,7 +333,7 @@ func (d *driverPluginClient) ExecTask(taskID string, cmd []string, timeout time. Timeout: ptypes.DurationProto(timeout), } - resp, err := d.client.ExecTask(context.Background(), req) + resp, err := d.client.ExecTask(d.doneCtx, req) if err != nil { return nil, err } diff --git a/plugins/drivers/driver.go b/plugins/drivers/driver.go index 2bb7267c4..458635f6d 100644 --- a/plugins/drivers/driver.go +++ b/plugins/drivers/driver.go @@ -1,6 +1,7 @@ package drivers import ( + "context" "fmt" "path/filepath" "sort" @@ -14,7 +15,6 @@ import ( "github.com/hashicorp/nomad/plugins/shared/hclspec" "github.com/zclconf/go-cty/cty" "github.com/zclconf/go-cty/cty/msgpack" - "golang.org/x/net/context" ) // DriverPlugin is the interface with drivers will implement. It is also diff --git a/plugins/drivers/plugin.go b/plugins/drivers/plugin.go index b485c8836..67165cb8a 100644 --- a/plugins/drivers/plugin.go +++ b/plugins/drivers/plugin.go @@ -38,9 +38,11 @@ func (p *PluginDriver) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) err func (p *PluginDriver) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) { return &driverPluginClient{ BasePluginClient: &base.BasePluginClient{ - Client: baseproto.NewBasePluginClient(c), + DoneCtx: ctx, + Client: baseproto.NewBasePluginClient(c), }, - client: proto.NewDriverClient(c), - logger: p.logger, + client: proto.NewDriverClient(c), + logger: p.logger, + doneCtx: ctx, }, nil } diff --git a/plugins/drivers/plugin_test.go b/plugins/drivers/plugin_test.go index 3409124e8..0bb01ed9f 100644 --- a/plugins/drivers/plugin_test.go +++ b/plugins/drivers/plugin_test.go @@ -2,6 +2,7 @@ package drivers import ( "bytes" + "context" "sync" "testing" "time" @@ -10,7 +11,6 @@ import ( "github.com/hashicorp/nomad/nomad/structs" "github.com/stretchr/testify/require" "github.com/ugorji/go/codec" - "golang.org/x/net/context" ) type testDriverState struct { diff --git a/plugins/drivers/server.go b/plugins/drivers/server.go index 4ad385e24..bbe73e73d 100644 --- a/plugins/drivers/server.go +++ b/plugins/drivers/server.go @@ -4,13 +4,12 @@ import ( "fmt" "io" - "golang.org/x/net/context" - "github.com/golang/protobuf/ptypes" hclog "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/plugins/drivers/proto" + context "golang.org/x/net/context" ) type driverPluginServer struct { diff --git a/plugins/drivers/testing.go b/plugins/drivers/testing.go index 000c81b92..3bace5fb1 100644 --- a/plugins/drivers/testing.go +++ b/plugins/drivers/testing.go @@ -1,16 +1,13 @@ package drivers import ( + "context" "fmt" "io/ioutil" "path/filepath" "runtime" "time" - "github.com/mitchellh/go-testing-interface" - "github.com/stretchr/testify/require" - "golang.org/x/net/context" - hclog "github.com/hashicorp/go-hclog" plugin "github.com/hashicorp/go-plugin" "github.com/hashicorp/nomad/client/allocdir" @@ -21,6 +18,8 @@ import ( "github.com/hashicorp/nomad/helper/uuid" "github.com/hashicorp/nomad/plugins/base" "github.com/hashicorp/nomad/plugins/shared/hclspec" + "github.com/mitchellh/go-testing-interface" + "github.com/stretchr/testify/require" ) type DriverHarness struct { diff --git a/plugins/shared/grpc_utils.go b/plugins/shared/grpc_utils.go new file mode 100644 index 000000000..34fb33a87 --- /dev/null +++ b/plugins/shared/grpc_utils.go @@ -0,0 +1,61 @@ +package shared + +import ( + "context" + "time" + + "github.com/hashicorp/nomad/plugins/base" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +// HandleStreamErr is used to handle a non io.EOF error in a stream. It handles +// detecting if the plugin has shutdown via the passeed pluginCtx. The +// parameters are: +// - err: the error returned from the streaming RPC +// - reqCtx: the context passed to the streaming request +// - pluginCtx: the plugins done ctx used to detect the plugin dying +// +// The return values are: +// - base.ErrPluginShutdown if the error is because the plugin shutdown +// - context.Canceled if the reqCtx is canceled +// - The original error +func HandleStreamErr(err error, reqCtx, pluginCtx context.Context) error { + if err == nil { + return nil + } + + // Determine if the error is because the plugin shutdown + if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable { + // Potentially wait a little before returning an error so we can detect + // the exit + select { + case <-pluginCtx.Done(): + err = base.ErrPluginShutdown + case <-reqCtx.Done(): + err = reqCtx.Err() + + // There is no guarantee that the select will choose the + // doneCtx first so we have to double check + select { + case <-pluginCtx.Done(): + err = base.ErrPluginShutdown + default: + } + case <-time.After(3 * time.Second): + // Its okay to wait a while since the connection isn't available and + // on local host it is likely shutting down. It is not expected for + // this to ever reach even close to 3 seconds. + } + + // It is an error we don't know how to handle, so return it + return err + } + + // Context was cancelled + if errStatus := status.FromContextError(reqCtx.Err()); errStatus.Code() == codes.Canceled { + return context.Canceled + } + + return err +}