diff --git a/command/alloc_exec.go b/command/alloc_exec.go index 0f93f408f..6dd845edb 100644 --- a/command/alloc_exec.go +++ b/command/alloc_exec.go @@ -265,6 +265,11 @@ func (l *AllocExecCommand) execImpl(client *api.Client, alloc *api.Allocation, t stdin = escapingio.NewReader(stdin, escapeChar[0], func(c byte) bool { switch c { case '.': + // need to restore tty state so error reporting here + // gets emitted at beginning of line + outCleanup() + inCleanup() + stderr.Write([]byte("\nConnection closed\n")) cancelFn() return true @@ -272,7 +277,6 @@ func (l *AllocExecCommand) execImpl(client *api.Client, alloc *api.Allocation, t return false } }) - } } diff --git a/helper/escapingio/reader.go b/helper/escapingio/reader.go index 704bdb483..099654c86 100644 --- a/helper/escapingio/reader.go +++ b/helper/escapingio/reader.go @@ -1,6 +1,7 @@ package escapingio import ( + "bufio" "io" ) @@ -22,12 +23,16 @@ type Handler func(c byte) bool // // Appearances of `~` when not preceded by a new line are propagated unmodified. func NewReader(r io.Reader, c byte, h Handler) io.Reader { - return &reader{ + pr, pw := io.Pipe() + reader := &reader{ impl: r, escapeChar: c, - state: sLookEscapeChar, handler: h, + pr: pr, + pw: pw, } + go reader.pipe() + return reader } // lookState represents the state of reader for what character of `\n~.` sequence @@ -52,112 +57,115 @@ type reader struct { escapeChar uint8 handler Handler - state lookState - - // unread is a buffered character for next read if not-nil - unread *byte + // buffers + pw *io.PipeWriter + pr *io.PipeReader } func (r *reader) Read(buf []byte) (int, error) { + return r.pr.Read(buf) +} + +func (r *reader) pipe() { + rb := make([]byte, 4096) + bw := bufio.NewWriter(r.pw) + + state := sLookEscapeChar + + for { + n, err := r.impl.Read(rb) + + if n > 0 { + state = r.processBuf(bw, rb, n, state) + bw.Flush() + if state == sLookChar { + // terminated with ~ - let's read one more character + n, err = r.impl.Read(rb[:1]) + if n == 1 { + state = sLookNewLine + if rb[0] == r.escapeChar { + // only emit escape character once + bw.WriteByte(rb[0]) + bw.Flush() + } else if r.handler(rb[0]) { + // skip if handled + } else { + bw.WriteByte(r.escapeChar) + bw.WriteByte(rb[0]) + bw.Flush() + } + } + } + } + + if err != nil { + // write ~ if it's the last thing + if state == sLookChar { + bw.WriteByte(r.escapeChar) + } + bw.Flush() + r.pw.CloseWithError(err) + break + } + } +} + +// processBuf process buffer and emits all output to writer +// if the last part of buffer is a new line followed by sequnce, it writes +// all output until the new line and returns sLookChar +func (r *reader) processBuf(bw io.Writer, buf []byte, n int, s lookState) lookState { + i := 0 + + wi := 0 + START: - var n int - var err error - - if r.unread != nil { - // try to return the unread character immediately - // without trying to block for another read - buf[0] = *r.unread - n = 1 - r.unread = nil - } else { - n, err = r.impl.Read(buf) - } - - // when we get to the end, check if we have any unprocessed \n~ - if n == 0 && err != nil { - if r.state == sLookChar && err != nil { - buf[0] = r.escapeChar - n = 1 + if s == sLookEscapeChar && buf[i] == r.escapeChar { + if i+1 >= n { + // buf terminates with ~ - write all before + bw.Write(buf[wi:i]) + return sLookChar } - return n, err - } - // inspect the state at beginning of read - if r.state == sLookChar { - r.state = sLookNewLine - - // escape character hasn't been emitted yet - if buf[0] == r.escapeChar { - // earlier ~ was swallowed already, so leave this as is - } else if handled := r.handler(buf[0]); handled { - // need to drop a single letter - copy(buf, buf[1:n]) - n-- + nc := buf[i+1] + if nc == r.escapeChar { + // skip one escape char + bw.Write(buf[wi:i]) + i++ + wi = i + } else if r.handler(nc) { + // skip both characters + bw.Write(buf[wi:i]) + i = i + 2 + wi = i } else { - // we need to re-introduce ~ with rest of body - // but be mindful if reintroducing ~ causes buffer to overflow - if n == len(buf) { - // in which case, save it for next read - c := buf[n-1] - r.unread = &c - copy(buf[1:], buf[:n]) - buf[0] = r.escapeChar - } else { - copy(buf[1:], buf[:n]) - buf[0] = r.escapeChar - n++ - } + i = i + 2 + // need to write everything keep going } } - n = r.processBuffer(buf, n) - if n == 0 && err == nil { - goto START - } - - return n, err -} - -// handles escaped character inside body of read buf. -func (r *reader) processBuffer(buf []byte, read int) int { - b := 0 - - for b < read { - - c := buf[b] - if r.state == sLookEscapeChar && r.escapeChar == c { - r.state = sLookEscapeChar - - // are we at the end of read; wait for next read - if b == read-1 { - read-- - r.state = sLookChar - return read - } - - // otherwise peek at next - nc := buf[b+1] - if nc == r.escapeChar { - // repeated ~, only emit one - skip one character - copy(buf[b:], buf[b+1:read]) - read-- - b++ - continue - } else if handled := r.handler(nc); handled { - // need to drop both ~ and letter - copy(buf[b:], buf[b+2:read]) - read -= 2 - continue - } else { - // need to pass output unmodified with ~ and letter - } - } else if c == '\n' || c == '\r' { - r.state = sLookEscapeChar - } else { - r.state = sLookNewLine + // search until we get \n~, or buf terminates + for { + if i >= n { + // got to end without new line, write and return + bw.Write(buf[wi:n]) + return sLookNewLine } - b++ - } - return read + if buf[i] == '\n' || buf[i] == '\r' { + // buf terminated at new line + if i+1 >= n { + bw.Write(buf[wi:n]) + return sLookEscapeChar + } + + // peek to see escape character go back to START if so + if buf[i+1] == r.escapeChar { + s = sLookEscapeChar + i++ + goto START + } + } + + i++ + } } diff --git a/helper/escapingio/reader_test.go b/helper/escapingio/reader_test.go index 807ee31f9..d35cdd454 100644 --- a/helper/escapingio/reader_test.go +++ b/helper/escapingio/reader_test.go @@ -2,17 +2,21 @@ package escapingio import ( "bytes" + "errors" "fmt" "io" "math/rand" "reflect" "regexp" "strings" + "sync" "testing" "testing/iotest" "testing/quick" + "time" "unicode" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -42,12 +46,11 @@ func TestEscapingReader_Static(t *testing.T) { for _, c := range cases { t.Run("sanity check naive implementation", func(t *testing.T) { - foundEscaped := "" - h := testHandler(&foundEscaped) + h := &testHandler{} - processed := naiveEscapeCharacters(c.input, '~', h) + processed := naiveEscapeCharacters(c.input, '~', h.handler) require.Equal(t, c.expected, processed) - require.Equal(t, c.escaped, foundEscaped) + require.Equal(t, c.escaped, h.escaped()) }) t.Run("chunks at a time: "+c.input, func(t *testing.T) { @@ -55,16 +58,15 @@ func TestEscapingReader_Static(t *testing.T) { input := strings.NewReader(c.input) - foundEscaped := "" - h := testHandler(&foundEscaped) + h := &testHandler{} - filter := NewReader(input, '~', h) + filter := NewReader(input, '~', h.handler) _, err := io.Copy(&found, filter) require.NoError(t, err) require.Equal(t, c.expected, found.String()) - require.Equal(t, c.escaped, foundEscaped) + require.Equal(t, c.escaped, h.escaped()) }) t.Run("1 byte at a time: "+c.input, func(t *testing.T) { @@ -72,19 +74,173 @@ func TestEscapingReader_Static(t *testing.T) { input := iotest.OneByteReader(strings.NewReader(c.input)) - foundEscaped := "" - h := testHandler(&foundEscaped) + h := &testHandler{} - filter := NewReader(input, '~', h) + filter := NewReader(input, '~', h.handler) _, err := io.Copy(&found, filter) require.NoError(t, err) require.Equal(t, c.expected, found.String()) - require.Equal(t, c.escaped, foundEscaped) + require.Equal(t, c.escaped, h.escaped()) + }) + + t.Run("without reading: "+c.input, func(t *testing.T) { + input := strings.NewReader(c.input) + + h := &testHandler{} + + filter := NewReader(input, '~', h.handler) + + // don't read to mimic a stalled reader + _ = filter + + assertEventually(t, func() (bool, error) { + escaped := h.escaped() + if c.escaped == escaped { + return true, nil + } + + return false, fmt.Errorf("expected %v but found %v", c.escaped, escaped) + }) }) } } +// TestEscapingReader_EmitsPartialReads should emit partial results +// if next character is not read +func TestEscapingReader_FlushesPartialReads(t *testing.T) { + pr, pw := io.Pipe() + + h := &testHandler{} + filter := NewReader(pr, '~', h.handler) + + var lock sync.Mutex + var read bytes.Buffer + + // helper for asserting reads + requireRead := func(expected *bytes.Buffer) { + readSoFar := "" + + start := time.Now() + for time.Since(start) < 2*time.Second { + lock.Lock() + readSoFar = read.String() + lock.Unlock() + + if readSoFar == expected.String() { + break + } + + time.Sleep(50 * time.Millisecond) + } + + require.Equal(t, expected.String(), readSoFar, "timed out without output") + } + + var rerr error + var wg sync.WaitGroup + wg.Add(1) + + // goroutine for reading partial data + go func() { + defer wg.Done() + + buf := make([]byte, 1024) + for { + n, err := filter.Read(buf) + lock.Lock() + read.Write(buf[:n]) + lock.Unlock() + + if err != nil { + rerr = err + break + } + } + }() + + expected := &bytes.Buffer{} + + // test basic start and no new lines + pw.Write([]byte("first data")) + expected.WriteString("first data") + requireRead(expected) + require.Equal(t, "", h.escaped()) + + // test ~. appearing in middle of line but stop at new line + pw.Write([]byte("~.inmiddleappears\n")) + expected.WriteString("~.inmiddleappears\n") + requireRead(expected) + require.Equal(t, "", h.escaped()) + + // from here on we test \n~ at boundary + + // ~~ after new line; and stop at \n~ + pw.Write([]byte("~~second line\n~")) + expected.WriteString("~second line\n") + requireRead(expected) + require.Equal(t, "", h.escaped()) + + // . to be skipped; stop at \n~ again + pw.Write([]byte(".third line\n~")) + expected.WriteString("third line\n") + requireRead(expected) + require.Equal(t, ".", h.escaped()) + + // q to be emitted; stop at \n + pw.Write([]byte("qfourth line\n")) + expected.WriteString("~qfourth line\n") + requireRead(expected) + require.Equal(t, ".q", h.escaped()) + + // ~. to be skipped; stop at \n~ + pw.Write([]byte("~.fifth line\n~")) + expected.WriteString("fifth line\n") + requireRead(expected) + require.Equal(t, ".q.", h.escaped()) + + // ~ alone after \n~ - should be emitted + pw.Write([]byte("~")) + expected.WriteString("~") + requireRead(expected) + require.Equal(t, ".q.", h.escaped()) + + // rest of line ending with \n~ + pw.Write([]byte("rest of line\n~")) + expected.WriteString("rest of line\n") + requireRead(expected) + require.Equal(t, ".q.", h.escaped()) + + // m alone after \n~ - should be emitted with ~ + pw.Write([]byte("m")) + expected.WriteString("~m") + requireRead(expected) + require.Equal(t, ".q.m", h.escaped()) + + // rest of line and end with \n + pw.Write([]byte("onemore line\n")) + expected.WriteString("onemore line\n") + requireRead(expected) + require.Equal(t, ".q.m", h.escaped()) + + // ~q to be emitted stop at \n~; last charcater + pw.Write([]byte("~qlast line\n~")) + expected.WriteString("~qlast line\n") + requireRead(expected) + require.Equal(t, ".q.mq", h.escaped()) + + // last ~ gets emitted and we preserve error + eerr := errors.New("my custom error") + pw.CloseWithError(eerr) + expected.WriteString("~") + requireRead(expected) + require.Equal(t, ".q.mq", h.escaped()) + + wg.Wait() + require.Error(t, rerr) + require.Equal(t, eerr, rerr) +} + func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) { f := func(v readingInput) bool { return checkEquivalenceToNaive(t, string(v)) @@ -95,40 +251,52 @@ func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) { })) } -// testHandler returns a handler that stores all basic ascii letters in result -// reference. We avoid complicated unicode characters that may cross -// byte boundary -func testHandler(result *string) Handler { - return func(c byte) bool { - rc := rune(c) - simple := unicode.IsLetter(rc) || - unicode.IsDigit(rc) || - unicode.IsPunct(rc) || - unicode.IsSymbol(rc) +// testHandler is a conveneient struct for finding "escaped" ascii letters +// in escaping reader. +// We avoid complicated unicode characters that may cross byte boundary +type testHandler struct { + l sync.Mutex + result string +} - if simple { - *result += string([]byte{c}) - } - return c == '.' +// handler is method to be passed to escaping io reader +func (t *testHandler) handler(c byte) bool { + rc := rune(c) + simple := unicode.IsLetter(rc) || + unicode.IsDigit(rc) || + unicode.IsPunct(rc) || + unicode.IsSymbol(rc) + + if simple { + t.l.Lock() + t.result += string([]byte{c}) + t.l.Unlock() } + return c == '.' +} + +// escaped returns all seen escaped characters so far +func (t *testHandler) escaped() string { + t.l.Lock() + defer t.l.Unlock() + + return t.result } // checkEquivalence returns true if parsing input with naive implementation // is equivalent to our reader func checkEquivalenceToNaive(t *testing.T, input string) bool { - nfe := "" - nh := testHandler(&nfe) - expected := naiveEscapeCharacters(input, '~', nh) + nh := &testHandler{} + expected := naiveEscapeCharacters(input, '~', nh.handler) - foundEscaped := "" - h := testHandler(&foundEscaped) + foundH := &testHandler{} var inputReader io.Reader = bytes.NewBufferString(input) inputReader = &arbtiraryReader{ buf: inputReader.(*bytes.Buffer), maxReadOnce: 10, } - filter := NewReader(inputReader, '~', h) + filter := NewReader(inputReader, '~', foundH.handler) var found bytes.Buffer _, err := io.Copy(&found, filter) if err != nil { @@ -136,11 +304,11 @@ func checkEquivalenceToNaive(t *testing.T, input string) bool { return false } - if nfe == foundEscaped && expected == found.String() { + if nh.escaped() == foundH.escaped() && expected == found.String() { return true } - t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped) + t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped()) t.Logf("read differed=%v expected=%s found=%v", expected != found.String(), expected, found.String()) return false @@ -159,15 +327,13 @@ func TestEscapingReader_Generated_EquivalentToReadOnce(t *testing.T) { // checkEquivalenceToReadOnce returns true if parsing input in a single // read matches multiple reads func checkEquivalenceToReadOnce(t *testing.T, input string) bool { - nfe := "" + nh := &testHandler{} var expected bytes.Buffer // getting expected value from read all at once { - h := testHandler(&nfe) - buf := make([]byte, len(input)+5) - inputReader := NewReader(bytes.NewBufferString(input), '~', h) + inputReader := NewReader(bytes.NewBufferString(input), '~', nh.handler) _, err := io.CopyBuffer(&expected, inputReader, buf) if err != nil { t.Logf("unexpected error while reading: %v", err) @@ -175,18 +341,16 @@ func checkEquivalenceToReadOnce(t *testing.T, input string) bool { } } - foundEscaped := "" + foundH := &testHandler{} var found bytes.Buffer // getting found by using arbitrary reader { - h := testHandler(&foundEscaped) - inputReader := &arbtiraryReader{ buf: bytes.NewBufferString(input), maxReadOnce: 10, } - filter := NewReader(inputReader, '~', h) + filter := NewReader(inputReader, '~', foundH.handler) _, err := io.Copy(&found, filter) if err != nil { t.Logf("unexpected error while reading: %v", err) @@ -194,11 +358,11 @@ func checkEquivalenceToReadOnce(t *testing.T, input string) bool { } } - if nfe == foundEscaped && expected.String() == found.String() { + if nh.escaped() == foundH.escaped() && expected.String() == found.String() { return true } - t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped) + t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped()) t.Logf("read differed=%v expected=%s found=%v", expected.String() != found.String(), expected.String(), found.String()) return false @@ -309,3 +473,21 @@ func (r *arbtiraryReader) Read(buf []byte) (int, error) { return r.buf.Read(buf[:l]) } + +func assertEventually(t *testing.T, testFn func() (bool, error)) { + start := time.Now() + var err error + var b bool + for { + if time.Since(start) > 2*time.Second { + assert.Fail(t, "timed out", "error: %v", err) + } + + b, err = testFn() + if b { + return + } + + time.Sleep(50 * time.Millisecond) + } +}