lib/protocol: Send Close message on read error (#7141)

This commit is contained in:
Simon Frei 2020-11-27 11:31:20 +01:00 committed by GitHub
parent a9764fc16c
commit bbb22c8c80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 94 additions and 23 deletions

View File

@ -170,6 +170,8 @@ type rawConnection struct {
closeOnce sync.Once
sendCloseOnce sync.Once
compression Compression
loopWG sync.WaitGroup // Need to ensure no leftover routines in testing
}
type asyncResult struct {
@ -244,20 +246,35 @@ func newRawConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, rec
dispatcherLoopStopped: make(chan struct{}),
closed: make(chan struct{}),
compression: compress,
loopWG: sync.WaitGroup{},
}
}
// Start creates the goroutines for sending and receiving of messages. It must
// be called exactly once after creating a connection.
func (c *rawConnection) Start() {
go c.readerLoop()
c.loopWG.Add(5)
go func() {
c.readerLoop()
c.loopWG.Done()
}()
go func() {
err := c.dispatcherLoop()
c.internalClose(err)
c.Close(err)
c.loopWG.Done()
}()
go func() {
c.writerLoop()
c.loopWG.Done()
}()
go func() {
c.pingSender()
c.loopWG.Done()
}()
go func() {
c.pingReceiver()
c.loopWG.Done()
}()
go c.writerLoop()
go c.pingSender()
go c.pingReceiver()
c.startTime = time.Now()
}
@ -410,7 +427,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
state = stateReady
}
if err := c.receiver.ClusterConfig(c.id, *msg); err != nil {
return errors.Wrap(err, "receiver error")
return fmt.Errorf("receiving cluster config: %w", err)
}
case *Index:
@ -422,7 +439,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
return errors.Wrap(err, "protocol error: index")
}
if err := c.handleIndex(*msg); err != nil {
return errors.Wrap(err, "receiver error")
return fmt.Errorf("receiving index: %w", err)
}
state = stateReady
@ -435,7 +452,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
return errors.Wrap(err, "protocol error: index update")
}
if err := c.handleIndexUpdate(*msg); err != nil {
return errors.Wrap(err, "receiver error")
return fmt.Errorf("receiving index update: %w", err)
}
state = stateReady
@ -462,7 +479,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
return fmt.Errorf("protocol error: response message in state %d", state)
}
if err := c.receiver.DownloadProgress(c.id, msg.Folder, msg.Updates); err != nil {
return errors.Wrap(err, "receiver error")
return fmt.Errorf("receiving download progress: %w", err)
}
case *Ping:
@ -474,7 +491,7 @@ func (c *rawConnection) dispatcherLoop() (err error) {
case *Close:
l.Debugln("read Close message")
return errors.New(msg.Reason)
return fmt.Errorf("closed by remote: %v", msg.Reason)
default:
l.Debugf("read unknown message: %+T", msg)

View File

@ -33,8 +33,10 @@ func TestPing(t *testing.T) {
c0 := NewConnection(c0ID, ar, bw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, newTestModel(), "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
c1.ClusterConfig(ClusterConfig{})
@ -57,8 +59,10 @@ func TestClose(t *testing.T) {
c0 := NewConnection(c0ID, ar, bw, m0, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, m1, "name", CompressionAlways)
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
c1.ClusterConfig(ClusterConfig{})
@ -97,8 +101,10 @@ func TestCloseOnBlockingSend(t *testing.T) {
m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start()
defer closeAndWait(c, rw)
wg := sync.WaitGroup{}
@ -149,8 +155,10 @@ func TestCloseRace(t *testing.T) {
c0 := NewConnection(c0ID, ar, bw, m0, "c0", CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, m1, "c1", CompressionNever)
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
c1.ClusterConfig(ClusterConfig{})
@ -184,8 +192,10 @@ func TestCloseRace(t *testing.T) {
func TestClusterConfigFirst(t *testing.T) {
m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start()
defer closeAndWait(c, rw)
select {
case c.outbox <- asyncMessage{&Ping{}, nil}:
@ -234,8 +244,10 @@ func TestCloseTimeout(t *testing.T) {
m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start()
defer closeAndWait(c, rw)
done := make(chan struct{})
go func() {
@ -852,8 +864,10 @@ func TestSha256OfEmptyBlock(t *testing.T) {
func TestClusterConfigAfterClose(t *testing.T) {
m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.BlockingRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c.Start()
defer closeAndWait(c, rw)
c.internalClose(errManual)
@ -874,11 +888,13 @@ func TestDispatcherToCloseDeadlock(t *testing.T) {
// Verify that we don't deadlock when calling Close() from within one of
// the model callbacks (ClusterConfig).
m := newTestModel()
c := NewConnection(c0ID, &testutils.BlockingRW{}, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, m, "name", CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
m.ccFn = func(devID DeviceID, cc ClusterConfig) {
c.Close(errManual)
}
c.Start()
defer closeAndWait(c, rw)
c.inbox <- &ClusterConfig{}
@ -945,3 +961,18 @@ func TestIndexIDString(t *testing.T) {
t.Error(i.String())
}
}
func closeAndWait(c Connection, closers ...io.Closer) {
for _, closer := range closers {
closer.Close()
}
var raw *rawConnection
switch i := c.(type) {
case wireFormatConnection:
raw = i.Connection.(*rawConnection)
case *rawConnection:
raw = i
}
raw.internalClose(ErrClosed)
raw.loopWG.Wait()
}

View File

@ -6,17 +6,40 @@
package testutils
// BlockingRW implements io.Reader and Writer but never returns when called
type BlockingRW struct{ nilChan chan struct{} }
import (
"errors"
"sync"
)
func (rw *BlockingRW) Read(p []byte) (n int, err error) {
<-rw.nilChan
return
var ErrClosed = errors.New("closed")
// BlockingRW implements io.Reader, Writer and Closer, but only returns when closed
type BlockingRW struct {
c chan struct{}
closeOnce sync.Once
}
func (rw *BlockingRW) Write(p []byte) (n int, err error) {
<-rw.nilChan
return
func NewBlockingRW() *BlockingRW {
return &BlockingRW{
c: make(chan struct{}),
closeOnce: sync.Once{},
}
}
func (rw *BlockingRW) Read(p []byte) (int, error) {
<-rw.c
return 0, ErrClosed
}
func (rw *BlockingRW) Write(p []byte) (int, error) {
<-rw.c
return 0, ErrClosed
}
func (rw *BlockingRW) Close() error {
rw.closeOnce.Do(func() {
close(rw.c)
})
return nil
}
// NoopRW implements io.Reader and Writer but never returns when called