mirror of
https://github.com/kemko/xc.git
synced 2026-01-01 15:55:43 +03:00
project move
This commit is contained in:
79
remote/copy.go
Normal file
79
remote/copy.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/kr/pty"
|
||||
"github.com/npat-efault/poller"
|
||||
"github.com/viert/xc/log"
|
||||
)
|
||||
|
||||
func (w *Worker) copy(task *Task) int {
|
||||
var err error
|
||||
var n int
|
||||
|
||||
cmd := createSCPCmd(task.Hostname, task.LocalFilename, task.RemoteFilename, task.RecursiveCopy)
|
||||
cmd.Env = append(os.Environ(), environment...)
|
||||
|
||||
ptmx, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
return ErrTerminalError
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
fd, err := poller.NewFD(int(ptmx.Fd()))
|
||||
if err != nil {
|
||||
return ErrTerminalError
|
||||
}
|
||||
defer fd.Close()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
taskForceStopped := false
|
||||
for {
|
||||
if w.forceStopped() {
|
||||
taskForceStopped = true
|
||||
break
|
||||
}
|
||||
fd.SetReadDeadline(time.Now().Add(pollDeadline))
|
||||
n, err = fd.Read(buf)
|
||||
if err != nil {
|
||||
if err != poller.ErrTimeout {
|
||||
// EOF, done
|
||||
break
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
w.data <- &Message{buf[:n], MTDebug, task.Hostname, 0}
|
||||
buf = make([]byte, bufferSize)
|
||||
}
|
||||
|
||||
exitCode := 0
|
||||
if taskForceStopped {
|
||||
cmd.Process.Kill()
|
||||
exitCode = ErrForceStop
|
||||
log.Debugf("WRK[%d]: Task on %s was force stopped", w.id, task.Hostname)
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
|
||||
if !taskForceStopped {
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
ws := exitErr.Sys().(syscall.WaitStatus)
|
||||
exitCode = ws.ExitStatus()
|
||||
} else {
|
||||
// MacOS hack
|
||||
exitCode = ErrMacOsExit
|
||||
}
|
||||
}
|
||||
log.Debugf("WRK[%d]: Task on %s exit code is %d", w.id, task.Hostname, exitCode)
|
||||
}
|
||||
return exitCode
|
||||
}
|
||||
78
remote/distribute.go
Normal file
78
remote/distribute.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/viert/xc/log"
|
||||
pb "gopkg.in/cheggaaa/pb.v1"
|
||||
)
|
||||
|
||||
// Distribute distributes a given local file or directory to a number of hosts
|
||||
func Distribute(hosts []string, localFilename string, remoteFilename string, recursive bool) *ExecResult {
|
||||
var (
|
||||
wg sync.WaitGroup
|
||||
bar *pb.ProgressBar
|
||||
sigs chan os.Signal
|
||||
r *ExecResult
|
||||
t *Task
|
||||
running int
|
||||
)
|
||||
|
||||
r = newExecResult()
|
||||
running = len(hosts)
|
||||
if currentProgressBar {
|
||||
bar = pb.StartNew(running)
|
||||
}
|
||||
|
||||
sigs = make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT)
|
||||
defer signal.Reset()
|
||||
|
||||
go func() {
|
||||
for _, host := range hosts {
|
||||
t = &Task{
|
||||
Hostname: host,
|
||||
LocalFilename: localFilename,
|
||||
RemoteFilename: remoteFilename,
|
||||
RecursiveCopy: recursive,
|
||||
Cmd: "",
|
||||
WG: &wg,
|
||||
}
|
||||
pool.AddTask(t)
|
||||
}
|
||||
wg.Wait()
|
||||
}()
|
||||
|
||||
for running > 0 {
|
||||
select {
|
||||
case d := <-pool.Data:
|
||||
switch d.Type {
|
||||
case MTDebug:
|
||||
if currentDebug {
|
||||
log.Debugf("DATASTREAM @ %s\n%v\n[%v]", d.Hostname, d.Data, string(d.Data))
|
||||
}
|
||||
case MTCopyFinished:
|
||||
running--
|
||||
if currentProgressBar {
|
||||
bar.Increment()
|
||||
}
|
||||
r.Codes[d.Hostname] = d.StatusCode
|
||||
if d.StatusCode == 0 {
|
||||
r.SuccessHosts = append(r.SuccessHosts, d.Hostname)
|
||||
} else {
|
||||
r.ErrorHosts = append(r.ErrorHosts, d.Hostname)
|
||||
}
|
||||
}
|
||||
case <-sigs:
|
||||
r.ForceStoppedHosts = pool.ForceStopAllTasks()
|
||||
}
|
||||
}
|
||||
|
||||
if currentProgressBar {
|
||||
bar.Finish()
|
||||
}
|
||||
return r
|
||||
}
|
||||
238
remote/executer.go
Normal file
238
remote/executer.go
Normal file
@@ -0,0 +1,238 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
"github.com/viert/xc/log"
|
||||
"github.com/viert/xc/term"
|
||||
pb "gopkg.in/cheggaaa/pb.v1"
|
||||
)
|
||||
|
||||
const (
|
||||
stdoutWriteRetry = 25
|
||||
)
|
||||
|
||||
// ExecResult is a struct with execution results
|
||||
type ExecResult struct {
|
||||
Codes map[string]int
|
||||
Outputs map[string][]string
|
||||
|
||||
SuccessHosts []string
|
||||
ErrorHosts []string
|
||||
ForceStoppedHosts int
|
||||
}
|
||||
|
||||
func newExecResult() *ExecResult {
|
||||
return &ExecResult{
|
||||
Codes: make(map[string]int),
|
||||
Outputs: make(map[string][]string),
|
||||
SuccessHosts: make([]string, 0),
|
||||
ErrorHosts: make([]string, 0),
|
||||
ForceStoppedHosts: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// Print prints ExecResults in a nice way
|
||||
func (r *ExecResult) Print() {
|
||||
msg := fmt.Sprintf(" Hosts processed: %d, success: %d, error: %d ",
|
||||
len(r.SuccessHosts)+len(r.ErrorHosts), len(r.SuccessHosts), len(r.ErrorHosts))
|
||||
h := term.HR(len(msg))
|
||||
fmt.Println(term.Green(h))
|
||||
fmt.Println(term.Green(msg))
|
||||
fmt.Println(term.Green(h))
|
||||
}
|
||||
|
||||
// PrintOutputMap prints collapsed-style output
|
||||
func (r *ExecResult) PrintOutputMap() {
|
||||
for output, hosts := range r.Outputs {
|
||||
msg := fmt.Sprintf(" %d host(s): %s ", len(hosts), strings.Join(hosts, ","))
|
||||
tableWidth := len(msg) + 2
|
||||
termWidth := term.GetTerminalWidth()
|
||||
if tableWidth > termWidth {
|
||||
tableWidth = termWidth
|
||||
}
|
||||
fmt.Println(term.Blue(term.HR(tableWidth)))
|
||||
fmt.Println(term.Blue(msg))
|
||||
fmt.Println(term.Blue(term.HR(tableWidth)))
|
||||
fmt.Println(output)
|
||||
}
|
||||
}
|
||||
|
||||
func enqueue(local string, remote string, hosts []string) {
|
||||
// This is in a goroutine because of decreasing the task channel size.
|
||||
// If there is a number of hosts greater than pool.dataSizeQueue (i.e. 1024)
|
||||
// this loop will actually block on reaching the limit until some tasks are
|
||||
// processed and some space in the queue is released.
|
||||
//
|
||||
// To avoid blocking on task generation this loop was moved into a goroutine
|
||||
var wg sync.WaitGroup
|
||||
for _, host := range hosts {
|
||||
// remoteFile should include hostname for the case we have
|
||||
// a number of aliases pointing to one server. With the same
|
||||
// remote filename the first task finished removes the file
|
||||
// while other tasks on the same server try to remove it afterwards and fail
|
||||
remoteFilename := fmt.Sprintf("%s.%s.sh", remote, host)
|
||||
task := &Task{
|
||||
Hostname: host,
|
||||
LocalFilename: local,
|
||||
RemoteFilename: remoteFilename,
|
||||
Cmd: remoteFilename,
|
||||
WG: &wg,
|
||||
}
|
||||
pool.AddTask(task)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
// RunParallel runs cmd on hosts in parallel mode
|
||||
func RunParallel(hosts []string, cmd string) *ExecResult {
|
||||
r := newExecResult()
|
||||
if len(hosts) == 0 {
|
||||
return r
|
||||
}
|
||||
|
||||
local, remote, err := prepareTempFiles(cmd)
|
||||
if err != nil {
|
||||
term.Errorf("Error creating temporary file: %s\n", err)
|
||||
return r
|
||||
}
|
||||
defer os.Remove(local)
|
||||
|
||||
running := len(hosts)
|
||||
copied := 0
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT)
|
||||
defer signal.Reset()
|
||||
|
||||
go enqueue(local, remote, hosts)
|
||||
|
||||
for running > 0 {
|
||||
select {
|
||||
case d := <-pool.Data:
|
||||
switch d.Type {
|
||||
case MTData:
|
||||
log.Debugf("MSG@%s[DATA](%d): %s", d.Hostname, d.StatusCode, string(d.Data))
|
||||
if !bytes.HasSuffix(d.Data, []byte{'\n'}) {
|
||||
d.Data = append(d.Data, '\n')
|
||||
}
|
||||
if currentPrependHostnames {
|
||||
fmt.Printf("%s: ", term.Blue(d.Hostname))
|
||||
}
|
||||
fmt.Print(string(d.Data))
|
||||
writeHostOutput(d.Hostname, d.Data)
|
||||
case MTDebug:
|
||||
if currentDebug {
|
||||
log.Debugf("DATASTREAM @ %s\n%v\n[%v]", d.Hostname, d.Data, string(d.Data))
|
||||
}
|
||||
case MTCopyFinished:
|
||||
log.Debugf("MSG@%s[COPYFIN](%d): %s", d.Hostname, d.StatusCode, string(d.Data))
|
||||
if d.StatusCode == 0 {
|
||||
copied++
|
||||
}
|
||||
case MTExecFinished:
|
||||
log.Debugf("MSG@%s[EXECFIN](%d): %s", d.Hostname, d.StatusCode, string(d.Data))
|
||||
r.Codes[d.Hostname] = d.StatusCode
|
||||
if d.StatusCode == 0 {
|
||||
r.SuccessHosts = append(r.SuccessHosts, d.Hostname)
|
||||
} else {
|
||||
r.ErrorHosts = append(r.ErrorHosts, d.Hostname)
|
||||
}
|
||||
running--
|
||||
}
|
||||
case <-sigs:
|
||||
fmt.Println()
|
||||
r.ForceStoppedHosts = pool.ForceStopAllTasks()
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// RunCollapse runs cmd on hosts in collapse mode
|
||||
func RunCollapse(hosts []string, cmd string) *ExecResult {
|
||||
var bar *pb.ProgressBar
|
||||
r := newExecResult()
|
||||
if len(hosts) == 0 {
|
||||
return r
|
||||
}
|
||||
|
||||
local, remote, err := prepareTempFiles(cmd)
|
||||
if err != nil {
|
||||
term.Errorf("Error creating temporary file: %s\n", err)
|
||||
return r
|
||||
}
|
||||
defer os.Remove(local)
|
||||
|
||||
running := len(hosts)
|
||||
copied := 0
|
||||
outputs := make(map[string]string)
|
||||
|
||||
if currentProgressBar {
|
||||
bar = pb.StartNew(running)
|
||||
}
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, syscall.SIGINT)
|
||||
defer signal.Reset()
|
||||
|
||||
go enqueue(local, remote, hosts)
|
||||
|
||||
for running > 0 {
|
||||
select {
|
||||
case d := <-pool.Data:
|
||||
switch d.Type {
|
||||
case MTData:
|
||||
outputs[d.Hostname] += string(d.Data)
|
||||
logData := make([]byte, len(d.Data))
|
||||
copy(logData, d.Data)
|
||||
if !bytes.HasSuffix(d.Data, []byte{'\n'}) {
|
||||
logData = append(d.Data, '\n')
|
||||
}
|
||||
writeHostOutput(d.Hostname, logData)
|
||||
case MTDebug:
|
||||
if currentDebug {
|
||||
log.Debugf("DATASTREAM @ %s\n%v\n[%v]", d.Hostname, d.Data, string(d.Data))
|
||||
}
|
||||
case MTCopyFinished:
|
||||
if d.StatusCode == 0 {
|
||||
copied++
|
||||
}
|
||||
case MTExecFinished:
|
||||
if currentProgressBar {
|
||||
bar.Increment()
|
||||
}
|
||||
r.Codes[d.Hostname] = d.StatusCode
|
||||
if d.StatusCode == 0 {
|
||||
r.SuccessHosts = append(r.SuccessHosts, d.Hostname)
|
||||
} else {
|
||||
r.ErrorHosts = append(r.ErrorHosts, d.Hostname)
|
||||
}
|
||||
running--
|
||||
}
|
||||
case <-sigs:
|
||||
fmt.Println()
|
||||
r.ForceStoppedHosts = pool.ForceStopAllTasks()
|
||||
}
|
||||
}
|
||||
|
||||
if currentProgressBar {
|
||||
bar.Finish()
|
||||
}
|
||||
|
||||
for k, v := range outputs {
|
||||
_, found := r.Outputs[v]
|
||||
if !found {
|
||||
r.Outputs[v] = make([]string, 0)
|
||||
}
|
||||
r.Outputs[v] = append(r.Outputs[v], k)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
91
remote/pool.go
Normal file
91
remote/pool.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"github.com/viert/xc/log"
|
||||
)
|
||||
|
||||
const (
|
||||
dataQueueSize = 1024
|
||||
)
|
||||
|
||||
// Pool is a class representing a worker pool
|
||||
type Pool struct {
|
||||
workers []*Worker
|
||||
queue chan *Task
|
||||
Data chan *Message
|
||||
}
|
||||
|
||||
// NewPool creates a new worker pool of a given size
|
||||
func NewPool(size int) *Pool {
|
||||
|
||||
p := &Pool{
|
||||
workers: make([]*Worker, size),
|
||||
queue: make(chan *Task, dataQueueSize),
|
||||
Data: make(chan *Message, dataQueueSize),
|
||||
}
|
||||
|
||||
for i := 0; i < size; i++ {
|
||||
p.workers[i] = NewWorker(p.queue, p.Data)
|
||||
}
|
||||
log.Debugf("Remote execution pool created with %d workers", size)
|
||||
log.Debugf("Data Queue Size is %d", dataQueueSize)
|
||||
return p
|
||||
}
|
||||
|
||||
// ForceStopAllTasks removes all pending tasks and force stops those in progress
|
||||
func (p *Pool) ForceStopAllTasks() int {
|
||||
// Remove all pending tasks from the queue
|
||||
log.Debug("Force stopping all tasks")
|
||||
i := 0
|
||||
rmvLoop:
|
||||
for {
|
||||
select {
|
||||
case <-p.queue:
|
||||
i++
|
||||
continue
|
||||
default:
|
||||
break rmvLoop
|
||||
}
|
||||
}
|
||||
log.Debugf("%d queued (and not yet started) tasks removed from the queue", i)
|
||||
|
||||
stopped := 0
|
||||
for _, wrk := range p.workers {
|
||||
if wrk.ForceStop() {
|
||||
log.Debugf("Worker %d was running a task so force stopped", wrk.ID())
|
||||
stopped++
|
||||
}
|
||||
}
|
||||
return stopped
|
||||
}
|
||||
|
||||
// Close shuts down the pool itself and all its workers
|
||||
func (p *Pool) Close() {
|
||||
log.Debug("Closing remote execution pool")
|
||||
p.ForceStopAllTasks()
|
||||
close(p.queue) // this should make all the workers step out of range loop on queue chan and shut down
|
||||
log.Debug("Closing the task queue")
|
||||
close(p.Data)
|
||||
}
|
||||
|
||||
// AddTask adds a task to the pool queue
|
||||
func (p *Pool) AddTask(task *Task) {
|
||||
if task.WG != nil {
|
||||
task.WG.Add(1)
|
||||
}
|
||||
p.queue <- task
|
||||
}
|
||||
|
||||
// AddTaskHostlist creates multiple tasks to be run on a multiple hosts
|
||||
func (p *Pool) AddTaskHostlist(task *Task, hosts []string) {
|
||||
for _, host := range hosts {
|
||||
t := &Task{
|
||||
Hostname: host,
|
||||
LocalFilename: task.LocalFilename,
|
||||
RemoteFilename: task.RemoteFilename,
|
||||
Cmd: task.Cmd,
|
||||
WG: task.WG,
|
||||
}
|
||||
p.AddTask(t)
|
||||
}
|
||||
}
|
||||
136
remote/remote.go
Normal file
136
remote/remote.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
pool *Pool
|
||||
currentUser string
|
||||
currentPassword string
|
||||
currentRaise RaiseType
|
||||
currentProgressBar bool
|
||||
currentPrependHostnames bool
|
||||
currentRemoteTmpdir string
|
||||
currentDebug bool
|
||||
outputFile *os.File
|
||||
|
||||
noneInterpreter string
|
||||
suInterpreter string
|
||||
sudoInterpreter string
|
||||
)
|
||||
|
||||
// Initialize initializes new execution pool
|
||||
func Initialize(numThreads int, username string) {
|
||||
pool = NewPool(numThreads)
|
||||
SetUser(username)
|
||||
SetPassword("")
|
||||
SetRaise(RTNone)
|
||||
}
|
||||
|
||||
// SetInterpreter sets none-raise interpreter
|
||||
func SetInterpreter(interpreter string) {
|
||||
noneInterpreter = interpreter
|
||||
}
|
||||
|
||||
// SetSudoInterpreter sets sudo-raise interpreter
|
||||
func SetSudoInterpreter(interpreter string) {
|
||||
sudoInterpreter = interpreter
|
||||
}
|
||||
|
||||
// SetSuInterpreter sets su-raise interpreter
|
||||
func SetSuInterpreter(interpreter string) {
|
||||
suInterpreter = interpreter
|
||||
}
|
||||
|
||||
// SetUser sets executer username
|
||||
func SetUser(username string) {
|
||||
currentUser = username
|
||||
}
|
||||
|
||||
// SetRaise sets executer raise type
|
||||
func SetRaise(raise RaiseType) {
|
||||
currentRaise = raise
|
||||
}
|
||||
|
||||
// SetPassword sets executer password
|
||||
func SetPassword(password string) {
|
||||
currentPassword = password
|
||||
}
|
||||
|
||||
// SetProgressBar sets current progressbar mode
|
||||
func SetProgressBar(pbar bool) {
|
||||
currentProgressBar = pbar
|
||||
}
|
||||
|
||||
// SetRemoteTmpdir sets current remote temp directory
|
||||
func SetRemoteTmpdir(tmpDir string) {
|
||||
currentRemoteTmpdir = tmpDir
|
||||
}
|
||||
|
||||
// SetDebug sets current debug mode
|
||||
func SetDebug(debug bool) {
|
||||
currentDebug = debug
|
||||
}
|
||||
|
||||
// SetPrependHostnames sets current prepend_hostnames value for parallel mode
|
||||
func SetPrependHostnames(prependHostnames bool) {
|
||||
currentPrependHostnames = prependHostnames
|
||||
}
|
||||
|
||||
// SetConnectTimeout sets the ssh connect timeout in sshOptions
|
||||
func SetConnectTimeout(timeout int) {
|
||||
sshOptions["ConnectTimeout"] = fmt.Sprintf("%d", timeout)
|
||||
}
|
||||
|
||||
// SetOutputFile sets output file for every command.
|
||||
// if it's nil, no output will be written to files
|
||||
func SetOutputFile(f *os.File) {
|
||||
outputFile = f
|
||||
}
|
||||
|
||||
// SetNumThreads recreates the execution pool with the given number of threads
|
||||
func SetNumThreads(numThreads int) {
|
||||
if len(pool.workers) == numThreads {
|
||||
return
|
||||
}
|
||||
pool.Close()
|
||||
pool = NewPool(numThreads)
|
||||
}
|
||||
|
||||
func prepareTempFiles(cmd string) (string, string, error) {
|
||||
f, err := ioutil.TempFile("", "xc.")
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
remoteFilename := filepath.Join(currentRemoteTmpdir, filepath.Base(f.Name()))
|
||||
io.WriteString(f, "#!/bin/bash\n\n")
|
||||
io.WriteString(f, fmt.Sprintf("nohup bash -c \"sleep 1; rm -f $0\" >/dev/null 2>&1 </dev/null &\n")) // self-destroy
|
||||
io.WriteString(f, cmd+"\n") // run command
|
||||
f.Chmod(0755)
|
||||
|
||||
return f.Name(), remoteFilename, nil
|
||||
}
|
||||
|
||||
// WriteOutput writes output to a user-defined logfile
|
||||
// prepending with the current datetime
|
||||
func WriteOutput(message string) {
|
||||
if outputFile == nil {
|
||||
return
|
||||
}
|
||||
tm := time.Now().Format("2006-01-02 15:04:05")
|
||||
message = fmt.Sprintf("[%s] %s", tm, message)
|
||||
outputFile.Write([]byte(message))
|
||||
}
|
||||
|
||||
func writeHostOutput(host string, data []byte) {
|
||||
message := fmt.Sprintf("%s: %s", host, string(data))
|
||||
WriteOutput(message)
|
||||
}
|
||||
129
remote/runcmd.go
Normal file
129
remote/runcmd.go
Normal file
@@ -0,0 +1,129 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/kr/pty"
|
||||
"github.com/npat-efault/poller"
|
||||
"github.com/viert/xc/log"
|
||||
)
|
||||
|
||||
func (w *Worker) runcmd(task *Task) int {
|
||||
var err error
|
||||
var n int
|
||||
var passwordSent bool
|
||||
|
||||
passwordSent = currentRaise == RTNone
|
||||
cmd := createSSHCmd(task.Hostname, task.Cmd)
|
||||
cmd.Env = append(os.Environ(), environment...)
|
||||
|
||||
ptmx, err := pty.Start(cmd)
|
||||
if err != nil {
|
||||
return ErrTerminalError
|
||||
}
|
||||
defer ptmx.Close()
|
||||
|
||||
fd, err := poller.NewFD(int(ptmx.Fd()))
|
||||
if err != nil {
|
||||
return ErrTerminalError
|
||||
}
|
||||
defer fd.Close()
|
||||
|
||||
buf := make([]byte, bufferSize)
|
||||
taskForceStopped := false
|
||||
shouldSkipEcho := false
|
||||
msgCount := 0
|
||||
|
||||
execLoop:
|
||||
for {
|
||||
if w.forceStopped() {
|
||||
taskForceStopped = true
|
||||
break
|
||||
}
|
||||
|
||||
fd.SetReadDeadline(time.Now().Add(pollDeadline))
|
||||
n, err = fd.Read(buf)
|
||||
if err != nil {
|
||||
if err != poller.ErrTimeout {
|
||||
// EOF, done
|
||||
break
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if n == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
w.data <- &Message{buf, MTDebug, task.Hostname, -1}
|
||||
msgCount++
|
||||
|
||||
chunks := bytes.SplitAfter(buf[:n], []byte{'\n'})
|
||||
for _, chunk := range chunks {
|
||||
// Trying to find Password prompt in first 5 chunks of data from server
|
||||
if msgCount < 5 {
|
||||
if !passwordSent && exPasswdPrompt.Match(chunk) {
|
||||
ptmx.Write([]byte(currentPassword + "\n"))
|
||||
passwordSent = true
|
||||
shouldSkipEcho = true
|
||||
continue
|
||||
}
|
||||
if shouldSkipEcho && exEcho.Match(chunk) {
|
||||
shouldSkipEcho = false
|
||||
continue
|
||||
}
|
||||
if passwordSent && exWrongPassword.Match(chunk) {
|
||||
w.data <- &Message{[]byte("sudo: Authentication failure\n"), MTData, task.Hostname, -1}
|
||||
taskForceStopped = true
|
||||
break execLoop
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if len(chunk) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if exConnectionClosed.Match(chunk) {
|
||||
continue
|
||||
}
|
||||
|
||||
if exLostConnection.Match(chunk) {
|
||||
continue
|
||||
}
|
||||
|
||||
// avoiding passing loop variable further as it's going to change its contents
|
||||
data := make([]byte, len(chunk))
|
||||
copy(data, chunk)
|
||||
w.data <- &Message{data, MTData, task.Hostname, -1}
|
||||
}
|
||||
}
|
||||
|
||||
exitCode := 0
|
||||
if taskForceStopped {
|
||||
cmd.Process.Kill()
|
||||
exitCode = ErrForceStop
|
||||
log.Debugf("WRK[%d]: Task on %s was force stopped", w.id, task.Hostname)
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
|
||||
if !taskForceStopped {
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
ws := exitErr.Sys().(syscall.WaitStatus)
|
||||
exitCode = ws.ExitStatus()
|
||||
} else {
|
||||
// MacOS hack
|
||||
exitCode = ErrMacOsExit
|
||||
}
|
||||
}
|
||||
log.Debugf("WRK[%d]: Task on %s exit code is %d", w.id, task.Hostname, exitCode)
|
||||
}
|
||||
return exitCode
|
||||
}
|
||||
286
remote/serial.go
Normal file
286
remote/serial.go
Normal file
@@ -0,0 +1,286 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/kr/pty"
|
||||
"github.com/npat-efault/poller"
|
||||
"github.com/viert/xc/log"
|
||||
"github.com/viert/xc/term"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
)
|
||||
|
||||
var (
|
||||
passwordSent = false
|
||||
shouldSkipEcho = false
|
||||
)
|
||||
|
||||
func forwardUserInput(in *poller.FD, out *os.File, stopped *bool) {
|
||||
inBuf := make([]byte, bufferSize)
|
||||
// processing stdin
|
||||
for {
|
||||
deadline := time.Now().Add(pollDeadline)
|
||||
in.SetReadDeadline(deadline)
|
||||
n, err := in.Read(inBuf)
|
||||
if n > 0 {
|
||||
// copy stdin to process ptmx
|
||||
out.Write(inBuf[:n])
|
||||
inBuf = make([]byte, bufferSize)
|
||||
}
|
||||
if err != nil {
|
||||
if err != poller.ErrTimeout {
|
||||
break
|
||||
}
|
||||
}
|
||||
if *stopped {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func interceptProcessOutput(in []byte, ptmx *os.File) (out []byte, err error) {
|
||||
out = []byte{}
|
||||
err = nil
|
||||
|
||||
if exConnectionClosed.Match(in) {
|
||||
log.Debug("Connection closed message catched")
|
||||
return
|
||||
}
|
||||
|
||||
if exLostConnection.Match(in) {
|
||||
log.Debug("Lost connection message catched")
|
||||
return
|
||||
}
|
||||
|
||||
if !passwordSent && exPasswdPrompt.Match(in) {
|
||||
ptmx.Write([]byte(currentPassword + "\n"))
|
||||
passwordSent = true
|
||||
shouldSkipEcho = true
|
||||
log.Debug("Password sent")
|
||||
return
|
||||
}
|
||||
|
||||
if shouldSkipEcho && exEcho.Match(in) {
|
||||
log.Debug("Echo skipped")
|
||||
shouldSkipEcho = false
|
||||
return
|
||||
}
|
||||
|
||||
if passwordSent && exWrongPassword.Match(in) {
|
||||
log.Debug("Authentication error while raising privileges")
|
||||
err = fmt.Errorf("auth_error")
|
||||
return
|
||||
}
|
||||
|
||||
out = in
|
||||
return
|
||||
}
|
||||
|
||||
func runAtHost(host string, cmd *exec.Cmd, r *ExecResult) {
|
||||
var (
|
||||
ptmx *os.File
|
||||
si *poller.FD
|
||||
buf []byte
|
||||
err error
|
||||
|
||||
stopped = false
|
||||
)
|
||||
|
||||
passwordSent = false
|
||||
shouldSkipEcho = false
|
||||
|
||||
sigs := make(chan os.Signal, 1)
|
||||
signal.Notify(sigs, syscall.SIGWINCH)
|
||||
defer signal.Reset()
|
||||
|
||||
ptmx, err = pty.Start(cmd)
|
||||
if err != nil {
|
||||
term.Errorf("Error creating PTY: %s\n", err)
|
||||
r.ErrorHosts = append(r.ErrorHosts, host)
|
||||
r.Codes[host] = ErrTerminalError
|
||||
return
|
||||
}
|
||||
pty.InheritSize(os.Stdin, ptmx)
|
||||
defer ptmx.Close()
|
||||
|
||||
stdinBackup, err := syscall.Dup(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
term.Errorf("Error duplicating stdin descriptor: %s\n", err)
|
||||
r.ErrorHosts = append(r.ErrorHosts, host)
|
||||
r.Codes[host] = ErrTerminalError
|
||||
return
|
||||
}
|
||||
|
||||
stdinState, err := terminal.MakeRaw(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
term.Errorf("Error setting stdin to raw mode: %s\n", err)
|
||||
r.ErrorHosts = append(r.ErrorHosts, host)
|
||||
r.Codes[host] = ErrTerminalError
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
terminal.Restore(int(os.Stdin.Fd()), stdinState)
|
||||
}()
|
||||
|
||||
si, err = poller.NewFD(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
term.Errorf("Error initializing poller: %s\n", err)
|
||||
r.ErrorHosts = append(r.ErrorHosts, host)
|
||||
r.Codes[host] = ErrTerminalError
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
log.Debug("Setting stdin back to blocking mode")
|
||||
si.Close()
|
||||
syscall.Dup2(stdinBackup, int(os.Stdin.Fd()))
|
||||
syscall.SetNonblock(int(os.Stdin.Fd()), false)
|
||||
}()
|
||||
|
||||
buf = make([]byte, bufferSize)
|
||||
go forwardUserInput(si, ptmx, &stopped)
|
||||
|
||||
for {
|
||||
n, err := ptmx.Read(buf)
|
||||
if n > 0 {
|
||||
// TODO random stuff with intercepting and omitting data
|
||||
data, err := interceptProcessOutput(buf[:n], ptmx)
|
||||
if err != nil {
|
||||
// auth error, can't proceed
|
||||
raise := "su"
|
||||
if currentRaise == RTSudo {
|
||||
raise = "sudo"
|
||||
}
|
||||
log.Debugf("Wrong %s password\n", raise)
|
||||
term.Errorf("Wrong %s password\n", raise)
|
||||
r.ErrorHosts = append(r.ErrorHosts, host)
|
||||
r.Codes[host] = ErrAuthenticationError
|
||||
break
|
||||
}
|
||||
|
||||
if len(data) > 0 {
|
||||
// copy stdin to process ptmx
|
||||
_, err = os.Stdout.Write(data)
|
||||
if err != nil {
|
||||
count := stdoutWriteRetry
|
||||
for os.IsTimeout(err) && count > 0 {
|
||||
time.Sleep(time.Millisecond)
|
||||
_, err = os.Stdout.Write(data)
|
||||
count--
|
||||
}
|
||||
if err != nil {
|
||||
log.Debugf("error writing to stdout not resolved in %d steps", stdoutWriteRetry)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && err != poller.ErrTimeout {
|
||||
stopped = true
|
||||
break
|
||||
}
|
||||
|
||||
select {
|
||||
case <-sigs:
|
||||
pty.InheritSize(os.Stdin, ptmx)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// RunSerial runs cmd on hosts in serial mode
|
||||
func RunSerial(hosts []string, argv string, delay int) *ExecResult {
|
||||
var (
|
||||
err error
|
||||
cmd *exec.Cmd
|
||||
local string
|
||||
remotePrefix string
|
||||
remoteCmd string
|
||||
sigs = make(chan os.Signal, 1)
|
||||
)
|
||||
r := newExecResult()
|
||||
|
||||
if argv != "" {
|
||||
local, remotePrefix, err = prepareTempFiles(argv)
|
||||
if err != nil {
|
||||
term.Errorf("Error creating tempfile: %s\n", err)
|
||||
return r
|
||||
}
|
||||
defer os.Remove(local)
|
||||
}
|
||||
|
||||
execLoop:
|
||||
for i, host := range hosts {
|
||||
msg := term.HR(7) + " " + host + " " + term.HR(36-len(host))
|
||||
fmt.Println(term.Blue(msg))
|
||||
|
||||
if argv != "" {
|
||||
remoteCmd = fmt.Sprintf("%s.%s.sh", remotePrefix, host)
|
||||
cmd = createSCPCmd(host, local, remoteCmd, false)
|
||||
log.Debugf("Created SCP command: %v", cmd)
|
||||
signal.Notify(sigs, syscall.SIGINT)
|
||||
err = cmd.Run()
|
||||
signal.Reset()
|
||||
if err != nil {
|
||||
term.Errorf("Error copying tempfile: %s\n", err)
|
||||
r.ErrorHosts = append(r.ErrorHosts, host)
|
||||
r.Codes[host] = ErrCopyFailed
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
cmd = createSSHCmd(host, remoteCmd)
|
||||
log.Debugf("Created SSH command: %v", cmd)
|
||||
|
||||
runAtHost(host, cmd, r)
|
||||
|
||||
exitCode := 0
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
if exitErr, ok := err.(*exec.ExitError); ok {
|
||||
ws := exitErr.Sys().(syscall.WaitStatus)
|
||||
exitCode = ws.ExitStatus()
|
||||
} else {
|
||||
// MacOS hack
|
||||
exitCode = ErrMacOsExit
|
||||
}
|
||||
}
|
||||
|
||||
r.Codes[host] = exitCode
|
||||
if exitCode != 0 {
|
||||
r.ErrorHosts = append(r.ErrorHosts, host)
|
||||
} else {
|
||||
r.SuccessHosts = append(r.SuccessHosts, host)
|
||||
}
|
||||
|
||||
// no delay after the last host
|
||||
if delay > 0 && i != len(hosts)-1 {
|
||||
log.Debugf("Delay %d secs", delay)
|
||||
timer := time.After(time.Duration(delay) * time.Second)
|
||||
signal.Notify(sigs, syscall.SIGINT)
|
||||
timeLoop:
|
||||
for {
|
||||
select {
|
||||
case <-sigs:
|
||||
log.Debugf("Delay interrupted by ^C")
|
||||
signal.Reset()
|
||||
break execLoop
|
||||
case <-timer:
|
||||
log.Debugf("Delay finished")
|
||||
signal.Reset()
|
||||
break timeLoop
|
||||
default:
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
68
remote/ssh.go
Normal file
68
remote/ssh.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
"github.com/viert/xc/log"
|
||||
)
|
||||
|
||||
var (
|
||||
sshOptions = map[string]string{
|
||||
"PasswordAuthentication": "no",
|
||||
"PubkeyAuthentication": "yes",
|
||||
"StrictHostKeyChecking": "no",
|
||||
"TCPKeepAlive": "yes",
|
||||
"ServerAliveCountMax": "12",
|
||||
"ServerAliveInterval": "5",
|
||||
}
|
||||
)
|
||||
|
||||
func sshOpts() (params []string) {
|
||||
params = make([]string, 0)
|
||||
for opt, value := range sshOptions {
|
||||
option := fmt.Sprintf("%s=%s", opt, value)
|
||||
params = append(params, "-o", option)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createSCPCmd(host string, local string, remote string, recursive bool) *exec.Cmd {
|
||||
params := []string{}
|
||||
if recursive {
|
||||
params = []string{"-r"}
|
||||
}
|
||||
params = append(params, sshOpts()...)
|
||||
remoteExpr := fmt.Sprintf("%s@%s:%s", currentUser, host, remote)
|
||||
params = append(params, local, remoteExpr)
|
||||
log.Debugf("Created command scp %v", params)
|
||||
return exec.Command("scp", params...)
|
||||
}
|
||||
|
||||
func createSSHCmd(host string, argv string) *exec.Cmd {
|
||||
params := []string{
|
||||
"-tt",
|
||||
"-l",
|
||||
currentUser,
|
||||
}
|
||||
params = append(params, sshOpts()...)
|
||||
params = append(params, host)
|
||||
params = append(params, getInterpreter()...)
|
||||
if argv != "" {
|
||||
params = append(params, "-c", argv)
|
||||
}
|
||||
log.Debugf("Created command ssh %v", params)
|
||||
return exec.Command("ssh", params...)
|
||||
}
|
||||
|
||||
func getInterpreter() []string {
|
||||
switch currentRaise {
|
||||
case RTSudo:
|
||||
return strings.Split(sudoInterpreter, " ")
|
||||
case RTSu:
|
||||
return strings.Split(suInterpreter, " ")
|
||||
default:
|
||||
return strings.Split(noneInterpreter, " ")
|
||||
}
|
||||
}
|
||||
174
remote/worker.go
Normal file
174
remote/worker.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package remote
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/viert/xc/log"
|
||||
)
|
||||
|
||||
// RaiseType enum
|
||||
type RaiseType int
|
||||
|
||||
// Raise types
|
||||
const (
|
||||
RTNone RaiseType = iota
|
||||
RTSu
|
||||
RTSudo
|
||||
)
|
||||
|
||||
// Task type represents a worker task descriptor
|
||||
type Task struct {
|
||||
Hostname string
|
||||
LocalFilename string
|
||||
RemoteFilename string
|
||||
RecursiveCopy bool
|
||||
Cmd string
|
||||
WG *sync.WaitGroup
|
||||
}
|
||||
|
||||
// MessageType describes a type of worker message
|
||||
type MessageType int
|
||||
|
||||
// Message represents a worker message
|
||||
type Message struct {
|
||||
Data []byte
|
||||
Type MessageType
|
||||
Hostname string
|
||||
StatusCode int
|
||||
}
|
||||
|
||||
// Enum of OutputTypes
|
||||
const (
|
||||
MTData MessageType = iota
|
||||
MTDebug
|
||||
MTCopyFinished
|
||||
MTExecFinished
|
||||
)
|
||||
|
||||
// Custom error codes
|
||||
const (
|
||||
ErrMacOsExit = 32500 + iota
|
||||
ErrForceStop
|
||||
ErrCopyFailed
|
||||
ErrTerminalError
|
||||
ErrAuthenticationError
|
||||
)
|
||||
|
||||
const (
|
||||
pollDeadline = 50 * time.Millisecond
|
||||
bufferSize = 4096
|
||||
)
|
||||
|
||||
// Worker type represents a worker object
|
||||
type Worker struct {
|
||||
id int
|
||||
queue chan *Task
|
||||
data chan *Message
|
||||
stop chan bool
|
||||
busy bool
|
||||
}
|
||||
|
||||
var (
|
||||
wrkseq = 1
|
||||
environment = []string{"LC_ALL=en_US.UTF-8", "LANG=en_US.UTF-8"}
|
||||
|
||||
// remote expressions to catch
|
||||
exConnectionClosed = regexp.MustCompile(`([Ss]hared\s+)?[Cc]onnection\s+to\s+.+\s+closed\.?[\n\r]+`)
|
||||
exPasswdPrompt = regexp.MustCompile(`[Pp]assword`)
|
||||
exWrongPassword = regexp.MustCompile(`[Ss]orry.+try.+again\.?`)
|
||||
exPermissionDenied = regexp.MustCompile(`[Pp]ermission\s+denied`)
|
||||
exLostConnection = regexp.MustCompile(`[Ll]ost\sconnection`)
|
||||
exEcho = regexp.MustCompile(`^[\n\r]+$`)
|
||||
)
|
||||
|
||||
// NewWorker creates a new worker
|
||||
func NewWorker(queue chan *Task, data chan *Message) *Worker {
|
||||
w := &Worker{
|
||||
id: wrkseq,
|
||||
queue: queue,
|
||||
data: data,
|
||||
stop: make(chan bool, 1),
|
||||
busy: false,
|
||||
}
|
||||
wrkseq++
|
||||
go w.run()
|
||||
return w
|
||||
}
|
||||
|
||||
// ID is a worker id getter
|
||||
func (w *Worker) ID() int {
|
||||
return w.id
|
||||
}
|
||||
|
||||
func (w *Worker) run() {
|
||||
var result int
|
||||
|
||||
log.Debugf("WRK[%d] Started", w.id)
|
||||
for task := range w.queue {
|
||||
// Every task consists of copying part and executing part
|
||||
// It may contain both or just one of them
|
||||
// If there are both parts, worker copies data and then runs
|
||||
// the given command immediately. This behaviour is handy for runscript
|
||||
// command when the script is being copied to a remote server
|
||||
// and called right after it.
|
||||
|
||||
w.busy = true
|
||||
log.Debugf("WRK[%d] Got a task for host %s by worker", w.id, task.Hostname)
|
||||
|
||||
// does the task have anything to copy?
|
||||
if task.RemoteFilename != "" && task.LocalFilename != "" {
|
||||
result = w.copy(task)
|
||||
log.Debugf("WRK[%d] Copy on %s, status=%d", w.id, task.Hostname, result)
|
||||
w.data <- &Message{nil, MTCopyFinished, task.Hostname, result}
|
||||
if result != 0 {
|
||||
log.Debugf("WRK[%d] Copy on %s, result != 0, catching", w.id, task.Hostname)
|
||||
// if copying failed we can't proceed further with the task if there's anything to run
|
||||
if task.Cmd != "" {
|
||||
log.Debugf("WRK[%d] Copy on %s, result != 0, task.Cmd == \"%s\", sending ExecFinished", w.id, task.Hostname, task.Cmd)
|
||||
w.data <- &Message{nil, MTExecFinished, task.Hostname, ErrCopyFailed}
|
||||
}
|
||||
w.busy = false
|
||||
if task.WG != nil {
|
||||
task.WG.Done()
|
||||
}
|
||||
// next task
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// does the task have anything to run?
|
||||
if task.Cmd != "" {
|
||||
log.Debugf("WRK[%d] runcmd(%s) at %s", task.Cmd, task.Hostname)
|
||||
result = w.runcmd(task)
|
||||
w.data <- &Message{nil, MTExecFinished, task.Hostname, result}
|
||||
}
|
||||
|
||||
if task.WG != nil {
|
||||
task.WG.Done()
|
||||
}
|
||||
|
||||
w.busy = false
|
||||
}
|
||||
log.Debugf("WRK[%d] Task queue has closed, worker is exiting", w.id)
|
||||
}
|
||||
|
||||
// ForceStop stops the current task execution and returns true
|
||||
// if any task were actually executed at the moment of calling ForceStop
|
||||
func (w *Worker) ForceStop() bool {
|
||||
if w.busy {
|
||||
w.stop <- true
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (w *Worker) forceStopped() bool {
|
||||
select {
|
||||
case <-w.stop:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user