diff --git a/helper/escapingio/reader_test.go b/helper/escapingio/reader_test.go index 807ee31f9..28ca2fb59 100644 --- a/helper/escapingio/reader_test.go +++ b/helper/escapingio/reader_test.go @@ -8,6 +8,7 @@ import ( "reflect" "regexp" "strings" + "sync" "testing" "testing/iotest" "testing/quick" @@ -42,12 +43,11 @@ func TestEscapingReader_Static(t *testing.T) { for _, c := range cases { t.Run("sanity check naive implementation", func(t *testing.T) { - foundEscaped := "" - h := testHandler(&foundEscaped) + h := &testHandler{} - processed := naiveEscapeCharacters(c.input, '~', h) + processed := naiveEscapeCharacters(c.input, '~', h.handler) require.Equal(t, c.expected, processed) - require.Equal(t, c.escaped, foundEscaped) + require.Equal(t, c.escaped, h.escaped()) }) t.Run("chunks at a time: "+c.input, func(t *testing.T) { @@ -55,16 +55,15 @@ func TestEscapingReader_Static(t *testing.T) { input := strings.NewReader(c.input) - foundEscaped := "" - h := testHandler(&foundEscaped) + h := &testHandler{} - filter := NewReader(input, '~', h) + filter := NewReader(input, '~', h.handler) _, err := io.Copy(&found, filter) require.NoError(t, err) require.Equal(t, c.expected, found.String()) - require.Equal(t, c.escaped, foundEscaped) + require.Equal(t, c.escaped, h.escaped()) }) t.Run("1 byte at a time: "+c.input, func(t *testing.T) { @@ -72,15 +71,14 @@ func TestEscapingReader_Static(t *testing.T) { input := iotest.OneByteReader(strings.NewReader(c.input)) - foundEscaped := "" - h := testHandler(&foundEscaped) + h := &testHandler{} - filter := NewReader(input, '~', h) + filter := NewReader(input, '~', h.handler) _, err := io.Copy(&found, filter) require.NoError(t, err) require.Equal(t, c.expected, found.String()) - require.Equal(t, c.escaped, foundEscaped) + require.Equal(t, c.escaped, h.escaped()) }) } } @@ -95,40 +93,52 @@ func TestEscapingReader_Generated_EquivalentToNaive(t *testing.T) { })) } -// testHandler returns a handler that stores all basic ascii letters in result -// reference. We avoid complicated unicode characters that may cross -// byte boundary -func testHandler(result *string) Handler { - return func(c byte) bool { - rc := rune(c) - simple := unicode.IsLetter(rc) || - unicode.IsDigit(rc) || - unicode.IsPunct(rc) || - unicode.IsSymbol(rc) +// testHandler is a conveneient struct for finding "escaped" ascii letters +// in escaping reader. +// We avoid complicated unicode characters that may cross byte boundary +type testHandler struct { + l sync.Mutex + result string +} - if simple { - *result += string([]byte{c}) - } - return c == '.' +// handler is method to be passed to escaping io reader +func (t *testHandler) handler(c byte) bool { + rc := rune(c) + simple := unicode.IsLetter(rc) || + unicode.IsDigit(rc) || + unicode.IsPunct(rc) || + unicode.IsSymbol(rc) + + if simple { + t.l.Lock() + t.result += string([]byte{c}) + t.l.Unlock() } + return c == '.' +} + +// escaped returns all seen escaped characters so far +func (t *testHandler) escaped() string { + t.l.Lock() + defer t.l.Unlock() + + return t.result } // checkEquivalence returns true if parsing input with naive implementation // is equivalent to our reader func checkEquivalenceToNaive(t *testing.T, input string) bool { - nfe := "" - nh := testHandler(&nfe) - expected := naiveEscapeCharacters(input, '~', nh) + nh := &testHandler{} + expected := naiveEscapeCharacters(input, '~', nh.handler) - foundEscaped := "" - h := testHandler(&foundEscaped) + foundH := &testHandler{} var inputReader io.Reader = bytes.NewBufferString(input) inputReader = &arbtiraryReader{ buf: inputReader.(*bytes.Buffer), maxReadOnce: 10, } - filter := NewReader(inputReader, '~', h) + filter := NewReader(inputReader, '~', foundH.handler) var found bytes.Buffer _, err := io.Copy(&found, filter) if err != nil { @@ -136,11 +146,11 @@ func checkEquivalenceToNaive(t *testing.T, input string) bool { return false } - if nfe == foundEscaped && expected == found.String() { + if nh.escaped() == foundH.escaped() && expected == found.String() { return true } - t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped) + t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped()) t.Logf("read differed=%v expected=%s found=%v", expected != found.String(), expected, found.String()) return false @@ -159,15 +169,13 @@ func TestEscapingReader_Generated_EquivalentToReadOnce(t *testing.T) { // checkEquivalenceToReadOnce returns true if parsing input in a single // read matches multiple reads func checkEquivalenceToReadOnce(t *testing.T, input string) bool { - nfe := "" + nh := &testHandler{} var expected bytes.Buffer // getting expected value from read all at once { - h := testHandler(&nfe) - buf := make([]byte, len(input)+5) - inputReader := NewReader(bytes.NewBufferString(input), '~', h) + inputReader := NewReader(bytes.NewBufferString(input), '~', nh.handler) _, err := io.CopyBuffer(&expected, inputReader, buf) if err != nil { t.Logf("unexpected error while reading: %v", err) @@ -175,18 +183,16 @@ func checkEquivalenceToReadOnce(t *testing.T, input string) bool { } } - foundEscaped := "" + foundH := &testHandler{} var found bytes.Buffer // getting found by using arbitrary reader { - h := testHandler(&foundEscaped) - inputReader := &arbtiraryReader{ buf: bytes.NewBufferString(input), maxReadOnce: 10, } - filter := NewReader(inputReader, '~', h) + filter := NewReader(inputReader, '~', foundH.handler) _, err := io.Copy(&found, filter) if err != nil { t.Logf("unexpected error while reading: %v", err) @@ -194,11 +200,11 @@ func checkEquivalenceToReadOnce(t *testing.T, input string) bool { } } - if nfe == foundEscaped && expected.String() == found.String() { + if nh.escaped() == foundH.escaped() && expected.String() == found.String() { return true } - t.Logf("escaped differed=%v expected=%v found=%v", nfe != foundEscaped, nfe, foundEscaped) + t.Logf("escaped differed=%v expected=%v found=%v", nh.escaped() != foundH.escaped(), nh.escaped(), foundH.escaped()) t.Logf("read differed=%v expected=%s found=%v", expected.String() != found.String(), expected.String(), found.String()) return false