Plugin client's handle plugin dying

This PR plumbs the plugins done ctx through the base and driver plugin
clients (device already had it). Further, it adds generic handling of
gRPC stream errors.
This commit is contained in:
Alex Dadgar
2018-11-12 17:09:27 -08:00
parent 0200000d53
commit 9d42f4d039
19 changed files with 165 additions and 110 deletions

View File

@@ -1,6 +1,7 @@
package exec
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -20,7 +21,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (

View File

@@ -1,6 +1,7 @@
package java
import (
"context"
"fmt"
"os"
"os/exec"
@@ -23,7 +24,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (

View File

@@ -16,7 +16,6 @@ import (
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
netctx "golang.org/x/net/context"
)
const (
@@ -232,7 +231,7 @@ func (d *Driver) Capabilities() (*drivers.Capabilities, error) {
return capabilities, nil
}
func (d *Driver) Fingerprint(ctx netctx.Context) (<-chan *drivers.Fingerprint, error) {
func (d *Driver) Fingerprint(ctx context.Context) (<-chan *drivers.Fingerprint, error) {
ch := make(chan *drivers.Fingerprint)
go d.handleFingerprint(ctx, ch)
return ch, nil
@@ -365,7 +364,7 @@ func (d *Driver) StartTask(cfg *drivers.TaskConfig) (*drivers.TaskHandle, *cstru
}
func (d *Driver) WaitTask(ctx netctx.Context, taskID string) (<-chan *drivers.ExitResult, error) {
func (d *Driver) WaitTask(ctx context.Context, taskID string) (<-chan *drivers.ExitResult, error) {
handle, ok := d.tasks.Get(taskID)
if !ok {
return nil, drivers.ErrTaskNotFound
@@ -430,7 +429,7 @@ func (d *Driver) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
return nil, nil
}
func (d *Driver) TaskEvents(ctx netctx.Context) (<-chan *drivers.TaskEvent, error) {
func (d *Driver) TaskEvents(ctx context.Context) (<-chan *drivers.TaskEvent, error) {
return d.eventer.TaskEvents(ctx)
}

View File

@@ -1,6 +1,7 @@
package qemu
import (
"context"
"errors"
"fmt"
"net"
@@ -25,7 +26,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (

View File

@@ -1,6 +1,7 @@
package rawexec
import (
"context"
"fmt"
"os"
"path/filepath"
@@ -22,7 +23,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
"golang.org/x/net/context"
)
const (

View File

@@ -4,6 +4,7 @@ package rkt
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
@@ -36,7 +37,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/plugins/shared/loader"
rktv1 "github.com/rkt/rkt/api/v1"
"golang.org/x/net/context"
)
const (

View File

@@ -3,17 +3,16 @@
package rkt
import (
"bytes"
"context"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"sync"
"testing"
"time"
"os"
"bytes"
"github.com/hashicorp/hcl2/hcl"
ctestutil "github.com/hashicorp/nomad/client/testutil"
"github.com/hashicorp/nomad/helper/testlog"
@@ -26,7 +25,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/hashicorp/nomad/testutil"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
)
var _ drivers.DriverPlugin = (*Driver)(nil)

View File

@@ -1,12 +1,12 @@
package eventer
import (
"context"
"sync"
"time"
hclog "github.com/hashicorp/go-hclog"
"github.com/hashicorp/nomad/plugins/drivers"
"golang.org/x/net/context"
)
var (

View File

@@ -12,10 +12,13 @@ import (
// gRPC to communicate to the remote plugin.
type BasePluginClient struct {
Client proto.BasePluginClient
// DoneCtx is closed when the plugin exits
DoneCtx context.Context
}
func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
presp, err := b.Client.PluginInfo(context.Background(), &proto.PluginInfoRequest{})
presp, err := b.Client.PluginInfo(b.DoneCtx, &proto.PluginInfoRequest{})
if err != nil {
return nil, err
}
@@ -41,7 +44,7 @@ func (b *BasePluginClient) PluginInfo() (*PluginInfoResponse, error) {
}
func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
presp, err := b.Client.ConfigSchema(context.Background(), &proto.ConfigSchemaRequest{})
presp, err := b.Client.ConfigSchema(b.DoneCtx, &proto.ConfigSchemaRequest{})
if err != nil {
return nil, err
}
@@ -51,7 +54,7 @@ func (b *BasePluginClient) ConfigSchema() (*hclspec.Spec, error) {
func (b *BasePluginClient) SetConfig(data []byte, config *ClientAgentConfig) error {
// Send the config
_, err := b.Client.SetConfig(context.Background(), &proto.SetConfigRequest{
_, err := b.Client.SetConfig(b.DoneCtx, &proto.SetConfigRequest{
MsgpackConfig: data,
NomadConfig: config.toProto(),
})

View File

@@ -51,7 +51,10 @@ func (p *PluginBase) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error
}
func (p *PluginBase) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &BasePluginClient{Client: proto.NewBasePluginClient(c)}, nil
return &BasePluginClient{
Client: proto.NewBasePluginClient(c),
DoneCtx: ctx,
}, nil
}
// MsgpackHandle is a shared handle for encoding/decoding of structs

View File

@@ -9,9 +9,7 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/device/proto"
netctx "golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/hashicorp/nomad/plugins/shared"
)
// devicePluginClient implements the client side of a remote device plugin, using
@@ -49,28 +47,33 @@ func (d *devicePluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerpri
// the gRPC stream to a channel. Exits either when context is cancelled or the
// stream has an error.
func (d *devicePluginClient) handleFingerprint(
ctx netctx.Context,
ctx context.Context,
stream proto.DevicePlugin_FingerprintClient,
out chan *FingerprintResponse) {
defer close(out)
for {
resp, err := stream.Recv()
if err != nil {
if err != io.EOF {
out <- &FingerprintResponse{
Error: d.handleStreamErr(err, ctx),
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
close(out)
return
}
// Send the response
out <- &FingerprintResponse{
f := &FingerprintResponse{
Devices: convertProtoDeviceGroups(resp.GetDeviceGroup()),
}
select {
case <-ctx.Done():
return
case out <- f:
}
}
}
@@ -116,69 +119,32 @@ func (d *devicePluginClient) Stats(ctx context.Context, interval time.Duration)
// the gRPC stream to a channel. Exits either when context is cancelled or the
// stream has an error.
func (d *devicePluginClient) handleStats(
ctx netctx.Context,
ctx context.Context,
stream proto.DevicePlugin_StatsClient,
out chan *StatsResponse) {
defer close(out)
for {
resp, err := stream.Recv()
if err != nil {
if err != io.EOF {
out <- &StatsResponse{
Error: d.handleStreamErr(err, ctx),
Error: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
close(out)
return
}
// Send the response
out <- &StatsResponse{
s := &StatsResponse{
Groups: convertProtoDeviceGroupsStats(resp.GetGroups()),
}
}
}
// handleStreamErr is used to handle a non io.EOF error in a stream. It handles
// detecting if the plugin has shutdown
func (d *devicePluginClient) handleStreamErr(err error, ctx context.Context) error {
if err == nil {
return nil
}
// Determine if the error is because the plugin shutdown
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
// Potentially wait a little before returning an error so we can detect
// the exit
select {
case <-d.doneCtx.Done():
err = base.ErrPluginShutdown
case <-ctx.Done():
err = ctx.Err()
// There is no guarantee that the select will choose the
// doneCtx first so we have to double check
select {
case <-d.doneCtx.Done():
err = base.ErrPluginShutdown
default:
}
case <-time.After(3 * time.Second):
// Its okay to wait a while since the connection isn't available and
// on local host it is likely shutting down. It is not expected for
// this to ever reach even close to 3 seconds.
return
case out <- s:
}
// It is an error we don't know how to handle, so return it
return err
}
// Context was cancelled
if errStatus := status.FromContextError(ctx.Err()); errStatus.Code() == codes.Canceled {
return context.Canceled
}
return err
}

View File

@@ -31,7 +31,8 @@ func (p *PluginDevice) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker
doneCtx: ctx,
client: proto.NewDevicePluginClient(c),
BasePluginClient: &base.BasePluginClient{
Client: bproto.NewBasePluginClient(c),
Client: bproto.NewBasePluginClient(c),
DoneCtx: ctx,
},
}, nil
}

View File

@@ -1,18 +1,19 @@
package drivers
import (
"context"
"errors"
"fmt"
"io"
"time"
"github.com/LK4D4/joincontext"
"github.com/golang/protobuf/ptypes"
hclog "github.com/hashicorp/go-hclog"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/drivers/proto"
"github.com/hashicorp/nomad/plugins/shared"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"golang.org/x/net/context"
)
var _ DriverPlugin = &driverPluginClient{}
@@ -22,12 +23,15 @@ type driverPluginClient struct {
client proto.DriverClient
logger hclog.Logger
// doneCtx is closed when the plugin exits
doneCtx context.Context
}
func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
req := &proto.TaskConfigSchemaRequest{}
resp, err := d.client.TaskConfigSchema(context.Background(), req)
resp, err := d.client.TaskConfigSchema(d.doneCtx, req)
if err != nil {
return nil, err
}
@@ -38,7 +42,7 @@ func (d *driverPluginClient) TaskConfigSchema() (*hclspec.Spec, error) {
func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
req := &proto.CapabilitiesRequest{}
resp, err := d.client.Capabilities(context.Background(), req)
resp, err := d.client.Capabilities(d.doneCtx, req)
if err != nil {
return nil, err
}
@@ -67,12 +71,15 @@ func (d *driverPluginClient) Capabilities() (*Capabilities, error) {
func (d *driverPluginClient) Fingerprint(ctx context.Context) (<-chan *Fingerprint, error) {
req := &proto.FingerprintRequest{}
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
stream, err := d.client.Fingerprint(ctx, req)
if err != nil {
return nil, err
}
ch := make(chan *Fingerprint)
ch := make(chan *Fingerprint, 1)
go d.handleFingerprint(ctx, ch, stream)
return ch, nil
@@ -82,17 +89,18 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin
defer close(ch)
for {
pb, err := stream.Recv()
if err == io.EOF {
return
}
if err != nil {
select {
case <-ctx.Done():
case ch <- &Fingerprint{Err: fmt.Errorf("error from RPC stream: %v", err)}:
if err != io.EOF {
d.logger.Error("error receiving stream from Fingerprint driver RPC", "error", err)
ch <- &Fingerprint{
Err: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
return
}
f := &Fingerprint{
Attributes: pb.Attributes,
Health: healthStateFromProto(pb.Health),
@@ -112,7 +120,7 @@ func (d *driverPluginClient) handleFingerprint(ctx context.Context, ch chan *Fin
func (d *driverPluginClient) RecoverTask(h *TaskHandle) error {
req := &proto.RecoverTaskRequest{Handle: taskHandleToProto(h)}
_, err := d.client.RecoverTask(context.Background(), req)
_, err := d.client.RecoverTask(d.doneCtx, req)
return err
}
@@ -124,7 +132,7 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr
Task: taskConfigToProto(c),
}
resp, err := d.client.StartTask(context.Background(), req)
resp, err := d.client.StartTask(d.doneCtx, req)
if err != nil {
return nil, nil, err
}
@@ -150,6 +158,10 @@ func (d *driverPluginClient) StartTask(c *TaskConfig) (*TaskHandle, *cstructs.Dr
// the same task without issue.
func (d *driverPluginClient) WaitTask(ctx context.Context, id string) (<-chan *ExitResult, error) {
ch := make(chan *ExitResult)
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
go d.handleWaitTask(ctx, id, ch)
return ch, nil
}
@@ -186,7 +198,7 @@ func (d *driverPluginClient) StopTask(taskID string, timeout time.Duration, sign
Signal: signal,
}
_, err := d.client.StopTask(context.Background(), req)
_, err := d.client.StopTask(d.doneCtx, req)
return err
}
@@ -199,7 +211,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
Force: force,
}
_, err := d.client.DestroyTask(context.Background(), req)
_, err := d.client.DestroyTask(d.doneCtx, req)
return err
}
@@ -207,7 +219,7 @@ func (d *driverPluginClient) DestroyTask(taskID string, force bool) error {
func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
req := &proto.InspectTaskRequest{TaskId: taskID}
resp, err := d.client.InspectTask(context.Background(), req)
resp, err := d.client.InspectTask(d.doneCtx, req)
if err != nil {
return nil, err
}
@@ -238,7 +250,7 @@ func (d *driverPluginClient) InspectTask(taskID string) (*TaskStatus, error) {
func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsage, error) {
req := &proto.TaskStatsRequest{TaskId: taskID}
resp, err := d.client.TaskStats(context.Background(), req)
resp, err := d.client.TaskStats(d.doneCtx, req)
if err != nil {
return nil, err
}
@@ -255,28 +267,36 @@ func (d *driverPluginClient) TaskStats(taskID string) (*cstructs.TaskResourceUsa
// tasks such as lifecycle events, terminal errors, etc.
func (d *driverPluginClient) TaskEvents(ctx context.Context) (<-chan *TaskEvent, error) {
req := &proto.TaskEventsRequest{}
// Join the passed context and the shutdown context
ctx, _ = joincontext.Join(ctx, d.doneCtx)
stream, err := d.client.TaskEvents(ctx, req)
if err != nil {
return nil, err
}
ch := make(chan *TaskEvent)
go d.handleTaskEvents(ch, stream)
ch := make(chan *TaskEvent, 1)
go d.handleTaskEvents(ctx, ch, stream)
return ch, nil
}
func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
func (d *driverPluginClient) handleTaskEvents(ctx context.Context, ch chan *TaskEvent, stream proto.Driver_TaskEventsClient) {
defer close(ch)
for {
ev, err := stream.Recv()
if err == io.EOF {
break
}
if err != nil {
d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err)
ch <- &TaskEvent{Err: err}
break
if err != io.EOF {
d.logger.Error("error receiving stream from TaskEvents driver RPC", "error", err)
ch <- &TaskEvent{
Err: shared.HandleStreamErr(err, ctx, d.doneCtx),
}
}
// End the stream
return
}
timestamp, _ := ptypes.Timestamp(ev.Timestamp)
event := &TaskEvent{
TaskID: ev.TaskId,
@@ -284,7 +304,11 @@ func (d *driverPluginClient) handleTaskEvents(ch chan *TaskEvent, stream proto.D
Message: ev.Message,
Timestamp: timestamp,
}
ch <- event
select {
case <-ctx.Done():
return
case ch <- event:
}
}
}
@@ -294,7 +318,7 @@ func (d *driverPluginClient) SignalTask(taskID string, signal string) error {
TaskId: taskID,
Signal: signal,
}
_, err := d.client.SignalTask(context.Background(), req)
_, err := d.client.SignalTask(d.doneCtx, req)
return err
}
@@ -309,7 +333,7 @@ func (d *driverPluginClient) ExecTask(taskID string, cmd []string, timeout time.
Timeout: ptypes.DurationProto(timeout),
}
resp, err := d.client.ExecTask(context.Background(), req)
resp, err := d.client.ExecTask(d.doneCtx, req)
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,7 @@
package drivers
import (
"context"
"fmt"
"path/filepath"
"sort"
@@ -14,7 +15,6 @@ import (
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/zclconf/go-cty/cty"
"github.com/zclconf/go-cty/cty/msgpack"
"golang.org/x/net/context"
)
// DriverPlugin is the interface with drivers will implement. It is also

View File

@@ -38,9 +38,11 @@ func (p *PluginDriver) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) err
func (p *PluginDriver) GRPCClient(ctx context.Context, broker *plugin.GRPCBroker, c *grpc.ClientConn) (interface{}, error) {
return &driverPluginClient{
BasePluginClient: &base.BasePluginClient{
Client: baseproto.NewBasePluginClient(c),
DoneCtx: ctx,
Client: baseproto.NewBasePluginClient(c),
},
client: proto.NewDriverClient(c),
logger: p.logger,
client: proto.NewDriverClient(c),
logger: p.logger,
doneCtx: ctx,
}, nil
}

View File

@@ -2,6 +2,7 @@ package drivers
import (
"bytes"
"context"
"sync"
"testing"
"time"
@@ -10,7 +11,6 @@ import (
"github.com/hashicorp/nomad/nomad/structs"
"github.com/stretchr/testify/require"
"github.com/ugorji/go/codec"
"golang.org/x/net/context"
)
type testDriverState struct {

View File

@@ -4,13 +4,12 @@ import (
"fmt"
"io"
"golang.org/x/net/context"
"github.com/golang/protobuf/ptypes"
hclog "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/plugins/drivers/proto"
context "golang.org/x/net/context"
)
type driverPluginServer struct {

View File

@@ -1,16 +1,13 @@
package drivers
import (
"context"
"fmt"
"io/ioutil"
"path/filepath"
"runtime"
"time"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
"golang.org/x/net/context"
hclog "github.com/hashicorp/go-hclog"
plugin "github.com/hashicorp/go-plugin"
"github.com/hashicorp/nomad/client/allocdir"
@@ -21,6 +18,8 @@ import (
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/plugins/base"
"github.com/hashicorp/nomad/plugins/shared/hclspec"
"github.com/mitchellh/go-testing-interface"
"github.com/stretchr/testify/require"
)
type DriverHarness struct {

View File

@@ -0,0 +1,61 @@
package shared
import (
"context"
"time"
"github.com/hashicorp/nomad/plugins/base"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
// HandleStreamErr is used to handle a non io.EOF error in a stream. It handles
// detecting if the plugin has shutdown via the passeed pluginCtx. The
// parameters are:
// - err: the error returned from the streaming RPC
// - reqCtx: the context passed to the streaming request
// - pluginCtx: the plugins done ctx used to detect the plugin dying
//
// The return values are:
// - base.ErrPluginShutdown if the error is because the plugin shutdown
// - context.Canceled if the reqCtx is canceled
// - The original error
func HandleStreamErr(err error, reqCtx, pluginCtx context.Context) error {
if err == nil {
return nil
}
// Determine if the error is because the plugin shutdown
if errStatus, ok := status.FromError(err); ok && errStatus.Code() == codes.Unavailable {
// Potentially wait a little before returning an error so we can detect
// the exit
select {
case <-pluginCtx.Done():
err = base.ErrPluginShutdown
case <-reqCtx.Done():
err = reqCtx.Err()
// There is no guarantee that the select will choose the
// doneCtx first so we have to double check
select {
case <-pluginCtx.Done():
err = base.ErrPluginShutdown
default:
}
case <-time.After(3 * time.Second):
// Its okay to wait a while since the connection isn't available and
// on local host it is likely shutting down. It is not expected for
// this to ever reach even close to 3 seconds.
}
// It is an error we don't know how to handle, so return it
return err
}
// Context was cancelled
if errStatus := status.FromContextError(reqCtx.Err()); errStatus.Code() == codes.Canceled {
return context.Canceled
}
return err
}