No parallel pty (#5)

* poller and pty removed for parallel execution
This commit is contained in:
Pavel Vorobyov
2019-10-01 10:42:10 +03:00
committed by GitHub
parent 7437dcc9e9
commit 7ea2ec5330
9 changed files with 320 additions and 394 deletions

75
remote/commands.go Normal file
View File

@@ -0,0 +1,75 @@
package remote
import (
"fmt"
"os/exec"
"strings"
"github.com/viert/xc/log"
)
func (w *Worker) copy(task *Task) int {
cmd := createSCPCmd(task.Hostname, task.LocalFilename, task.RemoteFilename, task.RecursiveCopy)
return w._run(task, cmd)
}
func (w *Worker) runcmd(task *Task) int {
cmd := createSSHCmd(task.Hostname, task.Cmd)
return w._run(task, cmd)
}
func (w *Worker) tarcopy(task *Task) int {
cmd := createTarCopyCmd(task.Hostname, task.LocalFilename, task.RemoteFilename)
return w._run(task, cmd)
}
func createTarCopyCmd(host string, local string, remote string) *exec.Cmd {
if remote == "" || remote == local {
remote = "."
}
options := strings.Join(sshOpts(), " ")
sshCmd := fmt.Sprintf("ssh -l %s %s %s", currentUser, options, host)
tarCmd := fmt.Sprintf("tar c %s | %s tar x -C %s", local, sshCmd, remote)
params := []string{"-c", tarCmd}
log.Debugf("Created command bash %v", params)
return exec.Command("bash", params...)
}
func createSCPCmd(host string, local string, remote string, recursive bool) *exec.Cmd {
params := []string{}
if recursive {
params = []string{"-r"}
}
params = append(params, sshOpts()...)
remoteExpr := fmt.Sprintf("%s@%s:%s", currentUser, host, remote)
params = append(params, local, remoteExpr)
log.Debugf("Created command scp %v", params)
return exec.Command("scp", params...)
}
func createSSHCmd(host string, argv string) *exec.Cmd {
params := []string{
"-tt",
"-l",
currentUser,
}
params = append(params, sshOpts()...)
params = append(params, host)
params = append(params, getInterpreter()...)
if argv != "" {
params = append(params, "-c", argv)
}
log.Debugf("Created command ssh %v", params)
return exec.Command("ssh", params...)
}
func getInterpreter() []string {
switch currentRaise {
case RTSudo:
return strings.Split(sudoInterpreter, " ")
case RTSu:
return strings.Split(suInterpreter, " ")
default:
return strings.Split(noneInterpreter, " ")
}
}

View File

@@ -1,145 +0,0 @@
package remote
import (
"os"
"os/exec"
"syscall"
"time"
"github.com/kr/pty"
"github.com/npat-efault/poller"
"github.com/viert/xc/log"
)
func (w *Worker) tarcopy(task *Task) int {
var err error
var n int
cmd := createTarCopyCmd(task.Hostname, task.LocalFilename, task.RemoteFilename)
cmd.Env = append(os.Environ(), environment...)
ptmx, err := pty.Start(cmd)
if err != nil {
return ErrTerminalError
}
defer ptmx.Close()
fd, err := poller.NewFD(int(ptmx.Fd()))
if err != nil {
return ErrTerminalError
}
defer fd.Close()
buf := make([]byte, bufferSize)
taskForceStopped := false
for {
if w.forceStopped() {
taskForceStopped = true
break
}
fd.SetReadDeadline(time.Now().Add(pollDeadline))
n, err = fd.Read(buf)
if err != nil {
if err != poller.ErrTimeout {
// EOF, done
break
} else {
continue
}
}
if n == 0 {
continue
}
w.data <- &Message{buf[:n], MTData, task.Hostname, 0}
buf = make([]byte, bufferSize)
}
exitCode := 0
if taskForceStopped {
cmd.Process.Kill()
exitCode = ErrForceStop
log.Debugf("WRK[%d]: Task on %s was force stopped", w.id, task.Hostname)
}
err = cmd.Wait()
if !taskForceStopped {
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
ws := exitErr.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
} else {
// MacOS hack
exitCode = ErrMacOsExit
}
}
log.Debugf("WRK[%d]: Task on %s exit code is %d", w.id, task.Hostname, exitCode)
}
return exitCode
}
func (w *Worker) copy(task *Task) int {
var err error
var n int
cmd := createSCPCmd(task.Hostname, task.LocalFilename, task.RemoteFilename, task.RecursiveCopy)
cmd.Env = append(os.Environ(), environment...)
ptmx, err := pty.Start(cmd)
if err != nil {
return ErrTerminalError
}
defer ptmx.Close()
fd, err := poller.NewFD(int(ptmx.Fd()))
if err != nil {
return ErrTerminalError
}
defer fd.Close()
buf := make([]byte, bufferSize)
taskForceStopped := false
for {
if w.forceStopped() {
taskForceStopped = true
break
}
fd.SetReadDeadline(time.Now().Add(pollDeadline))
n, err = fd.Read(buf)
if err != nil {
if err != poller.ErrTimeout {
// EOF, done
break
} else {
continue
}
}
if n == 0 {
continue
}
w.data <- &Message{buf[:n], MTDebug, task.Hostname, 0}
buf = make([]byte, bufferSize)
}
exitCode := 0
if taskForceStopped {
cmd.Process.Kill()
exitCode = ErrForceStop
log.Debugf("WRK[%d]: Task on %s was force stopped", w.id, task.Hostname)
}
err = cmd.Wait()
if !taskForceStopped {
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
ws := exitErr.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
} else {
// MacOS hack
exitCode = ErrMacOsExit
}
}
log.Debugf("WRK[%d]: Task on %s exit code is %d", w.id, task.Hostname, exitCode)
}
return exitCode
}

View File

@@ -22,7 +22,7 @@ var (
currentRemoteTmpdir string
currentDebug bool
outputFile *os.File
ptyLock *sync.Mutex
poolLock *sync.Mutex
noneInterpreter string
suInterpreter string
@@ -32,7 +32,7 @@ var (
// Initialize initializes new execution pool
func Initialize(numThreads int, username string) {
pool = NewPool(numThreads)
ptyLock = new(sync.Mutex)
poolLock = new(sync.Mutex)
SetUser(username)
SetPassword("")
SetRaise(RTNone)

View File

@@ -1,160 +0,0 @@
package remote
import (
"bytes"
"os"
"os/exec"
"syscall"
"time"
"github.com/kr/pty"
"github.com/npat-efault/poller"
"github.com/viert/xc/log"
"github.com/viert/xc/passmgr"
)
func (w *Worker) runcmd(task *Task) int {
var (
err error
n int
password string
passwordSent bool
ptmx *os.File
fd *poller.FD
)
cmd := createSSHCmd(task.Hostname, task.Cmd)
cmd.Env = append(os.Environ(), environment...)
// threadsafe acquiring necessary file descriptors
ptyLock.Lock()
ptmx, err = pty.Start(cmd)
if err != nil {
log.Debugf("WRK[%d]: Error creating ptmx: %v", w.id, err)
ptyLock.Unlock()
return ErrTerminalError
}
defer ptmx.Close()
fd, err = poller.NewFD(int(ptmx.Fd()))
if err != nil {
log.Debugf("WRK[%d]: Error creating poller FD: %v", w.id, err)
ptyLock.Unlock()
return ErrTerminalError
}
defer fd.Close()
ptyLock.Unlock()
// threadsafe acquiring necessary file descriptors ends
buf := make([]byte, bufferSize)
taskForceStopped := false
shouldSkipEcho := false
msgCount := 0
if currentRaise != RTNone {
passwordSent = false
if currentUsePasswordManager {
password = passmgr.GetPass(task.Hostname)
} else {
password = currentPassword
}
} else {
passwordSent = true
}
execLoop:
for {
if w.forceStopped() {
taskForceStopped = true
break
}
fd.SetReadDeadline(time.Now().Add(pollDeadline))
n, err = fd.Read(buf)
if err != nil {
if err != poller.ErrTimeout {
// EOF, done
log.Debugf("WRK[%d]: error reading process output: %v", w.id, err)
break
} else {
continue
}
}
if n == 0 {
continue
}
w.data <- &Message{buf[:n], MTDebug, task.Hostname, -1}
msgCount++
chunks := bytes.SplitAfter(buf[:n], []byte{'\n'})
for _, chunk := range chunks {
// Trying to find Password prompt in first 5 chunks of data from server
if msgCount < 5 {
if !passwordSent && exPasswdPrompt.Match(chunk) {
_, err := ptmx.Write([]byte(password + "\n"))
if err != nil {
log.Debugf("WRK[%d]: Error sending password: %v", w.id, err)
}
passwordSent = true
shouldSkipEcho = true
continue
}
if shouldSkipEcho && exEcho.Match(chunk) {
shouldSkipEcho = false
continue
}
if passwordSent && exWrongPassword.Match(chunk) {
w.data <- &Message{[]byte("sudo: Authentication failure\n"), MTData, task.Hostname, -1}
taskForceStopped = true
break execLoop
}
}
if len(chunk) == 0 {
continue
}
if exConnectionClosed.Match(chunk) {
continue
}
if exLostConnection.Match(chunk) {
continue
}
// avoiding passing loop variable further as it's going to change its contents
data := make([]byte, len(chunk))
copy(data, chunk)
w.data <- &Message{data, MTData, task.Hostname, -1}
}
}
exitCode := 0
if taskForceStopped {
err = cmd.Process.Kill()
if err != nil {
log.Debugf("WRK[%d]: Error killing the process: %v", w.id, err)
}
exitCode = ErrForceStop
log.Debugf("WRK[%d]: Task on %s was force stopped", w.id, task.Hostname)
}
err = cmd.Wait()
if !taskForceStopped {
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
ws := exitErr.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
} else {
// MacOS hack
exitCode = ErrMacOsExit
}
}
log.Debugf("WRK[%d]: Task on %s exit code is %d", w.id, task.Hostname, exitCode)
}
return exitCode
}

View File

@@ -17,11 +17,6 @@ import (
"golang.org/x/crypto/ssh/terminal"
)
var (
passwordSent = false
shouldSkipEcho = false
)
func forwardUserInput(in *poller.FD, out *os.File, stopped *bool) {
inBuf := make([]byte, bufferSize)
// processing stdin
@@ -45,35 +40,42 @@ func forwardUserInput(in *poller.FD, out *os.File, stopped *bool) {
}
}
func interceptProcessOutput(in []byte, ptmx *os.File, password string) (out []byte, err error) {
func interceptProcessOutput(in []byte, ptmx *os.File, password string, passwordSent *bool, shouldSkipEcho *bool) (out []byte, err error) {
out = []byte{}
err = nil
if currentDebug {
log.Debugf("DATASTREAM: %s", string(in))
}
if exConnectionClosed.Match(in) {
out = exConnectionClosed.ReplaceAll(in, []byte{})
log.Debug("Connection closed message catched")
return
}
if exLostConnection.Match(in) {
out = exLostConnection.ReplaceAll(in, []byte{})
log.Debug("Lost connection message catched")
return
}
if !passwordSent && exPasswdPrompt.Match(in) {
if !*passwordSent && exPasswdPrompt.Match(in) {
ptmx.Write([]byte(password + "\n"))
passwordSent = true
shouldSkipEcho = true
*passwordSent = true
*shouldSkipEcho = true
log.Debug("Password sent")
return
}
if shouldSkipEcho && exEcho.Match(in) {
if *shouldSkipEcho && exEcho.Match(in) {
log.Debug("Echo skipped")
shouldSkipEcho = false
*shouldSkipEcho = false
out = exEcho.ReplaceAll(in, []byte{})
return
}
if passwordSent && exWrongPassword.Match(in) {
if *passwordSent && exWrongPassword.Match(in) {
log.Debug("Authentication error while raising privileges")
err = fmt.Errorf("auth_error")
return
@@ -91,7 +93,9 @@ func runAtHost(host string, cmd *exec.Cmd, r *ExecResult) {
err error
password string
stopped = false
passwordSent = false
shouldSkipEcho = false
stopped = false
)
password = currentPassword
@@ -138,6 +142,7 @@ func runAtHost(host string, cmd *exec.Cmd, r *ExecResult) {
r.Codes[host] = ErrTerminalError
return
}
defer func() {
log.Debug("Setting stdin back to blocking mode")
si.Close()
@@ -156,7 +161,7 @@ func runAtHost(host string, cmd *exec.Cmd, r *ExecResult) {
n, err := ptmx.Read(buf)
if n > 0 {
// TODO random stuff with intercepting and omitting data
data, err := interceptProcessOutput(buf[:n], ptmx, password)
data, err := interceptProcessOutput(buf[:n], ptmx, password, &passwordSent, &shouldSkipEcho)
if err != nil {
// auth error, can't proceed
raise := "su"
@@ -188,6 +193,7 @@ func runAtHost(host string, cmd *exec.Cmd, r *ExecResult) {
}
if err != nil && err != poller.ErrTimeout {
log.Debugf("pty read error: %v", err)
stopped = true
break
}
@@ -231,7 +237,6 @@ execLoop:
if argv != "" {
remoteCmd = fmt.Sprintf("%s.%s.sh", remotePrefix, host)
cmd = createSCPCmd(host, local, remoteCmd, false)
log.Debugf("Created SCP command: %v", cmd)
signal.Notify(sigs, syscall.SIGINT)
err = cmd.Run()
signal.Reset()

View File

@@ -2,10 +2,6 @@ package remote
import (
"fmt"
"os/exec"
"strings"
"github.com/viert/xc/log"
)
var (
@@ -27,54 +23,3 @@ func sshOpts() (params []string) {
}
return
}
func createTarCopyCmd(host string, local string, remote string) *exec.Cmd {
if remote == "" || remote == local {
remote = "."
}
options := strings.Join(sshOpts(), " ")
sshCmd := fmt.Sprintf("ssh -l %s %s %s", currentUser, options, host)
tarCmd := fmt.Sprintf("tar c %s | %s tar x -C %s", local, sshCmd, remote)
params := []string{"-c", tarCmd}
log.Debugf("Created command bash %v", params)
return exec.Command("bash", params...)
}
func createSCPCmd(host string, local string, remote string, recursive bool) *exec.Cmd {
params := []string{}
if recursive {
params = []string{"-r"}
}
params = append(params, sshOpts()...)
remoteExpr := fmt.Sprintf("%s@%s:%s", currentUser, host, remote)
params = append(params, local, remoteExpr)
log.Debugf("Created command scp %v", params)
return exec.Command("scp", params...)
}
func createSSHCmd(host string, argv string) *exec.Cmd {
params := []string{
"-tt",
"-l",
currentUser,
}
params = append(params, sshOpts()...)
params = append(params, host)
params = append(params, getInterpreter()...)
if argv != "" {
params = append(params, "-c", argv)
}
log.Debugf("Created command ssh %v", params)
return exec.Command("ssh", params...)
}
func getInterpreter() []string {
switch currentRaise {
case RTSudo:
return strings.Split(sudoInterpreter, " ")
case RTSu:
return strings.Split(suInterpreter, " ")
default:
return strings.Split(noneInterpreter, " ")
}
}

View File

@@ -1,11 +1,18 @@
package remote
import (
"bytes"
"fmt"
"io"
"os"
"os/exec"
"regexp"
"sync"
"syscall"
"time"
"github.com/viert/xc/log"
"github.com/viert/xc/passmgr"
)
// RaiseType enum
@@ -64,6 +71,7 @@ const (
ErrCopyFailed
ErrTerminalError
ErrAuthenticationError
ErrCommandStartFailed
)
const (
@@ -186,3 +194,216 @@ func (w *Worker) forceStopped() bool {
return false
}
}
func (w *Worker) log(format string, args ...interface{}) {
format = fmt.Sprintf("WRK[%d]: %s", w.id, format)
log.Debugf(format, args...)
}
func (w *Worker) processStderr(rd io.ReadCloser, wr io.WriteCloser, finished *bool, task *Task) {
var (
n int
err error
buf []byte
)
buf = make([]byte, bufferSize)
w.log("starting stderr processor for host %s", task.Hostname)
for {
n, err = rd.Read(buf)
if err != nil {
*finished = true
break
}
if n > 0 {
w.data <- &Message{buf[:n], MTDebug, task.Hostname, -1}
chunks := bytes.SplitAfter(buf[:n], []byte{'\n'})
for _, chunk := range chunks {
if currentDebug {
w.log("STDERR CHUNK IN @ %s: %v %s", task.Hostname, chunk, string(chunk))
}
if exConnectionClosed.Match(chunk) {
chunk = exConnectionClosed.ReplaceAll(chunk, []byte{})
w.log("expr connection closed on stderr")
}
if exLostConnection.Match(chunk) {
chunk = exLostConnection.ReplaceAll(chunk, []byte{})
w.log("expr lost connection on stderr")
}
if len(chunk) == 0 {
continue
}
// avoiding passing loop variable further as it's going to change its contents
data := make([]byte, len(chunk))
copy(data, chunk)
if currentDebug {
w.log("STDERR CHUNK OUT @ %s: %v %s", task.Hostname, data, string(data))
}
w.data <- &Message{data, MTData, task.Hostname, -1}
}
}
}
w.log("exiting stderr processor for host %s", task.Hostname)
}
func (w *Worker) processStdout(rd io.ReadCloser, wr io.WriteCloser, finished *bool, task *Task) {
var (
n int
msgCount int
err error
buf []byte
password string
passwordSent bool
shouldSkipEcho bool
)
w.log("starting stdout processor for host %s", task.Hostname)
buf = make([]byte, bufferSize)
msgCount = 0
if currentRaise != RTNone {
passwordSent = false
if currentUsePasswordManager {
password = passmgr.GetPass(task.Hostname)
} else {
password = currentPassword
}
} else {
passwordSent = true
}
execLoop:
for {
n, err = rd.Read(buf)
if err != nil {
*finished = true
break
}
if n > 0 {
w.data <- &Message{buf[:n], MTDebug, task.Hostname, -1}
msgCount++
chunks := bytes.SplitAfter(buf[:n], []byte{'\n'})
for _, chunk := range chunks {
if currentDebug {
w.log("STDOUT CHUNK IN @ %s: %v %s", task.Hostname, chunk, string(chunk))
}
// Trying to find Password prompt in first 5 chunks of data from server
if msgCount < 10 {
if !passwordSent && exPasswdPrompt.Match(chunk) {
w.log("sending password for %s, msgCount=%d", task.Hostname, msgCount)
_, err := wr.Write([]byte(password + "\n"))
if err != nil {
w.log("error sending password: %v", err)
}
passwordSent = true
shouldSkipEcho = true
continue
}
}
if shouldSkipEcho && exEcho.Match(chunk) {
shouldSkipEcho = false
continue
}
if passwordSent && exWrongPassword.Match(chunk) {
w.data <- &Message{[]byte("sudo: Authentication failure\n"), MTData, task.Hostname, -1}
*finished = true
break execLoop
}
if len(chunk) == 0 {
continue
}
// avoiding passing loop variable further as it's going to change its contents
data := make([]byte, len(chunk))
copy(data, chunk)
if currentDebug {
w.log("STDOUT CHUNK OUT @ %s: %v %s", task.Hostname, data, string(data))
}
w.data <- &Message{data, MTData, task.Hostname, -1}
}
}
}
w.log("exiting stdout processor for host %s", task.Hostname)
}
func (w *Worker) _run(task *Task, cmd *exec.Cmd) int {
cmd.Env = append(os.Environ(), environment...)
sout, err := cmd.StdoutPipe()
if err != nil {
w.log("error creating stdout pipe: %v", err)
return ErrTerminalError
}
serr, err := cmd.StderrPipe()
if err != nil {
w.log("error creating stderr pipe: %v", err)
w.log("closing stdout pipe, err=%v", sout.Close())
return ErrTerminalError
}
sin, err := cmd.StdinPipe()
if err != nil {
w.log("error creating stdin pipe: %v", err)
w.log("closing stderr pipe, err=%v", serr.Close())
w.log("closing stdout pipe, err=%v", sout.Close())
return ErrTerminalError
}
err = cmd.Start()
if err != nil {
w.log("error starting cmd: %v", err)
w.log("closing stderr pipe, err=%v", serr.Close())
w.log("closing stdout pipe, err=%v", sout.Close())
w.log("closing stdin pipe, err=%v", sin.Close())
return ErrCommandStartFailed
}
stdoutFinished := false
stderrFinished := false
taskForceStopped := false
go w.processStdout(sout, sin, &stdoutFinished, task)
go w.processStderr(serr, sin, &stderrFinished, task)
for !(stdoutFinished && stderrFinished) {
if w.forceStopped() {
taskForceStopped = true
err = cmd.Process.Kill()
if err != nil {
w.log("error killing process: %v", err)
}
break
}
time.Sleep(pollDeadline)
}
exitCode := 0
w.log("out of waitloop running cmd.Wait to cleanup")
err = cmd.Wait()
if taskForceStopped {
return ErrForceStop
}
if err != nil {
if exitErr, ok := err.(*exec.ExitError); ok {
ws := exitErr.Sys().(syscall.WaitStatus)
exitCode = ws.ExitStatus()
} else {
// MacOS hack
exitCode = ErrMacOsExit
}
}
w.log("Task on %s exit coded is %d", task.Hostname, exitCode)
return exitCode
}