Merge pull request #10657 from hashicorp/b-alloc-exec-closing

Handle `nomad exec` termination events in order
This commit is contained in:
Mahmood Ali
2021-05-25 14:50:58 -04:00
committed by GitHub
6 changed files with 497 additions and 397 deletions

View File

@@ -2,16 +2,10 @@ package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
)
var (
@@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context,
stdin io.Reader, stdout, stderr io.Writer,
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
s := &execSession{
client: a.client,
alloc: alloc,
task: task,
tty: tty,
command: command,
errCh := make(chan error, 4)
stdin: stdin,
stdout: stdout,
stderr: stderr,
sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)
select {
case err := <-errCh:
return -2, err
default:
terminalSizeCh: terminalSizeCh,
q: q,
}
// Errors resulting from sending input (in goroutines) are silently dropped.
// To mitigate this, extra care is needed to distinguish between actual send errors
// and from send errors due to command terminating and our race to detect failures.
// If we have an actual network failure or send a bad input, we'd get an
// error in the reading side of websocket.
go func() {
bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
n, err := stdin.Read(bytes)
// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
sender(&input)
}
// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
sender(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()
// forwarding terminal size
go func() {
for {
resizeInput := ExecStreamingInput{}
select {
case <-ctx.Done():
return
case size, ok := <-terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
sender(&resizeInput)
}
}
}()
// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
sender(&execStreamingInputHeartbeat)
}
}
}()
for {
select {
case err := <-errCh:
// drop websocket code, not relevant to user
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, err
case <-ctx.Done():
return -2, ctx.Err()
case frame, ok := <-output:
if !ok {
return -2, errors.New("disconnected without receiving the exit code")
}
switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
return frame.Result.ExitCode, nil
default:
// noop - heartbeat
}
}
}
}
func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) {
nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}
commandBytes, err := json.Marshal(command)
if err != nil {
errCh <- fmt.Errorf("failed to marshal command: %s", err)
return nil, nil
}
q.Params["tty"] = strconv.FormatBool(tty)
q.Params["task"] = task
q.Params["command"] = string(commandBytes)
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)
var conn *websocket.Conn
if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}
if conn == nil {
conn, _, err = a.client.websocket(reqPath, q)
if err != nil {
errCh <- err
return nil, nil
}
}
// Create the output channel
frames := make(chan *ExecStreamingOutput, 10)
go func() {
defer conn.Close()
for ctx.Err() == nil {
// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
close(frames)
return
} else if err != nil {
errCh <- err
return
}
frames <- &frame
}
}()
var sendLock sync.Mutex
send := func(v *ExecStreamingInput) error {
sendLock.Lock()
defer sendLock.Unlock()
return conn.WriteJSON(v)
}
return send, frames
return s.run(ctx)
}
func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {

236
api/allocations_exec.go Normal file
View File

@@ -0,0 +1,236 @@
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
)
type execSession struct {
client *Client
alloc *Allocation
task string
tty bool
command []string
stdin io.Reader
stdout io.Writer
stderr io.Writer
terminalSizeCh <-chan TerminalSize
q *QueryOptions
}
func (s *execSession) run(ctx context.Context) (exitCode int, err error) {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
conn, err := s.startConnection()
if err != nil {
return -2, err
}
defer conn.Close()
sendErrCh := s.startTransmit(ctx, conn)
exitCh, recvErrCh := s.startReceiving(ctx, conn)
for {
select {
case <-ctx.Done():
return -2, ctx.Err()
case exitCode := <-exitCh:
return exitCode, nil
case recvErr := <-recvErrCh:
// drop websocket code, not relevant to user
if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, recvErr
case sendErr := <-sendErrCh:
return -2, fmt.Errorf("failed to send input: %w", sendErr)
}
}
}
func (s *execSession) startConnection() (*websocket.Conn, error) {
// First, attempt to connect to the node directly, but may fail due to network isolation
// and network errors. Fallback to using server-side forwarding instead.
nodeClient, err := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q)
if err == NodeDownErr {
return nil, NodeDownErr
}
q := s.q
if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}
commandBytes, err := json.Marshal(s.command)
if err != nil {
return nil, fmt.Errorf("failed to marshal command: %W", err)
}
q.Params["tty"] = strconv.FormatBool(s.tty)
q.Params["task"] = s.task
q.Params["command"] = string(commandBytes)
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID)
var conn *websocket.Conn
if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}
if conn == nil {
conn, _, err = s.client.websocket(reqPath, q)
if err != nil {
return nil, err
}
}
return conn, nil
}
func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error {
// FIXME: Handle websocket send errors.
// Currently, websocket write failures are dropped. As sending and
// receiving are running concurrently, it's expected that some send
// requests may fail with connection errors when connection closes.
// Connection errors should surface in the receive paths already,
// but I'm unsure about one-sided communication errors.
var sendLock sync.Mutex
send := func(v *ExecStreamingInput) {
sendLock.Lock()
defer sendLock.Unlock()
conn.WriteJSON(v)
}
errCh := make(chan error, 4)
// propagate stdin
go func() {
bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
n, err := s.stdin.Read(bytes)
// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
send(&input)
}
// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
send(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()
// propagate terminal sizing updates
go func() {
for {
resizeInput := ExecStreamingInput{}
select {
case <-ctx.Done():
return
case size, ok := <-s.terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
send(&resizeInput)
}
}
}()
// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
send(&execStreamingInputHeartbeat)
}
}
}()
return errCh
}
func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) {
exitCodeCh := make(chan int, 1)
errCh := make(chan error, 1)
go func() {
for ctx.Err() == nil {
// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err)
return
} else if err != nil {
errCh <- err
return
}
switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
s.stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
s.stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
exitCodeCh <- frame.Result.ExitCode
return
default:
// noop - heartbeat
}
}
}()
return exitCodeCh, errCh
}

View File

@@ -515,13 +515,6 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
go forwardExecInput(encoder, ws, errCh)
for {
select {
case <-ctx.Done():
errCh <- nil
return
default:
}
var res cstructs.StreamErrWrapper
err := decoder.Decode(&res)
if isClosedError(err) {

View File

@@ -7,7 +7,6 @@ import (
"io"
"reflect"
"regexp"
"strings"
"testing"
"time"
@@ -90,13 +89,7 @@ func (tc *NomadExecE2ETest) TestExecBasicResponses(f *framework.F) {
stdin, &stdout, &stderr,
resizeCh, nil)
// TODO: Occasionally, we get "Unexpected EOF" error, but with the correct output.
// investigate why
if err != nil && strings.Contains(err.Error(), io.ErrUnexpectedEOF.Error()) {
f.T().Logf("got unexpected EOF error, ignoring: %v", err)
} else {
assert.NoError(t, err)
}
assert.NoError(t, err)
assert.Equal(t, c.ExitCode, exitCode)

View File

@@ -2,16 +2,10 @@ package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
)
var (
@@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context,
stdin io.Reader, stdout, stderr io.Writer,
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
s := &execSession{
client: a.client,
alloc: alloc,
task: task,
tty: tty,
command: command,
errCh := make(chan error, 4)
stdin: stdin,
stdout: stdout,
stderr: stderr,
sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)
select {
case err := <-errCh:
return -2, err
default:
terminalSizeCh: terminalSizeCh,
q: q,
}
// Errors resulting from sending input (in goroutines) are silently dropped.
// To mitigate this, extra care is needed to distinguish between actual send errors
// and from send errors due to command terminating and our race to detect failures.
// If we have an actual network failure or send a bad input, we'd get an
// error in the reading side of websocket.
go func() {
bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
n, err := stdin.Read(bytes)
// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
sender(&input)
}
// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
sender(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()
// forwarding terminal size
go func() {
for {
resizeInput := ExecStreamingInput{}
select {
case <-ctx.Done():
return
case size, ok := <-terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
sender(&resizeInput)
}
}
}()
// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
sender(&execStreamingInputHeartbeat)
}
}
}()
for {
select {
case err := <-errCh:
// drop websocket code, not relevant to user
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, err
case <-ctx.Done():
return -2, ctx.Err()
case frame, ok := <-output:
if !ok {
return -2, errors.New("disconnected without receiving the exit code")
}
switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
return frame.Result.ExitCode, nil
default:
// noop - heartbeat
}
}
}
}
func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) {
nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}
commandBytes, err := json.Marshal(command)
if err != nil {
errCh <- fmt.Errorf("failed to marshal command: %s", err)
return nil, nil
}
q.Params["tty"] = strconv.FormatBool(tty)
q.Params["task"] = task
q.Params["command"] = string(commandBytes)
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)
var conn *websocket.Conn
if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}
if conn == nil {
conn, _, err = a.client.websocket(reqPath, q)
if err != nil {
errCh <- err
return nil, nil
}
}
// Create the output channel
frames := make(chan *ExecStreamingOutput, 10)
go func() {
defer conn.Close()
for ctx.Err() == nil {
// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
close(frames)
return
} else if err != nil {
errCh <- err
return
}
frames <- &frame
}
}()
var sendLock sync.Mutex
send := func(v *ExecStreamingInput) error {
sendLock.Lock()
defer sendLock.Unlock()
return conn.WriteJSON(v)
}
return send, frames
return s.run(ctx)
}
func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {

View File

@@ -0,0 +1,236 @@
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
)
type execSession struct {
client *Client
alloc *Allocation
task string
tty bool
command []string
stdin io.Reader
stdout io.Writer
stderr io.Writer
terminalSizeCh <-chan TerminalSize
q *QueryOptions
}
func (s *execSession) run(ctx context.Context) (exitCode int, err error) {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
conn, err := s.startConnection()
if err != nil {
return -2, err
}
defer conn.Close()
sendErrCh := s.startTransmit(ctx, conn)
exitCh, recvErrCh := s.startReceiving(ctx, conn)
for {
select {
case <-ctx.Done():
return -2, ctx.Err()
case exitCode := <-exitCh:
return exitCode, nil
case recvErr := <-recvErrCh:
// drop websocket code, not relevant to user
if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, recvErr
case sendErr := <-sendErrCh:
return -2, fmt.Errorf("failed to send input: %w", sendErr)
}
}
}
func (s *execSession) startConnection() (*websocket.Conn, error) {
// First, attempt to connect to the node directly, but may fail due to network isolation
// and network errors. Fallback to using server-side forwarding instead.
nodeClient, err := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q)
if err == NodeDownErr {
return nil, NodeDownErr
}
q := s.q
if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}
commandBytes, err := json.Marshal(s.command)
if err != nil {
return nil, fmt.Errorf("failed to marshal command: %W", err)
}
q.Params["tty"] = strconv.FormatBool(s.tty)
q.Params["task"] = s.task
q.Params["command"] = string(commandBytes)
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID)
var conn *websocket.Conn
if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}
if conn == nil {
conn, _, err = s.client.websocket(reqPath, q)
if err != nil {
return nil, err
}
}
return conn, nil
}
func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error {
// FIXME: Handle websocket send errors.
// Currently, websocket write failures are dropped. As sending and
// receiving are running concurrently, it's expected that some send
// requests may fail with connection errors when connection closes.
// Connection errors should surface in the receive paths already,
// but I'm unsure about one-sided communication errors.
var sendLock sync.Mutex
send := func(v *ExecStreamingInput) {
sendLock.Lock()
defer sendLock.Unlock()
conn.WriteJSON(v)
}
errCh := make(chan error, 4)
// propagate stdin
go func() {
bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
n, err := s.stdin.Read(bytes)
// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
send(&input)
}
// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
send(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()
// propagate terminal sizing updates
go func() {
for {
resizeInput := ExecStreamingInput{}
select {
case <-ctx.Done():
return
case size, ok := <-s.terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
send(&resizeInput)
}
}
}()
// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
send(&execStreamingInputHeartbeat)
}
}
}()
return errCh
}
func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) {
exitCodeCh := make(chan int, 1)
errCh := make(chan error, 1)
go func() {
for ctx.Err() == nil {
// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err)
return
} else if err != nil {
errCh <- err
return
}
switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
s.stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
s.stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
exitCodeCh <- frame.Result.ExitCode
return
default:
// noop - heartbeat
}
}
}()
return exitCodeCh, errCh
}