lib: Faster termination on exit (ref #6319) (#6329)

This commit is contained in:
Simon Frei 2020-02-13 14:43:00 +01:00 committed by GitHub
parent ca90f4e6af
commit c3637f2191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 71 additions and 28 deletions

View File

@ -7,6 +7,7 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"flag" "flag"
@ -95,7 +96,7 @@ func checkServer(deviceID protocol.DeviceID, server string) checkResult {
}) })
go func() { go func() {
addresses, err := disco.Lookup(deviceID) addresses, err := disco.Lookup(context.Background(), deviceID)
res <- checkResult{addresses: addresses, error: err} res <- checkResult{addresses: addresses, error: err}
}() }()

View File

@ -7,6 +7,7 @@
package api package api
import ( import (
"context"
"time" "time"
"github.com/syncthing/syncthing/lib/discover" "github.com/syncthing/syncthing/lib/discover"
@ -26,7 +27,7 @@ func (m *mockedCachingMux) Stop() {
// from events.Finder // from events.Finder
func (m *mockedCachingMux) Lookup(deviceID protocol.DeviceID) (direct []string, err error) { func (m *mockedCachingMux) Lookup(ctx context.Context, deviceID protocol.DeviceID) (direct []string, err error) {
return nil, nil return nil, nil
} }

View File

@ -360,6 +360,12 @@ func (s *service) connect(ctx context.Context) {
var seen []string var seen []string
for _, deviceCfg := range cfg.Devices { for _, deviceCfg := range cfg.Devices {
select {
case <-ctx.Done():
return
default:
}
deviceID := deviceCfg.DeviceID deviceID := deviceCfg.DeviceID
if deviceID == s.myID { if deviceID == s.myID {
continue continue
@ -380,7 +386,7 @@ func (s *service) connect(ctx context.Context) {
for _, addr := range deviceCfg.Addresses { for _, addr := range deviceCfg.Addresses {
if addr == "dynamic" { if addr == "dynamic" {
if s.discoverer != nil { if s.discoverer != nil {
if t, err := s.discoverer.Lookup(deviceID); err == nil { if t, err := s.discoverer.Lookup(ctx, deviceID); err == nil {
addrs = append(addrs, t...) addrs = append(addrs, t...)
} }
} }

View File

@ -7,6 +7,7 @@
package discover package discover
import ( import (
"context"
"sort" "sort"
stdsync "sync" stdsync "sync"
"time" "time"
@ -73,7 +74,7 @@ func (m *cachingMux) Add(finder Finder, cacheTime, negCacheTime time.Duration) {
// Lookup attempts to resolve the device ID using any of the added Finders, // Lookup attempts to resolve the device ID using any of the added Finders,
// while obeying the cache settings. // while obeying the cache settings.
func (m *cachingMux) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) { func (m *cachingMux) Lookup(ctx context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
m.mut.RLock() m.mut.RLock()
for i, finder := range m.finders { for i, finder := range m.finders {
if cacheEntry, ok := m.caches[i].Get(deviceID); ok { if cacheEntry, ok := m.caches[i].Get(deviceID); ok {
@ -99,7 +100,7 @@ func (m *cachingMux) Lookup(deviceID protocol.DeviceID) (addresses []string, err
} }
// Perform the actual lookup and cache the result. // Perform the actual lookup and cache the result.
if addrs, err := finder.Lookup(deviceID); err == nil { if addrs, err := finder.Lookup(ctx, deviceID); err == nil {
l.Debugln("lookup for", deviceID, "at", finder) l.Debugln("lookup for", deviceID, "at", finder)
l.Debugln(" addresses:", addrs) l.Debugln(" addresses:", addrs)
addresses = append(addresses, addrs...) addresses = append(addresses, addrs...)

View File

@ -7,6 +7,7 @@
package discover package discover
import ( import (
"context"
"reflect" "reflect"
"testing" "testing"
"time" "time"
@ -39,7 +40,9 @@ func TestCacheUnique(t *testing.T) {
f1 := &fakeDiscovery{addresses0} f1 := &fakeDiscovery{addresses0}
c.Add(f1, time.Minute, 0) c.Add(f1, time.Minute, 0)
addr, err := c.Lookup(protocol.LocalDeviceID) ctx := context.Background()
addr, err := c.Lookup(ctx, protocol.LocalDeviceID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -53,7 +56,7 @@ func TestCacheUnique(t *testing.T) {
f2 := &fakeDiscovery{addresses1} f2 := &fakeDiscovery{addresses1}
c.Add(f2, time.Minute, 0) c.Add(f2, time.Minute, 0)
addr, err = c.Lookup(protocol.LocalDeviceID) addr, err = c.Lookup(ctx, protocol.LocalDeviceID)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -66,7 +69,7 @@ type fakeDiscovery struct {
addresses []string addresses []string
} }
func (f *fakeDiscovery) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) { func (f *fakeDiscovery) Lookup(_ context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
return f.addresses, nil return f.addresses, nil
} }
@ -96,7 +99,7 @@ func TestCacheSlowLookup(t *testing.T) {
// Start a lookup, which will take at least a second // Start a lookup, which will take at least a second
t0 := time.Now() t0 := time.Now()
go c.Lookup(protocol.LocalDeviceID) go c.Lookup(context.Background(), protocol.LocalDeviceID)
<-started // The slow lookup method has been called so we're inside the lock <-started // The slow lookup method has been called so we're inside the lock
// It should be possible to get ChildErrors while it's running // It should be possible to get ChildErrors while it's running
@ -116,7 +119,7 @@ type slowDiscovery struct {
started chan struct{} started chan struct{}
} }
func (f *slowDiscovery) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) { func (f *slowDiscovery) Lookup(_ context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
close(f.started) close(f.started)
time.Sleep(f.delay) time.Sleep(f.delay)
return nil, nil return nil, nil

View File

@ -7,6 +7,7 @@
package discover package discover
import ( import (
"context"
"time" "time"
"github.com/syncthing/syncthing/lib/protocol" "github.com/syncthing/syncthing/lib/protocol"
@ -15,7 +16,7 @@ import (
// A Finder provides lookup services of some kind. // A Finder provides lookup services of some kind.
type Finder interface { type Finder interface {
Lookup(deviceID protocol.DeviceID) (address []string, err error) Lookup(ctx context.Context, deviceID protocol.DeviceID) (address []string, err error)
Error() error Error() error
String() string String() string
Cache() map[protocol.DeviceID]CacheEntry Cache() map[protocol.DeviceID]CacheEntry

View File

@ -41,8 +41,8 @@ type globalClient struct {
} }
type httpClient interface { type httpClient interface {
Get(url string) (*http.Response, error) Get(ctx context.Context, url string) (*http.Response, error)
Post(url, ctype string, data io.Reader) (*http.Response, error) Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error)
} }
const ( const (
@ -89,7 +89,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
// The http.Client used for announcements. It needs to have our // The http.Client used for announcements. It needs to have our
// certificate to prove our identity, and may or may not verify the server // certificate to prove our identity, and may or may not verify the server
// certificate depending on the insecure setting. // certificate depending on the insecure setting.
var announceClient httpClient = &http.Client{ var announceClient httpClient = &contextClient{&http.Client{
Timeout: requestTimeout, Timeout: requestTimeout,
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: dialer.DialContext, DialContext: dialer.DialContext,
@ -99,14 +99,14 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
Certificates: []tls.Certificate{cert}, Certificates: []tls.Certificate{cert},
}, },
}, },
} }}
if opts.id != "" { if opts.id != "" {
announceClient = newIDCheckingHTTPClient(announceClient, devID) announceClient = newIDCheckingHTTPClient(announceClient, devID)
} }
// The http.Client used for queries. We don't need to present our // The http.Client used for queries. We don't need to present our
// certificate here, so lets not include it. May be insecure if requested. // certificate here, so lets not include it. May be insecure if requested.
var queryClient httpClient = &http.Client{ var queryClient httpClient = &contextClient{&http.Client{
Timeout: requestTimeout, Timeout: requestTimeout,
Transport: &http.Transport{ Transport: &http.Transport{
DialContext: dialer.DialContext, DialContext: dialer.DialContext,
@ -115,7 +115,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
InsecureSkipVerify: opts.insecure, InsecureSkipVerify: opts.insecure,
}, },
}, },
} }}
if opts.id != "" { if opts.id != "" {
queryClient = newIDCheckingHTTPClient(queryClient, devID) queryClient = newIDCheckingHTTPClient(queryClient, devID)
} }
@ -139,7 +139,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
} }
// Lookup returns the list of addresses where the given device is available // Lookup returns the list of addresses where the given device is available
func (c *globalClient) Lookup(device protocol.DeviceID) (addresses []string, err error) { func (c *globalClient) Lookup(ctx context.Context, device protocol.DeviceID) (addresses []string, err error) {
if c.noLookup { if c.noLookup {
return nil, lookupError{ return nil, lookupError{
error: errors.New("lookups not supported"), error: errors.New("lookups not supported"),
@ -156,7 +156,7 @@ func (c *globalClient) Lookup(device protocol.DeviceID) (addresses []string, err
q.Set("device", device.String()) q.Set("device", device.String())
qURL.RawQuery = q.Encode() qURL.RawQuery = q.Encode()
resp, err := c.queryClient.Get(qURL.String()) resp, err := c.queryClient.Get(ctx, qURL.String())
if err != nil { if err != nil {
l.Debugln("globalClient.Lookup", qURL, err) l.Debugln("globalClient.Lookup", qURL, err)
return nil, err return nil, err
@ -211,7 +211,7 @@ func (c *globalClient) serve(ctx context.Context) {
timer.Reset(2 * time.Second) timer.Reset(2 * time.Second)
case <-timer.C: case <-timer.C:
c.sendAnnouncement(timer) c.sendAnnouncement(ctx, timer)
case <-ctx.Done(): case <-ctx.Done():
return return
@ -219,7 +219,7 @@ func (c *globalClient) serve(ctx context.Context) {
} }
} }
func (c *globalClient) sendAnnouncement(timer *time.Timer) { func (c *globalClient) sendAnnouncement(ctx context.Context, timer *time.Timer) {
var ann announcement var ann announcement
if c.addrList != nil { if c.addrList != nil {
ann.Addresses = c.addrList.ExternalAddresses() ann.Addresses = c.addrList.ExternalAddresses()
@ -239,7 +239,7 @@ func (c *globalClient) sendAnnouncement(timer *time.Timer) {
l.Debugf("Announcement: %s", postData) l.Debugf("Announcement: %s", postData)
resp, err := c.announceClient.Post(c.server, "application/json", bytes.NewReader(postData)) resp, err := c.announceClient.Post(ctx, c.server, "application/json", bytes.NewReader(postData))
if err != nil { if err != nil {
l.Debugln("announce POST:", err) l.Debugln("announce POST:", err)
c.setError(err) c.setError(err)
@ -362,8 +362,8 @@ func (c *idCheckingHTTPClient) check(resp *http.Response) error {
return nil return nil
} }
func (c *idCheckingHTTPClient) Get(url string) (*http.Response, error) { func (c *idCheckingHTTPClient) Get(ctx context.Context, url string) (*http.Response, error) {
resp, err := c.httpClient.Get(url) resp, err := c.httpClient.Get(ctx, url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -374,8 +374,8 @@ func (c *idCheckingHTTPClient) Get(url string) (*http.Response, error) {
return resp, nil return resp, nil
} }
func (c *idCheckingHTTPClient) Post(url, ctype string, data io.Reader) (*http.Response, error) { func (c *idCheckingHTTPClient) Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error) {
resp, err := c.httpClient.Post(url, ctype, data) resp, err := c.httpClient.Post(ctx, url, ctype, data)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -403,3 +403,32 @@ func (e *errorHolder) Error() error {
e.mut.Unlock() e.mut.Unlock()
return err return err
} }
type contextClient struct {
*http.Client
}
func (c *contextClient) Get(ctx context.Context, url string) (*http.Response, error) {
// For <go1.13 compatibility. Use the following commented line once that
// isn't required anymore.
// req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
req.Cancel = ctx.Done()
return c.Client.Do(req)
}
func (c *contextClient) Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error) {
// For <go1.13 compatibility. Use the following commented line once that
// isn't required anymore.
// req, err := http.NewRequestWithContext(ctx, "POST", url, data)
req, err := http.NewRequest("POST", url, data)
if err != nil {
return nil, err
}
req.Cancel = ctx.Done()
req.Header.Set("Content-Type", ctype)
return c.Client.Do(req)
}

View File

@ -7,6 +7,7 @@
package discover package discover
import ( import (
"context"
"crypto/tls" "crypto/tls"
"io/ioutil" "io/ioutil"
"net" "net"
@ -225,7 +226,7 @@ func testLookup(url string) ([]string, error) {
go disco.Serve() go disco.Serve()
defer disco.Stop() defer disco.Stop()
return disco.Lookup(protocol.LocalDeviceID) return disco.Lookup(context.Background(), protocol.LocalDeviceID)
} }
type fakeDiscoveryServer struct { type fakeDiscoveryServer struct {

View File

@ -91,7 +91,7 @@ func NewLocal(id protocol.DeviceID, addr string, addrList AddressLister, evLogge
} }
// Lookup returns a list of addresses the device is available at. // Lookup returns a list of addresses the device is available at.
func (c *localClient) Lookup(device protocol.DeviceID) (addresses []string, err error) { func (c *localClient) Lookup(_ context.Context, device protocol.DeviceID) (addresses []string, err error) {
if cache, ok := c.Get(device); ok { if cache, ok := c.Get(device); ok {
if time.Since(cache.when) < CacheLifeTime { if time.Since(cache.when) < CacheLifeTime {
addresses = cache.Addresses addresses = cache.Addresses