lib/connections: Trigger dialer when connection gets closed (#7753)

This commit is contained in:
Simon Frei 2021-06-17 13:57:44 +02:00 committed by GitHub
parent aeca1fb575
commit 857caf3637
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 231 additions and 68 deletions

View File

@ -231,6 +231,71 @@ func TestConnectionStatus(t *testing.T) {
check(nil, nil)
}
func TestNextDialRegistryCleanup(t *testing.T) {
now := time.Now()
firsts := []time.Time{
now.Add(-dialCoolDownInterval + time.Second),
now.Add(-dialCoolDownDelay + time.Second),
now.Add(-2 * dialCoolDownDelay),
}
r := make(nextDialRegistry)
// Cases where the device should be cleaned up
r[protocol.LocalDeviceID] = nextDialDevice{}
r.sleepDurationAndCleanup(now)
if l := len(r); l > 0 {
t.Errorf("Expected empty to be cleaned up, got length %v", l)
}
for _, dev := range []nextDialDevice{
// attempts below threshold, outside of interval
{
attempts: 1,
coolDownIntervalStart: firsts[1],
},
{
attempts: 1,
coolDownIntervalStart: firsts[2],
},
// Threshold reached, but outside of cooldown delay
{
attempts: dialCoolDownMaxAttemps,
coolDownIntervalStart: firsts[2],
},
} {
r[protocol.LocalDeviceID] = dev
r.sleepDurationAndCleanup(now)
if l := len(r); l > 0 {
t.Errorf("attempts: %v, start: %v: Expected all cleaned up, got length %v", dev.attempts, dev.coolDownIntervalStart, l)
}
}
// Cases where the device should stay monitored
for _, dev := range []nextDialDevice{
// attempts below threshold, inside of interval
{
attempts: 1,
coolDownIntervalStart: firsts[0],
},
// attempts at threshold, inside delay
{
attempts: dialCoolDownMaxAttemps,
coolDownIntervalStart: firsts[0],
},
{
attempts: dialCoolDownMaxAttemps,
coolDownIntervalStart: firsts[1],
},
} {
r[protocol.LocalDeviceID] = dev
r.sleepDurationAndCleanup(now)
if l := len(r); l != 1 {
t.Errorf("attempts: %v, start: %v: Expected device still tracked, got length %v", dev.attempts, dev.coolDownIntervalStart, l)
}
}
}
func BenchmarkConnections(pb *testing.B) {
addrs := []string{
"tcp://127.0.0.1:0",

View File

@ -142,10 +142,13 @@ type service struct {
natService *nat.Service
evLogger events.Logger
deviceAddressesChanged chan struct{}
listenersMut sync.RWMutex
listeners map[string]genericListener
listenerTokens map[string]suture.ServiceToken
dialNow chan struct{}
dialNowDevices map[protocol.DeviceID]struct{}
dialNowDevicesMut sync.Mutex
listenersMut sync.RWMutex
listeners map[string]genericListener
listenerTokens map[string]suture.ServiceToken
}
func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *tls.Config, discoverer discover.Finder, bepProtocolName string, tlsDefaultCommonName string, evLogger events.Logger) Service {
@ -166,10 +169,13 @@ func NewService(cfg config.Wrapper, myID protocol.DeviceID, mdl Model, tlsCfg *t
natService: nat.NewService(myID, cfg),
evLogger: evLogger,
deviceAddressesChanged: make(chan struct{}, 1),
listenersMut: sync.NewRWMutex(),
listeners: make(map[string]genericListener),
listenerTokens: make(map[string]suture.ServiceToken),
dialNowDevicesMut: sync.NewMutex(),
dialNow: make(chan struct{}, 1),
dialNowDevices: make(map[protocol.DeviceID]struct{}),
listenersMut: sync.NewRWMutex(),
listeners: make(map[string]genericListener),
listenerTokens: make(map[string]suture.ServiceToken),
}
cfg.Subscribe(service)
@ -324,6 +330,13 @@ func (s *service) handle(ctx context.Context) error {
rd, wr := s.limiter.getLimiters(remoteID, c, isLAN)
protoConn := protocol.NewConnection(remoteID, rd, wr, c, s.model, c, deviceCfg.Compression, s.cfg.FolderPasswords(remoteID))
go func() {
<-protoConn.Closed()
s.dialNowDevicesMut.Lock()
s.dialNowDevices[remoteID] = struct{}{}
s.scheduleDialNow()
s.dialNowDevicesMut.Unlock()
}()
l.Infof("Established secure connection to %s at %s", remoteID, c)
@ -334,7 +347,7 @@ func (s *service) handle(ctx context.Context) error {
func (s *service) connect(ctx context.Context) error {
// Map of when to earliest dial each given device + address again
nextDialAt := make(map[string]time.Time)
nextDialAt := make(nextDialRegistry)
// Used as delay for the first few connection attempts (adjusted up to
// minConnectionLoopSleep), increased exponentially until it reaches
@ -369,7 +382,7 @@ func (s *service) connect(ctx context.Context) error {
// The sleep time is until the next dial scheduled in nextDialAt,
// clamped by stdConnectionLoopSleep as we don't want to sleep too
// long (config changes might happen).
sleep = filterAndFindSleepDuration(nextDialAt, now)
sleep = nextDialAt.sleepDurationAndCleanup(now)
}
// ... while making sure not to loop too quickly either.
@ -379,9 +392,20 @@ func (s *service) connect(ctx context.Context) error {
l.Debugln("Next connection loop in", sleep)
timeout := time.NewTimer(sleep)
select {
case <-s.deviceAddressesChanged:
case <-time.After(sleep):
case <-s.dialNow:
// Remove affected devices from nextDialAt to dial immediately,
// regardless of when we last dialed it (there's cool down in the
// registry for too many repeat dials).
s.dialNowDevicesMut.Lock()
for device := range s.dialNowDevices {
nextDialAt.redialDevice(device, now)
}
s.dialNowDevices = make(map[protocol.DeviceID]struct{})
s.dialNowDevicesMut.Unlock()
timeout.Stop()
case <-timeout.C:
case <-ctx.Done():
return ctx.Err()
}
@ -401,7 +425,7 @@ func (s *service) bestDialerPriority(cfg config.Configuration) int {
return bestDialerPriority
}
func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Configuration, bestDialerPriority int, nextDialAt map[string]time.Time, initial bool) {
func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Configuration, bestDialerPriority int, nextDialAt nextDialRegistry, initial bool) {
// Figure out current connection limits up front to see if there's any
// point in resolving devices and such at all.
allowAdditional := 0 // no limit
@ -477,7 +501,7 @@ func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Con
}
}
func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg config.Configuration, deviceCfg config.DeviceConfiguration, nextDialAt map[string]time.Time, initial bool, priorityCutoff int) []dialTarget {
func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg config.Configuration, deviceCfg config.DeviceConfiguration, nextDialAt nextDialRegistry, initial bool, priorityCutoff int) []dialTarget {
deviceID := deviceCfg.DeviceID
addrs := s.resolveDeviceAddrs(ctx, deviceCfg)
@ -485,18 +509,16 @@ func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg con
dialTargets := make([]dialTarget, 0, len(addrs))
for _, addr := range addrs {
// Use a special key that is more than just the address, as you
// might have two devices connected to the same relay
nextDialKey := deviceID.String() + "/" + addr
when, ok := nextDialAt[nextDialKey]
if ok && !initial && when.After(now) {
// Use both device and address, as you might have two devices connected
// to the same relay
if !initial && nextDialAt.get(deviceID, addr).After(now) {
l.Debugf("Not dialing %s via %v as it's not time yet", deviceID, addr)
continue
}
// If we fail at any step before actually getting the dialer
// retry in a minute
nextDialAt[nextDialKey] = now.Add(time.Minute)
nextDialAt.set(deviceID, addr, now.Add(time.Minute))
uri, err := url.Parse(addr)
if err != nil {
@ -532,7 +554,7 @@ func (s *service) resolveDialTargets(ctx context.Context, now time.Time, cfg con
}
dialer := dialerFactory.New(s.cfg.Options(), s.tlsCfg)
nextDialAt[nextDialKey] = now.Add(dialer.RedialFrequency())
nextDialAt.set(deviceID, addr, now.Add(dialer.RedialFrequency()))
// For LAN addresses, increase the priority so that we
// try these first.
@ -735,24 +757,24 @@ func (s *service) CommitConfiguration(from, to config.Configuration) bool {
}
func (s *service) checkAndSignalConnectLoopOnUpdatedDevices(from, to config.Configuration) {
oldDevices := make(map[protocol.DeviceID]config.DeviceConfiguration, len(from.Devices))
for _, dev := range from.Devices {
oldDevices[dev.DeviceID] = dev
}
oldDevices := from.DeviceMap()
for _, dev := range to.Devices {
oldDev, ok := oldDevices[dev.DeviceID]
if !ok || !util.EqualStrings(oldDev.Addresses, dev.Addresses) {
select {
case s.deviceAddressesChanged <- struct{}{}:
default:
// channel is blocked - a config update is already pending for the connection loop.
}
s.scheduleDialNow()
break
}
}
}
func (s *service) scheduleDialNow() {
select {
case s.dialNow <- struct{}{}:
default:
// channel is blocked - a config update is already pending for the connection loop.
}
}
func (s *service) AllAddresses() []string {
s.listenersMut.RLock()
var addrs []string
@ -877,21 +899,6 @@ func getListenerFactory(cfg config.Configuration, uri *url.URL) (listenerFactory
return listenerFactory, nil
}
func filterAndFindSleepDuration(nextDialAt map[string]time.Time, now time.Time) time.Duration {
sleep := stdConnectionLoopSleep
for key, next := range nextDialAt {
if next.Before(now) {
// Expired entry, address was not seen in last pass(es)
delete(nextDialAt, key)
continue
}
if cur := next.Sub(now); cur < sleep {
sleep = cur
}
}
return sleep
}
func urlsToStrings(urls []*url.URL) []string {
strings := make([]string, len(urls))
for i, url := range urls {
@ -1050,3 +1057,89 @@ func (s *service) validateIdentity(c internalConn, expectedID protocol.DeviceID)
return nil
}
type nextDialRegistry map[protocol.DeviceID]nextDialDevice
type nextDialDevice struct {
nextDial map[string]time.Time
coolDownIntervalStart time.Time
attempts int
}
func (r nextDialRegistry) get(device protocol.DeviceID, addr string) time.Time {
return r[device].nextDial[addr]
}
const (
dialCoolDownInterval = 2 * time.Minute
dialCoolDownDelay = 5 * time.Minute
dialCoolDownMaxAttemps = 3
)
// redialDevice marks the device for immediate redial, unless the remote keeps
// dropping established connections. Thus we keep track of when the first forced
// re-dial happened, and how many attempts happen in the dialCoolDownInterval
// after that. If it's more than dialCoolDownMaxAttempts, don't force-redial
// that device for dialCoolDownDelay (regular dials still happen).
func (r nextDialRegistry) redialDevice(device protocol.DeviceID, now time.Time) {
dev, ok := r[device]
if !ok {
r[device] = nextDialDevice{
coolDownIntervalStart: now,
attempts: 1,
}
return
}
if dev.attempts == 0 || now.Before(dev.coolDownIntervalStart.Add(dialCoolDownInterval)) {
if dev.attempts >= dialCoolDownMaxAttemps {
// Device has been force redialed too often - let it cool down.
return
}
if dev.attempts == 0 {
dev.coolDownIntervalStart = now
}
dev.attempts++
dev.nextDial = make(map[string]time.Time)
return
}
if dev.attempts >= dialCoolDownMaxAttemps && now.Before(dev.coolDownIntervalStart.Add(dialCoolDownDelay)) {
return // Still cooling down
}
delete(r, device)
}
func (r nextDialRegistry) set(device protocol.DeviceID, addr string, next time.Time) {
if _, ok := r[device]; !ok {
r[device] = nextDialDevice{nextDial: make(map[string]time.Time)}
}
r[device].nextDial[addr] = next
}
func (r nextDialRegistry) sleepDurationAndCleanup(now time.Time) time.Duration {
sleep := stdConnectionLoopSleep
for id, dev := range r {
for address, next := range dev.nextDial {
if next.Before(now) {
// Expired entry, address was not seen in last pass(es)
delete(dev.nextDial, address)
continue
}
if cur := next.Sub(now); cur < sleep {
sleep = cur
}
}
if dev.attempts > 0 {
interval := dialCoolDownInterval
if dev.attempts >= dialCoolDownMaxAttemps {
interval = dialCoolDownDelay
}
if now.After(dev.coolDownIntervalStart.Add(interval)) {
dev.attempts = 0
}
}
if len(dev.nextDial) == 0 && dev.attempts == 0 {
delete(r, id)
}
}
return sleep
}

View File

@ -27,14 +27,18 @@ func newFakeConnection(id protocol.DeviceID, model Model) *fakeConnection {
Connection: new(protocolmocks.Connection),
id: id,
model: model,
closed: make(chan struct{}),
}
f.RequestCalls(func(ctx context.Context, folder, name string, blockNo int, offset int64, size int, hash []byte, weakHash uint32, fromTemporary bool) ([]byte, error) {
return f.fileData[name], nil
})
f.IDReturns(id)
f.CloseCalls(func(err error) {
f.closeOnce.Do(func() {
close(f.closed)
})
model.Closed(id, err)
f.ClosedReturns(true)
f.ClosedReturns(f.closed)
})
return f
}
@ -47,6 +51,8 @@ type fakeConnection struct {
fileData map[string][]byte
folder string
model Model
closed chan struct{}
closeOnce sync.Once
mut sync.Mutex
}

View File

@ -2245,8 +2245,10 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
t.Error("not shared with device2")
}
if conn2.Closed() {
select {
case <-conn2.Closed():
t.Error("conn already closed")
default:
}
if _, err := wcfg.RemoveDevice(device2); err != nil {
@ -2271,7 +2273,9 @@ func TestSharedWithClearedOnDisconnect(t *testing.T) {
}
}
if !conn2.Closed() {
select {
case <-conn2.Closed():
default:
t.Error("connection not closed")
}

View File

@ -224,7 +224,7 @@ func (e encryptedConnection) Close(err error) {
e.conn.Close(err)
}
func (e encryptedConnection) Closed() bool {
func (e encryptedConnection) Closed() <-chan struct{} {
return e.conn.Closed()
}

View File

@ -16,15 +16,15 @@ type Connection struct {
closeArgsForCall []struct {
arg1 error
}
ClosedStub func() bool
ClosedStub func() <-chan struct{}
closedMutex sync.RWMutex
closedArgsForCall []struct {
}
closedReturns struct {
result1 bool
result1 <-chan struct{}
}
closedReturnsOnCall map[int]struct {
result1 bool
result1 <-chan struct{}
}
ClusterConfigStub func(protocol.ClusterConfig)
clusterConfigMutex sync.RWMutex
@ -220,7 +220,7 @@ func (fake *Connection) CloseArgsForCall(i int) error {
return argsForCall.arg1
}
func (fake *Connection) Closed() bool {
func (fake *Connection) Closed() <-chan struct{} {
fake.closedMutex.Lock()
ret, specificReturn := fake.closedReturnsOnCall[len(fake.closedArgsForCall)]
fake.closedArgsForCall = append(fake.closedArgsForCall, struct {
@ -244,32 +244,32 @@ func (fake *Connection) ClosedCallCount() int {
return len(fake.closedArgsForCall)
}
func (fake *Connection) ClosedCalls(stub func() bool) {
func (fake *Connection) ClosedCalls(stub func() <-chan struct{}) {
fake.closedMutex.Lock()
defer fake.closedMutex.Unlock()
fake.ClosedStub = stub
}
func (fake *Connection) ClosedReturns(result1 bool) {
func (fake *Connection) ClosedReturns(result1 <-chan struct{}) {
fake.closedMutex.Lock()
defer fake.closedMutex.Unlock()
fake.ClosedStub = nil
fake.closedReturns = struct {
result1 bool
result1 <-chan struct{}
}{result1}
}
func (fake *Connection) ClosedReturnsOnCall(i int, result1 bool) {
func (fake *Connection) ClosedReturnsOnCall(i int, result1 <-chan struct{}) {
fake.closedMutex.Lock()
defer fake.closedMutex.Unlock()
fake.ClosedStub = nil
if fake.closedReturnsOnCall == nil {
fake.closedReturnsOnCall = make(map[int]struct {
result1 bool
result1 <-chan struct{}
})
}
fake.closedReturnsOnCall[i] = struct {
result1 bool
result1 <-chan struct{}
}{result1}
}

View File

@ -151,7 +151,7 @@ type Connection interface {
ClusterConfig(config ClusterConfig)
DownloadProgress(ctx context.Context, folder string, updates []FileDownloadProgressUpdate)
Statistics() Statistics
Closed() bool
Closed() <-chan struct{}
ConnectionInfo
}
@ -380,13 +380,8 @@ func (c *rawConnection) ClusterConfig(config ClusterConfig) {
}
}
func (c *rawConnection) Closed() bool {
select {
case <-c.closed:
return true
default:
return false
}
func (c *rawConnection) Closed() <-chan struct{} {
return c.closed
}
// DownloadProgress sends the progress updates for the files that are currently being downloaded.