mirror of
https://github.com/kemko/nomad.git
synced 2026-01-06 10:25:42 +03:00
Merge pull request #10657 from hashicorp/b-alloc-exec-closing
Handle `nomad exec` termination events in order
This commit is contained in:
@@ -2,16 +2,10 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context,
|
||||
stdin io.Reader, stdout, stderr io.Writer,
|
||||
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {
|
||||
|
||||
ctx, cancelFn := context.WithCancel(ctx)
|
||||
defer cancelFn()
|
||||
s := &execSession{
|
||||
client: a.client,
|
||||
alloc: alloc,
|
||||
task: task,
|
||||
tty: tty,
|
||||
command: command,
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
|
||||
sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return -2, err
|
||||
default:
|
||||
terminalSizeCh: terminalSizeCh,
|
||||
q: q,
|
||||
}
|
||||
|
||||
// Errors resulting from sending input (in goroutines) are silently dropped.
|
||||
// To mitigate this, extra care is needed to distinguish between actual send errors
|
||||
// and from send errors due to command terminating and our race to detect failures.
|
||||
// If we have an actual network failure or send a bad input, we'd get an
|
||||
// error in the reading side of websocket.
|
||||
|
||||
go func() {
|
||||
|
||||
bytes := make([]byte, 2048)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
|
||||
|
||||
n, err := stdin.Read(bytes)
|
||||
|
||||
// always send data if we read some
|
||||
if n != 0 {
|
||||
input.Stdin.Data = bytes[:n]
|
||||
sender(&input)
|
||||
}
|
||||
|
||||
// then handle error
|
||||
if err == io.EOF {
|
||||
// if n != 0, send data and we'll get n = 0 on next read
|
||||
if n == 0 {
|
||||
input.Stdin.Close = true
|
||||
sender(&input)
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// forwarding terminal size
|
||||
go func() {
|
||||
for {
|
||||
resizeInput := ExecStreamingInput{}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case size, ok := <-terminalSizeCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
resizeInput.TTYSize = &size
|
||||
sender(&resizeInput)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
// send a heartbeat every 10 seconds
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
// heartbeat message
|
||||
case <-time.After(10 * time.Second):
|
||||
sender(&execStreamingInputHeartbeat)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
// drop websocket code, not relevant to user
|
||||
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
|
||||
return -2, errors.New(wsErr.Text)
|
||||
}
|
||||
return -2, err
|
||||
case <-ctx.Done():
|
||||
return -2, ctx.Err()
|
||||
case frame, ok := <-output:
|
||||
if !ok {
|
||||
return -2, errors.New("disconnected without receiving the exit code")
|
||||
}
|
||||
|
||||
switch {
|
||||
case frame.Stdout != nil:
|
||||
if len(frame.Stdout.Data) != 0 {
|
||||
stdout.Write(frame.Stdout.Data)
|
||||
}
|
||||
// don't really do anything if stdout is closing
|
||||
case frame.Stderr != nil:
|
||||
if len(frame.Stderr.Data) != 0 {
|
||||
stderr.Write(frame.Stderr.Data)
|
||||
}
|
||||
// don't really do anything if stderr is closing
|
||||
case frame.Exited && frame.Result != nil:
|
||||
return frame.Result.ExitCode, nil
|
||||
default:
|
||||
// noop - heartbeat
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
|
||||
errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) {
|
||||
nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
|
||||
|
||||
if q == nil {
|
||||
q = &QueryOptions{}
|
||||
}
|
||||
if q.Params == nil {
|
||||
q.Params = make(map[string]string)
|
||||
}
|
||||
|
||||
commandBytes, err := json.Marshal(command)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("failed to marshal command: %s", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
q.Params["tty"] = strconv.FormatBool(tty)
|
||||
q.Params["task"] = task
|
||||
q.Params["command"] = string(commandBytes)
|
||||
|
||||
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)
|
||||
|
||||
var conn *websocket.Conn
|
||||
|
||||
if nodeClient != nil {
|
||||
conn, _, _ = nodeClient.websocket(reqPath, q)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
conn, _, err = a.client.websocket(reqPath, q)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create the output channel
|
||||
frames := make(chan *ExecStreamingOutput, 10)
|
||||
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
for ctx.Err() == nil {
|
||||
|
||||
// Decode the next frame
|
||||
var frame ExecStreamingOutput
|
||||
err := conn.ReadJSON(&frame)
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
||||
close(frames)
|
||||
return
|
||||
} else if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
frames <- &frame
|
||||
}
|
||||
}()
|
||||
|
||||
var sendLock sync.Mutex
|
||||
send := func(v *ExecStreamingInput) error {
|
||||
sendLock.Lock()
|
||||
defer sendLock.Unlock()
|
||||
|
||||
return conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
return send, frames
|
||||
|
||||
return s.run(ctx)
|
||||
}
|
||||
|
||||
func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {
|
||||
|
||||
236
api/allocations_exec.go
Normal file
236
api/allocations_exec.go
Normal file
@@ -0,0 +1,236 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type execSession struct {
|
||||
client *Client
|
||||
alloc *Allocation
|
||||
task string
|
||||
tty bool
|
||||
command []string
|
||||
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
|
||||
terminalSizeCh <-chan TerminalSize
|
||||
|
||||
q *QueryOptions
|
||||
}
|
||||
|
||||
func (s *execSession) run(ctx context.Context) (exitCode int, err error) {
|
||||
ctx, cancelFn := context.WithCancel(ctx)
|
||||
defer cancelFn()
|
||||
|
||||
conn, err := s.startConnection()
|
||||
if err != nil {
|
||||
return -2, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
sendErrCh := s.startTransmit(ctx, conn)
|
||||
exitCh, recvErrCh := s.startReceiving(ctx, conn)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return -2, ctx.Err()
|
||||
case exitCode := <-exitCh:
|
||||
return exitCode, nil
|
||||
case recvErr := <-recvErrCh:
|
||||
// drop websocket code, not relevant to user
|
||||
if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" {
|
||||
return -2, errors.New(wsErr.Text)
|
||||
}
|
||||
|
||||
return -2, recvErr
|
||||
case sendErr := <-sendErrCh:
|
||||
return -2, fmt.Errorf("failed to send input: %w", sendErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *execSession) startConnection() (*websocket.Conn, error) {
|
||||
// First, attempt to connect to the node directly, but may fail due to network isolation
|
||||
// and network errors. Fallback to using server-side forwarding instead.
|
||||
nodeClient, err := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q)
|
||||
if err == NodeDownErr {
|
||||
return nil, NodeDownErr
|
||||
}
|
||||
|
||||
q := s.q
|
||||
if q == nil {
|
||||
q = &QueryOptions{}
|
||||
}
|
||||
if q.Params == nil {
|
||||
q.Params = make(map[string]string)
|
||||
}
|
||||
|
||||
commandBytes, err := json.Marshal(s.command)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal command: %W", err)
|
||||
}
|
||||
|
||||
q.Params["tty"] = strconv.FormatBool(s.tty)
|
||||
q.Params["task"] = s.task
|
||||
q.Params["command"] = string(commandBytes)
|
||||
|
||||
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID)
|
||||
|
||||
var conn *websocket.Conn
|
||||
|
||||
if nodeClient != nil {
|
||||
conn, _, _ = nodeClient.websocket(reqPath, q)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
conn, _, err = s.client.websocket(reqPath, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error {
|
||||
|
||||
// FIXME: Handle websocket send errors.
|
||||
// Currently, websocket write failures are dropped. As sending and
|
||||
// receiving are running concurrently, it's expected that some send
|
||||
// requests may fail with connection errors when connection closes.
|
||||
// Connection errors should surface in the receive paths already,
|
||||
// but I'm unsure about one-sided communication errors.
|
||||
var sendLock sync.Mutex
|
||||
send := func(v *ExecStreamingInput) {
|
||||
sendLock.Lock()
|
||||
defer sendLock.Unlock()
|
||||
|
||||
conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
|
||||
// propagate stdin
|
||||
go func() {
|
||||
|
||||
bytes := make([]byte, 2048)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
|
||||
|
||||
n, err := s.stdin.Read(bytes)
|
||||
|
||||
// always send data if we read some
|
||||
if n != 0 {
|
||||
input.Stdin.Data = bytes[:n]
|
||||
send(&input)
|
||||
}
|
||||
|
||||
// then handle error
|
||||
if err == io.EOF {
|
||||
// if n != 0, send data and we'll get n = 0 on next read
|
||||
if n == 0 {
|
||||
input.Stdin.Close = true
|
||||
send(&input)
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// propagate terminal sizing updates
|
||||
go func() {
|
||||
for {
|
||||
resizeInput := ExecStreamingInput{}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case size, ok := <-s.terminalSizeCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
resizeInput.TTYSize = &size
|
||||
send(&resizeInput)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
// send a heartbeat every 10 seconds
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
// heartbeat message
|
||||
case <-time.After(10 * time.Second):
|
||||
send(&execStreamingInputHeartbeat)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) {
|
||||
exitCodeCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
|
||||
// Decode the next frame
|
||||
var frame ExecStreamingOutput
|
||||
err := conn.ReadJSON(&frame)
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
||||
errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err)
|
||||
return
|
||||
} else if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case frame.Stdout != nil:
|
||||
if len(frame.Stdout.Data) != 0 {
|
||||
s.stdout.Write(frame.Stdout.Data)
|
||||
}
|
||||
// don't really do anything if stdout is closing
|
||||
case frame.Stderr != nil:
|
||||
if len(frame.Stderr.Data) != 0 {
|
||||
s.stderr.Write(frame.Stderr.Data)
|
||||
}
|
||||
// don't really do anything if stderr is closing
|
||||
case frame.Exited && frame.Result != nil:
|
||||
exitCodeCh <- frame.Result.ExitCode
|
||||
return
|
||||
default:
|
||||
// noop - heartbeat
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
return exitCodeCh, errCh
|
||||
}
|
||||
@@ -515,13 +515,6 @@ func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExec
|
||||
go forwardExecInput(encoder, ws, errCh)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errCh <- nil
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
var res cstructs.StreamErrWrapper
|
||||
err := decoder.Decode(&res)
|
||||
if isClosedError(err) {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"io"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -90,13 +89,7 @@ func (tc *NomadExecE2ETest) TestExecBasicResponses(f *framework.F) {
|
||||
stdin, &stdout, &stderr,
|
||||
resizeCh, nil)
|
||||
|
||||
// TODO: Occasionally, we get "Unexpected EOF" error, but with the correct output.
|
||||
// investigate why
|
||||
if err != nil && strings.Contains(err.Error(), io.ErrUnexpectedEOF.Error()) {
|
||||
f.T().Logf("got unexpected EOF error, ignoring: %v", err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, c.ExitCode, exitCode)
|
||||
|
||||
|
||||
203
vendor/github.com/hashicorp/nomad/api/allocations.go
generated
vendored
203
vendor/github.com/hashicorp/nomad/api/allocations.go
generated
vendored
@@ -2,16 +2,10 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"sort"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -87,195 +81,22 @@ func (a *Allocations) Exec(ctx context.Context,
|
||||
stdin io.Reader, stdout, stderr io.Writer,
|
||||
terminalSizeCh <-chan TerminalSize, q *QueryOptions) (exitCode int, err error) {
|
||||
|
||||
ctx, cancelFn := context.WithCancel(ctx)
|
||||
defer cancelFn()
|
||||
s := &execSession{
|
||||
client: a.client,
|
||||
alloc: alloc,
|
||||
task: task,
|
||||
tty: tty,
|
||||
command: command,
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
stdin: stdin,
|
||||
stdout: stdout,
|
||||
stderr: stderr,
|
||||
|
||||
sender, output := a.execFrames(ctx, alloc, task, tty, command, errCh, q)
|
||||
|
||||
select {
|
||||
case err := <-errCh:
|
||||
return -2, err
|
||||
default:
|
||||
terminalSizeCh: terminalSizeCh,
|
||||
q: q,
|
||||
}
|
||||
|
||||
// Errors resulting from sending input (in goroutines) are silently dropped.
|
||||
// To mitigate this, extra care is needed to distinguish between actual send errors
|
||||
// and from send errors due to command terminating and our race to detect failures.
|
||||
// If we have an actual network failure or send a bad input, we'd get an
|
||||
// error in the reading side of websocket.
|
||||
|
||||
go func() {
|
||||
|
||||
bytes := make([]byte, 2048)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
|
||||
|
||||
n, err := stdin.Read(bytes)
|
||||
|
||||
// always send data if we read some
|
||||
if n != 0 {
|
||||
input.Stdin.Data = bytes[:n]
|
||||
sender(&input)
|
||||
}
|
||||
|
||||
// then handle error
|
||||
if err == io.EOF {
|
||||
// if n != 0, send data and we'll get n = 0 on next read
|
||||
if n == 0 {
|
||||
input.Stdin.Close = true
|
||||
sender(&input)
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// forwarding terminal size
|
||||
go func() {
|
||||
for {
|
||||
resizeInput := ExecStreamingInput{}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case size, ok := <-terminalSizeCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
resizeInput.TTYSize = &size
|
||||
sender(&resizeInput)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
// send a heartbeat every 10 seconds
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
// heartbeat message
|
||||
case <-time.After(10 * time.Second):
|
||||
sender(&execStreamingInputHeartbeat)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case err := <-errCh:
|
||||
// drop websocket code, not relevant to user
|
||||
if wsErr, ok := err.(*websocket.CloseError); ok && wsErr.Text != "" {
|
||||
return -2, errors.New(wsErr.Text)
|
||||
}
|
||||
return -2, err
|
||||
case <-ctx.Done():
|
||||
return -2, ctx.Err()
|
||||
case frame, ok := <-output:
|
||||
if !ok {
|
||||
return -2, errors.New("disconnected without receiving the exit code")
|
||||
}
|
||||
|
||||
switch {
|
||||
case frame.Stdout != nil:
|
||||
if len(frame.Stdout.Data) != 0 {
|
||||
stdout.Write(frame.Stdout.Data)
|
||||
}
|
||||
// don't really do anything if stdout is closing
|
||||
case frame.Stderr != nil:
|
||||
if len(frame.Stderr.Data) != 0 {
|
||||
stderr.Write(frame.Stderr.Data)
|
||||
}
|
||||
// don't really do anything if stderr is closing
|
||||
case frame.Exited && frame.Result != nil:
|
||||
return frame.Result.ExitCode, nil
|
||||
default:
|
||||
// noop - heartbeat
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Allocations) execFrames(ctx context.Context, alloc *Allocation, task string, tty bool, command []string,
|
||||
errCh chan<- error, q *QueryOptions) (sendFn func(*ExecStreamingInput) error, output <-chan *ExecStreamingOutput) {
|
||||
nodeClient, _ := a.client.GetNodeClientWithTimeout(alloc.NodeID, ClientConnTimeout, q)
|
||||
|
||||
if q == nil {
|
||||
q = &QueryOptions{}
|
||||
}
|
||||
if q.Params == nil {
|
||||
q.Params = make(map[string]string)
|
||||
}
|
||||
|
||||
commandBytes, err := json.Marshal(command)
|
||||
if err != nil {
|
||||
errCh <- fmt.Errorf("failed to marshal command: %s", err)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
q.Params["tty"] = strconv.FormatBool(tty)
|
||||
q.Params["task"] = task
|
||||
q.Params["command"] = string(commandBytes)
|
||||
|
||||
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", alloc.ID)
|
||||
|
||||
var conn *websocket.Conn
|
||||
|
||||
if nodeClient != nil {
|
||||
conn, _, _ = nodeClient.websocket(reqPath, q)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
conn, _, err = a.client.websocket(reqPath, q)
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Create the output channel
|
||||
frames := make(chan *ExecStreamingOutput, 10)
|
||||
|
||||
go func() {
|
||||
defer conn.Close()
|
||||
for ctx.Err() == nil {
|
||||
|
||||
// Decode the next frame
|
||||
var frame ExecStreamingOutput
|
||||
err := conn.ReadJSON(&frame)
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
||||
close(frames)
|
||||
return
|
||||
} else if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
frames <- &frame
|
||||
}
|
||||
}()
|
||||
|
||||
var sendLock sync.Mutex
|
||||
send := func(v *ExecStreamingInput) error {
|
||||
sendLock.Lock()
|
||||
defer sendLock.Unlock()
|
||||
|
||||
return conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
return send, frames
|
||||
|
||||
return s.run(ctx)
|
||||
}
|
||||
|
||||
func (a *Allocations) Stats(alloc *Allocation, q *QueryOptions) (*AllocResourceUsage, error) {
|
||||
|
||||
236
vendor/github.com/hashicorp/nomad/api/allocations_exec.go
generated
vendored
Normal file
236
vendor/github.com/hashicorp/nomad/api/allocations_exec.go
generated
vendored
Normal file
@@ -0,0 +1,236 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type execSession struct {
|
||||
client *Client
|
||||
alloc *Allocation
|
||||
task string
|
||||
tty bool
|
||||
command []string
|
||||
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
stderr io.Writer
|
||||
|
||||
terminalSizeCh <-chan TerminalSize
|
||||
|
||||
q *QueryOptions
|
||||
}
|
||||
|
||||
func (s *execSession) run(ctx context.Context) (exitCode int, err error) {
|
||||
ctx, cancelFn := context.WithCancel(ctx)
|
||||
defer cancelFn()
|
||||
|
||||
conn, err := s.startConnection()
|
||||
if err != nil {
|
||||
return -2, err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
sendErrCh := s.startTransmit(ctx, conn)
|
||||
exitCh, recvErrCh := s.startReceiving(ctx, conn)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return -2, ctx.Err()
|
||||
case exitCode := <-exitCh:
|
||||
return exitCode, nil
|
||||
case recvErr := <-recvErrCh:
|
||||
// drop websocket code, not relevant to user
|
||||
if wsErr, ok := recvErr.(*websocket.CloseError); ok && wsErr.Text != "" {
|
||||
return -2, errors.New(wsErr.Text)
|
||||
}
|
||||
|
||||
return -2, recvErr
|
||||
case sendErr := <-sendErrCh:
|
||||
return -2, fmt.Errorf("failed to send input: %w", sendErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *execSession) startConnection() (*websocket.Conn, error) {
|
||||
// First, attempt to connect to the node directly, but may fail due to network isolation
|
||||
// and network errors. Fallback to using server-side forwarding instead.
|
||||
nodeClient, err := s.client.GetNodeClientWithTimeout(s.alloc.NodeID, ClientConnTimeout, s.q)
|
||||
if err == NodeDownErr {
|
||||
return nil, NodeDownErr
|
||||
}
|
||||
|
||||
q := s.q
|
||||
if q == nil {
|
||||
q = &QueryOptions{}
|
||||
}
|
||||
if q.Params == nil {
|
||||
q.Params = make(map[string]string)
|
||||
}
|
||||
|
||||
commandBytes, err := json.Marshal(s.command)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal command: %W", err)
|
||||
}
|
||||
|
||||
q.Params["tty"] = strconv.FormatBool(s.tty)
|
||||
q.Params["task"] = s.task
|
||||
q.Params["command"] = string(commandBytes)
|
||||
|
||||
reqPath := fmt.Sprintf("/v1/client/allocation/%s/exec", s.alloc.ID)
|
||||
|
||||
var conn *websocket.Conn
|
||||
|
||||
if nodeClient != nil {
|
||||
conn, _, _ = nodeClient.websocket(reqPath, q)
|
||||
}
|
||||
|
||||
if conn == nil {
|
||||
conn, _, err = s.client.websocket(reqPath, q)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (s *execSession) startTransmit(ctx context.Context, conn *websocket.Conn) <-chan error {
|
||||
|
||||
// FIXME: Handle websocket send errors.
|
||||
// Currently, websocket write failures are dropped. As sending and
|
||||
// receiving are running concurrently, it's expected that some send
|
||||
// requests may fail with connection errors when connection closes.
|
||||
// Connection errors should surface in the receive paths already,
|
||||
// but I'm unsure about one-sided communication errors.
|
||||
var sendLock sync.Mutex
|
||||
send := func(v *ExecStreamingInput) {
|
||||
sendLock.Lock()
|
||||
defer sendLock.Unlock()
|
||||
|
||||
conn.WriteJSON(v)
|
||||
}
|
||||
|
||||
errCh := make(chan error, 4)
|
||||
|
||||
// propagate stdin
|
||||
go func() {
|
||||
|
||||
bytes := make([]byte, 2048)
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
input := ExecStreamingInput{Stdin: &ExecStreamingIOOperation{}}
|
||||
|
||||
n, err := s.stdin.Read(bytes)
|
||||
|
||||
// always send data if we read some
|
||||
if n != 0 {
|
||||
input.Stdin.Data = bytes[:n]
|
||||
send(&input)
|
||||
}
|
||||
|
||||
// then handle error
|
||||
if err == io.EOF {
|
||||
// if n != 0, send data and we'll get n = 0 on next read
|
||||
if n == 0 {
|
||||
input.Stdin.Close = true
|
||||
send(&input)
|
||||
return
|
||||
}
|
||||
} else if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// propagate terminal sizing updates
|
||||
go func() {
|
||||
for {
|
||||
resizeInput := ExecStreamingInput{}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case size, ok := <-s.terminalSizeCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
resizeInput.TTYSize = &size
|
||||
send(&resizeInput)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
// send a heartbeat every 10 seconds
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
// heartbeat message
|
||||
case <-time.After(10 * time.Second):
|
||||
send(&execStreamingInputHeartbeat)
|
||||
}
|
||||
|
||||
}
|
||||
}()
|
||||
|
||||
return errCh
|
||||
}
|
||||
|
||||
func (s *execSession) startReceiving(ctx context.Context, conn *websocket.Conn) (<-chan int, <-chan error) {
|
||||
exitCodeCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
for ctx.Err() == nil {
|
||||
|
||||
// Decode the next frame
|
||||
var frame ExecStreamingOutput
|
||||
err := conn.ReadJSON(&frame)
|
||||
if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
|
||||
errCh <- fmt.Errorf("websocket closed before receiving exit code: %w", err)
|
||||
return
|
||||
} else if err != nil {
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
switch {
|
||||
case frame.Stdout != nil:
|
||||
if len(frame.Stdout.Data) != 0 {
|
||||
s.stdout.Write(frame.Stdout.Data)
|
||||
}
|
||||
// don't really do anything if stdout is closing
|
||||
case frame.Stderr != nil:
|
||||
if len(frame.Stderr.Data) != 0 {
|
||||
s.stderr.Write(frame.Stderr.Data)
|
||||
}
|
||||
// don't really do anything if stderr is closing
|
||||
case frame.Exited && frame.Result != nil:
|
||||
exitCodeCh <- frame.Result.ExitCode
|
||||
return
|
||||
default:
|
||||
// noop - heartbeat
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}()
|
||||
|
||||
return exitCodeCh, errCh
|
||||
}
|
||||
Reference in New Issue
Block a user