mirror of
https://github.com/kemko/nomad.git
synced 2026-01-01 16:05:42 +03:00
this is the CE side of an Enterprise-only feature.
a job trying to use this in CE will fail to validate.
to enable daily-scheduled execution entirely client-side,
a job may now contain:
task "name" {
schedule {
cron {
start = "0 12 * * * *" # may not include "," or "/"
end = "0 16" # partial cron, with only {minute} {hour}
timezone = "EST" # anything in your tzdb
}
}
...
and everything about the allocation will be placed as usual,
but if outside the specified schedule, the taskrunner will block
on the client, waiting on the schedule start, before proceeding
with the task driver execution, etc.
this includes a taksrunner hook, which watches for the end of
the schedule, at which point it will kill the task.
then, restarts-allowing, a new task will start and again block
waiting for start, and so on.
this also includes all the plumbing required to pipe API calls
through from command->api->agent->server->client, so that
tasks can be force-run, force-paused, or resume the schedule
on demand.
838 lines
24 KiB
Go
838 lines
24 KiB
Go
// Copyright (c) HashiCorp, Inc.
|
||
// SPDX-License-Identifier: BUSL-1.1
|
||
|
||
package agent
|
||
|
||
import (
|
||
"context"
|
||
"encoding/json"
|
||
"errors"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"slices"
|
||
"strconv"
|
||
"strings"
|
||
|
||
"github.com/golang/snappy"
|
||
"github.com/gorilla/websocket"
|
||
"github.com/hashicorp/go-msgpack/v2/codec"
|
||
cstructs "github.com/hashicorp/nomad/client/structs"
|
||
"github.com/hashicorp/nomad/nomad/structs"
|
||
"github.com/hashicorp/nomad/plugins/drivers"
|
||
)
|
||
|
||
const (
|
||
allocNotFoundErr = "allocation not found"
|
||
resourceNotFoundErr = "resource not found"
|
||
)
|
||
|
||
func (s *HTTPServer) AllocsRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
if req.Method != http.MethodGet {
|
||
return nil, CodedError(405, ErrInvalidMethod)
|
||
}
|
||
|
||
args := structs.AllocListRequest{}
|
||
if s.parse(resp, req, &args.Region, &args.QueryOptions) {
|
||
return nil, nil
|
||
}
|
||
|
||
// Parse resources and task_states field selection
|
||
resources, err := parseBool(req, "resources")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
taskStates, err := parseBool(req, "task_states")
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if resources != nil || taskStates != nil {
|
||
args.Fields = structs.NewAllocStubFields()
|
||
if resources != nil {
|
||
args.Fields.Resources = *resources
|
||
}
|
||
if taskStates != nil {
|
||
args.Fields.TaskStates = *taskStates
|
||
}
|
||
}
|
||
|
||
var out structs.AllocListResponse
|
||
if err := s.agent.RPC("Alloc.List", &args, &out); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
setMeta(resp, &out.QueryMeta)
|
||
if out.Allocations == nil {
|
||
out.Allocations = make([]*structs.AllocListStub, 0)
|
||
}
|
||
for _, alloc := range out.Allocations {
|
||
alloc.SetEventDisplayMessages()
|
||
}
|
||
return out.Allocations, nil
|
||
}
|
||
|
||
func (s *HTTPServer) AllocSpecificRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
reqSuffix := strings.TrimPrefix(req.URL.Path, "/v1/allocation/")
|
||
|
||
// tokenize the suffix of the path to get the alloc id and find the action
|
||
// invoked on the alloc id
|
||
tokens := strings.Split(reqSuffix, "/")
|
||
if len(tokens) > 2 || len(tokens) < 1 {
|
||
return nil, CodedError(404, resourceNotFoundErr)
|
||
}
|
||
allocID := tokens[0]
|
||
|
||
if len(tokens) == 1 {
|
||
return s.allocGet(allocID, resp, req)
|
||
}
|
||
|
||
switch tokens[1] {
|
||
case "checks":
|
||
return s.allocChecks(allocID, resp, req)
|
||
case "stop":
|
||
return s.allocStop(allocID, resp, req)
|
||
case "services":
|
||
return s.allocServiceRegistrations(resp, req, allocID)
|
||
}
|
||
|
||
return nil, CodedError(404, resourceNotFoundErr)
|
||
}
|
||
|
||
func (s *HTTPServer) allocGet(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
if req.Method != http.MethodGet {
|
||
return nil, CodedError(405, ErrInvalidMethod)
|
||
}
|
||
|
||
args := structs.AllocSpecificRequest{
|
||
AllocID: allocID,
|
||
}
|
||
if s.parse(resp, req, &args.Region, &args.QueryOptions) {
|
||
return nil, nil
|
||
}
|
||
|
||
var out structs.SingleAllocResponse
|
||
if err := s.agent.RPC("Alloc.GetAlloc", &args, &out); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
setMeta(resp, &out.QueryMeta)
|
||
if out.Alloc == nil {
|
||
return nil, CodedError(404, "alloc not found")
|
||
}
|
||
|
||
// Decode the payload if there is any
|
||
|
||
alloc := out.Alloc
|
||
if alloc.Job != nil && len(alloc.Job.Payload) != 0 {
|
||
decoded, err := snappy.Decode(nil, alloc.Job.Payload)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
alloc = alloc.Copy()
|
||
alloc.Job.Payload = decoded
|
||
}
|
||
alloc.SetEventDisplayMessages()
|
||
|
||
// Handle 0.12 ports upgrade path
|
||
alloc = alloc.Copy()
|
||
alloc.AllocatedResources.Canonicalize()
|
||
|
||
return alloc, nil
|
||
}
|
||
|
||
func (s *HTTPServer) allocStop(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
if !(req.Method == "POST" || req.Method == "PUT") {
|
||
return nil, CodedError(405, ErrInvalidMethod)
|
||
}
|
||
|
||
noShutdownDelay := false
|
||
if noShutdownDelayQS := req.URL.Query().Get("no_shutdown_delay"); noShutdownDelayQS != "" {
|
||
var err error
|
||
noShutdownDelay, err = strconv.ParseBool(noShutdownDelayQS)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("no_shutdown_delay value is not a boolean: %v", err)
|
||
}
|
||
}
|
||
|
||
sr := &structs.AllocStopRequest{
|
||
AllocID: allocID,
|
||
NoShutdownDelay: noShutdownDelay,
|
||
}
|
||
s.parseWriteRequest(req, &sr.WriteRequest)
|
||
|
||
var out structs.AllocStopResponse
|
||
rpcErr := s.agent.RPC("Alloc.Stop", &sr, &out)
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrUnknownAllocation(rpcErr) {
|
||
rpcErr = CodedError(404, allocNotFoundErr)
|
||
}
|
||
return nil, rpcErr
|
||
}
|
||
|
||
setIndex(resp, out.Index)
|
||
return &out, nil
|
||
}
|
||
|
||
// allocServiceRegistrations returns a list of all service registrations
|
||
// assigned to the job identifier. It is callable via the
|
||
// /v1/allocation/:alloc_id/services HTTP API and uses the
|
||
// structs.AllocServiceRegistrationsRPCMethod RPC method.
|
||
func (s *HTTPServer) allocServiceRegistrations(
|
||
resp http.ResponseWriter, req *http.Request, allocID string) (interface{}, error) {
|
||
|
||
// The endpoint only supports GET requests.
|
||
if req.Method != http.MethodGet {
|
||
return nil, CodedError(http.StatusMethodNotAllowed, ErrInvalidMethod)
|
||
}
|
||
|
||
// Set up the request args and parse this to ensure the query options are
|
||
// set.
|
||
args := structs.AllocServiceRegistrationsRequest{AllocID: allocID}
|
||
if s.parse(resp, req, &args.Region, &args.QueryOptions) {
|
||
return nil, nil
|
||
}
|
||
|
||
// Perform the RPC request.
|
||
var reply structs.AllocServiceRegistrationsResponse
|
||
if err := s.agent.RPC(structs.AllocServiceRegistrationsRPCMethod, &args, &reply); err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
setMeta(resp, &reply.QueryMeta)
|
||
|
||
if reply.Services == nil {
|
||
return nil, CodedError(http.StatusNotFound, allocNotFoundErr)
|
||
}
|
||
return reply.Services, nil
|
||
}
|
||
|
||
func (s *HTTPServer) ClientAllocRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
reqSuffix := strings.TrimPrefix(req.URL.Path, "/v1/client/allocation/")
|
||
|
||
// tokenize the suffix of the path to get the alloc id and find the action
|
||
// invoked on the alloc id
|
||
tokens := strings.Split(reqSuffix, "/")
|
||
if len(tokens) != 2 {
|
||
return nil, CodedError(404, resourceNotFoundErr)
|
||
}
|
||
allocID := tokens[0]
|
||
switch tokens[1] {
|
||
case "checks":
|
||
return s.allocChecks(allocID, resp, req)
|
||
case "stats":
|
||
return s.allocStats(allocID, resp, req)
|
||
case "exec":
|
||
return s.allocExec(allocID, resp, req)
|
||
case "snapshot":
|
||
if s.agent.Client() == nil {
|
||
return nil, clientNotRunning
|
||
}
|
||
return s.allocSnapshot(allocID, resp, req)
|
||
case "restart":
|
||
return s.allocRestart(allocID, resp, req)
|
||
case "gc":
|
||
return s.allocGC(allocID, resp, req)
|
||
case "signal":
|
||
return s.allocSignal(allocID, resp, req)
|
||
case "pause":
|
||
return s.allocPause(allocID, resp, req)
|
||
}
|
||
|
||
return nil, CodedError(404, resourceNotFoundErr)
|
||
}
|
||
|
||
func (s *HTTPServer) ClientGCRequest(resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
|
||
// Build the request and get the requested Node ID
|
||
args := structs.NodeSpecificRequest{}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
parseNode(req, &args.NodeID)
|
||
|
||
// Determine the handler to use
|
||
useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForNode(args.NodeID)
|
||
|
||
// Make the RPC
|
||
var reply structs.GenericResponse
|
||
var rpcErr error
|
||
if useLocalClient {
|
||
rpcErr = s.agent.Client().ClientRPC("Allocations.GarbageCollectAll", &args, &reply)
|
||
} else if useClientRPC {
|
||
rpcErr = s.agent.Client().RPC("ClientAllocations.GarbageCollectAll", &args, &reply)
|
||
} else if useServerRPC {
|
||
rpcErr = s.agent.Server().RPC("ClientAllocations.GarbageCollectAll", &args, &reply)
|
||
} else {
|
||
rpcErr = CodedError(400, "No local Node and node_id not provided")
|
||
}
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrNoNodeConn(rpcErr) {
|
||
rpcErr = CodedError(404, rpcErr.Error())
|
||
}
|
||
}
|
||
|
||
return nil, rpcErr
|
||
}
|
||
|
||
func (s *HTTPServer) allocRestart(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
// Build the request and parse the ACL token
|
||
args := structs.AllocRestartRequest{
|
||
AllocID: allocID,
|
||
TaskName: "",
|
||
}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
|
||
// Explicitly parse the body separately to disallow overriding AllocID in req Body.
|
||
var reqBody struct {
|
||
TaskName string
|
||
AllTasks bool
|
||
}
|
||
err := json.NewDecoder(req.Body).Decode(&reqBody)
|
||
if err != nil && err != io.EOF {
|
||
return nil, err
|
||
}
|
||
if reqBody.TaskName != "" {
|
||
args.TaskName = reqBody.TaskName
|
||
}
|
||
if reqBody.AllTasks {
|
||
args.AllTasks = reqBody.AllTasks
|
||
}
|
||
|
||
// Determine the handler to use
|
||
useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)
|
||
|
||
// Make the RPC
|
||
var reply structs.GenericResponse
|
||
var rpcErr error
|
||
if useLocalClient {
|
||
rpcErr = s.agent.Client().ClientRPC("Allocations.Restart", &args, &reply)
|
||
} else if useClientRPC {
|
||
rpcErr = s.agent.Client().RPC("ClientAllocations.Restart", &args, &reply)
|
||
} else if useServerRPC {
|
||
rpcErr = s.agent.Server().RPC("ClientAllocations.Restart", &args, &reply)
|
||
} else {
|
||
rpcErr = CodedError(400, "No local Node and node_id not provided")
|
||
}
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
|
||
rpcErr = CodedError(404, rpcErr.Error())
|
||
}
|
||
}
|
||
|
||
return reply, rpcErr
|
||
}
|
||
|
||
func (s *HTTPServer) allocGC(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
// Build the request and parse the ACL token
|
||
args := structs.AllocSpecificRequest{
|
||
AllocID: allocID,
|
||
}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
|
||
// Determine the handler to use
|
||
useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)
|
||
|
||
// Make the RPC
|
||
var reply structs.GenericResponse
|
||
var rpcErr error
|
||
if useLocalClient {
|
||
rpcErr = s.agent.Client().ClientRPC("Allocations.GarbageCollect", &args, &reply)
|
||
} else if useClientRPC {
|
||
rpcErr = s.agent.Client().RPC("ClientAllocations.GarbageCollect", &args, &reply)
|
||
} else if useServerRPC {
|
||
rpcErr = s.agent.Server().RPC("ClientAllocations.GarbageCollect", &args, &reply)
|
||
} else {
|
||
rpcErr = CodedError(400, "No local Node and node_id not provided")
|
||
}
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
|
||
rpcErr = CodedError(404, rpcErr.Error())
|
||
}
|
||
}
|
||
|
||
return nil, rpcErr
|
||
}
|
||
|
||
func (s *HTTPServer) allocSignal(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
if !(req.Method == "POST" || req.Method == "PUT") {
|
||
return nil, CodedError(405, ErrInvalidMethod)
|
||
}
|
||
|
||
// Build the request and parse the ACL token
|
||
args := structs.AllocSignalRequest{}
|
||
err := decodeBody(req, &args)
|
||
if err != nil {
|
||
return nil, CodedError(400, fmt.Sprintf("Failed to decode body: %v", err))
|
||
}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
args.AllocID = allocID
|
||
|
||
// Determine the handler to use
|
||
useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)
|
||
|
||
// Make the RPC
|
||
var reply structs.GenericResponse
|
||
var rpcErr error
|
||
if useLocalClient {
|
||
rpcErr = s.agent.Client().ClientRPC("Allocations.Signal", &args, &reply)
|
||
} else if useClientRPC {
|
||
rpcErr = s.agent.Client().RPC("ClientAllocations.Signal", &args, &reply)
|
||
} else if useServerRPC {
|
||
rpcErr = s.agent.Server().RPC("ClientAllocations.Signal", &args, &reply)
|
||
} else {
|
||
rpcErr = CodedError(400, "No local Node and node_id not provided")
|
||
}
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
|
||
rpcErr = CodedError(404, rpcErr.Error())
|
||
}
|
||
}
|
||
|
||
return reply, rpcErr
|
||
}
|
||
|
||
func (s *HTTPServer) allocPause(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
switch req.Method {
|
||
case http.MethodPost, http.MethodPut:
|
||
return s.allocPauseSet(allocID, resp, req)
|
||
case http.MethodGet:
|
||
return s.allocPauseGet(allocID, resp, req)
|
||
default:
|
||
return nil, CodedError(http.StatusMethodNotAllowed, ErrInvalidMethod)
|
||
}
|
||
}
|
||
|
||
func (s *HTTPServer) allocPauseGet(allocID string, resp http.ResponseWriter, req *http.Request) (any, error) {
|
||
// Build the request and parse the ACL token
|
||
task := req.URL.Query().Get("task")
|
||
args := structs.AllocGetPauseStateRequest{
|
||
AllocID: allocID,
|
||
Task: task,
|
||
}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
|
||
// Determine the handler to use
|
||
useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)
|
||
|
||
// Make the RPC
|
||
var reply structs.AllocGetPauseStateResponse
|
||
var rpcErr error
|
||
if useLocalClient {
|
||
rpcErr = s.agent.Client().ClientRPC("Allocations.GetPauseState", &args, &reply)
|
||
} else if useClientRPC {
|
||
rpcErr = s.agent.Client().RPC("ClientAllocations.GetPauseState", &args, &reply)
|
||
} else if useServerRPC {
|
||
rpcErr = s.agent.Server().RPC("ClientAllocations.GetPauseState", &args, &reply)
|
||
} else {
|
||
rpcErr = CodedError(400, "No local Node and node_id not provided")
|
||
}
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
|
||
rpcErr = CodedError(404, rpcErr.Error())
|
||
}
|
||
}
|
||
|
||
return reply, rpcErr
|
||
}
|
||
|
||
func (s *HTTPServer) allocPauseSet(allocID string, resp http.ResponseWriter, req *http.Request) (any, error) {
|
||
// Build the request and parse the ACL token
|
||
args := structs.AllocPauseRequest{
|
||
AllocID: allocID,
|
||
}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
|
||
// Explicitly parse the body separately to disallow overriding the allocID
|
||
var reqBody struct {
|
||
Task string
|
||
ScheduleState string
|
||
}
|
||
err := json.NewDecoder(req.Body).Decode(&reqBody)
|
||
if err != nil && err != io.EOF {
|
||
return nil, err
|
||
}
|
||
args.Task = reqBody.Task
|
||
|
||
switch reqBody.ScheduleState {
|
||
case "pause":
|
||
args.ScheduleState = structs.TaskScheduleStateForcePause
|
||
case "run":
|
||
args.ScheduleState = structs.TaskScheduleStateForceRun
|
||
case "scheduled":
|
||
args.ScheduleState = structs.TaskScheduleStateSchedResume
|
||
default:
|
||
return nil, CodedError(400, "Not a valid task schedule state")
|
||
}
|
||
|
||
// Determine the handler to use
|
||
useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)
|
||
|
||
// Make the RPC
|
||
var reply structs.GenericResponse
|
||
var rpcErr error
|
||
if useLocalClient {
|
||
rpcErr = s.agent.Client().ClientRPC("Allocations.SetPauseState", &args, &reply)
|
||
} else if useClientRPC {
|
||
rpcErr = s.agent.Client().RPC("ClientAllocations.SetPauseState", &args, &reply)
|
||
} else if useServerRPC {
|
||
rpcErr = s.agent.Server().RPC("ClientAllocations.SetPauseState", &args, &reply)
|
||
} else {
|
||
rpcErr = CodedError(400, "No local Node and node_id not provided")
|
||
}
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
|
||
rpcErr = CodedError(404, rpcErr.Error())
|
||
}
|
||
}
|
||
|
||
return reply, rpcErr
|
||
}
|
||
|
||
func (s *HTTPServer) allocSnapshot(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
var secret string
|
||
s.parseToken(req, &secret)
|
||
if !s.agent.Client().ValidateMigrateToken(allocID, secret) {
|
||
return nil, structs.ErrPermissionDenied
|
||
}
|
||
|
||
allocFS, err := s.agent.Client().GetAllocFS(allocID)
|
||
if err != nil {
|
||
return nil, fmt.Errorf(allocNotFoundErr)
|
||
}
|
||
if err := allocFS.Snapshot(resp); err != nil {
|
||
return nil, fmt.Errorf("error making snapshot: %v", err)
|
||
}
|
||
return nil, nil
|
||
}
|
||
|
||
func (s *HTTPServer) allocStats(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
|
||
// Build the request and parse the ACL token
|
||
task := req.URL.Query().Get("task")
|
||
args := cstructs.AllocStatsRequest{
|
||
AllocID: allocID,
|
||
Task: task,
|
||
}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
|
||
// Determine the handler to use
|
||
useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)
|
||
|
||
// Make the RPC
|
||
var reply cstructs.AllocStatsResponse
|
||
var rpcErr error
|
||
if useLocalClient {
|
||
rpcErr = s.agent.Client().ClientRPC("Allocations.Stats", &args, &reply)
|
||
} else if useClientRPC {
|
||
rpcErr = s.agent.Client().RPC("ClientAllocations.Stats", &args, &reply)
|
||
} else if useServerRPC {
|
||
rpcErr = s.agent.Server().RPC("ClientAllocations.Stats", &args, &reply)
|
||
} else {
|
||
rpcErr = CodedError(400, "No local Node and node_id not provided")
|
||
}
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
|
||
rpcErr = CodedError(404, rpcErr.Error())
|
||
}
|
||
}
|
||
|
||
return reply.Stats, rpcErr
|
||
}
|
||
|
||
func (s *HTTPServer) allocChecks(allocID string, resp http.ResponseWriter, req *http.Request) (any, error) {
|
||
// Build the request and parse the ACL token
|
||
args := cstructs.AllocChecksRequest{
|
||
AllocID: allocID,
|
||
}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
|
||
// Determine the handler to use
|
||
useLocalClient, useClientRPC, useServerRPC := s.rpcHandlerForAlloc(allocID)
|
||
|
||
// Make the RPC
|
||
var reply cstructs.AllocChecksResponse
|
||
var rpcErr error
|
||
switch {
|
||
case useLocalClient:
|
||
rpcErr = s.agent.Client().ClientRPC("Allocations.Checks", &args, &reply)
|
||
case useClientRPC:
|
||
rpcErr = s.agent.Client().RPC("ClientAllocations.Checks", &args, &reply)
|
||
case useServerRPC:
|
||
rpcErr = s.agent.Server().RPC("ClientAllocations.Checks", &args, &reply)
|
||
default:
|
||
rpcErr = CodedError(400, "No local Node and node_id not provided")
|
||
}
|
||
|
||
if rpcErr != nil {
|
||
if structs.IsErrNoNodeConn(rpcErr) || structs.IsErrUnknownAllocation(rpcErr) {
|
||
rpcErr = CodedError(404, rpcErr.Error())
|
||
}
|
||
}
|
||
|
||
return reply.Results, rpcErr
|
||
}
|
||
|
||
func (s *HTTPServer) allocExec(allocID string, resp http.ResponseWriter, req *http.Request) (interface{}, error) {
|
||
// Build the request and parse the ACL token
|
||
task := req.URL.Query().Get("task")
|
||
cmdJsonStr := req.URL.Query().Get("command")
|
||
var command []string
|
||
err := json.Unmarshal([]byte(cmdJsonStr), &command)
|
||
if err != nil {
|
||
// this shouldn't happen, []string is always be serializable to json
|
||
return nil, fmt.Errorf("failed to marshal command into json: %v", err)
|
||
}
|
||
|
||
ttyB := false
|
||
if tty := req.URL.Query().Get("tty"); tty != "" {
|
||
ttyB, err = strconv.ParseBool(tty)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("tty value is not a boolean: %v", err)
|
||
}
|
||
}
|
||
|
||
args := cstructs.AllocExecRequest{
|
||
AllocID: allocID,
|
||
Task: task,
|
||
Cmd: command,
|
||
Tty: ttyB,
|
||
}
|
||
s.parse(resp, req, &args.QueryOptions.Region, &args.QueryOptions)
|
||
|
||
conn, err := s.wsUpgrader.Upgrade(resp, req, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to upgrade connection: %v", err)
|
||
}
|
||
|
||
if err := readWsHandshake(conn.ReadJSON, req, &args.QueryOptions); err != nil {
|
||
conn.WriteMessage(websocket.CloseMessage,
|
||
websocket.FormatCloseMessage(toWsCode(400), err.Error()))
|
||
return nil, err
|
||
}
|
||
|
||
return s.execStream(conn, &args)
|
||
}
|
||
|
||
// readWsHandshake reads the websocket handshake message and sets
|
||
// query authentication token, if request requires a handshake
|
||
func readWsHandshake(readFn func(interface{}) error, req *http.Request, q *structs.QueryOptions) error {
|
||
|
||
// Avoid handshake if request doesn't require one
|
||
if hv := req.URL.Query().Get("ws_handshake"); hv == "" {
|
||
return nil
|
||
} else if h, err := strconv.ParseBool(hv); err != nil {
|
||
return fmt.Errorf("ws_handshake value is not a boolean: %v", err)
|
||
} else if !h {
|
||
return nil
|
||
}
|
||
|
||
var h wsHandshakeMessage
|
||
err := readFn(&h)
|
||
if err != nil {
|
||
return err
|
||
}
|
||
|
||
supportedWSHandshakeVersion := 1
|
||
if h.Version != supportedWSHandshakeVersion {
|
||
return fmt.Errorf("unexpected handshake value: %v", h.Version)
|
||
}
|
||
|
||
q.AuthToken = h.AuthToken
|
||
return nil
|
||
}
|
||
|
||
type wsHandshakeMessage struct {
|
||
Version int `json:"version"`
|
||
AuthToken string `json:"auth_token"`
|
||
}
|
||
|
||
// execStream finds the appropriate RPC handler and then runs the bidirectional
|
||
// websocket-to-RPC stream
|
||
func (s *HTTPServer) execStream(ws *websocket.Conn, args *cstructs.AllocExecRequest) (any, error) {
|
||
allocID := args.AllocID
|
||
method := "Allocations.Exec"
|
||
|
||
// Get the correct handler
|
||
localClient, remoteClient, localServer := s.rpcHandlerForAlloc(allocID)
|
||
var handler structs.StreamingRpcHandler
|
||
var handlerErr error
|
||
if localClient {
|
||
handler, handlerErr = s.agent.Client().StreamingRpcHandler(method)
|
||
} else if remoteClient {
|
||
handler, handlerErr = s.agent.Client().RemoteStreamingRpcHandler(method)
|
||
} else if localServer {
|
||
handler, handlerErr = s.agent.Server().StreamingRpcHandler(method)
|
||
}
|
||
|
||
if handlerErr != nil {
|
||
return nil, CodedError(500, handlerErr.Error())
|
||
}
|
||
|
||
return s.execStreamImpl(ws, args, handler)
|
||
}
|
||
|
||
// execStreamImpl is called by execStream with the appropriate RPC handler and
|
||
// then runs the bidirectional websocket-to-RPC stream.
|
||
func (s *HTTPServer) execStreamImpl(ws *websocket.Conn, args *cstructs.AllocExecRequest, handler structs.StreamingRpcHandler) (any, error) {
|
||
|
||
// Create a pipe connecting the (possibly remote) handler to the http response
|
||
httpPipe, handlerPipe := net.Pipe()
|
||
decoder := codec.NewDecoder(httpPipe, structs.MsgpackHandle)
|
||
encoder := codec.NewEncoder(httpPipe, structs.MsgpackHandle)
|
||
|
||
// Create a goroutine that closes the pipe if the connection closes.
|
||
ctx, cancel := context.WithCancel(context.Background())
|
||
go func() {
|
||
<-ctx.Done()
|
||
httpPipe.Close()
|
||
|
||
// don't close ws - wait to drain messages
|
||
}()
|
||
|
||
// Create a channel for the final result
|
||
resultCh := make(chan HTTPCodedError, 1)
|
||
|
||
// stream response back to the websocket: this should be the only goroutine
|
||
// that writes to this websocket connection
|
||
go func() {
|
||
defer cancel()
|
||
errCh := make(chan HTTPCodedError, 2)
|
||
|
||
// Send the request
|
||
if err := encoder.Encode(args); err != nil {
|
||
resultCh <- s.execStreamHandleError(ws, CodedError(500, err.Error()))
|
||
return
|
||
}
|
||
|
||
// only start this after we've tried to send the initial args
|
||
go forwardExecInput(ctx, encoder, ws, errCh)
|
||
|
||
for {
|
||
select {
|
||
case codedErr := <-errCh:
|
||
resultCh <- s.execStreamHandleError(ws, codedErr)
|
||
return
|
||
default:
|
||
}
|
||
|
||
var res cstructs.StreamErrWrapper
|
||
err := decoder.Decode(&res)
|
||
if err != nil {
|
||
errCh <- CodedError(500, err.Error())
|
||
continue
|
||
}
|
||
decoder.Reset(httpPipe)
|
||
|
||
if err := res.Error; err != nil {
|
||
code := 500
|
||
if err.Code != nil {
|
||
code = int(*err.Code)
|
||
}
|
||
errCh <- CodedError(code, err.Error())
|
||
continue
|
||
}
|
||
if err := ws.WriteMessage(websocket.TextMessage, res.Payload); err != nil {
|
||
errCh <- CodedError(500, err.Error())
|
||
continue
|
||
}
|
||
}
|
||
}()
|
||
|
||
// start streaming request to streaming RPC - returns when streaming
|
||
// completes or errors
|
||
handler(handlerPipe)
|
||
|
||
// stop streaming background goroutines for streaming - but not websocket
|
||
// activity
|
||
cancel()
|
||
|
||
// retrieve any error and/or wait until goroutine stop and close errCh
|
||
// connection before closing websocket connection
|
||
result := <-resultCh
|
||
ws.Close()
|
||
return nil, result
|
||
}
|
||
|
||
// execStreamHandleError writes a CloseMessage to the websocket if we get an
|
||
// error that isn't a "close error" caused by the RPC pipe finishing up. Note
|
||
// that this should *only* ever be called in the same goroutine as we're
|
||
// streaming the responses
|
||
func (s *HTTPServer) execStreamHandleError(ws *websocket.Conn, codedErr HTTPCodedError) HTTPCodedError {
|
||
// we won't return an error on ws close, but at least make it available in
|
||
// the logs so we can trace spurious disconnects
|
||
s.logger.Trace("alloc exec channel closed with error", "error", codedErr)
|
||
|
||
if isClosedError(codedErr) {
|
||
return nil // we're intentionally throwing this error away
|
||
} else if codedErr != nil {
|
||
ws.WriteMessage(websocket.CloseMessage,
|
||
websocket.FormatCloseMessage(toWsCode(codedErr.Code()), codedErr.Error()))
|
||
return codedErr
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func toWsCode(httpCode int) int {
|
||
switch httpCode {
|
||
case 500:
|
||
return websocket.CloseInternalServerErr
|
||
default:
|
||
// placeholder error code
|
||
return websocket.ClosePolicyViolation
|
||
}
|
||
}
|
||
|
||
// isClosedError checks if the websocket "error" is one of the benign "close" status codes
|
||
func isClosedError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
return errors.Is(err, io.EOF) ||
|
||
errors.Is(err, io.ErrClosedPipe) ||
|
||
err == io.ErrClosedPipe ||
|
||
slices.ContainsFunc([]string{
|
||
"closed", // msgpack decode error [pos 0]: io: read/write on closed pipe"
|
||
"EOF",
|
||
"close 1000", // CLOSE_NORMAL
|
||
"close 1001", // CLOSE_GOING_AWAY
|
||
"close 1005", // CLOSED_NO_STATUS
|
||
}, func(s string) bool { return strings.Contains(err.Error(), s) })
|
||
}
|
||
|
||
// forwardExecInput forwards exec input (e.g. stdin) from websocket connection
|
||
// to the streaming RPC connection to client
|
||
func forwardExecInput(ctx context.Context, encoder *codec.Encoder, ws *websocket.Conn, errCh chan<- HTTPCodedError) {
|
||
for {
|
||
select {
|
||
case <-ctx.Done():
|
||
return
|
||
default:
|
||
}
|
||
|
||
sf := &drivers.ExecTaskStreamingRequestMsg{}
|
||
err := ws.ReadJSON(sf)
|
||
if err == io.EOF {
|
||
return
|
||
}
|
||
|
||
if err != nil {
|
||
errCh <- CodedError(500, err.Error())
|
||
return
|
||
}
|
||
|
||
err = encoder.Encode(sf)
|
||
if err != nil {
|
||
errCh <- CodedError(500, err.Error())
|
||
}
|
||
}
|
||
}
|