exec: api: handle closing errors differently

refactor the api handling of `nomad exec`, and ensure that we process
all received events before handling websocket closing.

The exit code should be the last message received, and we ought to
ignore any websocket close error we receive afterwards.

Previously, we used two channels: one for websocket frames and another
for handling errors. This raised the possibility that we processed the
error before processing the frames, resulting into an "unexpected EOF"
error.
This commit is contained in:
Mahmood Ali
2021-05-24 14:52:00 -04:00
parent ab4b42f4f4
commit a15a61759e
4 changed files with 486 additions and 382 deletions

View File

@@ -2,16 +2,10 @@ package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
)
var (
@@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context,
stdin io.Reader, stdout, stderr io.Writer,
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
s := &execSession{
client: a.client,
alloc: alloc,
task: task,
tty: tty,
command: command,
errCh := make(chan error, 4)
stdin: stdin,
stdout: stdout,
stderr: stderr,
sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)
select {
case err := <-errCh:
return -2, err
default:
terminalSizeCh: terminalSizeCh,
q: q,
}
// Errors resulting from sending input (in goroutines) are silently dropped.
// To mitigate this, extra care is needed to distinguish between actual send errors
// and from send errors due to command terminating and our race to detect failures.
// If we have an actual network failure or send a bad input, we'd get an
// error in the reading side of websocket.
go func() {
bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
n, err := stdin.Read(bytes)
// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
sender(&input)
}
// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
sender(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()
// forwarding terminal size
go func() {
for {
resizeInput := ExecStreamingInput{}
select {
case <-ctx.Done():
return
case size, ok := <-terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
sender(&resizeInput)
}
}
}()
// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
sender(&execStreamingInputHeartbeat)
}
}
}()
for {
select {
case err := <-errCh:
// drop websocket code, not relevant to user
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, err
case <-ctx.Done():
return -2, ctx.Err()
case frame, ok := <-output:
if !ok {
return -2, errors.New("disconnected without receiving the exit code")
}
switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
return frame.Result.ExitCode, nil
default:
// noop - heartbeat
}
}
}
}
func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) {
nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}
commandBytes, err := json.Marshal(command)
if err != nil {
errCh <- fmt.Errorf("failed to marshal command: %s", err)
return nil, nil
}
q.Params["tty"] = strconv.FormatBool(tty)
q.Params["task"] = task
q.Params["command"] = string(commandBytes)
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)
var conn *websocket.Conn
if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}
if conn == nil {
conn, _, err = a.client.websocket(reqPath, q)
if err != nil {
errCh <- err
return nil, nil
}
}
// Create the output channel
frames := make(chan *ExecStreamingOutput, 10)
go func() {
defer conn.Close()
for ctx.Err() == nil {
// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
close(frames)
return
} else if err != nil {
errCh <- err
return
}
frames <- &frame
}
}()
var sendLock sync.Mutex
send := func(v *ExecStreamingInput) error {
sendLock.Lock()
defer sendLock.Unlock()
return conn.WriteJSON(v)
}
return send, frames
return s.run(ctx)
}
func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {

231
api/allocations_exec.go Normal file
View File

@@ -0,0 +1,231 @@
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
)
type execSession struct {
client *Client
alloc *Allocation
task string
tty bool
command []string
stdin io.Reader
stdout io.Writer
stderr io.Writer
terminalSizeCh <-chan TerminalSize
q *QueryOptions
}
func (s *execSession) run(ctx context.Context) (exitCode int, err error) {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
conn, err := s.startConnection()
if err != nil {
return -2, err
}
defer conn.Close()
sendErrCh := s.startTransmit(ctx, conn)
exitCh, recvErrCh := s.startReceiving(ctx, conn)
for {
select {
case <-ctx.Done():
return -2, ctx.Err()
case exitCode := <-exitCh:
return exitCode, nil
case recvErr := <-recvErrCh:
// drop websocket code, not relevant to user
if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, recvErr
case sendErr := <-sendErrCh:
return -2, fmt.Errorf("failed to send input: %w", sendErr)
}
}
}
func (s *execSession) startConnection() (*websocket.Conn, error) {
nodeClient, _ := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q)
q := s.q
if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}
commandBytes, err := json.Marshal(s.command)
if err != nil {
return nil, fmt.Errorf("failed to marshal command: %W", err)
}
q.Params["tty"] = strconv.FormatBool(s.tty)
q.Params["task"] = s.task
q.Params["command"] = string(commandBytes)
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID)
var conn *websocket.Conn
if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}
if conn == nil {
conn, _, err = s.client.websocket(reqPath, q)
if err != nil {
return nil, err
}
}
return conn, nil
}
func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error {
// FIXME: Handle websocket send errors.
// Currently, websocket write failures are dropped. As sending and
// receiving are running concurrently, it's expected that some send
// requests may fail with connection errors when connection closes.
// Connection errors should surface in the receive paths already,
// but I'm unsure about one-sided communication errors.
var sendLock sync.Mutex
send := func(v *ExecStreamingInput) {
sendLock.Lock()
defer sendLock.Unlock()
conn.WriteJSON(v)
}
errCh := make(chan error, 4)
// propagate stdin
go func() {
bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
n, err := s.stdin.Read(bytes)
// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
send(&input)
}
// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
send(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()
// propagate terminal sizing updates
go func() {
for {
resizeInput := ExecStreamingInput{}
select {
case <-ctx.Done():
return
case size, ok := <-s.terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
send(&resizeInput)
}
}
}()
// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
send(&execStreamingInputHeartbeat)
}
}
}()
return errCh
}
func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) {
exitCodeCh := make(chan int, 1)
errCh := make(chan error, 1)
go func() {
for ctx.Err() == nil {
// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err)
return
} else if err != nil {
errCh <- err
return
}
switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
s.stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
s.stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
exitCodeCh <- frame.Result.ExitCode
return
default:
// noop - heartbeat
}
}
}()
return exitCodeCh, errCh
}

View File

@@ -2,16 +2,10 @@ package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"sort"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
)
var (
@@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context,
stdin io.Reader, stdout, stderr io.Writer,
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
s := &execSession{
client: a.client,
alloc: alloc,
task: task,
tty: tty,
command: command,
errCh := make(chan error, 4)
stdin: stdin,
stdout: stdout,
stderr: stderr,
sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)
select {
case err := <-errCh:
return -2, err
default:
terminalSizeCh: terminalSizeCh,
q: q,
}
// Errors resulting from sending input (in goroutines) are silently dropped.
// To mitigate this, extra care is needed to distinguish between actual send errors
// and from send errors due to command terminating and our race to detect failures.
// If we have an actual network failure or send a bad input, we'd get an
// error in the reading side of websocket.
go func() {
bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
n, err := stdin.Read(bytes)
// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
sender(&input)
}
// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
sender(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()
// forwarding terminal size
go func() {
for {
resizeInput := ExecStreamingInput{}
select {
case <-ctx.Done():
return
case size, ok := <-terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
sender(&resizeInput)
}
}
}()
// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
sender(&execStreamingInputHeartbeat)
}
}
}()
for {
select {
case err := <-errCh:
// drop websocket code, not relevant to user
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, err
case <-ctx.Done():
return -2, ctx.Err()
case frame, ok := <-output:
if !ok {
return -2, errors.New("disconnected without receiving the exit code")
}
switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
return frame.Result.ExitCode, nil
default:
// noop - heartbeat
}
}
}
}
func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) {
nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}
commandBytes, err := json.Marshal(command)
if err != nil {
errCh <- fmt.Errorf("failed to marshal command: %s", err)
return nil, nil
}
q.Params["tty"] = strconv.FormatBool(tty)
q.Params["task"] = task
q.Params["command"] = string(commandBytes)
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)
var conn *websocket.Conn
if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}
if conn == nil {
conn, _, err = a.client.websocket(reqPath, q)
if err != nil {
errCh <- err
return nil, nil
}
}
// Create the output channel
frames := make(chan *ExecStreamingOutput, 10)
go func() {
defer conn.Close()
for ctx.Err() == nil {
// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
close(frames)
return
} else if err != nil {
errCh <- err
return
}
frames <- &frame
}
}()
var sendLock sync.Mutex
send := func(v *ExecStreamingInput) error {
sendLock.Lock()
defer sendLock.Unlock()
return conn.WriteJSON(v)
}
return send, frames
return s.run(ctx)
}
func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {

View File

@@ -0,0 +1,231 @@
package api
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"strconv"
"sync"
"time"
"github.com/gorilla/websocket"
)
type execSession struct {
client *Client
alloc *Allocation
task string
tty bool
command []string
stdin io.Reader
stdout io.Writer
stderr io.Writer
terminalSizeCh <-chan TerminalSize
q *QueryOptions
}
func (s *execSession) run(ctx context.Context) (exitCode int, err error) {
ctx, cancelFn := context.WithCancel(ctx)
defer cancelFn()
conn, err := s.startConnection()
if err != nil {
return -2, err
}
defer conn.Close()
sendErrCh := s.startTransmit(ctx, conn)
exitCh, recvErrCh := s.startReceiving(ctx, conn)
for {
select {
case <-ctx.Done():
return -2, ctx.Err()
case exitCode := <-exitCh:
return exitCode, nil
case recvErr := <-recvErrCh:
// drop websocket code, not relevant to user
if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" {
return -2, errors.New(wsErr.Text)
}
return -2, recvErr
case sendErr := <-sendErrCh:
return -2, fmt.Errorf("failed to send input: %w", sendErr)
}
}
}
func (s *execSession) startConnection() (*websocket.Conn, error) {
nodeClient, _ := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q)
q := s.q
if q == nil {
q = &QueryOptions{}
}
if q.Params == nil {
q.Params = make(map[string]string)
}
commandBytes, err := json.Marshal(s.command)
if err != nil {
return nil, fmt.Errorf("failed to marshal command: %W", err)
}
q.Params["tty"] = strconv.FormatBool(s.tty)
q.Params["task"] = s.task
q.Params["command"] = string(commandBytes)
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID)
var conn *websocket.Conn
if nodeClient != nil {
conn, _, _ = nodeClient.websocket(reqPath, q)
}
if conn == nil {
conn, _, err = s.client.websocket(reqPath, q)
if err != nil {
return nil, err
}
}
return conn, nil
}
func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error {
// FIXME: Handle websocket send errors.
// Currently, websocket write failures are dropped. As sending and
// receiving are running concurrently, it's expected that some send
// requests may fail with connection errors when connection closes.
// Connection errors should surface in the receive paths already,
// but I'm unsure about one-sided communication errors.
var sendLock sync.Mutex
send := func(v *ExecStreamingInput) {
sendLock.Lock()
defer sendLock.Unlock()
conn.WriteJSON(v)
}
errCh := make(chan error, 4)
// propagate stdin
go func() {
bytes := make([]byte, 2048)
for {
if ctx.Err() != nil {
return
}
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
n, err := s.stdin.Read(bytes)
// always send data if we read some
if n != 0 {
input.Stdin.Data = bytes[:n]
send(&input)
}
// then handle error
if err == io.EOF {
// if n != 0, send data and we'll get n = 0 on next read
if n == 0 {
input.Stdin.Close = true
send(&input)
return
}
} else if err != nil {
errCh <- err
return
}
}
}()
// propagate terminal sizing updates
go func() {
for {
resizeInput := ExecStreamingInput{}
select {
case <-ctx.Done():
return
case size, ok := <-s.terminalSizeCh:
if !ok {
return
}
resizeInput.TTYSize = &size
send(&resizeInput)
}
}
}()
// send a heartbeat every 10 seconds
go func() {
for {
select {
case <-ctx.Done():
return
// heartbeat message
case <-time.After(10 * time.Second):
send(&execStreamingInputHeartbeat)
}
}
}()
return errCh
}
func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) {
exitCodeCh := make(chan int, 1)
errCh := make(chan error, 1)
go func() {
for ctx.Err() == nil {
// Decode the next frame
var frame ExecStreamingOutput
err := conn.ReadJSON(&frame)
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err)
return
} else if err != nil {
errCh <- err
return
}
switch {
case frame.Stdout != nil:
if len(frame.Stdout.Data) != 0 {
s.stdout.Write(frame.Stdout.Data)
}
// don't really do anything if stdout is closing
case frame.Stderr != nil:
if len(frame.Stderr.Data) != 0 {
s.stderr.Write(frame.Stderr.Data)
}
// don't really do anything if stderr is closing
case frame.Exited && frame.Result != nil:
exitCodeCh <- frame.Result.ExitCode
return
default:
// noop - heartbeat
}
}
}()
return exitCodeCh, errCh
}