server: server forwarding logic for nomad exec endpoint

This commit is contained in:
Mahmood Ali
2019-04-28 17:25:27 -04:00
parent 979a6a1778
commit a77a3ba9b0
4 changed files with 356 additions and 40 deletions

View File

@@ -2,11 +2,16 @@ package nomad
import (
"errors"
"fmt"
"io"
"net"
"time"
metrics "github.com/armon/go-metrics"
log "github.com/hashicorp/go-hclog"
cstructs "github.com/hashicorp/nomad/client/structs"
"github.com/hashicorp/nomad/helper"
"github.com/ugorji/go/codec"
"github.com/hashicorp/nomad/acl"
"github.com/hashicorp/nomad/nomad/structs"
@@ -19,6 +24,10 @@ type ClientAllocations struct {
logger log.Logger
}
func (a *ClientAllocations) register() {
a.srv.streamingRpcs.Register("Allocations.Exec", a.exec)
}
// GarbageCollectAll is used to garbage collect all allocations on a client.
func (a *ClientAllocations) GarbageCollectAll(args *structs.NodeSpecificRequest, reply *structs.GenericResponse) error {
// We only allow stale reads since the only potentially stale information is
@@ -287,3 +296,125 @@ func (a *ClientAllocations) Stats(args *cstructs.AllocStatsRequest, reply *cstru
// Make the RPC
return NodeRpc(state.Session, "Allocations.Stats", args, reply)
}
// exec is used to execute command in a running task
func (a *ClientAllocations) exec(conn io.ReadWriteCloser) {
defer conn.Close()
defer metrics.MeasureSince([]string{"nomad", "alloc", "exec"}, time.Now())
// Decode the arguments
var args cstructs.AllocExecRequest
decoder := codec.NewDecoder(conn, structs.MsgpackHandle)
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
if err := decoder.Decode(&args); err != nil {
handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
return
}
// Check if we need to forward to a different region
if r := args.RequestRegion(); r != a.srv.Region() {
forwardRegionStreamingRpc(a.srv, conn, encoder, &args, "Allocations.Exec",
args.AllocID, &args.QueryOptions)
return
}
// Check node read permissions
if aclObj, err := a.srv.ResolveToken(args.AuthToken); err != nil {
handleStreamResultError(err, nil, encoder)
return
} else if aclObj != nil {
// client ultimately checks if AllocNodeExec is required
exec := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityAllocExec)
if !exec {
handleStreamResultError(structs.ErrPermissionDenied, nil, encoder)
return
}
}
// Verify the arguments.
if args.AllocID == "" {
handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder)
return
}
// Retrieve the allocation
snap, err := a.srv.State().Snapshot()
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}
alloc, err := snap.AllocByID(nil, args.AllocID)
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}
if alloc == nil {
handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder)
return
}
nodeID := alloc.NodeID
// Make sure Node is valid and new enough to support RPC
node, err := snap.NodeByID(nil, nodeID)
if err != nil {
handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
return
}
if node == nil {
err := fmt.Errorf("Unknown node %q", nodeID)
handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
return
}
if err := nodeSupportsRpc(node); err != nil {
handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
return
}
// Get the connection to the client either by forwarding to another server
// or creating a direct stream
var clientConn net.Conn
state, ok := a.srv.getNodeConn(nodeID)
if !ok {
// Determine the Server that has a connection to the node.
srv, err := a.srv.serverWithNodeConn(nodeID, a.srv.Region())
if err != nil {
var code *int64
if structs.IsErrNoNodeConn(err) {
code = helper.Int64ToPtr(404)
}
handleStreamResultError(err, code, encoder)
return
}
// Get a connection to the server
conn, err := a.srv.streamingRpc(srv, "Allocations.Exec")
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}
clientConn = conn
} else {
stream, err := NodeStreamingRpc(state.Session, "Allocations.Exec")
if err != nil {
handleStreamResultError(err, nil, encoder)
return
}
clientConn = stream
}
defer clientConn.Close()
// Send the request.
outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
handleStreamResultError(err, nil, encoder)
return
}
structs.Bridge(conn, clientConn)
return
}

View File

@@ -1,8 +1,13 @@
package nomad
import (
"encoding/json"
"fmt"
"io"
"net"
"strings"
"testing"
"time"
msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
"github.com/hashicorp/nomad/acl"
@@ -12,9 +17,12 @@ import (
"github.com/hashicorp/nomad/helper/uuid"
"github.com/hashicorp/nomad/nomad/mock"
"github.com/hashicorp/nomad/nomad/structs"
nstructs "github.com/hashicorp/nomad/nomad/structs"
"github.com/hashicorp/nomad/plugins/drivers"
"github.com/hashicorp/nomad/testutil"
"github.com/kr/pretty"
"github.com/stretchr/testify/require"
"github.com/ugorji/go/codec"
)
func TestClientAllocations_GarbageCollectAll_Local(t *testing.T) {
@@ -1040,3 +1048,179 @@ func TestClientAllocations_Restart_ACL(t *testing.T) {
})
}
}
// TestAlloc_ExecStreaming asserts that exec task requests are forwarded
// to appropriate server or remote regions
func TestAlloc_ExecStreaming(t *testing.T) {
t.Skip("try skipping")
t.Parallel()
////// Nomad clusters topology - not specific to test
localServer := TestServer(t, nil)
defer localServer.Shutdown()
remoteServer := TestServer(t, func(c *Config) {
c.DevDisableBootstrap = true
})
defer remoteServer.Shutdown()
remoteRegionServer := TestServer(t, func(c *Config) {
c.Region = "two"
})
defer remoteRegionServer.Shutdown()
TestJoin(t, localServer, remoteServer)
TestJoin(t, localServer, remoteRegionServer)
testutil.WaitForLeader(t, localServer.RPC)
testutil.WaitForLeader(t, remoteServer.RPC)
testutil.WaitForLeader(t, remoteRegionServer.RPC)
c, cleanup := client.TestClient(t, func(c *config.Config) {
c.Servers = []string{localServer.config.RPCAddr.String()}
})
defer cleanup()
// Wait for the client to connect
testutil.WaitForResult(func() (bool, error) {
nodes := remoteServer.connectedNodes()
return len(nodes) == 1, nil
}, func(err error) {
require.NoError(t, err, "failed to have a client")
})
// Force remove the connection locally in case it exists
remoteServer.nodeConnsLock.Lock()
delete(remoteServer.nodeConns, c.NodeID())
remoteServer.nodeConnsLock.Unlock()
///// Start task
a := mock.BatchAlloc()
a.NodeID = c.NodeID()
a.Job.Type = structs.JobTypeBatch
a.Job.TaskGroups[0].Count = 1
a.Job.TaskGroups[0].Tasks[0].Config = map[string]interface{}{
"run_for": "20s",
"exec_command": map[string]interface{}{
"run_for": "1ms",
"stdout_string": "expected output",
"exit_code": 3,
},
}
// Upsert the allocation
localState := localServer.State()
require.Nil(t, localState.UpsertJob(999, a.Job))
require.Nil(t, localState.UpsertAllocs(1003, []*structs.Allocation{a}))
remoteState := remoteServer.State()
require.Nil(t, remoteState.UpsertJob(999, a.Job))
require.Nil(t, remoteState.UpsertAllocs(1003, []*structs.Allocation{a}))
// Wait for the client to run the allocation
testutil.WaitForResult(func() (bool, error) {
alloc, err := localState.AllocByID(nil, a.ID)
if err != nil {
return false, err
}
if alloc == nil {
return false, fmt.Errorf("unknown alloc")
}
if alloc.ClientStatus != structs.AllocClientStatusRunning {
return false, fmt.Errorf("alloc client status: %v", alloc.ClientStatus)
}
return true, nil
}, func(err error) {
require.NoError(t, err, "task didn't start yet")
})
///////// Actually run query now
cases := []struct {
name string
rpc func(string) (structs.StreamingRpcHandler, error)
}{
{"client", c.StreamingRpcHandler},
{"local_server", localServer.StreamingRpcHandler},
{"remote_server", remoteServer.StreamingRpcHandler},
{"remote_region", remoteRegionServer.StreamingRpcHandler},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
// Make the request
req := &cstructs.AllocExecRequest{
AllocID: a.ID,
Task: a.Job.TaskGroups[0].Tasks[0].Name,
Tty: true,
Cmd: []string{"placeholder command"},
QueryOptions: nstructs.QueryOptions{Region: "global"},
}
// Get the handler
handler, err := tc.rpc("Allocations.Exec")
require.Nil(t, err)
// Create a pipe
p1, p2 := net.Pipe()
defer p1.Close()
defer p2.Close()
errCh := make(chan error)
frames := make(chan *drivers.ExecTaskStreamingResponseMsg)
// Start the handler
go handler(p2)
go decodeFrames(t, p1, frames, errCh)
// Send the request
encoder := codec.NewEncoder(p1, nstructs.MsgpackHandle)
require.Nil(t, encoder.Encode(req))
timeout := time.After(3 * time.Second)
OUTER:
for {
select {
case <-timeout:
require.FailNow(t, "timed out before getting exit code")
case err := <-errCh:
require.NoError(t, err)
case f := <-frames:
if f.Exited && f.Result != nil {
code := int(f.Result.ExitCode)
require.Equal(t, 3, code)
break OUTER
}
}
}
})
}
}
func decodeFrames(t *testing.T, p1 net.Conn, frames chan<- *drivers.ExecTaskStreamingResponseMsg, errCh chan<- error) {
// Start the decoder
decoder := codec.NewDecoder(p1, nstructs.MsgpackHandle)
for {
var msg cstructs.StreamErrWrapper
if err := decoder.Decode(&msg); err != nil {
if err == io.EOF || strings.Contains(err.Error(), "closed") {
return
}
t.Logf("received error decoding: %#v", err)
errCh <- fmt.Errorf("error decoding: %v", err)
return
}
if msg.Error != nil {
errCh <- msg.Error
continue
}
var frame drivers.ExecTaskStreamingResponseMsg
json.Unmarshal(msg.Payload, &frame)
t.Logf("received message: %#v", msg)
frames <- &frame
}
}

View File

@@ -33,7 +33,7 @@ func (f *FileSystem) register() {
// handleStreamResultError is a helper for sending an error with a potential
// error code. The transmission of the error is ignored if the error has been
// generated by the closing of the underlying transport.
func (f *FileSystem) handleStreamResultError(err error, code *int64, encoder *codec.Encoder) {
func handleStreamResultError(err error, code *int64, encoder *codec.Encoder) {
// Nothing to do as the conn is closed
if err == io.EOF || strings.Contains(err.Error(), "closed") {
return
@@ -48,7 +48,7 @@ func (f *FileSystem) handleStreamResultError(err error, code *int64, encoder *co
// forwardRegionStreamingRpc is used to make a streaming RPC to a different
// region. It looks up the allocation in the remote region to determine what
// remote server can route the request.
func (f *FileSystem) forwardRegionStreamingRpc(conn io.ReadWriteCloser,
func forwardRegionStreamingRpc(fsrv *Server, conn io.ReadWriteCloser,
encoder *codec.Encoder, args interface{}, method, allocID string, qo *structs.QueryOptions) {
// Request the allocation from the target region
allocReq := &structs.AllocSpecificRequest{
@@ -56,31 +56,31 @@ func (f *FileSystem) forwardRegionStreamingRpc(conn io.ReadWriteCloser,
QueryOptions: *qo,
}
var allocResp structs.SingleAllocResponse
if err := f.srv.forwardRegion(qo.RequestRegion(), "Alloc.GetAlloc", allocReq, &allocResp); err != nil {
f.handleStreamResultError(err, nil, encoder)
if err := fsrv.forwardRegion(qo.RequestRegion(), "Alloc.GetAlloc", allocReq, &allocResp); err != nil {
handleStreamResultError(err, nil, encoder)
return
}
if allocResp.Alloc == nil {
f.handleStreamResultError(structs.NewErrUnknownAllocation(allocID), helper.Int64ToPtr(404), encoder)
handleStreamResultError(structs.NewErrUnknownAllocation(allocID), helper.Int64ToPtr(404), encoder)
return
}
// Determine the Server that has a connection to the node.
srv, err := f.srv.serverWithNodeConn(allocResp.Alloc.NodeID, qo.RequestRegion())
srv, err := fsrv.serverWithNodeConn(allocResp.Alloc.NodeID, qo.RequestRegion())
if err != nil {
var code *int64
if structs.IsErrNoNodeConn(err) {
code = helper.Int64ToPtr(404)
}
f.handleStreamResultError(err, code, encoder)
handleStreamResultError(err, code, encoder)
return
}
// Get a connection to the server
srvConn, err := f.srv.streamingRpc(srv, method)
srvConn, err := fsrv.streamingRpc(srv, method)
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
defer srvConn.Close()
@@ -88,7 +88,7 @@ func (f *FileSystem) forwardRegionStreamingRpc(conn io.ReadWriteCloser,
// Send the request.
outEncoder := codec.NewEncoder(srvConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
@@ -217,46 +217,46 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) {
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
if err := decoder.Decode(&args); err != nil {
f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
return
}
// Check if we need to forward to a different region
if r := args.RequestRegion(); r != f.srv.Region() {
f.forwardRegionStreamingRpc(conn, encoder, &args, "FileSystem.Stream",
forwardRegionStreamingRpc(f.srv, conn, encoder, &args, "FileSystem.Stream",
args.AllocID, &args.QueryOptions)
return
}
// Check node read permissions
if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
} else if aclObj != nil && !aclObj.AllowNsOp(args.Namespace, acl.NamespaceCapabilityReadFS) {
f.handleStreamResultError(structs.ErrPermissionDenied, nil, encoder)
handleStreamResultError(structs.ErrPermissionDenied, nil, encoder)
return
}
// Verify the arguments.
if args.AllocID == "" {
f.handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder)
handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder)
return
}
// Retrieve the allocation
snap, err := f.srv.State().Snapshot()
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
alloc, err := snap.AllocByID(nil, args.AllocID)
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
if alloc == nil {
f.handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder)
handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder)
return
}
nodeID := alloc.NodeID
@@ -264,18 +264,18 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) {
// Make sure Node is valid and new enough to support RPC
node, err := snap.NodeByID(nil, nodeID)
if err != nil {
f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
return
}
if node == nil {
err := fmt.Errorf("Unknown node %q", nodeID)
f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
return
}
if err := nodeSupportsRpc(node); err != nil {
f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
return
}
@@ -291,14 +291,14 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) {
if structs.IsErrNoNodeConn(err) {
code = helper.Int64ToPtr(404)
}
f.handleStreamResultError(err, code, encoder)
handleStreamResultError(err, code, encoder)
return
}
// Get a connection to the server
conn, err := f.srv.streamingRpc(srv, "FileSystem.Stream")
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
@@ -306,7 +306,7 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) {
} else {
stream, err := NodeStreamingRpc(state.Session, "FileSystem.Stream")
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
clientConn = stream
@@ -316,7 +316,7 @@ func (f *FileSystem) stream(conn io.ReadWriteCloser) {
// Send the request.
outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
@@ -335,50 +335,50 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) {
encoder := codec.NewEncoder(conn, structs.MsgpackHandle)
if err := decoder.Decode(&args); err != nil {
f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
return
}
// Check if we need to forward to a different region
if r := args.RequestRegion(); r != f.srv.Region() {
f.forwardRegionStreamingRpc(conn, encoder, &args, "FileSystem.Logs",
forwardRegionStreamingRpc(f.srv, conn, encoder, &args, "FileSystem.Logs",
args.AllocID, &args.QueryOptions)
return
}
// Check node read permissions
if aclObj, err := f.srv.ResolveToken(args.AuthToken); err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
} else if aclObj != nil {
readfs := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityReadFS)
logs := aclObj.AllowNsOp(args.QueryOptions.Namespace, acl.NamespaceCapabilityReadLogs)
if !readfs && !logs {
f.handleStreamResultError(structs.ErrPermissionDenied, nil, encoder)
handleStreamResultError(structs.ErrPermissionDenied, nil, encoder)
return
}
}
// Verify the arguments.
if args.AllocID == "" {
f.handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder)
handleStreamResultError(errors.New("missing AllocID"), helper.Int64ToPtr(400), encoder)
return
}
// Retrieve the allocation
snap, err := f.srv.State().Snapshot()
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
alloc, err := snap.AllocByID(nil, args.AllocID)
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
if alloc == nil {
f.handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder)
handleStreamResultError(structs.NewErrUnknownAllocation(args.AllocID), helper.Int64ToPtr(404), encoder)
return
}
nodeID := alloc.NodeID
@@ -386,18 +386,18 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) {
// Make sure Node is valid and new enough to support RPC
node, err := snap.NodeByID(nil, nodeID)
if err != nil {
f.handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
handleStreamResultError(err, helper.Int64ToPtr(500), encoder)
return
}
if node == nil {
err := fmt.Errorf("Unknown node %q", nodeID)
f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
return
}
if err := nodeSupportsRpc(node); err != nil {
f.handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
handleStreamResultError(err, helper.Int64ToPtr(400), encoder)
return
}
@@ -413,14 +413,14 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) {
if structs.IsErrNoNodeConn(err) {
code = helper.Int64ToPtr(404)
}
f.handleStreamResultError(err, code, encoder)
handleStreamResultError(err, code, encoder)
return
}
// Get a connection to the server
conn, err := f.srv.streamingRpc(srv, "FileSystem.Logs")
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
@@ -428,7 +428,7 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) {
} else {
stream, err := NodeStreamingRpc(state.Session, "FileSystem.Logs")
if err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}
clientConn = stream
@@ -438,7 +438,7 @@ func (f *FileSystem) logs(conn io.ReadWriteCloser) {
// Send the request.
outEncoder := codec.NewEncoder(clientConn, structs.MsgpackHandle)
if err := outEncoder.Encode(args); err != nil {
f.handleStreamResultError(err, nil, encoder)
handleStreamResultError(err, nil, encoder)
return
}

View File

@@ -1027,6 +1027,7 @@ func (s *Server) setupRpcServer(server *rpc.Server, ctx *RPCContext) {
// Client endpoints
s.staticEndpoints.ClientStats = &ClientStats{srv: s, logger: s.logger.Named("client_stats")}
s.staticEndpoints.ClientAllocations = &ClientAllocations{srv: s, logger: s.logger.Named("client_allocs")}
s.staticEndpoints.ClientAllocations.register()
// Streaming endpoints
s.staticEndpoints.FileSystem = &FileSystem{srv: s, logger: s.logger.Named("client_fs")}