From 857caf363735dc2c79795bc4511b0bbe2fcc0d0b Mon Sep 17 00:00:00 2001 From: Simon Frei Date: Thu, 17 Jun 2021 13:57:44 +0200 Subject: [PATCH] lib/connections: Trigger dialer when connection gets closed (#7753) --- lib/connections/connections_test.go | 65 ++++++++++ lib/connections/service.go | 185 +++++++++++++++++++++------- lib/model/fakeconns_test.go | 8 +- lib/model/model_test.go | 8 +- lib/protocol/encryption.go | 2 +- lib/protocol/mocks/connection.go | 20 +-- lib/protocol/protocol.go | 11 +- 7 files changed, 231 insertions(+), 68 deletions(-) diff --git a/lib/connections/connections_test.go b/lib/connections/connections_test.go index 72ff3cc52..b9a10d5ca 100644 --- a/lib/connections/connections_test.go +++ b/lib/connections/connections_test.go @@ -231,6 +231,71 @@ func TestConnectionStatus(t *testing.T) { check(nil, nil) } +func TestNextDialRegistryCleanup(t *testing.T) { + now := time.Now() + firsts := []time.Time{ + now.Add(-dialCoolDownInterval + time.Second), + now.Add(-dialCoolDownDelay + time.Second), + now.Add(-2 * dialCoolDownDelay), + } + + r := make(nextDialRegistry) + + // Cases where the device should be cleaned up + + r[protocol.LocalDeviceID] = nextDialDevice{} + r.sleepDurationAndCleanup(now) + if l := len(r); l > 0 { + t.Errorf("Expected empty to be cleaned up, got length %v", l) + } + for _, dev := range []nextDialDevice{ + // attempts below threshold, outside of interval + { + attempts: 1, + coolDownIntervalStart: firsts[1], + }, + { + attempts: 1, + coolDownIntervalStart: firsts[2], + }, + // Threshold reached, but outside of cooldown delay + { + attempts: dialCoolDownMaxAttemps, + coolDownIntervalStart: firsts[2], + }, + } { + r[protocol.LocalDeviceID] = dev + r.sleepDurationAndCleanup(now) + if l := len(r); l > 0 { + t.Errorf("attempts: %v, start: %v: Expected all cleaned up, got length %v", dev.attempts, dev.coolDownIntervalStart, l) + } + } + + // Cases where the device should stay monitored + for _, dev := range []nextDialDevice{ + // attempts below threshold, inside of interval + { + attempts: 1, + coolDownIntervalStart: firsts[0], + }, + // attempts at threshold, inside delay + { + attempts: dialCoolDownMaxAttemps, + coolDownIntervalStart: firsts[0], + }, + { + attempts: dialCoolDownMaxAttemps, + coolDownIntervalStart: firsts[1], + }, + } { + r[protocol.LocalDeviceID] = dev + r.sleepDurationAndCleanup(now) + if l := len(r); l != 1 { + t.Errorf("attempts: %v, start: %v: Expected device still tracked, got length %v", dev.attempts, dev.coolDownIntervalStart, l) + } + } +} + func BenchmarkConnections(pb *testing.B) { addrs := []string{ "tcp://127.0.0.1:0", diff --git a/lib/connections/service.go b/lib/connections/service.go index 7859a9e93..565815c7c 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -142,10 +142,13 @@ type service struct { natService *nat.Service evLogger events.Logger - deviceAddressesChanged chan struct{} - listenersMut sync.RWMutex - listeners map[string]genericListener - listenerTokens map[string]suture.ServiceToken + dialNow chan struct{} + dialNowDevices map[protocol.DeviceID]struct{} + dialNowDevicesMut sync.Mutex + + listenersMut sync.RWMutex + listeners map[string]genericListener + listenerTokens map[string]suture.ServiceToken } func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, bepProtocolName string, tlsDefaultCommonName string, evLogger events.Logger) Service { @@ -166,10 +169,13 @@ func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *t natService: nat.NewService(myID, cfg), evLogger: evLogger, - deviceAddressesChanged: make(chan struct{}, 1), - listenersMut: sync.NewRWMutex(), - listeners: make(map[string]genericListener), - listenerTokens: make(map[string]suture.ServiceToken), + dialNowDevicesMut: sync.NewMutex(), + dialNow: make(chan struct{}, 1), + dialNowDevices: make(map[protocol.DeviceID]struct{}), + + listenersMut: sync.NewRWMutex(), + listeners: make(map[string]genericListener), + listenerTokens: make(map[string]suture.ServiceToken), } cfg.Subscribe(service) @@ -324,6 +330,13 @@ func (s *service) handle(ctx context.Context) error { rd, wr := s.limiter.getLimiters(remoteID, c, isLAN) protoConn := protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression, s.cfg.FolderPasswords(remoteID)) + go func() { + <-protoConn.Closed() + s.dialNowDevicesMut.Lock() + s.dialNowDevices[remoteID] = struct{}{} + s.scheduleDialNow() + s.dialNowDevicesMut.Unlock() + }() l.Infof("Established secure connection to %s at %s", remoteID, c) @@ -334,7 +347,7 @@ func (s *service) handle(ctx context.Context) error { func (s *service) connect(ctx context.Context) error { // Map of when to earliest dial each given device + address again - nextDialAt := make(map[string]time.Time) + nextDialAt := make(nextDialRegistry) // Used as delay for the first few connection attempts (adjusted up to // minConnectionLoopSleep), increased exponentially until it reaches @@ -369,7 +382,7 @@ func (s *service) connect(ctx context.Context) error { // The sleep time is until the next dial scheduled in nextDialAt, // clamped by stdConnectionLoopSleep as we don't want to sleep too // long (config changes might happen). - sleep = filterAndFindSleepDuration(nextDialAt, now) + sleep = nextDialAt.sleepDurationAndCleanup(now) } // ... while making sure not to loop too quickly either. @@ -379,9 +392,20 @@ func (s *service) connect(ctx context.Context) error { l.Debugln("Next connection loop in", sleep) + timeout := time.NewTimer(sleep) select { - case <-s.deviceAddressesChanged: - case <-time.After(sleep): + case <-s.dialNow: + // Remove affected devices from nextDialAt to dial immediately, + // regardless of when we last dialed it (there's cool down in the + // registry for too many repeat dials). + s.dialNowDevicesMut.Lock() + for device := range s.dialNowDevices { + nextDialAt.redialDevice(device, now) + } + s.dialNowDevices = make(map[protocol.DeviceID]struct{}) + s.dialNowDevicesMut.Unlock() + timeout.Stop() + case <-timeout.C: case <-ctx.Done(): return ctx.Err() } @@ -401,7 +425,7 @@ func (s *service) bestDialerPriority(cfg config.Configuration) int { return bestDialerPriority } -func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Configuration, bestDialerPriority int, nextDialAt map[string]time.Time, initial bool) { +func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Configuration, bestDialerPriority int, nextDialAt nextDialRegistry, initial bool) { // Figure out current connection limits up front to see if there's any // point in resolving devices and such at all. allowAdditional := 0 // no limit @@ -477,7 +501,7 @@ func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Con } } -func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg config.Configuration, deviceCfg config.DeviceConfiguration, nextDialAt map[string]time.Time, initial bool, priorityCutoff int) []dialTarget { +func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg config.Configuration, deviceCfg config.DeviceConfiguration, nextDialAt nextDialRegistry, initial bool, priorityCutoff int) []dialTarget { deviceID := deviceCfg.DeviceID addrs := s.resolveDeviceAddrs(ctx, deviceCfg) @@ -485,18 +509,16 @@ func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg con dialTargets := make([]dialTarget, 0, len(addrs)) for _, addr := range addrs { - // Use a special key that is more than just the address, as you - // might have two devices connected to the same relay - nextDialKey := deviceID.String() + "/" + addr - when, ok := nextDialAt[nextDialKey] - if ok && !initial && when.After(now) { + // Use both device and address, as you might have two devices connected + // to the same relay + if !initial && nextDialAt.get(deviceID, addr).After(now) { l.Debugf("Not dialing %s via %v as it's not time yet", deviceID, addr) continue } // If we fail at any step before actually getting the dialer // retry in a minute - nextDialAt[nextDialKey] = now.Add(time.Minute) + nextDialAt.set(deviceID, addr, now.Add(time.Minute)) uri, err := url.Parse(addr) if err != nil { @@ -532,7 +554,7 @@ func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg con } dialer := dialerFactory.New(s.cfg.Options(), s.tlsCfg) - nextDialAt[nextDialKey] = now.Add(dialer.RedialFrequency()) + nextDialAt.set(deviceID, addr, now.Add(dialer.RedialFrequency())) // For LAN addresses, increase the priority so that we // try these first. @@ -735,24 +757,24 @@ func (s *service) CommitConfiguration(from, to config.Configuration) bool { } func (s *service) checkAndSignalConnectLoopOnUpdatedDevices(from, to config.Configuration) { - oldDevices := make(map[protocol.DeviceID]config.DeviceConfiguration, len(from.Devices)) - for _, dev := range from.Devices { - oldDevices[dev.DeviceID] = dev - } - + oldDevices := from.DeviceMap() for _, dev := range to.Devices { oldDev, ok := oldDevices[dev.DeviceID] if !ok || !util.EqualStrings(oldDev.Addresses, dev.Addresses) { - select { - case s.deviceAddressesChanged <- struct{}{}: - default: - // channel is blocked - a config update is already pending for the connection loop. - } + s.scheduleDialNow() break } } } +func (s *service) scheduleDialNow() { + select { + case s.dialNow <- struct{}{}: + default: + // channel is blocked - a config update is already pending for the connection loop. + } +} + func (s *service) AllAddresses() []string { s.listenersMut.RLock() var addrs []string @@ -877,21 +899,6 @@ func getListenerFactory(cfg config.Configuration, uri *url.URL) (listenerFactory return listenerFactory, nil } -func filterAndFindSleepDuration(nextDialAt map[string]time.Time, now time.Time) time.Duration { - sleep := stdConnectionLoopSleep - for key, next := range nextDialAt { - if next.Before(now) { - // Expired entry, address was not seen in last pass(es) - delete(nextDialAt, key) - continue - } - if cur := next.Sub(now); cur < sleep { - sleep = cur - } - } - return sleep -} - func urlsToStrings(urls []*url.URL) []string { strings := make([]string, len(urls)) for i, url := range urls { @@ -1050,3 +1057,89 @@ func (s *service) validateIdentity(c internalConn, expectedID protocol.DeviceID) return nil } + +type nextDialRegistry map[protocol.DeviceID]nextDialDevice + +type nextDialDevice struct { + nextDial map[string]time.Time + coolDownIntervalStart time.Time + attempts int +} + +func (r nextDialRegistry) get(device protocol.DeviceID, addr string) time.Time { + return r[device].nextDial[addr] +} + +const ( + dialCoolDownInterval = 2 * time.Minute + dialCoolDownDelay = 5 * time.Minute + dialCoolDownMaxAttemps = 3 +) + +// redialDevice marks the device for immediate redial, unless the remote keeps +// dropping established connections. Thus we keep track of when the first forced +// re-dial happened, and how many attempts happen in the dialCoolDownInterval +// after that. If it's more than dialCoolDownMaxAttempts, don't force-redial +// that device for dialCoolDownDelay (regular dials still happen). +func (r nextDialRegistry) redialDevice(device protocol.DeviceID, now time.Time) { + dev, ok := r[device] + if !ok { + r[device] = nextDialDevice{ + coolDownIntervalStart: now, + attempts: 1, + } + return + } + if dev.attempts == 0 || now.Before(dev.coolDownIntervalStart.Add(dialCoolDownInterval)) { + if dev.attempts >= dialCoolDownMaxAttemps { + // Device has been force redialed too often - let it cool down. + return + } + if dev.attempts == 0 { + dev.coolDownIntervalStart = now + } + dev.attempts++ + dev.nextDial = make(map[string]time.Time) + return + } + if dev.attempts >= dialCoolDownMaxAttemps && now.Before(dev.coolDownIntervalStart.Add(dialCoolDownDelay)) { + return // Still cooling down + } + delete(r, device) +} + +func (r nextDialRegistry) set(device protocol.DeviceID, addr string, next time.Time) { + if _, ok := r[device]; !ok { + r[device] = nextDialDevice{nextDial: make(map[string]time.Time)} + } + r[device].nextDial[addr] = next +} + +func (r nextDialRegistry) sleepDurationAndCleanup(now time.Time) time.Duration { + sleep := stdConnectionLoopSleep + for id, dev := range r { + for address, next := range dev.nextDial { + if next.Before(now) { + // Expired entry, address was not seen in last pass(es) + delete(dev.nextDial, address) + continue + } + if cur := next.Sub(now); cur < sleep { + sleep = cur + } + } + if dev.attempts > 0 { + interval := dialCoolDownInterval + if dev.attempts >= dialCoolDownMaxAttemps { + interval = dialCoolDownDelay + } + if now.After(dev.coolDownIntervalStart.Add(interval)) { + dev.attempts = 0 + } + } + if len(dev.nextDial) == 0 && dev.attempts == 0 { + delete(r, id) + } + } + return sleep +} diff --git a/lib/model/fakeconns_test.go b/lib/model/fakeconns_test.go index 87abe11c1..79b4e8552 100644 --- a/lib/model/fakeconns_test.go +++ b/lib/model/fakeconns_test.go @@ -27,14 +27,18 @@ func newFakeConnection(id protocol.DeviceID, model Model) *fakeConnection { Connection: new(protocolmocks.Connection), id: id, model: model, + closed: make(chan struct{}), } f.RequestCalls(func(ctx context.Context, folder, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error) { return f.fileData[name], nil }) f.IDReturns(id) f.CloseCalls(func(err error) { + f.closeOnce.Do(func() { + close(f.closed) + }) model.Closed(id, err) - f.ClosedReturns(true) + f.ClosedReturns(f.closed) }) return f } @@ -47,6 +51,8 @@ type fakeConnection struct { fileData map[string][]byte folder string model Model + closed chan struct{} + closeOnce sync.Once mut sync.Mutex } diff --git a/lib/model/model_test.go b/lib/model/model_test.go index 3921bba72..db288396d 100644 --- a/lib/model/model_test.go +++ b/lib/model/model_test.go @@ -2245,8 +2245,10 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { t.Error("not shared with device2") } - if conn2.Closed() { + select { + case <-conn2.Closed(): t.Error("conn already closed") + default: } if _, err := wcfg.RemoveDevice(device2); err != nil { @@ -2271,7 +2273,9 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) { } } - if !conn2.Closed() { + select { + case <-conn2.Closed(): + default: t.Error("connection not closed") } diff --git a/lib/protocol/encryption.go b/lib/protocol/encryption.go index b977d2e50..2a634ce02 100644 --- a/lib/protocol/encryption.go +++ b/lib/protocol/encryption.go @@ -224,7 +224,7 @@ func (e encryptedConnection) Close(err error) { e.conn.Close(err) } -func (e encryptedConnection) Closed() bool { +func (e encryptedConnection) Closed() <-chan struct{} { return e.conn.Closed() } diff --git a/lib/protocol/mocks/connection.go b/lib/protocol/mocks/connection.go index a07ce4c2f..3aab4b8b9 100644 --- a/lib/protocol/mocks/connection.go +++ b/lib/protocol/mocks/connection.go @@ -16,15 +16,15 @@ type Connection struct { closeArgsForCall []struct { arg1 error } - ClosedStub func() bool + ClosedStub func() <-chan struct{} closedMutex sync.RWMutex closedArgsForCall []struct { } closedReturns struct { - result1 bool + result1 <-chan struct{} } closedReturnsOnCall map[int]struct { - result1 bool + result1 <-chan struct{} } ClusterConfigStub func(protocol.ClusterConfig) clusterConfigMutex sync.RWMutex @@ -220,7 +220,7 @@ func (fake *Connection) CloseArgsForCall(i int) error { return argsForCall.arg1 } -func (fake *Connection) Closed() bool { +func (fake *Connection) Closed() <-chan struct{} { fake.closedMutex.Lock() ret, specificReturn := fake.closedReturnsOnCall[len(fake.closedArgsForCall)] fake.closedArgsForCall = append(fake.closedArgsForCall, struct { @@ -244,32 +244,32 @@ func (fake *Connection) ClosedCallCount() int { return len(fake.closedArgsForCall) } -func (fake *Connection) ClosedCalls(stub func() bool) { +func (fake *Connection) ClosedCalls(stub func() <-chan struct{}) { fake.closedMutex.Lock() defer fake.closedMutex.Unlock() fake.ClosedStub = stub } -func (fake *Connection) ClosedReturns(result1 bool) { +func (fake *Connection) ClosedReturns(result1 <-chan struct{}) { fake.closedMutex.Lock() defer fake.closedMutex.Unlock() fake.ClosedStub = nil fake.closedReturns = struct { - result1 bool + result1 <-chan struct{} }{result1} } -func (fake *Connection) ClosedReturnsOnCall(i int, result1 bool) { +func (fake *Connection) ClosedReturnsOnCall(i int, result1 <-chan struct{}) { fake.closedMutex.Lock() defer fake.closedMutex.Unlock() fake.ClosedStub = nil if fake.closedReturnsOnCall == nil { fake.closedReturnsOnCall = make(map[int]struct { - result1 bool + result1 <-chan struct{} }) } fake.closedReturnsOnCall[i] = struct { - result1 bool + result1 <-chan struct{} }{result1} } diff --git a/lib/protocol/protocol.go b/lib/protocol/protocol.go index 4c89c76b6..28244f092 100644 --- a/lib/protocol/protocol.go +++ b/lib/protocol/protocol.go @@ -151,7 +151,7 @@ type Connection interface { ClusterConfig(config ClusterConfig) DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate) Statistics() Statistics - Closed() bool + Closed() <-chan struct{} ConnectionInfo } @@ -380,13 +380,8 @@ func (c *rawConnection) ClusterConfig(config ClusterConfig) { } } -func (c *rawConnection) Closed() bool { - select { - case <-c.closed: - return true - default: - return false - } +func (c *rawConnection) Closed() <-chan struct{} { + return c.closed } // DownloadProgress sends the progress updates for the files that are currently being downloaded.