Update yamux library to pick up memory performance optimization

This commit is contained in:
Preetha Appan
2018-03-14 15:14:52 -05:00
parent 4252ffe71c
commit 1480cd5f03
4 changed files with 95 additions and 36 deletions

View File

@@ -123,6 +123,12 @@ func (s *Session) IsClosed() bool {
}
}
// CloseChan returns a read-only channel which is closed as
// soon as the session is closed.
func (s *Session) CloseChan() <-chan struct{} {
return s.shutdownCh
}
// NumStreams returns the number of currently open streams
func (s *Session) NumStreams() int {
s.streamLock.Lock()
@@ -323,8 +329,17 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
// potential shutdown. Since there's the expectation that sends can happen
// in a timely manner, we enforce the connection write timeout here.
func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
select {
case <-timer.C:
default:
}
timerPool.Put(t)
}()
ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
select {
@@ -349,8 +364,17 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e
// the send happens right here, we enforce the connection write timeout if we
// can't queue the header to be sent.
func (s *Session) sendNoWait(hdr header) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
select {
case <-timer.C:
default:
}
timerPool.Put(t)
}()
select {
case s.sendCh <- sendReady{Hdr: hdr}:
@@ -408,11 +432,20 @@ func (s *Session) recv() {
}
}
// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
var (
handlers = []func(*Session, header) error{
typeData: (*Session).handleStreamMessage,
typeWindowUpdate: (*Session).handleStreamMessage,
typePing: (*Session).handlePing,
typeGoAway: (*Session).handleGoAway,
}
)
// recvLoop continues to receive data until a fatal error is encountered
func (s *Session) recvLoop() error {
defer close(s.recvDoneCh)
hdr := header(make([]byte, headerSize))
var handler func(header) error
for {
// Read the header
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
@@ -428,22 +461,12 @@ func (s *Session) recvLoop() error {
return ErrInvalidVersion
}
// Switch on the type
switch hdr.MsgType() {
case typeData:
handler = s.handleStreamMessage
case typeWindowUpdate:
handler = s.handleStreamMessage
case typeGoAway:
handler = s.handleGoAway
case typePing:
handler = s.handlePing
default:
mt := hdr.MsgType()
if mt < typeData || mt > typeGoAway {
return ErrInvalidMsgType
}
// Invoke the handler
if err := handler(hdr); err != nil {
if err := handlers[mt](s, hdr); err != nil {
return err
}
}

View File

@@ -47,8 +47,8 @@ type Stream struct {
recvNotifyCh chan struct{}
sendNotifyCh chan struct{}
readDeadline time.Time
writeDeadline time.Time
readDeadline atomic.Value // time.Time
writeDeadline atomic.Value // time.Time
}
// newStream is used to construct a new stream within
@@ -67,6 +67,8 @@ func newStream(session *Session, id uint32, state streamState) *Stream {
recvNotifyCh: make(chan struct{}, 1),
sendNotifyCh: make(chan struct{}, 1),
}
s.readDeadline.Store(time.Time{})
s.writeDeadline.Store(time.Time{})
return s
}
@@ -91,10 +93,13 @@ START:
case streamRemoteClose:
fallthrough
case streamClosed:
s.recvLock.Lock()
if s.recvBuf == nil || s.recvBuf.Len() == 0 {
s.recvLock.Unlock()
s.stateLock.Unlock()
return 0, io.EOF
}
s.recvLock.Unlock()
case streamReset:
s.stateLock.Unlock()
return 0, ErrConnectionReset
@@ -118,12 +123,18 @@ START:
WAIT:
var timeout <-chan time.Time
if !s.readDeadline.IsZero() {
delay := s.readDeadline.Sub(time.Now())
timeout = time.After(delay)
var timer *time.Timer
readDeadline := s.readDeadline.Load().(time.Time)
if !readDeadline.IsZero() {
delay := readDeadline.Sub(time.Now())
timer = time.NewTimer(delay)
timeout = timer.C
}
select {
case <-s.recvNotifyCh:
if timer != nil {
timer.Stop()
}
goto START
case <-timeout:
return 0, ErrTimeout
@@ -180,7 +191,7 @@ START:
// Send the header
s.sendHdr.encode(typeData, flags, s.id, max)
if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
return 0, err
}
@@ -192,8 +203,9 @@ START:
WAIT:
var timeout <-chan time.Time
if !s.writeDeadline.IsZero() {
delay := s.writeDeadline.Sub(time.Now())
writeDeadline := s.writeDeadline.Load().(time.Time)
if !writeDeadline.IsZero() {
delay := writeDeadline.Sub(time.Now())
timeout = time.After(delay)
}
select {
@@ -230,18 +242,25 @@ func (s *Stream) sendWindowUpdate() error {
// Determine the delta update
max := s.session.config.MaxStreamWindowSize
delta := max - atomic.LoadUint32(&s.recvWindow)
var bufLen uint32
s.recvLock.Lock()
if s.recvBuf != nil {
bufLen = uint32(s.recvBuf.Len())
}
delta := (max - bufLen) - s.recvWindow
// Determine the flags if any
flags := s.sendFlags()
// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
return nil
}
// Update our window
atomic.AddUint32(&s.recvWindow, delta)
s.recvWindow += delta
s.recvLock.Unlock()
// Send the header
s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
@@ -384,16 +403,18 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
if length == 0 {
return nil
}
if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
return ErrRecvWindowExceeded
}
// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}
// Copy into buffer
s.recvLock.Lock()
if length > s.recvWindow {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
return ErrRecvWindowExceeded
}
if s.recvBuf == nil {
// Allocate the receive buffer just-in-time to fit the full data frame.
// This way we can read in the whole packet without further allocations.
@@ -406,7 +427,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
}
// Decrement the receive window
atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
s.recvWindow -= length
s.recvLock.Unlock()
// Unblock any readers
@@ -427,13 +448,13 @@ func (s *Stream) SetDeadline(t time.Time) error {
// SetReadDeadline sets the deadline for future Read calls.
func (s *Stream) SetReadDeadline(t time.Time) error {
s.readDeadline = t
s.readDeadline.Store(t)
return nil
}
// SetWriteDeadline sets the deadline for future Write calls
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.writeDeadline = t
s.writeDeadline.Store(t)
return nil
}

View File

@@ -1,5 +1,20 @@
package yamux
import (
"sync"
"time"
)
var (
timerPool = &sync.Pool{
New: func() interface{} {
timer := time.NewTimer(time.Hour * 1e6)
timer.Stop()
return timer
},
}
)
// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {

2
vendor/vendor.json vendored
View File

@@ -173,7 +173,7 @@
{"path":"github.com/hashicorp/vault/helper/compressutil","checksumSHA1":"au+CDkddC4sVFV15UaPiI7FvSw0=","revision":"1fd46cbcb10569bd205c3f662e7a4f16f1e69056","revisionTime":"2017-08-11T01:28:18Z"},
{"path":"github.com/hashicorp/vault/helper/jsonutil","checksumSHA1":"yUiSTPf0QUuL2r/81sjuytqBoeQ=","revision":"0c3e14f047aede0a70256e1e8b321610910b246e","revisionTime":"2017-08-01T15:50:41Z"},
{"path":"github.com/hashicorp/vault/helper/parseutil","checksumSHA1":"GGveKvOwScWGZAAnupzpyw+0Jko=","revision":"1fd46cbcb10569bd205c3f662e7a4f16f1e69056","revisionTime":"2017-08-11T01:28:18Z"},
{"path":"github.com/hashicorp/yamux","checksumSHA1":"VMaF3Q7RIrRzvbnPbqxuSLryOvc=","revision":"badf81fca035b8ebac61b5ab83330b72541056f4","revisionTime":"2016-06-09T13:59:02Z"},
{"path":"github.com/hashicorp/yamux","checksumSHA1":"NnWv17i1tpvBNJtpdRRWpE6j4LY=","revision":"2658be15c5f05e76244154714161f17e3e77de2e","revisionTime":"2018-03-14T20:07:45Z"},
{"path":"github.com/hpcloud/tail/util","checksumSHA1":"0xM336Lb25URO/1W1/CtGoRygVU=","revision":"37f4271387456dd1bf82ab1ad9229f060cc45386","revisionTime":"2017-08-14T16:06:53Z"},
{"path":"github.com/hpcloud/tail/watch","checksumSHA1":"TP4OAv5JMtzj2TB6OQBKqauaKDc=","revision":"37f4271387456dd1bf82ab1ad9229f060cc45386","revisionTime":"2017-08-14T16:06:53Z"},
{"path":"github.com/jmespath/go-jmespath","comment":"0.2.2-2-gc01cf91","revision":"c01cf91b011868172fdcd9f41838e80c9d716264"},