diff --git a/command/agent/alloc_endpoint.go b/command/agent/alloc_endpoint.go index 72e6108c2..dd06f283d 100644 --- a/command/agent/alloc_endpoint.go +++ b/command/agent/alloc_endpoint.go @@ -1,14 +1,21 @@ package agent import ( + "context" "encoding/json" "fmt" + "io" + "net" "net/http" + "strconv" "strings" "github.com/golang/snappy" + "github.com/gorilla/websocket" cstructs "github.com/hashicorp/nomad/client/structs" "github.com/hashicorp/nomad/nomad/structs" + "github.com/hashicorp/nomad/plugins/drivers" + "github.com/ugorji/go/codec" ) const ( @@ -129,6 +136,8 @@ func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Requ switch tokens[1] { case "stats": return s.allocStats(allocID, resp, req) + case "exec": + return s.allocExec(allocID, resp, req) case "snapshot": if s.agent.client == nil { return nil, clientNotRunning @@ -347,3 +356,187 @@ func (s *HTTPServer) allocStats(allocID string, resp http.ResponseWriter, req *h return reply.Stats, rpcErr } + +func (s *HTTPServer) allocExec(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) { + // Build the request and parse the ACL token + task := req.URL.Query().Get("task") + cmdJsonStr := req.URL.Query().Get("command") + var command []string + err := json.Unmarshal([]byte(cmdJsonStr), &command) + if err != nil { + // this shouldn't happen, []string is always be serializable to json + return nil, fmt.Errorf("failed to marshal command into json: %v", err) + } + + ttyB := false + if tty := req.URL.Query().Get("tty"); tty != "" { + ttyB, err = strconv.ParseBool(tty) + if err != nil { + return nil, fmt.Errorf("tty value is not a boolean: %v", err) + } + } + + args := cstructs.AllocExecRequest{ + AllocID: allocID, + Task: task, + Cmd: command, + Tty: ttyB, + } + s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions) + + conn, err := s.wsUpgrader.Upgrade(resp, req, nil) + if err != nil { + return nil, fmt.Errorf("failed to upgrade connection: %v", err) + } + + return s.execStreamImpl(conn, &args) +} + +func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest) (interface{}, error) { + allocID := args.AllocID + method := "Allocations.Exec" + + // Get the correct handler + localClient, remoteClient, localServer := s.rpcHandlerForAlloc(allocID) + var handler structs.StreamingRpcHandler + var handlerErr error + if localClient { + handler, handlerErr = s.agent.Client().StreamingRpcHandler(method) + } else if remoteClient { + handler, handlerErr = s.agent.Client().RemoteStreamingRpcHandler(method) + } else if localServer { + handler, handlerErr = s.agent.Server().StreamingRpcHandler(method) + } + + if handlerErr != nil { + return nil, CodedError(500, handlerErr.Error()) + } + + // Create a pipe connecting the (possibly remote) handler to the http response + httpPipe, handlerPipe := net.Pipe() + decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle) + encoder := codec.NewEncoder(httpPipe, structs.MsgpackHandle) + + // Create a goroutine that closes the pipe if the connection closes. + ctx, cancel := context.WithCancel(context.Background()) + go func() { + <-ctx.Done() + httpPipe.Close() + + // don't close ws - wait to drain messages + }() + + // Create a channel that decodes the results + errCh := make(chan HTTPCodedError, 2) + + // stream response + go func() { + defer cancel() + + // Send the request + if err := encoder.Encode(args); err != nil { + errCh <- CodedError(500, err.Error()) + return + } + + go forwardExecInput(encoder, ws, errCh) + + for { + select { + case <-ctx.Done(): + errCh <- nil + return + default: + } + + var res cstructs.StreamErrWrapper + err := decoder.Decode(&res) + if isClosedError(err) { + ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + errCh <- nil + return + } + + if err != nil { + errCh <- CodedError(500, err.Error()) + return + } + decoder.Reset(httpPipe) + + if err := res.Error; err != nil { + code := 500 + if err.Code != nil { + code = int(*err.Code) + } + errCh <- CodedError(code, err.Error()) + return + } + + if err := ws.WriteMessage(websocket.TextMessage, res.Payload); err != nil { + errCh <- CodedError(500, err.Error()) + return + } + } + }() + + // start streaming request to streaming RPC - returns when streaming completes or errors + handler(handlerPipe) + // stop streaming background goroutines for streaming - but not websocket activity + cancel() + // retreieve any error and/or wait until goroutine stop and close errCh connection before + // closing websocket connection + codedErr := <-errCh + + if isClosedError(codedErr) { + codedErr = nil + } else if codedErr != nil { + ws.WriteMessage(websocket.CloseMessage, + websocket.FormatCloseMessage(toWsCode(codedErr.Code()), codedErr.Error())) + } + ws.Close() + + return nil, codedErr +} + +func toWsCode(httpCode int) int { + switch httpCode { + case 500: + return websocket.CloseInternalServerErr + default: + // placeholder error code + return websocket.ClosePolicyViolation + } +} + +func isClosedError(err error) bool { + if err == nil { + return false + } + + return err == io.EOF || + err == io.ErrClosedPipe || + strings.Contains(err.Error(), "closed") || + strings.Contains(err.Error(), "EOF") +} + +// forwardExecInput forwards exec input (e.g. stdin) from websocket connection +// to the streaming RPC connection to client +func forwardExecInput(encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) { + for { + sf := &drivers.ExecTaskStreamingRequestMsg{} + err := ws.ReadJSON(sf) + if err == io.EOF { + return + } + + if err != nil { + errCh <- CodedError(500, err.Error()) + return + } + + err = encoder.Encode(sf) + if err != nil { + errCh <- CodedError(500, err.Error()) + } + } +} diff --git a/command/agent/http.go b/command/agent/http.go index e1e33fa34..1bb673a2a 100644 --- a/command/agent/http.go +++ b/command/agent/http.go @@ -15,6 +15,7 @@ import ( "github.com/NYTimes/gziphandler" assetfs "github.com/elazarl/go-bindata-assetfs" + "github.com/gorilla/websocket" log "github.com/hashicorp/go-hclog" "github.com/hashicorp/nomad/helper/tlsutil" "github.com/hashicorp/nomad/nomad/structs" @@ -54,6 +55,8 @@ type HTTPServer struct { listenerCh chan struct{} logger log.Logger Addr string + + wsUpgrader *websocket.Upgrader } // NewHTTPServer starts new HTTP server over the agent @@ -85,6 +88,11 @@ func NewHTTPServer(agent *Agent, config *Config) (*HTTPServer, error) { // Create the mux mux := http.NewServeMux() + wsUpgrader := &websocket.Upgrader{ + ReadBufferSize: 2048, + WriteBufferSize: 2048, + } + // Create the server srv := &HTTPServer{ agent: agent, @@ -93,6 +101,7 @@ func NewHTTPServer(agent *Agent, config *Config) (*HTTPServer, error) { listenerCh: make(chan struct{}), logger: agent.httpLogger, Addr: ln.Addr().String(), + wsUpgrader: wsUpgrader, } srv.registerHandlers(config.EnableDebug)