project move

This commit is contained in:
Pavel Vorobyov
2019-09-24 11:04:48 +03:00
parent eb2a495406
commit 7e2dec0ef0
33 changed files with 5613 additions and 1 deletions

79
remote/copy.go Normal file
View File

@@ -0,0 +1,79 @@
package remote
import (
"os"
"os/exec"
"syscall"
"time"
"github.com/kr/pty"
"github.com/npat-efault/poller"
"github.com/viert/xc/log"
)
func (w *Worker) copy(task *Task) int {
var err error
var n int
cmd := createSCPCmd(task.Hostname, task.LocalFilename, task.RemoteFilename, task.RecursiveCopy)
cmd.Env = append(os.Environ(), environment...)
ptmx, err := pty.Start(cmd)
if err != nil {
return ErrTerminalError
}
defer ptmx.Close()
fd, err := poller.NewFD(int(ptmx.Fd()))
if err != nil {
return ErrTerminalError
}
defer fd.Close()
buf := make([]byte, bufferSize)
taskForceStopped := false
for {
if w.forceStopped() {
taskForceStopped = true
break
}
fd.SetReadDeadline(time.Now().Add(pollDeadline))
n, err = fd.Read(buf)
if err != nil {
if err != poller.ErrTimeout {
// EOF, done
break
} else {
continue
}
}
if n == 0 {
continue
}
w.data <- &Message{buf[:n], MTDebug, task.Hostname, 0}
buf = make([]byte, bufferSize)
}
exitCode := 0
if taskForceStopped {
cmd.Process.Kill()
exitCode = ErrForceStop
log.Debugf("WRK[%d]: Task on %s was force stopped", w.id, task.Hostname)
}
err = cmd.Wait()
if !taskForceStopped {
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
ws := exitErr.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
} else {
// MacOS hack
exitCode = ErrMacOsExit
}
}
log.Debugf("WRK[%d]: Task on %s exit code is %d", w.id, task.Hostname, exitCode)
}
return exitCode
}

78
remote/distribute.go Normal file
View File

@@ -0,0 +1,78 @@
package remote
import (
"os"
"os/signal"
"sync"
"syscall"
"github.com/viert/xc/log"
pb "gopkg.in/cheggaaa/pb.v1"
)
// Distribute distributes a given local file or directory to a number of hosts
func Distribute(hosts []string, localFilename string, remoteFilename string, recursive bool) *ExecResult {
var (
wg sync.WaitGroup
bar *pb.ProgressBar
sigs chan os.Signal
r *ExecResult
t *Task
running int
)
r = newExecResult()
running = len(hosts)
if currentProgressBar {
bar = pb.StartNew(running)
}
sigs = make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT)
defer signal.Reset()
go func() {
for _, host := range hosts {
t = &Task{
Hostname: host,
LocalFilename: localFilename,
RemoteFilename: remoteFilename,
RecursiveCopy: recursive,
Cmd: "",
WG: &wg,
}
pool.AddTask(t)
}
wg.Wait()
}()
for running > 0 {
select {
case d := <-pool.Data:
switch d.Type {
case MTDebug:
if currentDebug {
log.Debugf("DATASTREAM @ %s\n%v\n[%v]", d.Hostname, d.Data, string(d.Data))
}
case MTCopyFinished:
running--
if currentProgressBar {
bar.Increment()
}
r.Codes[d.Hostname] = d.StatusCode
if d.StatusCode == 0 {
r.SuccessHosts = append(r.SuccessHosts, d.Hostname)
} else {
r.ErrorHosts = append(r.ErrorHosts, d.Hostname)
}
}
case <-sigs:
r.ForceStoppedHosts = pool.ForceStopAllTasks()
}
}
if currentProgressBar {
bar.Finish()
}
return r
}

238
remote/executer.go Normal file
View File

@@ -0,0 +1,238 @@
package remote
import (
"bytes"
"fmt"
"os"
"os/signal"
"strings"
"sync"
"syscall"
"github.com/viert/xc/log"
"github.com/viert/xc/term"
pb "gopkg.in/cheggaaa/pb.v1"
)
const (
stdoutWriteRetry = 25
)
// ExecResult is a struct with execution results
type ExecResult struct {
Codes map[string]int
Outputs map[string][]string
SuccessHosts []string
ErrorHosts []string
ForceStoppedHosts int
}
func newExecResult() *ExecResult {
return &ExecResult{
Codes: make(map[string]int),
Outputs: make(map[string][]string),
SuccessHosts: make([]string, 0),
ErrorHosts: make([]string, 0),
ForceStoppedHosts: 0,
}
}
// Print prints ExecResults in a nice way
func (r *ExecResult) Print() {
msg := fmt.Sprintf(" Hosts processed: %d, success: %d, error: %d ",
len(r.SuccessHosts)+len(r.ErrorHosts), len(r.SuccessHosts), len(r.ErrorHosts))
h := term.HR(len(msg))
fmt.Println(term.Green(h))
fmt.Println(term.Green(msg))
fmt.Println(term.Green(h))
}
// PrintOutputMap prints collapsed-style output
func (r *ExecResult) PrintOutputMap() {
for output, hosts := range r.Outputs {
msg := fmt.Sprintf(" %d host(s): %s ", len(hosts), strings.Join(hosts, ","))
tableWidth := len(msg) + 2
termWidth := term.GetTerminalWidth()
if tableWidth > termWidth {
tableWidth = termWidth
}
fmt.Println(term.Blue(term.HR(tableWidth)))
fmt.Println(term.Blue(msg))
fmt.Println(term.Blue(term.HR(tableWidth)))
fmt.Println(output)
}
}
func enqueue(local string, remote string, hosts []string) {
// This is in a goroutine because of decreasing the task channel size.
// If there is a number of hosts greater than pool.dataSizeQueue (i.e. 1024)
// this loop will actually block on reaching the limit until some tasks are
// processed and some space in the queue is released.
//
// To avoid blocking on task generation this loop was moved into a goroutine
var wg sync.WaitGroup
for _, host := range hosts {
// remoteFile should include hostname for the case we have
// a number of aliases pointing to one server. With the same
// remote filename the first task finished removes the file
// while other tasks on the same server try to remove it afterwards and fail
remoteFilename := fmt.Sprintf("%s.%s.sh", remote, host)
task := &Task{
Hostname: host,
LocalFilename: local,
RemoteFilename: remoteFilename,
Cmd: remoteFilename,
WG: &wg,
}
pool.AddTask(task)
}
wg.Wait()
}
// RunParallel runs cmd on hosts in parallel mode
func RunParallel(hosts []string, cmd string) *ExecResult {
r := newExecResult()
if len(hosts) == 0 {
return r
}
local, remote, err := prepareTempFiles(cmd)
if err != nil {
term.Errorf("Error creating temporary file: %s\n", err)
return r
}
defer os.Remove(local)
running := len(hosts)
copied := 0
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT)
defer signal.Reset()
go enqueue(local, remote, hosts)
for running > 0 {
select {
case d := <-pool.Data:
switch d.Type {
case MTData:
log.Debugf("MSG@%s[DATA](%d): %s", d.Hostname, d.StatusCode, string(d.Data))
if !bytes.HasSuffix(d.Data, []byte{'\n'}) {
d.Data = append(d.Data, '\n')
}
if currentPrependHostnames {
fmt.Printf("%s: ", term.Blue(d.Hostname))
}
fmt.Print(string(d.Data))
writeHostOutput(d.Hostname, d.Data)
case MTDebug:
if currentDebug {
log.Debugf("DATASTREAM @ %s\n%v\n[%v]", d.Hostname, d.Data, string(d.Data))
}
case MTCopyFinished:
log.Debugf("MSG@%s[COPYFIN](%d): %s", d.Hostname, d.StatusCode, string(d.Data))
if d.StatusCode == 0 {
copied++
}
case MTExecFinished:
log.Debugf("MSG@%s[EXECFIN](%d): %s", d.Hostname, d.StatusCode, string(d.Data))
r.Codes[d.Hostname] = d.StatusCode
if d.StatusCode == 0 {
r.SuccessHosts = append(r.SuccessHosts, d.Hostname)
} else {
r.ErrorHosts = append(r.ErrorHosts, d.Hostname)
}
running--
}
case <-sigs:
fmt.Println()
r.ForceStoppedHosts = pool.ForceStopAllTasks()
}
}
return r
}
// RunCollapse runs cmd on hosts in collapse mode
func RunCollapse(hosts []string, cmd string) *ExecResult {
var bar *pb.ProgressBar
r := newExecResult()
if len(hosts) == 0 {
return r
}
local, remote, err := prepareTempFiles(cmd)
if err != nil {
term.Errorf("Error creating temporary file: %s\n", err)
return r
}
defer os.Remove(local)
running := len(hosts)
copied := 0
outputs := make(map[string]string)
if currentProgressBar {
bar = pb.StartNew(running)
}
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT)
defer signal.Reset()
go enqueue(local, remote, hosts)
for running > 0 {
select {
case d := <-pool.Data:
switch d.Type {
case MTData:
outputs[d.Hostname] += string(d.Data)
logData := make([]byte, len(d.Data))
copy(logData, d.Data)
if !bytes.HasSuffix(d.Data, []byte{'\n'}) {
logData = append(d.Data, '\n')
}
writeHostOutput(d.Hostname, logData)
case MTDebug:
if currentDebug {
log.Debugf("DATASTREAM @ %s\n%v\n[%v]", d.Hostname, d.Data, string(d.Data))
}
case MTCopyFinished:
if d.StatusCode == 0 {
copied++
}
case MTExecFinished:
if currentProgressBar {
bar.Increment()
}
r.Codes[d.Hostname] = d.StatusCode
if d.StatusCode == 0 {
r.SuccessHosts = append(r.SuccessHosts, d.Hostname)
} else {
r.ErrorHosts = append(r.ErrorHosts, d.Hostname)
}
running--
}
case <-sigs:
fmt.Println()
r.ForceStoppedHosts = pool.ForceStopAllTasks()
}
}
if currentProgressBar {
bar.Finish()
}
for k, v := range outputs {
_, found := r.Outputs[v]
if !found {
r.Outputs[v] = make([]string, 0)
}
r.Outputs[v] = append(r.Outputs[v], k)
}
return r
}

91
remote/pool.go Normal file
View File

@@ -0,0 +1,91 @@
package remote
import (
"github.com/viert/xc/log"
)
const (
dataQueueSize = 1024
)
// Pool is a class representing a worker pool
type Pool struct {
workers []*Worker
queue chan *Task
Data chan *Message
}
// NewPool creates a new worker pool of a given size
func NewPool(size int) *Pool {
p := &Pool{
workers: make([]*Worker, size),
queue: make(chan *Task, dataQueueSize),
Data: make(chan *Message, dataQueueSize),
}
for i := 0; i < size; i++ {
p.workers[i] = NewWorker(p.queue, p.Data)
}
log.Debugf("Remote execution pool created with %d workers", size)
log.Debugf("Data Queue Size is %d", dataQueueSize)
return p
}
// ForceStopAllTasks removes all pending tasks and force stops those in progress
func (p *Pool) ForceStopAllTasks() int {
// Remove all pending tasks from the queue
log.Debug("Force stopping all tasks")
i := 0
rmvLoop:
for {
select {
case <-p.queue:
i++
continue
default:
break rmvLoop
}
}
log.Debugf("%d queued (and not yet started) tasks removed from the queue", i)
stopped := 0
for _, wrk := range p.workers {
if wrk.ForceStop() {
log.Debugf("Worker %d was running a task so force stopped", wrk.ID())
stopped++
}
}
return stopped
}
// Close shuts down the pool itself and all its workers
func (p *Pool) Close() {
log.Debug("Closing remote execution pool")
p.ForceStopAllTasks()
close(p.queue) // this should make all the workers step out of range loop on queue chan and shut down
log.Debug("Closing the task queue")
close(p.Data)
}
// AddTask adds a task to the pool queue
func (p *Pool) AddTask(task *Task) {
if task.WG != nil {
task.WG.Add(1)
}
p.queue <- task
}
// AddTaskHostlist creates multiple tasks to be run on a multiple hosts
func (p *Pool) AddTaskHostlist(task *Task, hosts []string) {
for _, host := range hosts {
t := &Task{
Hostname: host,
LocalFilename: task.LocalFilename,
RemoteFilename: task.RemoteFilename,
Cmd: task.Cmd,
WG: task.WG,
}
p.AddTask(t)
}
}

136
remote/remote.go Normal file
View File

@@ -0,0 +1,136 @@
package remote
import (
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"time"
)
var (
pool *Pool
currentUser string
currentPassword string
currentRaise RaiseType
currentProgressBar bool
currentPrependHostnames bool
currentRemoteTmpdir string
currentDebug bool
outputFile *os.File
noneInterpreter string
suInterpreter string
sudoInterpreter string
)
// Initialize initializes new execution pool
func Initialize(numThreads int, username string) {
pool = NewPool(numThreads)
SetUser(username)
SetPassword("")
SetRaise(RTNone)
}
// SetInterpreter sets none-raise interpreter
func SetInterpreter(interpreter string) {
noneInterpreter = interpreter
}
// SetSudoInterpreter sets sudo-raise interpreter
func SetSudoInterpreter(interpreter string) {
sudoInterpreter = interpreter
}
// SetSuInterpreter sets su-raise interpreter
func SetSuInterpreter(interpreter string) {
suInterpreter = interpreter
}
// SetUser sets executer username
func SetUser(username string) {
currentUser = username
}
// SetRaise sets executer raise type
func SetRaise(raise RaiseType) {
currentRaise = raise
}
// SetPassword sets executer password
func SetPassword(password string) {
currentPassword = password
}
// SetProgressBar sets current progressbar mode
func SetProgressBar(pbar bool) {
currentProgressBar = pbar
}
// SetRemoteTmpdir sets current remote temp directory
func SetRemoteTmpdir(tmpDir string) {
currentRemoteTmpdir = tmpDir
}
// SetDebug sets current debug mode
func SetDebug(debug bool) {
currentDebug = debug
}
// SetPrependHostnames sets current prepend_hostnames value for parallel mode
func SetPrependHostnames(prependHostnames bool) {
currentPrependHostnames = prependHostnames
}
// SetConnectTimeout sets the ssh connect timeout in sshOptions
func SetConnectTimeout(timeout int) {
sshOptions["ConnectTimeout"] = fmt.Sprintf("%d", timeout)
}
// SetOutputFile sets output file for every command.
// if it's nil, no output will be written to files
func SetOutputFile(f *os.File) {
outputFile = f
}
// SetNumThreads recreates the execution pool with the given number of threads
func SetNumThreads(numThreads int) {
if len(pool.workers) == numThreads {
return
}
pool.Close()
pool = NewPool(numThreads)
}
func prepareTempFiles(cmd string) (string, string, error) {
f, err := ioutil.TempFile("", "xc.")
if err != nil {
return "", "", err
}
defer f.Close()
remoteFilename := filepath.Join(currentRemoteTmpdir, filepath.Base(f.Name()))
io.WriteString(f, "#!/bin/bash\n\n")
io.WriteString(f, fmt.Sprintf("nohup bash -c \"sleep 1; rm -f $0\" >/dev/null 2>&1 </dev/null &\n")) // self-destroy
io.WriteString(f, cmd+"\n") // run command
f.Chmod(0755)
return f.Name(), remoteFilename, nil
}
// WriteOutput writes output to a user-defined logfile
// prepending with the current datetime
func WriteOutput(message string) {
if outputFile == nil {
return
}
tm := time.Now().Format("2006-01-02 15:04:05")
message = fmt.Sprintf("[%s] %s", tm, message)
outputFile.Write([]byte(message))
}
func writeHostOutput(host string, data []byte) {
message := fmt.Sprintf("%s: %s", host, string(data))
WriteOutput(message)
}

129
remote/runcmd.go Normal file
View File

@@ -0,0 +1,129 @@
package remote
import (
"bytes"
"os"
"os/exec"
"syscall"
"time"
"github.com/kr/pty"
"github.com/npat-efault/poller"
"github.com/viert/xc/log"
)
func (w *Worker) runcmd(task *Task) int {
var err error
var n int
var passwordSent bool
passwordSent = currentRaise == RTNone
cmd := createSSHCmd(task.Hostname, task.Cmd)
cmd.Env = append(os.Environ(), environment...)
ptmx, err := pty.Start(cmd)
if err != nil {
return ErrTerminalError
}
defer ptmx.Close()
fd, err := poller.NewFD(int(ptmx.Fd()))
if err != nil {
return ErrTerminalError
}
defer fd.Close()
buf := make([]byte, bufferSize)
taskForceStopped := false
shouldSkipEcho := false
msgCount := 0
execLoop:
for {
if w.forceStopped() {
taskForceStopped = true
break
}
fd.SetReadDeadline(time.Now().Add(pollDeadline))
n, err = fd.Read(buf)
if err != nil {
if err != poller.ErrTimeout {
// EOF, done
break
} else {
continue
}
}
if n == 0 {
continue
}
w.data <- &Message{buf, MTDebug, task.Hostname, -1}
msgCount++
chunks := bytes.SplitAfter(buf[:n], []byte{'\n'})
for _, chunk := range chunks {
// Trying to find Password prompt in first 5 chunks of data from server
if msgCount < 5 {
if !passwordSent && exPasswdPrompt.Match(chunk) {
ptmx.Write([]byte(currentPassword + "\n"))
passwordSent = true
shouldSkipEcho = true
continue
}
if shouldSkipEcho && exEcho.Match(chunk) {
shouldSkipEcho = false
continue
}
if passwordSent && exWrongPassword.Match(chunk) {
w.data <- &Message{[]byte("sudo: Authentication failure\n"), MTData, task.Hostname, -1}
taskForceStopped = true
break execLoop
}
}
if len(chunk) == 0 {
continue
}
if exConnectionClosed.Match(chunk) {
continue
}
if exLostConnection.Match(chunk) {
continue
}
// avoiding passing loop variable further as it's going to change its contents
data := make([]byte, len(chunk))
copy(data, chunk)
w.data <- &Message{data, MTData, task.Hostname, -1}
}
}
exitCode := 0
if taskForceStopped {
cmd.Process.Kill()
exitCode = ErrForceStop
log.Debugf("WRK[%d]: Task on %s was force stopped", w.id, task.Hostname)
}
err = cmd.Wait()
if !taskForceStopped {
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
ws := exitErr.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
} else {
// MacOS hack
exitCode = ErrMacOsExit
}
}
log.Debugf("WRK[%d]: Task on %s exit code is %d", w.id, task.Hostname, exitCode)
}
return exitCode
}

286
remote/serial.go Normal file
View File

@@ -0,0 +1,286 @@
package remote
import (
"fmt"
"os"
"os/exec"
"os/signal"
"syscall"
"time"
"github.com/kr/pty"
"github.com/npat-efault/poller"
"github.com/viert/xc/log"
"github.com/viert/xc/term"
"golang.org/x/crypto/ssh/terminal"
)
var (
passwordSent = false
shouldSkipEcho = false
)
func forwardUserInput(in *poller.FD, out *os.File, stopped *bool) {
inBuf := make([]byte, bufferSize)
// processing stdin
for {
deadline := time.Now().Add(pollDeadline)
in.SetReadDeadline(deadline)
n, err := in.Read(inBuf)
if n > 0 {
// copy stdin to process ptmx
out.Write(inBuf[:n])
inBuf = make([]byte, bufferSize)
}
if err != nil {
if err != poller.ErrTimeout {
break
}
}
if *stopped {
break
}
}
}
func interceptProcessOutput(in []byte, ptmx *os.File) (out []byte, err error) {
out = []byte{}
err = nil
if exConnectionClosed.Match(in) {
log.Debug("Connection closed message catched")
return
}
if exLostConnection.Match(in) {
log.Debug("Lost connection message catched")
return
}
if !passwordSent && exPasswdPrompt.Match(in) {
ptmx.Write([]byte(currentPassword + "\n"))
passwordSent = true
shouldSkipEcho = true
log.Debug("Password sent")
return
}
if shouldSkipEcho && exEcho.Match(in) {
log.Debug("Echo skipped")
shouldSkipEcho = false
return
}
if passwordSent && exWrongPassword.Match(in) {
log.Debug("Authentication error while raising privileges")
err = fmt.Errorf("auth_error")
return
}
out = in
return
}
func runAtHost(host string, cmd *exec.Cmd, r *ExecResult) {
var (
ptmx *os.File
si *poller.FD
buf []byte
err error
stopped = false
)
passwordSent = false
shouldSkipEcho = false
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGWINCH)
defer signal.Reset()
ptmx, err = pty.Start(cmd)
if err != nil {
term.Errorf("Error creating PTY: %s\n", err)
r.ErrorHosts = append(r.ErrorHosts, host)
r.Codes[host] = ErrTerminalError
return
}
pty.InheritSize(os.Stdin, ptmx)
defer ptmx.Close()
stdinBackup, err := syscall.Dup(int(os.Stdin.Fd()))
if err != nil {
term.Errorf("Error duplicating stdin descriptor: %s\n", err)
r.ErrorHosts = append(r.ErrorHosts, host)
r.Codes[host] = ErrTerminalError
return
}
stdinState, err := terminal.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
term.Errorf("Error setting stdin to raw mode: %s\n", err)
r.ErrorHosts = append(r.ErrorHosts, host)
r.Codes[host] = ErrTerminalError
return
}
defer func() {
terminal.Restore(int(os.Stdin.Fd()), stdinState)
}()
si, err = poller.NewFD(int(os.Stdin.Fd()))
if err != nil {
term.Errorf("Error initializing poller: %s\n", err)
r.ErrorHosts = append(r.ErrorHosts, host)
r.Codes[host] = ErrTerminalError
return
}
defer func() {
log.Debug("Setting stdin back to blocking mode")
si.Close()
syscall.Dup2(stdinBackup, int(os.Stdin.Fd()))
syscall.SetNonblock(int(os.Stdin.Fd()), false)
}()
buf = make([]byte, bufferSize)
go forwardUserInput(si, ptmx, &stopped)
for {
n, err := ptmx.Read(buf)
if n > 0 {
// TODO random stuff with intercepting and omitting data
data, err := interceptProcessOutput(buf[:n], ptmx)
if err != nil {
// auth error, can't proceed
raise := "su"
if currentRaise == RTSudo {
raise = "sudo"
}
log.Debugf("Wrong %s password\n", raise)
term.Errorf("Wrong %s password\n", raise)
r.ErrorHosts = append(r.ErrorHosts, host)
r.Codes[host] = ErrAuthenticationError
break
}
if len(data) > 0 {
// copy stdin to process ptmx
_, err = os.Stdout.Write(data)
if err != nil {
count := stdoutWriteRetry
for os.IsTimeout(err) && count > 0 {
time.Sleep(time.Millisecond)
_, err = os.Stdout.Write(data)
count--
}
if err != nil {
log.Debugf("error writing to stdout not resolved in %d steps", stdoutWriteRetry)
}
}
}
}
if err != nil && err != poller.ErrTimeout {
stopped = true
break
}
select {
case <-sigs:
pty.InheritSize(os.Stdin, ptmx)
default:
continue
}
}
}
// RunSerial runs cmd on hosts in serial mode
func RunSerial(hosts []string, argv string, delay int) *ExecResult {
var (
err error
cmd *exec.Cmd
local string
remotePrefix string
remoteCmd string
sigs = make(chan os.Signal, 1)
)
r := newExecResult()
if argv != "" {
local, remotePrefix, err = prepareTempFiles(argv)
if err != nil {
term.Errorf("Error creating tempfile: %s\n", err)
return r
}
defer os.Remove(local)
}
execLoop:
for i, host := range hosts {
msg := term.HR(7) + " " + host + " " + term.HR(36-len(host))
fmt.Println(term.Blue(msg))
if argv != "" {
remoteCmd = fmt.Sprintf("%s.%s.sh", remotePrefix, host)
cmd = createSCPCmd(host, local, remoteCmd, false)
log.Debugf("Created SCP command: %v", cmd)
signal.Notify(sigs, syscall.SIGINT)
err = cmd.Run()
signal.Reset()
if err != nil {
term.Errorf("Error copying tempfile: %s\n", err)
r.ErrorHosts = append(r.ErrorHosts, host)
r.Codes[host] = ErrCopyFailed
continue
}
}
cmd = createSSHCmd(host, remoteCmd)
log.Debugf("Created SSH command: %v", cmd)
runAtHost(host, cmd, r)
exitCode := 0
err = cmd.Wait()
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
ws := exitErr.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
} else {
// MacOS hack
exitCode = ErrMacOsExit
}
}
r.Codes[host] = exitCode
if exitCode != 0 {
r.ErrorHosts = append(r.ErrorHosts, host)
} else {
r.SuccessHosts = append(r.SuccessHosts, host)
}
// no delay after the last host
if delay > 0 && i != len(hosts)-1 {
log.Debugf("Delay %d secs", delay)
timer := time.After(time.Duration(delay) * time.Second)
signal.Notify(sigs, syscall.SIGINT)
timeLoop:
for {
select {
case <-sigs:
log.Debugf("Delay interrupted by ^C")
signal.Reset()
break execLoop
case <-timer:
log.Debugf("Delay finished")
signal.Reset()
break timeLoop
default:
continue
}
}
}
}
return r
}

68
remote/ssh.go Normal file
View File

@@ -0,0 +1,68 @@
package remote
import (
"fmt"
"os/exec"
"strings"
"github.com/viert/xc/log"
)
var (
sshOptions = map[string]string{
"PasswordAuthentication": "no",
"PubkeyAuthentication": "yes",
"StrictHostKeyChecking": "no",
"TCPKeepAlive": "yes",
"ServerAliveCountMax": "12",
"ServerAliveInterval": "5",
}
)
func sshOpts() (params []string) {
params = make([]string, 0)
for opt, value := range sshOptions {
option := fmt.Sprintf("%s=%s", opt, value)
params = append(params, "-o", option)
}
return
}
func createSCPCmd(host string, local string, remote string, recursive bool) *exec.Cmd {
params := []string{}
if recursive {
params = []string{"-r"}
}
params = append(params, sshOpts()...)
remoteExpr := fmt.Sprintf("%s@%s:%s", currentUser, host, remote)
params = append(params, local, remoteExpr)
log.Debugf("Created command scp %v", params)
return exec.Command("scp", params...)
}
func createSSHCmd(host string, argv string) *exec.Cmd {
params := []string{
"-tt",
"-l",
currentUser,
}
params = append(params, sshOpts()...)
params = append(params, host)
params = append(params, getInterpreter()...)
if argv != "" {
params = append(params, "-c", argv)
}
log.Debugf("Created command ssh %v", params)
return exec.Command("ssh", params...)
}
func getInterpreter() []string {
switch currentRaise {
case RTSudo:
return strings.Split(sudoInterpreter, " ")
case RTSu:
return strings.Split(suInterpreter, " ")
default:
return strings.Split(noneInterpreter, " ")
}
}

174
remote/worker.go Normal file
View File

@@ -0,0 +1,174 @@
package remote
import (
"regexp"
"sync"
"time"
"github.com/viert/xc/log"
)
// RaiseType enum
type RaiseType int
// Raise types
const (
RTNone RaiseType = iota
RTSu
RTSudo
)
// Task type represents a worker task descriptor
type Task struct {
Hostname string
LocalFilename string
RemoteFilename string
RecursiveCopy bool
Cmd string
WG *sync.WaitGroup
}
// MessageType describes a type of worker message
type MessageType int
// Message represents a worker message
type Message struct {
Data []byte
Type MessageType
Hostname string
StatusCode int
}
// Enum of OutputTypes
const (
MTData MessageType = iota
MTDebug
MTCopyFinished
MTExecFinished
)
// Custom error codes
const (
ErrMacOsExit = 32500 + iota
ErrForceStop
ErrCopyFailed
ErrTerminalError
ErrAuthenticationError
)
const (
pollDeadline = 50 * time.Millisecond
bufferSize = 4096
)
// Worker type represents a worker object
type Worker struct {
id int
queue chan *Task
data chan *Message
stop chan bool
busy bool
}
var (
wrkseq = 1
environment = []string{"LC_ALL=en_US.UTF-8", "LANG=en_US.UTF-8"}
// remote expressions to catch
exConnectionClosed = regexp.MustCompile(`([Ss]hared\s+)?[Cc]onnection\s+to\s+.+\s+closed\.?[\n\r]+`)
exPasswdPrompt = regexp.MustCompile(`[Pp]assword`)
exWrongPassword = regexp.MustCompile(`[Ss]orry.+try.+again\.?`)
exPermissionDenied = regexp.MustCompile(`[Pp]ermission\s+denied`)
exLostConnection = regexp.MustCompile(`[Ll]ost\sconnection`)
exEcho = regexp.MustCompile(`^[\n\r]+$`)
)
// NewWorker creates a new worker
func NewWorker(queue chan *Task, data chan *Message) *Worker {
w := &Worker{
id: wrkseq,
queue: queue,
data: data,
stop: make(chan bool, 1),
busy: false,
}
wrkseq++
go w.run()
return w
}
// ID is a worker id getter
func (w *Worker) ID() int {
return w.id
}
func (w *Worker) run() {
var result int
log.Debugf("WRK[%d] Started", w.id)
for task := range w.queue {
// Every task consists of copying part and executing part
// It may contain both or just one of them
// If there are both parts, worker copies data and then runs
// the given command immediately. This behaviour is handy for runscript
// command when the script is being copied to a remote server
// and called right after it.
w.busy = true
log.Debugf("WRK[%d] Got a task for host %s by worker", w.id, task.Hostname)
// does the task have anything to copy?
if task.RemoteFilename != "" && task.LocalFilename != "" {
result = w.copy(task)
log.Debugf("WRK[%d] Copy on %s, status=%d", w.id, task.Hostname, result)
w.data <- &Message{nil, MTCopyFinished, task.Hostname, result}
if result != 0 {
log.Debugf("WRK[%d] Copy on %s, result != 0, catching", w.id, task.Hostname)
// if copying failed we can't proceed further with the task if there's anything to run
if task.Cmd != "" {
log.Debugf("WRK[%d] Copy on %s, result != 0, task.Cmd == \"%s\", sending ExecFinished", w.id, task.Hostname, task.Cmd)
w.data <- &Message{nil, MTExecFinished, task.Hostname, ErrCopyFailed}
}
w.busy = false
if task.WG != nil {
task.WG.Done()
}
// next task
continue
}
}
// does the task have anything to run?
if task.Cmd != "" {
log.Debugf("WRK[%d] runcmd(%s) at %s", task.Cmd, task.Hostname)
result = w.runcmd(task)
w.data <- &Message{nil, MTExecFinished, task.Hostname, result}
}
if task.WG != nil {
task.WG.Done()
}
w.busy = false
}
log.Debugf("WRK[%d] Task queue has closed, worker is exiting", w.id)
}
// ForceStop stops the current task execution and returns true
// if any task were actually executed at the moment of calling ForceStop
func (w *Worker) ForceStop() bool {
if w.busy {
w.stop <- true
return true
}
return false
}
func (w *Worker) forceStopped() bool {
select {
case <-w.stop:
return true
default:
return false
}
}