lib: Handle adding enc folders on an existing conn (fixes #7509) (#7510)

This commit is contained in:
Simon Frei 2021-03-22 21:50:19 +01:00 committed by GitHub
parent f7929229c8
commit 924b96856f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 160 additions and 87 deletions

View File

@ -335,13 +335,7 @@ func (s *service) handle(ctx context.Context) error {
isLAN := s.isLAN(c.RemoteAddr())
rd, wr := s.limiter.getLimiters(remoteID, c, isLAN)
var protoConn protocol.Connection
passwords := s.cfg.FolderPasswords(remoteID)
if len(passwords) > 0 {
protoConn = protocol.NewEncryptedConnection(passwords, remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
} else {
protoConn = protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression)
}
protoConn := protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression, s.cfg.FolderPasswords(remoteID))
l.Infof("Established secure connection to %s at %s", remoteID, c)

View File

@ -33,7 +33,7 @@ func newFakeConnection(id protocol.DeviceID, model Model) *fakeConnection {
})
f.IDReturns(id)
f.CloseCalls(func(err error) {
model.Closed(f, err)
model.Closed(id, err)
f.ClosedReturns(true)
})
return f

View File

@ -43,10 +43,10 @@ type Model struct {
arg1 string
arg2 string
}
ClosedStub func(protocol.Connection, error)
ClosedStub func(protocol.DeviceID, error)
closedMutex sync.RWMutex
closedArgsForCall []struct {
arg1 protocol.Connection
arg1 protocol.DeviceID
arg2 error
}
ClusterConfigStub func(protocol.DeviceID, protocol.ClusterConfig) error
@ -684,10 +684,10 @@ func (fake *Model) BringToFrontArgsForCall(i int) (string, string) {
return argsForCall.arg1, argsForCall.arg2
}
func (fake *Model) Closed(arg1 protocol.Connection, arg2 error) {
func (fake *Model) Closed(arg1 protocol.DeviceID, arg2 error) {
fake.closedMutex.Lock()
fake.closedArgsForCall = append(fake.closedArgsForCall, struct {
arg1 protocol.Connection
arg1 protocol.DeviceID
arg2 error
}{arg1, arg2})
stub := fake.ClosedStub
@ -704,13 +704,13 @@ func (fake *Model) ClosedCallCount() int {
return len(fake.closedArgsForCall)
}
func (fake *Model) ClosedCalls(stub func(protocol.Connection, error)) {
func (fake *Model) ClosedCalls(stub func(protocol.DeviceID, error)) {
fake.closedMutex.Lock()
defer fake.closedMutex.Unlock()
fake.ClosedStub = stub
}
func (fake *Model) ClosedArgsForCall(i int) (protocol.Connection, error) {
func (fake *Model) ClosedArgsForCall(i int) (protocol.DeviceID, error) {
fake.closedMutex.RLock()
defer fake.closedMutex.RUnlock()
argsForCall := fake.closedArgsForCall[i]

View File

@ -293,7 +293,7 @@ func (m *model) initFolders(cfg config.Configuration) error {
ignoredDevices := observedDeviceSet(m.cfg.IgnoredDevices())
m.cleanPending(cfg.DeviceMap(), cfg.FolderMap(), ignoredDevices, nil)
m.resendClusterConfig(clusterConfigDevices.AsSlice())
m.sendClusterConfig(clusterConfigDevices.AsSlice())
return nil
}
@ -1510,7 +1510,7 @@ func (m *model) ccCheckEncryption(fcfg config.FolderConfiguration, folderDevice
m.fmut.Unlock()
// We can only announce ourselfs once we have the token,
// thus we need to resend CCs now that we have it.
m.resendClusterConfig(fcfg.DeviceIDs())
m.sendClusterConfig(fcfg.DeviceIDs())
return nil
}
}
@ -1520,7 +1520,7 @@ func (m *model) ccCheckEncryption(fcfg config.FolderConfiguration, folderDevice
return nil
}
func (m *model) resendClusterConfig(ids []protocol.DeviceID) {
func (m *model) sendClusterConfig(ids []protocol.DeviceID) {
if len(ids) == 0 {
return
}
@ -1534,7 +1534,8 @@ func (m *model) resendClusterConfig(ids []protocol.DeviceID) {
m.pmut.RUnlock()
// Generating cluster-configs acquires fmut -> must happen outside of pmut.
for _, conn := range ccConns {
cm := m.generateClusterConfig(conn.ID())
cm, passwords := m.generateClusterConfig(conn.ID())
conn.SetFolderPasswords(passwords)
go conn.ClusterConfig(cm)
}
}
@ -1728,9 +1729,7 @@ func (m *model) introduceDevice(device protocol.Device, introducerCfg config.Dev
}
// Closed is called when a connection has been closed
func (m *model) Closed(conn protocol.Connection, err error) {
device := conn.ID()
func (m *model) Closed(device protocol.DeviceID, err error) {
m.pmut.Lock()
conn, ok := m.conn[device]
if !ok {
@ -2247,7 +2246,8 @@ func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) {
m.pmut.Unlock()
// Acquires fmut, so has to be done outside of pmut.
cm := m.generateClusterConfig(deviceID)
cm, passwords := m.generateClusterConfig(deviceID)
conn.SetFolderPasswords(passwords)
conn.ClusterConfig(cm)
if (device.Name == "" || m.cfg.Options().OverwriteRemoteDevNames) && hello.DeviceName != "" {
@ -2407,15 +2407,17 @@ func (m *model) numHashers(folder string) int {
return 1
}
// generateClusterConfig returns a ClusterConfigMessage that is correct for
// the given peer device
func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.ClusterConfig {
// generateClusterConfig returns a ClusterConfigMessage that is correct and the
// set of folder passwords for the given peer device
func (m *model) generateClusterConfig(device protocol.DeviceID) (protocol.ClusterConfig, map[string]string) {
var message protocol.ClusterConfig
m.fmut.RLock()
defer m.fmut.RUnlock()
for _, folderCfg := range m.cfg.FolderList() {
folders := m.cfg.FolderList()
passwords := make(map[string]string, len(folders))
for _, folderCfg := range folders {
if !folderCfg.SharedWith(device) {
continue
}
@ -2448,8 +2450,8 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster
// another cluster config once the folder is started.
protocolFolder.Paused = folderCfg.Paused || fs == nil
for _, device := range folderCfg.Devices {
deviceCfg, _ := m.cfg.Device(device.DeviceID)
for _, folderDevice := range folderCfg.Devices {
deviceCfg, _ := m.cfg.Device(folderDevice.DeviceID)
protocolDevice := protocol.Device{
ID: deviceCfg.DeviceID,
@ -2462,8 +2464,11 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster
if deviceCfg.DeviceID == m.id && hasEncryptionToken {
protocolDevice.EncryptionPasswordToken = encryptionToken
} else if device.EncryptionPassword != "" {
protocolDevice.EncryptionPasswordToken = protocol.PasswordToken(folderCfg.ID, device.EncryptionPassword)
} else if folderDevice.EncryptionPassword != "" {
protocolDevice.EncryptionPasswordToken = protocol.PasswordToken(folderCfg.ID, folderDevice.EncryptionPassword)
if folderDevice.DeviceID == device {
passwords[folderCfg.ID] = folderDevice.EncryptionPassword
}
}
if fs != nil {
@ -2482,7 +2487,7 @@ func (m *model) generateClusterConfig(device protocol.DeviceID) protocol.Cluster
message.Folders = append(message.Folders, protocolFolder)
}
return message
return message, passwords
}
func (m *model) State(folder string) (string, time.Time, error) {
@ -2891,7 +2896,7 @@ func (m *model) CommitConfiguration(from, to config.Configuration) bool {
}
m.pmut.RUnlock()
// Generating cluster-configs acquires fmut -> must happen outside of pmut.
m.resendClusterConfig(clusterConfigDevices.AsSlice())
m.sendClusterConfig(clusterConfigDevices.AsSlice())
ignoredDevices := observedDeviceSet(to.IgnoredDevices)
m.cleanPending(toDevices, toFolders, ignoredDevices, removedFolders)

View File

@ -341,7 +341,7 @@ func TestDeviceRename(t *testing.T) {
t.Errorf("Device already has a name")
}
m.Closed(conn, protocol.ErrTimeout)
m.Closed(conn.ID(), protocol.ErrTimeout)
hello.DeviceName = "tester"
m.AddConnection(conn, hello)
@ -349,7 +349,7 @@ func TestDeviceRename(t *testing.T) {
t.Errorf("Device did not get a name")
}
m.Closed(conn, protocol.ErrTimeout)
m.Closed(conn.ID(), protocol.ErrTimeout)
hello.DeviceName = "tester2"
m.AddConnection(conn, hello)
@ -367,7 +367,7 @@ func TestDeviceRename(t *testing.T) {
t.Errorf("Device name not saved in config")
}
m.Closed(conn, protocol.ErrTimeout)
m.Closed(conn.ID(), protocol.ErrTimeout)
waiter, err := cfg.Modify(func(cfg *config.Configuration) {
cfg.Options.OverwriteRemoteDevNames = true
@ -428,7 +428,7 @@ func TestClusterConfig(t *testing.T) {
m.ServeBackground()
defer cleanupModel(m)
cm := m.generateClusterConfig(device2)
cm, _ := m.generateClusterConfig(device2)
if l := len(cm.Folders); l != 2 {
t.Fatalf("Incorrect number of folders %d != 2", l)
@ -853,7 +853,7 @@ func TestIssue4897(t *testing.T) {
defer cleanupModel(m)
cancel()
cm := m.generateClusterConfig(device1)
cm, _ := m.generateClusterConfig(device1)
if l := len(cm.Folders); l != 1 {
t.Errorf("Cluster config contains %v folders, expected 1", l)
}
@ -873,7 +873,7 @@ func TestIssue5063(t *testing.T) {
for _, c := range m.conn {
conn := c.(*fakeConnection)
conn.CloseCalls(func(_ error) {})
defer m.Closed(c, errStopped) // to unblock deferred m.Stop()
defer m.Closed(c.ID(), errStopped) // to unblock deferred m.Stop()
}
m.pmut.Unlock()
@ -2428,8 +2428,8 @@ func TestNoRequestsFromPausedDevices(t *testing.T) {
t.Errorf("should have two available")
}
m.Closed(newFakeConnection(device1, m), errDeviceUnknown)
m.Closed(newFakeConnection(device2, m), errDeviceUnknown)
m.Closed(device1, errDeviceUnknown)
m.Closed(device2, errDeviceUnknown)
avail = m.testAvailability("default", file, file.Blocks[0])
if len(avail) != 0 {
@ -3171,7 +3171,7 @@ func TestConnCloseOnRestart(t *testing.T) {
br := &testutils.BlockingRW{}
nw := &testutils.NoopRW{}
m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, new(protocolmocks.ConnectionInfo), protocol.CompressionNever), protocol.Hello{})
m.AddConnection(protocol.NewConnection(device1, br, nw, testutils.NoopCloser{}, m, new(protocolmocks.ConnectionInfo), protocol.CompressionNever, nil), protocol.Hello{})
m.pmut.RLock()
if len(m.closed) != 1 {
t.Fatalf("Expected just one conn (len(m.conn) == %v)", len(m.conn))
@ -4142,7 +4142,7 @@ func TestCCFolderNotRunning(t *testing.T) {
defer cleanupModelAndRemoveDir(m, tfs.URI())
// A connection can happen before all the folders are started.
cc := m.generateClusterConfig(device1)
cc, _ := m.generateClusterConfig(device1)
if l := len(cc.Folders); l != 1 {
t.Fatalf("Expected 1 folder in CC, got %v", l)
}

View File

@ -60,9 +60,9 @@ func benchmarkRequestsTLS(b *testing.B, conn0, conn1 net.Conn) {
func benchmarkRequestsConnPair(b *testing.B, conn0, conn1 net.Conn) {
// Start up Connections on them
c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata)
c0 := NewConnection(LocalDeviceID, conn0, conn0, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata, nil)
c0.Start()
c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata)
c1 := NewConnection(LocalDeviceID, conn1, conn1, testutils.NoopCloser{}, new(fakeModel), new(mockedConnectionInfo), CompressionMetadata, nil)
c1.Start()
// Satisfy the assertions in the protocol by sending an initial cluster config
@ -188,7 +188,7 @@ func (m *fakeModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) error
return nil
}
func (m *fakeModel) Closed(conn Connection, err error) {
func (m *fakeModel) Closed(DeviceID, error) {
}
func (m *fakeModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error {

View File

@ -49,7 +49,7 @@ func (t *TestModel) Request(deviceID DeviceID, folder, name string, blockNo, siz
return &fakeRequestResponse{buf}, nil
}
func (t *TestModel) Closed(conn Connection, err error) {
func (t *TestModel) Closed(_ DeviceID, err error) {
t.closedErr = err
close(t.closedCh)
}

View File

@ -14,6 +14,7 @@ import (
"fmt"
"io"
"strings"
"sync"
"time"
"github.com/gogo/protobuf/proto"
@ -41,11 +42,11 @@ const (
// must decrypt those and answer requests by encrypting the data.
type encryptedModel struct {
model Model
folderKeys map[string]*[keySize]byte // folder ID -> key
folderKeys *folderKeyRegistry
}
func (e encryptedModel) Index(deviceID DeviceID, folder string, files []FileInfo) error {
if folderKey, ok := e.folderKeys[folder]; ok {
if folderKey, ok := e.folderKeys.get(folder); ok {
// incoming index data to be decrypted
if err := decryptFileInfos(files, folderKey); err != nil {
return err
@ -55,7 +56,7 @@ func (e encryptedModel) Index(deviceID DeviceID, folder string, files []FileInfo
}
func (e encryptedModel) IndexUpdate(deviceID DeviceID, folder string, files []FileInfo) error {
if folderKey, ok := e.folderKeys[folder]; ok {
if folderKey, ok := e.folderKeys.get(folder); ok {
// incoming index data to be decrypted
if err := decryptFileInfos(files, folderKey); err != nil {
return err
@ -65,7 +66,7 @@ func (e encryptedModel) IndexUpdate(deviceID DeviceID, folder string, files []Fi
}
func (e encryptedModel) Request(deviceID DeviceID, folder, name string, blockNo, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error) {
folderKey, ok := e.folderKeys[folder]
folderKey, ok := e.folderKeys.get(folder)
if !ok {
return e.model.Request(deviceID, folder, name, blockNo, size, offset, hash, weakHash, fromTemporary)
}
@ -123,7 +124,7 @@ func (e encryptedModel) Request(deviceID DeviceID, folder, name string, blockNo,
}
func (e encryptedModel) DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error {
if _, ok := e.folderKeys[folder]; !ok {
if _, ok := e.folderKeys.get(folder); !ok {
return e.model.DownloadProgress(deviceID, folder, updates)
}
@ -135,42 +136,46 @@ func (e encryptedModel) ClusterConfig(deviceID DeviceID, config ClusterConfig) e
return e.model.ClusterConfig(deviceID, config)
}
func (e encryptedModel) Closed(conn Connection, err error) {
e.model.Closed(conn, err)
func (e encryptedModel) Closed(device DeviceID, err error) {
e.model.Closed(device, err)
}
// The encryptedConnection sits between the model and the encrypted device. It
// encrypts outgoing metadata and decrypts incoming responses.
type encryptedConnection struct {
ConnectionInfo
conn Connection
folderKeys map[string]*[keySize]byte // folder ID -> key
conn *rawConnection
folderKeys *folderKeyRegistry
}
func (e encryptedConnection) Start() {
e.conn.Start()
}
func (e encryptedConnection) SetFolderPasswords(passwords map[string]string) {
e.folderKeys.setPasswords(passwords)
}
func (e encryptedConnection) ID() DeviceID {
return e.conn.ID()
}
func (e encryptedConnection) Index(ctx context.Context, folder string, files []FileInfo) error {
if folderKey, ok := e.folderKeys[folder]; ok {
if folderKey, ok := e.folderKeys.get(folder); ok {
encryptFileInfos(files, folderKey)
}
return e.conn.Index(ctx, folder, files)
}
func (e encryptedConnection) IndexUpdate(ctx context.Context, folder string, files []FileInfo) error {
if folderKey, ok := e.folderKeys[folder]; ok {
if folderKey, ok := e.folderKeys.get(folder); ok {
encryptFileInfos(files, folderKey)
}
return e.conn.IndexUpdate(ctx, folder, files)
}
func (e encryptedConnection) Request(ctx context.Context, folder string, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error) {
folderKey, ok := e.folderKeys[folder]
folderKey, ok := e.folderKeys.get(folder)
if !ok {
return e.conn.Request(ctx, folder, name, blockNo, offset, size, hash, weakHash, fromTemporary)
}
@ -205,7 +210,7 @@ func (e encryptedConnection) Request(ctx context.Context, folder string, name st
}
func (e encryptedConnection) DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate) {
if _, ok := e.folderKeys[folder]; !ok {
if _, ok := e.folderKeys.get(folder); !ok {
e.conn.DownloadProgress(ctx, folder, updates)
}
@ -590,3 +595,27 @@ func isEncryptedParentFromComponents(pathComponents []string) bool {
}
return true
}
type folderKeyRegistry struct {
keys map[string]*[keySize]byte // folder ID -> key
mut sync.RWMutex
}
func newFolderKeyRegistry(passwords map[string]string) *folderKeyRegistry {
return &folderKeyRegistry{
keys: keysFromPasswords(passwords),
}
}
func (r *folderKeyRegistry) get(folder string) (*[keySize]byte, bool) {
r.mut.RLock()
key, ok := r.keys[folder]
r.mut.RUnlock()
return key, ok
}
func (r *folderKeyRegistry) setPasswords(passwords map[string]string) {
r.mut.Lock()
r.keys = keysFromPasswords(passwords)
r.mut.Unlock()
}

View File

@ -135,6 +135,11 @@ type Connection struct {
result1 []byte
result2 error
}
SetFolderPasswordsStub func(map[string]string)
setFolderPasswordsMutex sync.RWMutex
setFolderPasswordsArgsForCall []struct {
arg1 map[string]string
}
StartStub func()
startMutex sync.RWMutex
startArgsForCall []struct {
@ -817,6 +822,38 @@ func (fake *Connection) RequestReturnsOnCall(i int, result1 []byte, result2 erro
}{result1, result2}
}
func (fake *Connection) SetFolderPasswords(arg1 map[string]string) {
fake.setFolderPasswordsMutex.Lock()
fake.setFolderPasswordsArgsForCall = append(fake.setFolderPasswordsArgsForCall, struct {
arg1 map[string]string
}{arg1})
stub := fake.SetFolderPasswordsStub
fake.recordInvocation("SetFolderPasswords", []interface{}{arg1})
fake.setFolderPasswordsMutex.Unlock()
if stub != nil {
fake.SetFolderPasswordsStub(arg1)
}
}
func (fake *Connection) SetFolderPasswordsCallCount() int {
fake.setFolderPasswordsMutex.RLock()
defer fake.setFolderPasswordsMutex.RUnlock()
return len(fake.setFolderPasswordsArgsForCall)
}
func (fake *Connection) SetFolderPasswordsCalls(stub func(map[string]string)) {
fake.setFolderPasswordsMutex.Lock()
defer fake.setFolderPasswordsMutex.Unlock()
fake.SetFolderPasswordsStub = stub
}
func (fake *Connection) SetFolderPasswordsArgsForCall(i int) map[string]string {
fake.setFolderPasswordsMutex.RLock()
defer fake.setFolderPasswordsMutex.RUnlock()
argsForCall := fake.setFolderPasswordsArgsForCall[i]
return argsForCall.arg1
}
func (fake *Connection) Start() {
fake.startMutex.Lock()
fake.startArgsForCall = append(fake.startArgsForCall, struct {
@ -1080,6 +1117,8 @@ func (fake *Connection) Invocations() map[string][][]interface{} {
defer fake.remoteAddrMutex.RUnlock()
fake.requestMutex.RLock()
defer fake.requestMutex.RUnlock()
fake.setFolderPasswordsMutex.RLock()
defer fake.setFolderPasswordsMutex.RUnlock()
fake.startMutex.RLock()
defer fake.startMutex.RUnlock()
fake.statisticsMutex.RLock()

View File

@ -126,8 +126,8 @@ type Model interface {
Request(deviceID DeviceID, folder, name string, blockNo, size int32, offset int64, hash []byte, weakHash uint32, fromTemporary bool) (RequestResponse, error)
// A cluster configuration message was received
ClusterConfig(deviceID DeviceID, config ClusterConfig) error
// The peer device closed the connection
Closed(conn Connection, err error)
// The peer device closed the connection or an error occurred
Closed(device DeviceID, err error)
// The peer device sent progress updates for the files it is currently downloading
DownloadProgress(deviceID DeviceID, folder string, updates []FileDownloadProgressUpdate) error
}
@ -140,6 +140,7 @@ type RequestResponse interface {
type Connection interface {
Start()
SetFolderPasswords(passwords map[string]string)
Close(err error)
ID() DeviceID
Index(ctx context.Context, folder string, files []FileInfo) error
@ -225,24 +226,16 @@ const (
// Should not be modified in production code, just for testing.
var CloseTimeout = 10 * time.Second
func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
receiver = nativeModel{receiver}
rc := newRawConnection(deviceID, reader, writer, closer, receiver, connInfo, compress)
return wireFormatConnection{rc}
}
func NewEncryptedConnection(passwords map[string]string, deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression) Connection {
keys := keysFromPasswords(passwords)
func NewConnection(deviceID DeviceID, reader io.Reader, writer io.Writer, closer io.Closer, receiver Model, connInfo ConnectionInfo, compress Compression, passwords map[string]string) Connection {
// Encryption / decryption is first (outermost) before conversion to
// native path formats.
nm := nativeModel{receiver}
em := encryptedModel{model: nm, folderKeys: keys}
em := &encryptedModel{model: nm, folderKeys: newFolderKeyRegistry(passwords)}
// We do the wire format conversion first (outermost) so that the
// metadata is in wire format when it reaches the encryption step.
rc := newRawConnection(deviceID, reader, writer, closer, em, connInfo, compress)
ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: keys}
ec := encryptedConnection{ConnectionInfo: rc, conn: rc, folderKeys: em.folderKeys}
wc := wireFormatConnection{ec}
return wc
@ -748,6 +741,8 @@ func (c *rawConnection) writerLoop() {
}
func (c *rawConnection) writeMessage(msg message) error {
msgContext, _ := messageContext(msg)
l.Debugf("Writing %v", msgContext)
if c.shouldCompressMessage(msg) {
return c.writeCompressedMessage(msg)
}
@ -955,7 +950,7 @@ func (c *rawConnection) internalClose(err error) {
<-c.dispatcherLoopStopped
c.receiver.Closed(c, err)
c.receiver.Closed(c.ID(), err)
})
}

View File

@ -31,10 +31,10 @@ func TestPing(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways, nil))
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c1 := getRawConnection(NewConnection(c1ID, br, aw, testutils.NoopCloser{}, newTestModel(), new(mockedConnectionInfo), CompressionAlways, nil))
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
@ -57,10 +57,10 @@ func TestClose(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionAlways, nil))
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionAlways)
c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionAlways, nil)
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
@ -102,7 +102,7 @@ func TestCloseOnBlockingSend(t *testing.T) {
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
c.Start()
defer closeAndWait(c, rw)
@ -153,10 +153,10 @@ func TestCloseRace(t *testing.T) {
ar, aw := io.Pipe()
br, bw := io.Pipe()
c0 := NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionNever).(wireFormatConnection).Connection.(*rawConnection)
c0 := getRawConnection(NewConnection(c0ID, ar, bw, testutils.NoopCloser{}, m0, new(mockedConnectionInfo), CompressionNever, nil))
c0.Start()
defer closeAndWait(c0, ar, bw)
c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionNever)
c1 := NewConnection(c1ID, br, aw, testutils.NoopCloser{}, m1, new(mockedConnectionInfo), CompressionNever, nil)
c1.Start()
defer closeAndWait(c1, ar, bw)
c0.ClusterConfig(ClusterConfig{})
@ -193,7 +193,7 @@ func TestClusterConfigFirst(t *testing.T) {
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := getRawConnection(NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
c.Start()
defer closeAndWait(c, rw)
@ -245,7 +245,7 @@ func TestCloseTimeout(t *testing.T) {
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
c.Start()
defer closeAndWait(c, rw)
@ -865,7 +865,7 @@ func TestClusterConfigAfterClose(t *testing.T) {
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := getRawConnection(NewConnection(c0ID, rw, rw, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
c.Start()
defer closeAndWait(c, rw)
@ -889,7 +889,7 @@ func TestDispatcherToCloseDeadlock(t *testing.T) {
// the model callbacks (ClusterConfig).
m := newTestModel()
rw := testutils.NewBlockingRW()
c := NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways).(wireFormatConnection).Connection.(*rawConnection)
c := getRawConnection(NewConnection(c0ID, rw, &testutils.NoopRW{}, testutils.NoopCloser{}, m, new(mockedConnectionInfo), CompressionAlways, nil))
m.ccFn = func(devID DeviceID, cc ClusterConfig) {
c.Close(errManual)
}
@ -962,17 +962,28 @@ func TestIndexIDString(t *testing.T) {
}
}
func closeAndWait(c Connection, closers ...io.Closer) {
func closeAndWait(c interface{}, closers ...io.Closer) {
for _, closer := range closers {
closer.Close()
}
var raw *rawConnection
switch i := c.(type) {
case wireFormatConnection:
raw = i.Connection.(*rawConnection)
case *rawConnection:
raw = i
default:
raw = getRawConnection(c.(Connection))
}
raw.internalClose(ErrClosed)
raw.loopWG.Wait()
}
func getRawConnection(c Connection) *rawConnection {
var raw *rawConnection
switch i := c.(type) {
case wireFormatConnection:
raw = i.Connection.(encryptedConnection).conn
case encryptedConnection:
raw = i.conn
}
return raw
}