diff --git a/cmd/stfinddevice/main.go b/cmd/stfinddevice/main.go index 7790ba0e6..688c29b2c 100644 --- a/cmd/stfinddevice/main.go +++ b/cmd/stfinddevice/main.go @@ -7,6 +7,7 @@ package main import ( + "context" "crypto/tls" "errors" "flag" @@ -95,7 +96,7 @@ func checkServer(deviceID protocol.DeviceID, server string) checkResult { }) go func() { - addresses, err := disco.Lookup(deviceID) + addresses, err := disco.Lookup(context.Background(), deviceID) res <- checkResult{addresses: addresses, error: err} }() diff --git a/lib/api/mocked_discovery_test.go b/lib/api/mocked_discovery_test.go index 885dfc804..eb0c11a20 100644 --- a/lib/api/mocked_discovery_test.go +++ b/lib/api/mocked_discovery_test.go @@ -7,6 +7,7 @@ package api import ( + "context" "time" "github.com/syncthing/syncthing/lib/discover" @@ -26,7 +27,7 @@ func (m *mockedCachingMux) Stop() { // from events.Finder -func (m *mockedCachingMux) Lookup(deviceID protocol.DeviceID) (direct []string, err error) { +func (m *mockedCachingMux) Lookup(ctx context.Context, deviceID protocol.DeviceID) (direct []string, err error) { return nil, nil } diff --git a/lib/connections/service.go b/lib/connections/service.go index 69542d23a..0fce09d2d 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -360,6 +360,12 @@ func (s *service) connect(ctx context.Context) { var seen []string for _, deviceCfg := range cfg.Devices { + select { + case <-ctx.Done(): + return + default: + } + deviceID := deviceCfg.DeviceID if deviceID == s.myID { continue @@ -380,7 +386,7 @@ func (s *service) connect(ctx context.Context) { for _, addr := range deviceCfg.Addresses { if addr == "dynamic" { if s.discoverer != nil { - if t, err := s.discoverer.Lookup(deviceID); err == nil { + if t, err := s.discoverer.Lookup(ctx, deviceID); err == nil { addrs = append(addrs, t...) } } diff --git a/lib/discover/cache.go b/lib/discover/cache.go index 230d89fdc..fc74fe85c 100644 --- a/lib/discover/cache.go +++ b/lib/discover/cache.go @@ -7,6 +7,7 @@ package discover import ( + "context" "sort" stdsync "sync" "time" @@ -73,7 +74,7 @@ func (m *cachingMux) Add(finder Finder, cacheTime, negCacheTime time.Duration) { // Lookup attempts to resolve the device ID using any of the added Finders, // while obeying the cache settings. -func (m *cachingMux) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) { +func (m *cachingMux) Lookup(ctx context.Context, deviceID protocol.DeviceID) (addresses []string, err error) { m.mut.RLock() for i, finder := range m.finders { if cacheEntry, ok := m.caches[i].Get(deviceID); ok { @@ -99,7 +100,7 @@ func (m *cachingMux) Lookup(deviceID protocol.DeviceID) (addresses []string, err } // Perform the actual lookup and cache the result. - if addrs, err := finder.Lookup(deviceID); err == nil { + if addrs, err := finder.Lookup(ctx, deviceID); err == nil { l.Debugln("lookup for", deviceID, "at", finder) l.Debugln(" addresses:", addrs) addresses = append(addresses, addrs...) diff --git a/lib/discover/cache_test.go b/lib/discover/cache_test.go index 69d85eebe..e44c18446 100644 --- a/lib/discover/cache_test.go +++ b/lib/discover/cache_test.go @@ -7,6 +7,7 @@ package discover import ( + "context" "reflect" "testing" "time" @@ -39,7 +40,9 @@ func TestCacheUnique(t *testing.T) { f1 := &fakeDiscovery{addresses0} c.Add(f1, time.Minute, 0) - addr, err := c.Lookup(protocol.LocalDeviceID) + ctx := context.Background() + + addr, err := c.Lookup(ctx, protocol.LocalDeviceID) if err != nil { t.Fatal(err) } @@ -53,7 +56,7 @@ func TestCacheUnique(t *testing.T) { f2 := &fakeDiscovery{addresses1} c.Add(f2, time.Minute, 0) - addr, err = c.Lookup(protocol.LocalDeviceID) + addr, err = c.Lookup(ctx, protocol.LocalDeviceID) if err != nil { t.Fatal(err) } @@ -66,7 +69,7 @@ type fakeDiscovery struct { addresses []string } -func (f *fakeDiscovery) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) { +func (f *fakeDiscovery) Lookup(_ context.Context, deviceID protocol.DeviceID) (addresses []string, err error) { return f.addresses, nil } @@ -96,7 +99,7 @@ func TestCacheSlowLookup(t *testing.T) { // Start a lookup, which will take at least a second t0 := time.Now() - go c.Lookup(protocol.LocalDeviceID) + go c.Lookup(context.Background(), protocol.LocalDeviceID) <-started // The slow lookup method has been called so we're inside the lock // It should be possible to get ChildErrors while it's running @@ -116,7 +119,7 @@ type slowDiscovery struct { started chan struct{} } -func (f *slowDiscovery) Lookup(deviceID protocol.DeviceID) (addresses []string, err error) { +func (f *slowDiscovery) Lookup(_ context.Context, deviceID protocol.DeviceID) (addresses []string, err error) { close(f.started) time.Sleep(f.delay) return nil, nil diff --git a/lib/discover/discover.go b/lib/discover/discover.go index 4d4821332..383049693 100644 --- a/lib/discover/discover.go +++ b/lib/discover/discover.go @@ -7,6 +7,7 @@ package discover import ( + "context" "time" "github.com/syncthing/syncthing/lib/protocol" @@ -15,7 +16,7 @@ import ( // A Finder provides lookup services of some kind. type Finder interface { - Lookup(deviceID protocol.DeviceID) (address []string, err error) + Lookup(ctx context.Context, deviceID protocol.DeviceID) (address []string, err error) Error() error String() string Cache() map[protocol.DeviceID]CacheEntry diff --git a/lib/discover/global.go b/lib/discover/global.go index a95c8af92..cc58a2b72 100644 --- a/lib/discover/global.go +++ b/lib/discover/global.go @@ -41,8 +41,8 @@ type globalClient struct { } type httpClient interface { - Get(url string) (*http.Response, error) - Post(url, ctype string, data io.Reader) (*http.Response, error) + Get(ctx context.Context, url string) (*http.Response, error) + Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error) } const ( @@ -89,7 +89,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo // The http.Client used for announcements. It needs to have our // certificate to prove our identity, and may or may not verify the server // certificate depending on the insecure setting. - var announceClient httpClient = &http.Client{ + var announceClient httpClient = &contextClient{&http.Client{ Timeout: requestTimeout, Transport: &http.Transport{ DialContext: dialer.DialContext, @@ -99,14 +99,14 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo Certificates: []tls.Certificate{cert}, }, }, - } + }} if opts.id != "" { announceClient = newIDCheckingHTTPClient(announceClient, devID) } // The http.Client used for queries. We don't need to present our // certificate here, so lets not include it. May be insecure if requested. - var queryClient httpClient = &http.Client{ + var queryClient httpClient = &contextClient{&http.Client{ Timeout: requestTimeout, Transport: &http.Transport{ DialContext: dialer.DialContext, @@ -115,7 +115,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo InsecureSkipVerify: opts.insecure, }, }, - } + }} if opts.id != "" { queryClient = newIDCheckingHTTPClient(queryClient, devID) } @@ -139,7 +139,7 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo } // Lookup returns the list of addresses where the given device is available -func (c *globalClient) Lookup(device protocol.DeviceID) (addresses []string, err error) { +func (c *globalClient) Lookup(ctx context.Context, device protocol.DeviceID) (addresses []string, err error) { if c.noLookup { return nil, lookupError{ error: errors.New("lookups not supported"), @@ -156,7 +156,7 @@ func (c *globalClient) Lookup(device protocol.DeviceID) (addresses []string, err q.Set("device", device.String()) qURL.RawQuery = q.Encode() - resp, err := c.queryClient.Get(qURL.String()) + resp, err := c.queryClient.Get(ctx, qURL.String()) if err != nil { l.Debugln("globalClient.Lookup", qURL, err) return nil, err @@ -211,7 +211,7 @@ func (c *globalClient) serve(ctx context.Context) { timer.Reset(2 * time.Second) case <-timer.C: - c.sendAnnouncement(timer) + c.sendAnnouncement(ctx, timer) case <-ctx.Done(): return @@ -219,7 +219,7 @@ func (c *globalClient) serve(ctx context.Context) { } } -func (c *globalClient) sendAnnouncement(timer *time.Timer) { +func (c *globalClient) sendAnnouncement(ctx context.Context, timer *time.Timer) { var ann announcement if c.addrList != nil { ann.Addresses = c.addrList.ExternalAddresses() @@ -239,7 +239,7 @@ func (c *globalClient) sendAnnouncement(timer *time.Timer) { l.Debugf("Announcement: %s", postData) - resp, err := c.announceClient.Post(c.server, "application/json", bytes.NewReader(postData)) + resp, err := c.announceClient.Post(ctx, c.server, "application/json", bytes.NewReader(postData)) if err != nil { l.Debugln("announce POST:", err) c.setError(err) @@ -362,8 +362,8 @@ func (c *idCheckingHTTPClient) check(resp *http.Response) error { return nil } -func (c *idCheckingHTTPClient) Get(url string) (*http.Response, error) { - resp, err := c.httpClient.Get(url) +func (c *idCheckingHTTPClient) Get(ctx context.Context, url string) (*http.Response, error) { + resp, err := c.httpClient.Get(ctx, url) if err != nil { return nil, err } @@ -374,8 +374,8 @@ func (c *idCheckingHTTPClient) Get(url string) (*http.Response, error) { return resp, nil } -func (c *idCheckingHTTPClient) Post(url, ctype string, data io.Reader) (*http.Response, error) { - resp, err := c.httpClient.Post(url, ctype, data) +func (c *idCheckingHTTPClient) Post(ctx context.Context, url, ctype string, data io.Reader) (*http.Response, error) { + resp, err := c.httpClient.Post(ctx, url, ctype, data) if err != nil { return nil, err } @@ -403,3 +403,32 @@ func (e *errorHolder) Error() error { e.mut.Unlock() return err } + +type contextClient struct { + *http.Client +} + +func (c *contextClient) Get(ctx context.Context, url string) (*http.Response, error) { + // For