From 4e7736023a73709c66b99844c05fab8127338c19 Mon Sep 17 00:00:00 2001 From: Mahmood Ali Date: Thu, 16 May 2019 14:12:40 -0400 Subject: [PATCH] Add a escaping reader that mimics ssh behavior Adds an escaping reading that mimics ssh handling of input escape sequences. The reader parses chunks to look for \n~ --- helper/escapingio/reader.go | 163 ++++++++++++++++ helper/escapingio/reader_test.go | 319 +++++++++++++++++++++++++++++++ 2 files changed, 482 insertions(+) create mode 100644 helper/escapingio/reader.go create mode 100644 helper/escapingio/reader_test.go diff --git a/helper/escapingio/reader.go b/helper/escapingio/reader.go new file mode 100644 index 000000000..999cceb87 --- /dev/null +++ b/helper/escapingio/reader.go @@ -0,0 +1,163 @@ +package escapingio + +import ( + "io" +) + +// Handler is a callback for handling an escaped char. Reader would skip +// the escape char and passed char if returns true; otherwise, it preserves them +// in output +type Handler func(c byte) bool + +// NewReader returns a reader that escapes the c character (following new lines), +// in the same manner OpenSSH handling, which defaults to `~`. +// +// For illustrative purposes, we use `~` in documentation as a shorthand for escaping character. +// +// If following a new line, reader sees: +// * `~~`, only one is emitted +// * `~.` (or any character), the handler is invoked with the character. +// If handler returns true, `~.` will be skipped; otherwise, it's propagated. +// * `~` and it's the last character in stream, it's propagated +// +// Appearances of `~` when not followed by a new line is propagated unmodified. +func NewReader(r io.Reader, c byte, h Handler) io.Reader { + return &reader{ + impl: r, + escapeChar: c, + state: sLookEscapeChar, + handler: h, + } +} + +// lookState represents the state of reader for what character of `\n~.` sequence +// reader is looking for +type lookState int + +const ( + // sLookNewLine indicates that reader is looking for new line + sLookNewLine lookState = iota + + // sLookEscapeChar indicates that reader is looking for ~ + sLookEscapeChar + + // sLookChar indicates that reader just read `~` is waiting for next character + // before acting + sLookChar +) + +// to ease comments, i'll assume escape character to be `~` +type reader struct { + impl io.Reader + escapeChar uint8 + handler Handler + + state lookState + + // unread is a buffered character for next read if not-nil + unread *byte +} + +func (r *reader) Read(buf []byte) (int, error) { +START: + var n int + var err error + + if r.unread != nil { + // try to return the unread character immediately + // without trying to block for another read + buf[0] = *r.unread + n = 1 + r.unread = nil + } else { + n, err = r.impl.Read(buf) + } + + // when we get to the end, check if we have any unprocessed \n~ + if n == 0 && err != nil { + if r.state == sLookChar && err != nil { + buf[0] = r.escapeChar + n = 1 + } + return n, err + } + + // inspect the state at beginning of read + if r.state == sLookChar { + r.state = sLookNewLine + + // escape character hasn't been emitted yet + if buf[0] == r.escapeChar { + // earlier ~ was sallowed already, so leave this as is + } else if handled := r.handler(buf[0]); handled { + // need to drop a single letter + copy(buf, buf[1:n]) + n-- + } else { + // we need to re-introduce ~ with rest of body + // but be mindful if reintroducing ~ causes buffer to overflow + if n == len(buf) { + // in which case, save it for next read + c := buf[n-1] + r.unread = &c + copy(buf[1:], buf[:n]) + buf[0] = r.escapeChar + } else { + copy(buf[1:], buf[:n]) + buf[0] = r.escapeChar + n++ + } + } + } + + n = r.processBuffer(buf, n) + if n == 0 && err == nil { + goto START + } + + return n, err +} + +// handles escaped character inside body of read buf. +func (r *reader) processBuffer(buf []byte, read int) int { + b := 0 + + for b < read { + + c := buf[b] + if r.state == sLookEscapeChar && r.escapeChar == c { + r.state = sLookEscapeChar + + // are we at the end of read; wait for next read + if b == read-1 { + read-- + r.state = sLookChar + return read + } + + // otherwise peek at next + nc := buf[b+1] + if nc == r.escapeChar { + // repeated ~, only emit one - skip one character + copy(buf[b:], buf[b+1:read]) + read-- + b++ + continue + } else if handled := r.handler(nc); handled { + // need to drop both ~ and letter + copy(buf[b:], buf[b+2:read]) + read -= 2 + continue + } else { + // need to pass output unmodified with ~ and letter + } + } else if c == '\n' || c == '\r' { + r.state = sLookEscapeChar + } else { + r.state = sLookNewLine + } + b++ + } + + return read +} diff --git a/helper/escapingio/reader_test.go b/helper/escapingio/reader_test.go new file mode 100644 index 000000000..762195ec1 --- /dev/null +++ b/helper/escapingio/reader_test.go @@ -0,0 +1,319 @@ +package escapingio + +import ( + "bytes" + "fmt" + "io" + "math/rand" + "reflect" + "regexp" + "strings" + "testing" + "testing/iotest" + "testing/quick" + "unicode" + + "github.com/stretchr/testify/require" +) + +func TestEscapingReader_Static(t *testing.T) { + cases := []struct { + input string + expected string + escaped string + }{ + {"hello", "hello", ""}, + {"he\nllo", "he\nllo", ""}, + {"he~.lo", "he~.lo", ""}, + {"he\n~.rest", "he\nrest", "."}, + {"he\n~.r\n~.est", "he\nr\nest", ".."}, + {"he\n~~r\n~~est", "he\n~r\n~est", ""}, + {"he\n~~r\n~.est", "he\n~r\nest", "."}, + {"he\nr~~est", "he\nr~~est", ""}, + {"he\nr\n~qest", "he\nr\n~qest", "q"}, + {"he\nr\r~qe\r~.st", "he\nr\r~qe\rst", "q."}, + {"~q", "~q", "q"}, + {"~.", "", "."}, + {"m~.", "m~.", ""}, + {"\n~.", "\n", "."}, + {"~", "~", ""}, + {"\r~.", "\r", "."}, + } + + for _, c := range cases { + t.Run("sanity check naive implementation", func(t *testing.T) { + foundEscaped := "" + h := testHandler(&foundEscaped) + + processed := naiveEscapeCharacters(c.input, '~', h) + require.Equal(t, c.expected, processed) + require.Equal(t, c.escaped, foundEscaped) + }) + + t.Run("chunks at a time: "+c.input, func(t *testing.T) { + var found bytes.Buffer + + input := strings.NewReader(c.input) + + foundEscaped := "" + h := testHandler(&foundEscaped) + + filter := NewReader(input, '~', h) + + _, err := io.Copy(&found, filter) + require.NoError(t, err) + + require.Equal(t, c.expected, found.String()) + require.Equal(t, c.escaped, foundEscaped) + }) + + t.Run("1 byte at a time: "+c.input, func(t *testing.T) { + var found bytes.Buffer + + input := iotest.OneByteReader(strings.NewReader(c.input)) + + foundEscaped := "" + h := testHandler(&foundEscaped) + + filter := NewReader(input, '~', h) + _, err := io.Copy(&found, filter) + require.NoError(t, err) + + require.Equal(t, c.expected, found.String()) + require.Equal(t, c.escaped, foundEscaped) + }) + } +} + +func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) { + called := 0 + f := func(v readingInput) bool { + called++ + return checkEquivalenceToNaive(t, string(v)) + } + + require.NoError(t, quick.Check(f, &quick.Config{ + MaxCountScale: 200, + })) + + fmt.Println("CALLED ", called) +} + +// testHandler returns a handler that stores all basic ascii letters in result +// reference. We avoid complicated unicode characters that may cross +// byte boundary +func testHandler(result *string) Handler { + return func(c byte) bool { + rc := rune(c) + simple := unicode.IsLetter(rc) || + unicode.IsDigit(rc) || + unicode.IsPunct(rc) || + unicode.IsSymbol(rc) + + if simple { + *result += string([]byte{c}) + } + return c == '.' + } +} + +// checkEquivalence returns true if parsing input with naive implementation +// is equivalent to our reader +func checkEquivalenceToNaive(t *testing.T, input string) bool { + nfe := "" + nh := testHandler(&nfe) + expected := naiveEscapeCharacters(input, '~', nh) + + foundEscaped := "" + h := testHandler(&foundEscaped) + + var inputReader io.Reader = bytes.NewBufferString(input) + inputReader = &arbtiraryReader{ + buf: inputReader.(*bytes.Buffer), + maxReadOnce: 10, + } + filter := NewReader(inputReader, '~', h) + var found bytes.Buffer + _, err := io.Copy(&found, filter) + if err != nil { + t.Logf("unexpected error while reading: %v", err) + return false + } + + if nfe == foundEscaped && expected == found.String() { + return true + } + + t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped) + t.Logf("read differed=%v expected=%s found=%v", expected != found.String(), expected, found.String()) + return false + +} + +func TestEscapingReader_Generated_EquivalentToReadOnce(t *testing.T) { + called := 0 + f := func(v readingInput) bool { + called++ + return checkEquivalenceToNaive(t, string(v)) + } + + require.NoError(t, quick.Check(f, &quick.Config{ + MaxCountScale: 200, + })) + + fmt.Println("CALLED ", called) +} + +// checkEquivalenceToReadOnce returns true if parsing input in a single +// read matches multiple reads +func checkEquivalenceToReadOnce(t *testing.T, input string) bool { + nfe := "" + var expected bytes.Buffer + + // getting expected value from read all at once + { + h := testHandler(&nfe) + + buf := make([]byte, len(input)+5) + inputReader := NewReader(bytes.NewBufferString(input), '~', h) + _, err := io.CopyBuffer(&expected, inputReader, buf) + if err != nil { + t.Logf("unexpected error while reading: %v", err) + return false + } + } + + foundEscaped := "" + var found bytes.Buffer + + // getting found by using arbitrary reader + { + h := testHandler(&foundEscaped) + + inputReader := &arbtiraryReader{ + buf: bytes.NewBufferString(input), + maxReadOnce: 10, + } + filter := NewReader(inputReader, '~', h) + _, err := io.Copy(&found, filter) + if err != nil { + t.Logf("unexpected error while reading: %v", err) + return false + } + } + + if nfe == foundEscaped && expected.String() == found.String() { + return true + } + + t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped) + t.Logf("read differed=%v expected=%s found=%v", expected.String() != found.String(), expected.String(), found.String()) + return false + +} + +// readingInput is a string with some quick generation capability to +// inject some \n, \n~., \n~q in text +type readingInput string + +func (i readingInput) Generate(rand *rand.Rand, size int) reflect.Value { + v, ok := quick.Value(reflect.TypeOf(""), rand) + if !ok { + panic("couldn't generate a string") + } + + // inject some terminals + var b bytes.Buffer + injectProbabilistically := func() { + p := rand.Float32() + if p < 0.05 { + b.WriteString("\n~.") + } else if p < 0.10 { + b.WriteString("\n~q") + } else if p < 0.15 { + b.WriteString("\n") + } else if p < 0.2 { + b.WriteString("~") + } else if p < 0.25 { + b.WriteString("~~") + } + } + + for _, c := range v.String() { + injectProbabilistically() + b.WriteRune(c) + } + + injectProbabilistically() + + return reflect.ValueOf(readingInput(b.String())) +} + +// naiveEscapeCharacters is a simplified implementation that operates +// on entire unchunked string. Uses regexp implementation. +// +// It differs from the other implementation in handling unicode characters +// proceeding `\n~` +func naiveEscapeCharacters(input string, escapeChar byte, h Handler) string { + reg := regexp.MustCompile(fmt.Sprintf("(\n|\r)%c.", escapeChar)) + + // check first appearances + if len(input) > 1 && input[0] == escapeChar { + if input[1] == escapeChar { + input = input[1:] + } else if h(input[1]) { + input = input[2:] + } else { + // we are good + } + + } + + return reg.ReplaceAllStringFunc(input, func(match string) string { + if len(match) != 3 { + panic(fmt.Errorf("match isn't 3 characters: %s", match)) + } + + c := match[2] + + // ignore some unicode partial codes + ltr := ('a' <= c && c <= 'z') || + ('A' <= c && c <= 'Z') || + ('0' <= c && c <= '9') || + (c == '~' || c == '.' || c == escapeChar) + + if c == escapeChar { + return match[:2] + } else if ltr && h(c) { + return match[:1] + } else { + return match + } + }) +} + +// arbitraryReader is a reader that reads arbitrary length at a time +// to simulate input being read in chunks. +type arbtiraryReader struct { + buf *bytes.Buffer + maxReadOnce int +} + +func (r *arbtiraryReader) Read(buf []byte) (int, error) { + l := r.buf.Len() + if l == 0 || l == 1 { + return r.buf.Read(buf) + } + + if l > r.maxReadOnce { + l = r.maxReadOnce + } + if l != 1 { + l = rand.Intn(l-1) + 1 + } + if l > len(buf) { + l = len(buf) + } + + return r.buf.Read(buf[:l]) +}