lib: Faster termination on exit (ref #6319) (#6329)

This commit is contained in:
Simon Frei 2020-02-13 14:43:00 +01:00 committed by GitHub
parent ca90f4e6af
commit c3637f2191
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 71 additions and 28 deletions

View File

@ -7,6 +7,7 @@
package main
import (
"context"
"crypto/tls"
"errors"
"flag"
@ -95,7 +96,7 @@ func checkServer(deviceID protocol.DeviceID, server string) checkResult {
})
go func() {
addresses, err := disco.Lookup(deviceID)
addresses, err := disco.Lookup(context.Background(), deviceID)
res <- checkResult{addresses: addresses, error: err}
}()

View File

@ -7,6 +7,7 @@
package api
import (
"context"
"time"
"github.com/syncthing/syncthing/lib/discover"
@ -26,7 +27,7 @@ func (m *mockedCachingMux) Stop() {
// from events.Finder
func (m *mockedCachingMux) Lookup(deviceID protocol.DeviceID) (direct []string, err error) {
func (m *mockedCachingMux) Lookup(ctx context.Context, deviceID protocol.DeviceID) (direct []string, err error) {
return nil, nil
}

View File

@ -360,6 +360,12 @@ func (s *service) connect(ctx context.Context) {
var seen []string
for _, deviceCfg := range cfg.Devices {
select {
case <-ctx.Done():
return
default:
}
deviceID := deviceCfg.DeviceID
if deviceID == s.myID {
continue
@ -380,7 +386,7 @@ func (s *service) connect(ctx context.Context) {
for _, addr := range deviceCfg.Addresses {
if addr == "dynamic" {
if s.discoverer != nil {
if t, err := s.discoverer.Lookup(deviceID); err == nil {
if t, err := s.discoverer.Lookup(ctx, deviceID); err == nil {
addrs = append(addrs, t...)
}
}

View File

@ -7,6 +7,7 @@
package discover
import (
"context"
"sort"
stdsync "sync"
"time"
@ -73,7 +74,7 @@ func (m *cachingMux) Add(finder Finder, cacheTime, negCacheTime time.Duration) {
// Lookup attempts to resolve the device ID using any of the added Finders,
// while obeying the cache settings.
func (m *cachingMux) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) {
func (m *cachingMux) Lookup(ctx context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
m.mut.RLock()
for i, finder := range m.finders {
if cacheEntry, ok := m.caches[i].Get(deviceID); ok {
@ -99,7 +100,7 @@ func (m *cachingMux) Lookup(deviceID protocol.DeviceID) (addresses []string, err
}
// Perform the actual lookup and cache the result.
if addrs, err := finder.Lookup(deviceID); err == nil {
if addrs, err := finder.Lookup(ctx, deviceID); err == nil {
l.Debugln("lookup for", deviceID, "at", finder)
l.Debugln(" addresses:", addrs)
addresses = append(addresses, addrs...)

View File

@ -7,6 +7,7 @@
package discover
import (
"context"
"reflect"
"testing"
"time"
@ -39,7 +40,9 @@ func TestCacheUnique(t *testing.T) {
f1 := &fakeDiscovery{addresses0}
c.Add(f1, time.Minute, 0)
addr, err := c.Lookup(protocol.LocalDeviceID)
ctx := context.Background()
addr, err := c.Lookup(ctx, protocol.LocalDeviceID)
if err != nil {
t.Fatal(err)
}
@ -53,7 +56,7 @@ func TestCacheUnique(t *testing.T) {
f2 := &fakeDiscovery{addresses1}
c.Add(f2, time.Minute, 0)
addr, err = c.Lookup(protocol.LocalDeviceID)
addr, err = c.Lookup(ctx, protocol.LocalDeviceID)
if err != nil {
t.Fatal(err)
}
@ -66,7 +69,7 @@ type fakeDiscovery struct {
addresses []string
}
func (f *fakeDiscovery) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) {
func (f *fakeDiscovery) Lookup(_ context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
return f.addresses, nil
}
@ -96,7 +99,7 @@ func TestCacheSlowLookup(t *testing.T) {
// Start a lookup, which will take at least a second
t0 := time.Now()
go c.Lookup(protocol.LocalDeviceID)
go c.Lookup(context.Background(), protocol.LocalDeviceID)
<-started // The slow lookup method has been called so we're inside the lock
// It should be possible to get ChildErrors while it's running
@ -116,7 +119,7 @@ type slowDiscovery struct {
started chan struct{}
}
func (f *slowDiscovery) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) {
func (f *slowDiscovery) Lookup(_ context.Context, deviceID protocol.DeviceID) (addresses []string, err error) {
close(f.started)
time.Sleep(f.delay)
return nil, nil

View File

@ -7,6 +7,7 @@
package discover
import (
"context"
"time"
"github.com/syncthing/syncthing/lib/protocol"
@ -15,7 +16,7 @@ import (
// A Finder provides lookup services of some kind.
type Finder interface {
Lookup(deviceID protocol.DeviceID) (address []string, err error)
Lookup(ctx context.Context, deviceID protocol.DeviceID) (address []string, err error)
Error() error
String() string
Cache() map[protocol.DeviceID]CacheEntry

View File

@ -41,8 +41,8 @@ type globalClient struct {
}
type httpClient interface {
Get(url string) (*http.Response, error)
Post(url, ctype string, data io.Reader) (*http.Response, error)
Get(ctx context.Context, url string) (*http.Response, error)
Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error)
}
const (
@ -89,7 +89,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
// The http.Client used for announcements. It needs to have our
// certificate to prove our identity, and may or may not verify the server
// certificate depending on the insecure setting.
var announceClient httpClient = &http.Client{
var announceClient httpClient = &contextClient{&http.Client{
Timeout: requestTimeout,
Transport: &http.Transport{
DialContext: dialer.DialContext,
@ -99,14 +99,14 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
Certificates: []tls.Certificate{cert},
},
},
}
}}
if opts.id != "" {
announceClient = newIDCheckingHTTPClient(announceClient, devID)
}
// The http.Client used for queries. We don't need to present our
// certificate here, so lets not include it. May be insecure if requested.
var queryClient httpClient = &http.Client{
var queryClient httpClient = &contextClient{&http.Client{
Timeout: requestTimeout,
Transport: &http.Transport{
DialContext: dialer.DialContext,
@ -115,7 +115,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
InsecureSkipVerify: opts.insecure,
},
},
}
}}
if opts.id != "" {
queryClient = newIDCheckingHTTPClient(queryClient, devID)
}
@ -139,7 +139,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo
}
// Lookup returns the list of addresses where the given device is available
func (c *globalClient) Lookup(device protocol.DeviceID) (addresses []string, err error) {
func (c *globalClient) Lookup(ctx context.Context, device protocol.DeviceID) (addresses []string, err error) {
if c.noLookup {
return nil, lookupError{
error: errors.New("lookups not supported"),
@ -156,7 +156,7 @@ func (c *globalClient) Lookup(device protocol.DeviceID) (addresses []string, err
q.Set("device", device.String())
qURL.RawQuery = q.Encode()
resp, err := c.queryClient.Get(qURL.String())
resp, err := c.queryClient.Get(ctx, qURL.String())
if err != nil {
l.Debugln("globalClient.Lookup", qURL, err)
return nil, err
@ -211,7 +211,7 @@ func (c *globalClient) serve(ctx context.Context) {
timer.Reset(2 * time.Second)
case <-timer.C:
c.sendAnnouncement(timer)
c.sendAnnouncement(ctx, timer)
case <-ctx.Done():
return
@ -219,7 +219,7 @@ func (c *globalClient) serve(ctx context.Context) {
}
}
func (c *globalClient) sendAnnouncement(timer *time.Timer) {
func (c *globalClient) sendAnnouncement(ctx context.Context, timer *time.Timer) {
var ann announcement
if c.addrList != nil {
ann.Addresses = c.addrList.ExternalAddresses()
@ -239,7 +239,7 @@ func (c *globalClient) sendAnnouncement(timer *time.Timer) {
l.Debugf("Announcement: %s", postData)
resp, err := c.announceClient.Post(c.server, "application/json", bytes.NewReader(postData))
resp, err := c.announceClient.Post(ctx, c.server, "application/json", bytes.NewReader(postData))
if err != nil {
l.Debugln("announce POST:", err)
c.setError(err)
@ -362,8 +362,8 @@ func (c *idCheckingHTTPClient) check(resp *http.Response) error {
return nil
}
func (c *idCheckingHTTPClient) Get(url string) (*http.Response, error) {
resp, err := c.httpClient.Get(url)
func (c *idCheckingHTTPClient) Get(ctx context.Context, url string) (*http.Response, error) {
resp, err := c.httpClient.Get(ctx, url)
if err != nil {
return nil, err
}
@ -374,8 +374,8 @@ func (c *idCheckingHTTPClient) Get(url string) (*http.Response, error) {
return resp, nil
}
func (c *idCheckingHTTPClient) Post(url, ctype string, data io.Reader) (*http.Response, error) {
resp, err := c.httpClient.Post(url, ctype, data)
func (c *idCheckingHTTPClient) Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error) {
resp, err := c.httpClient.Post(ctx, url, ctype, data)
if err != nil {
return nil, err
}
@ -403,3 +403,32 @@ func (e *errorHolder) Error() error {
e.mut.Unlock()
return err
}
type contextClient struct {
*http.Client
}
func (c *contextClient) Get(ctx context.Context, url string) (*http.Response, error) {
// For <go1.13 compatibility. Use the following commented line once that
// isn't required anymore.
// req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
req, err := http.NewRequest("GET", url, nil)
if err != nil {
return nil, err
}
req.Cancel = ctx.Done()
return c.Client.Do(req)
}
func (c *contextClient) Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error) {
// For <go1.13 compatibility. Use the following commented line once that
// isn't required anymore.
// req, err := http.NewRequestWithContext(ctx, "POST", url, data)
req, err := http.NewRequest("POST", url, data)
if err != nil {
return nil, err
}
req.Cancel = ctx.Done()
req.Header.Set("Content-Type", ctype)
return c.Client.Do(req)
}

View File

@ -7,6 +7,7 @@
package discover
import (
"context"
"crypto/tls"
"io/ioutil"
"net"
@ -225,7 +226,7 @@ func testLookup(url string) ([]string, error) {
go disco.Serve()
defer disco.Stop()
return disco.Lookup(protocol.LocalDeviceID)
return disco.Lookup(context.Background(), protocol.LocalDeviceID)
}
type fakeDiscoveryServer struct {

View File

@ -91,7 +91,7 @@ func NewLocal(id protocol.DeviceID, addr string, addrList AddressLister, evLogge
}
// Lookup returns a list of addresses the device is available at.
func (c *localClient) Lookup(device protocol.DeviceID) (addresses []string, err error) {
func (c *localClient) Lookup(_ context.Context, device protocol.DeviceID) (addresses []string, err error) {
if cache, ok := c.Get(device); ok {
if time.Since(cache.when) < CacheLifeTime {
addresses = cache.Addresses