From 1bae4b7f501679dffbc14dad45ce0cd42e9f4de0 Mon Sep 17 00:00:00 2001 From: Simon Frei Date: Tue, 26 Nov 2019 08:39:51 +0100 Subject: [PATCH] all: Use context in lib/dialer (#6177) * all: Use context in lib/dialer * a bit slimmer * https://github.com/syncthing/syncthing/pull/5753 * bot * missed adding debug.go * errors.Cause * simultaneous dialing * anti-leak --- cmd/strelaypoolsrv/main.go | 3 +- cmd/strelaysrv/testutil/main.go | 11 ++-- cmd/syncthing/main.go | 2 +- lib/connections/quic_dial.go | 4 +- lib/connections/relay_dial.go | 7 ++- lib/connections/relay_listen.go | 8 ++- lib/connections/service.go | 12 ++-- lib/connections/structs.go | 7 ++- lib/connections/tcp_dial.go | 7 ++- lib/dialer/debug.go | 23 +++++++ lib/dialer/internal.go | 105 +------------------------------ lib/dialer/public.go | 94 ++++++++++++++------------- lib/discover/global.go | 8 +-- lib/nat/registry.go | 4 +- lib/osutil/ping.go | 11 ++-- lib/pmp/pmp.go | 7 ++- lib/rc/rc.go | 2 +- lib/relay/client/dynamic.go | 2 +- lib/relay/client/methods.go | 17 +++-- lib/relay/client/static.go | 8 ++- lib/upgrade/upgrade_supported.go | 4 +- lib/upnp/upnp.go | 26 +++++--- lib/ur/usage_report.go | 4 +- lib/util/utils.go | 3 + 24 files changed, 175 insertions(+), 204 deletions(-) create mode 100644 lib/dialer/debug.go diff --git a/cmd/strelaypoolsrv/main.go b/cmd/strelaypoolsrv/main.go index 1d3bb37e1..da3f7f08f 100644 --- a/cmd/strelaypoolsrv/main.go +++ b/cmd/strelaypoolsrv/main.go @@ -7,6 +7,7 @@ package main import ( "bytes" "compress/gzip" + "context" "crypto/tls" "encoding/json" "flag" @@ -480,7 +481,7 @@ func handleRelayTest(request request) { if debug { log.Println("Request for", request.relay) } - if !client.TestRelay(request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3) { + if !client.TestRelay(context.TODO(), request.relay.uri, []tls.Certificate{testCert}, time.Second, 2*time.Second, 3) { if debug { log.Println("Test for relay", request.relay, "failed") } diff --git a/cmd/strelaysrv/testutil/main.go b/cmd/strelaysrv/testutil/main.go index 2db410d42..2818b7019 100644 --- a/cmd/strelaysrv/testutil/main.go +++ b/cmd/strelaysrv/testutil/main.go @@ -4,6 +4,7 @@ package main import ( "bufio" + "context" "crypto/tls" "flag" "log" @@ -19,6 +20,8 @@ import ( ) func main() { + ctx := context.Background() + log.SetOutput(os.Stdout) log.SetFlags(log.LstdFlags | log.Lshortfile) @@ -76,7 +79,7 @@ func main() { }() for { - conn, err := client.JoinSession(<-recv) + conn, err := client.JoinSession(ctx, <-recv) if err != nil { log.Fatalln("Failed to join", err) } @@ -90,13 +93,13 @@ func main() { log.Fatal(err) } - invite, err := client.GetInvitationFromRelay(uri, id, []tls.Certificate{cert}, 10*time.Second) + invite, err := client.GetInvitationFromRelay(ctx, uri, id, []tls.Certificate{cert}, 10*time.Second) if err != nil { log.Fatal(err) } log.Println("Received invitation", invite) - conn, err := client.JoinSession(invite) + conn, err := client.JoinSession(ctx, invite) if err != nil { log.Fatalln("Failed to join", err) } @@ -104,7 +107,7 @@ func main() { connectToStdio(stdin, conn) log.Println("Finished", conn.RemoteAddr(), conn.LocalAddr()) } else if test { - if client.TestRelay(uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4) { + if client.TestRelay(ctx, uri, []tls.Certificate{cert}, time.Second, 2*time.Second, 4) { log.Println("OK") } else { log.Println("FAIL") diff --git a/cmd/syncthing/main.go b/cmd/syncthing/main.go index 578f002eb..2f2fb49bd 100644 --- a/cmd/syncthing/main.go +++ b/cmd/syncthing/main.go @@ -512,7 +512,7 @@ func upgradeViaRest() error { r.Header.Set("X-API-Key", cfg.GUI().APIKey) tr := &http.Transport{ - Dial: dialer.Dial, + DialContext: dialer.DialContext, Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, } diff --git a/lib/connections/quic_dial.go b/lib/connections/quic_dial.go index 7121b8325..d4b83c649 100644 --- a/lib/connections/quic_dial.go +++ b/lib/connections/quic_dial.go @@ -42,7 +42,7 @@ type quicDialer struct { commonDialer } -func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) { +func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL) (internalConn, error) { uri = fixupPort(uri, config.DefaultQUICPort) addr, err := net.ResolveUDPAddr("udp", uri.Host) @@ -66,7 +66,7 @@ func (d *quicDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, erro } } - ctx, cancel := context.WithTimeout(context.Background(), quicOperationTimeout) + ctx, cancel := context.WithTimeout(ctx, quicOperationTimeout) defer cancel() session, err := quic.DialContext(ctx, conn, addr, uri.Host, d.tlsCfg, quicConfig) diff --git a/lib/connections/relay_dial.go b/lib/connections/relay_dial.go index c98a7054d..1d79f3ff5 100644 --- a/lib/connections/relay_dial.go +++ b/lib/connections/relay_dial.go @@ -7,6 +7,7 @@ package connections import ( + "context" "crypto/tls" "net/url" "time" @@ -27,13 +28,13 @@ type relayDialer struct { commonDialer } -func (d *relayDialer) Dial(id protocol.DeviceID, uri *url.URL) (internalConn, error) { - inv, err := client.GetInvitationFromRelay(uri, id, d.tlsCfg.Certificates, 10*time.Second) +func (d *relayDialer) Dial(ctx context.Context, id protocol.DeviceID, uri *url.URL) (internalConn, error) { + inv, err := client.GetInvitationFromRelay(ctx, uri, id, d.tlsCfg.Certificates, 10*time.Second) if err != nil { return internalConn{}, err } - conn, err := client.JoinSession(inv) + conn, err := client.JoinSession(ctx, inv) if err != nil { return internalConn{}, err } diff --git a/lib/connections/relay_listen.go b/lib/connections/relay_listen.go index 81a105db4..c66b344dd 100644 --- a/lib/connections/relay_listen.go +++ b/lib/connections/relay_listen.go @@ -13,6 +13,8 @@ import ( "sync" "time" + "github.com/pkg/errors" + "github.com/syncthing/syncthing/lib/config" "github.com/syncthing/syncthing/lib/dialer" "github.com/syncthing/syncthing/lib/nat" @@ -70,9 +72,11 @@ func (t *relayListener) serve(ctx context.Context) error { return err } - conn, err := client.JoinSession(inv) + conn, err := client.JoinSession(ctx, inv) if err != nil { - l.Infoln("Listen (BEP/relay): joining session:", err) + if errors.Cause(err) != context.Canceled { + l.Infoln("Listen (BEP/relay): joining session:", err) + } continue } diff --git a/lib/connections/service.go b/lib/connections/service.go index 90cb723b1..3b5e2c86b 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -9,7 +9,6 @@ package connections import ( "context" "crypto/tls" - "errors" "fmt" "net" "net/url" @@ -31,6 +30,7 @@ import ( _ "github.com/syncthing/syncthing/lib/pmp" _ "github.com/syncthing/syncthing/lib/upnp" + "github.com/pkg/errors" "github.com/thejerf/suture" "golang.org/x/time/rate" ) @@ -463,7 +463,7 @@ func (s *service) connect(ctx context.Context) { }) } - conn, ok := s.dialParallel(deviceCfg.DeviceID, dialTargets) + conn, ok := s.dialParallel(ctx, deviceCfg.DeviceID, dialTargets) if ok { s.conns <- conn } @@ -701,6 +701,10 @@ func (s *service) ConnectionStatus() map[string]ConnectionStatusEntry { } func (s *service) setConnectionStatus(address string, err error) { + if errors.Cause(err) != context.Canceled { + return + } + status := ConnectionStatusEntry{When: time.Now().UTC().Truncate(time.Second)} if err != nil { errStr := err.Error() @@ -828,7 +832,7 @@ func IsAllowedNetwork(host string, allowed []string) bool { return false } -func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) { +func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) { // Group targets into buckets by priority dialTargetBuckets := make(map[int][]dialTarget, len(dialTargets)) for _, tgt := range dialTargets { @@ -851,7 +855,7 @@ func (s *service) dialParallel(deviceID protocol.DeviceID, dialTargets []dialTar for _, tgt := range tgts { wg.Add(1) go func(tgt dialTarget) { - conn, err := tgt.Dial() + conn, err := tgt.Dial(ctx) if err == nil { // Closes the connection on error err = s.validateIdentity(conn, deviceID) diff --git a/lib/connections/structs.go b/lib/connections/structs.go index 8cbe6dae8..d85adde44 100644 --- a/lib/connections/structs.go +++ b/lib/connections/structs.go @@ -7,6 +7,7 @@ package connections import ( + "context" "crypto/tls" "fmt" "io" @@ -164,7 +165,7 @@ func (d *commonDialer) RedialFrequency() time.Duration { } type genericDialer interface { - Dial(protocol.DeviceID, *url.URL) (internalConn, error) + Dial(context.Context, protocol.DeviceID, *url.URL) (internalConn, error) RedialFrequency() time.Duration } @@ -223,7 +224,7 @@ type dialTarget struct { deviceID protocol.DeviceID } -func (t dialTarget) Dial() (internalConn, error) { +func (t dialTarget) Dial(ctx context.Context) (internalConn, error) { l.Debugln("dialing", t.deviceID, t.uri, "prio", t.priority) - return t.dialer.Dial(t.deviceID, t.uri) + return t.dialer.Dial(ctx, t.deviceID, t.uri) } diff --git a/lib/connections/tcp_dial.go b/lib/connections/tcp_dial.go index a8ba4dcb5..315316796 100644 --- a/lib/connections/tcp_dial.go +++ b/lib/connections/tcp_dial.go @@ -7,6 +7,7 @@ package connections import ( + "context" "crypto/tls" "net/url" "time" @@ -29,10 +30,12 @@ type tcpDialer struct { commonDialer } -func (d *tcpDialer) Dial(_ protocol.DeviceID, uri *url.URL) (internalConn, error) { +func (d *tcpDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL) (internalConn, error) { uri = fixupPort(uri, config.DefaultTCPPort) - conn, err := dialer.DialTimeout(uri.Scheme, uri.Host, 10*time.Second) + timeoutCtx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + conn, err := dialer.DialContext(timeoutCtx, uri.Scheme, uri.Host) if err != nil { return internalConn{}, err } diff --git a/lib/dialer/debug.go b/lib/dialer/debug.go new file mode 100644 index 000000000..9891be4af --- /dev/null +++ b/lib/dialer/debug.go @@ -0,0 +1,23 @@ +// Copyright (C) 2019 The Syncthing Authors. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +package dialer + +import ( + "os" + "strings" + + "github.com/syncthing/syncthing/lib/logger" +) + +var ( + l = logger.DefaultLogger.NewFacility("dialer", "Dialing connections") + // To run before init() of other files that log on init. + _ = func() error { + l.SetDebug("dialer", strings.Contains(os.Getenv("STTRACE"), "dialer") || os.Getenv("STTRACE") == "all") + return nil + }() +) diff --git a/lib/dialer/internal.go b/lib/dialer/internal.go index 31a8d4dad..e67859258 100644 --- a/lib/dialer/internal.go +++ b/lib/dialer/internal.go @@ -7,34 +7,24 @@ package dialer import ( - "net" "net/http" "net/url" "os" "time" "golang.org/x/net/proxy" - - "github.com/syncthing/syncthing/lib/logger" ) var ( - l = logger.DefaultLogger.NewFacility("dialer", "Dialing connections") - proxyDialer proxy.Dialer - usingProxy bool - noFallback = os.Getenv("ALL_PROXY_NO_FALLBACK") != "" + noFallback = os.Getenv("ALL_PROXY_NO_FALLBACK") != "" ) -type dialFunc func(network, addr string) (net.Conn, error) - func init() { proxy.RegisterDialerType("socks", socksDialerFunction) - proxyDialer = getDialer(proxy.Direct) - usingProxy = proxyDialer != proxy.Direct - if usingProxy { + if proxyDialer := proxy.FromEnvironment(); proxyDialer != proxy.Direct { http.DefaultTransport = &http.Transport{ - Dial: Dial, + DialContext: DialContext, Proxy: http.ProxyFromEnvironment, TLSHandshakeTimeout: 10 * time.Second, } @@ -55,31 +45,6 @@ func init() { } } -func dialWithFallback(proxyDialFunc dialFunc, fallbackDialFunc dialFunc, network, addr string) (net.Conn, error) { - conn, err := proxyDialFunc(network, addr) - if err == nil { - l.Debugf("Dialing %s address %s via proxy - success, %s -> %s", network, addr, conn.LocalAddr(), conn.RemoteAddr()) - SetTCPOptions(conn) - return dialerConn{ - conn, newDialerAddr(network, addr), - }, nil - } - l.Debugf("Dialing %s address %s via proxy - error %s", network, addr, err) - - if noFallback { - return conn, err - } - - conn, err = fallbackDialFunc(network, addr) - if err == nil { - l.Debugf("Dialing %s address %s via fallback - success, %s -> %s", network, addr, conn.LocalAddr(), conn.RemoteAddr()) - SetTCPOptions(conn) - } else { - l.Debugf("Dialing %s address %s via fallback - error %s", network, addr, err) - } - return conn, err -} - // This is a rip off of proxy.FromURL for "socks" URL scheme func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error) { var auth *proxy.Auth @@ -93,67 +58,3 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error) return proxy.SOCKS5("tcp", u.Host, auth, forward) } - -// This is a rip off of proxy.FromEnvironment with a custom forward dialer -func getDialer(forward proxy.Dialer) proxy.Dialer { - allProxy := os.Getenv("all_proxy") - if len(allProxy) == 0 { - return forward - } - - proxyURL, err := url.Parse(allProxy) - if err != nil { - return forward - } - prxy, err := proxy.FromURL(proxyURL, forward) - if err != nil { - return forward - } - - noProxy := os.Getenv("no_proxy") - if len(noProxy) == 0 { - return prxy - } - - perHost := proxy.NewPerHost(prxy, forward) - perHost.AddFromString(noProxy) - return perHost -} - -type timeoutDirectDialer struct { - timeout time.Duration -} - -func (d *timeoutDirectDialer) Dial(network, addr string) (net.Conn, error) { - return net.DialTimeout(network, addr, d.timeout) -} - -type dialerConn struct { - net.Conn - addr net.Addr -} - -func (c dialerConn) RemoteAddr() net.Addr { - return c.addr -} - -func newDialerAddr(network, addr string) net.Addr { - netaddr, err := net.ResolveIPAddr(network, addr) - if err == nil { - return netaddr - } - return fallbackAddr{network, addr} -} - -type fallbackAddr struct { - network string - addr string -} - -func (a fallbackAddr) Network() string { - return a.network -} - -func (a fallbackAddr) String() string { - return a.addr -} diff --git a/lib/dialer/public.go b/lib/dialer/public.go index e5663a3a9..30cf6c9ac 100644 --- a/lib/dialer/public.go +++ b/lib/dialer/public.go @@ -7,49 +7,18 @@ package dialer import ( + "context" + "errors" "fmt" "net" "time" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" + "golang.org/x/net/proxy" ) -// Dial tries dialing via proxy if a proxy is configured, and falls back to -// a direct connection if no proxy is defined, or connecting via proxy fails. -func Dial(network, addr string) (net.Conn, error) { - if usingProxy { - return dialWithFallback(proxyDialer.Dial, net.Dial, network, addr) - } - return net.Dial(network, addr) -} - -// DialTimeout tries dialing via proxy with a timeout if a proxy is configured, -// and falls back to a direct connection if no proxy is defined, or connecting -// via proxy fails. The timeout can potentially be applied twice, once trying -// to connect via the proxy connection, and second time trying to connect -// directly. -func DialTimeout(network, addr string, timeout time.Duration) (net.Conn, error) { - if usingProxy { - // Because the proxy package is poorly structured, we have to - // construct a struct that matches proxy.Dialer but has a timeout - // and reconstrcut the proxy dialer using that, in order to be able to - // set a timeout. - dd := &timeoutDirectDialer{ - timeout: timeout, - } - // Check if the dialer we are getting is not timeoutDirectDialer we just - // created. It could happen that usingProxy is true, but getDialer - // returns timeoutDirectDialer due to env vars changing. - if timeoutProxyDialer := getDialer(dd); timeoutProxyDialer != dd { - directDialFunc := func(inetwork, iaddr string) (net.Conn, error) { - return net.DialTimeout(inetwork, iaddr, timeout) - } - return dialWithFallback(timeoutProxyDialer.Dial, directDialFunc, network, addr) - } - } - return net.DialTimeout(network, addr, timeout) -} +var errUnexpectedInterfaceType = errors.New("unexpected interface type") // SetTCPOptions sets our default TCP options on a TCP connection, possibly // digging through dialerConn to extract the *net.TCPConn @@ -70,10 +39,6 @@ func SetTCPOptions(conn net.Conn) error { return err } return nil - - case dialerConn: - return SetTCPOptions(conn.Conn) - default: return fmt.Errorf("unknown connection type %T", conn) } @@ -89,11 +54,54 @@ func SetTrafficClass(conn net.Conn, class int) error { return e1 } return e2 - - case dialerConn: - return SetTrafficClass(conn.Conn, class) - default: return fmt.Errorf("unknown connection type %T", conn) } } + +func dialContextWithFallback(ctx context.Context, fallback proxy.ContextDialer, network, addr string) (net.Conn, error) { + dialer, ok := proxy.FromEnvironment().(proxy.ContextDialer) + if !ok { + return nil, errUnexpectedInterfaceType + } + if dialer == proxy.Direct { + return fallback.DialContext(ctx, network, addr) + } + if noFallback { + return dialer.DialContext(ctx, network, addr) + } + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + var proxyConn, fallbackConn net.Conn + var proxyErr, fallbackErr error + proxyDone := make(chan struct{}) + fallbackDone := make(chan struct{}) + go func() { + proxyConn, proxyErr = dialer.DialContext(ctx, network, addr) + close(proxyDone) + }() + go func() { + fallbackConn, fallbackErr = fallback.DialContext(ctx, network, addr) + close(fallbackDone) + }() + <-proxyDone + if proxyErr == nil { + go func() { + <-fallbackDone + if fallbackErr == nil { + fallbackConn.Close() + } + }() + return proxyConn, nil + } + <-fallbackDone + return fallbackConn, fallbackErr +} + +// DialContext dials via context and/or directly, depending on how it is configured. +// If dialing via proxy and allowing fallback, dialing for both happens simultaneously +// and the proxy connection is returned if successful. +func DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return dialContextWithFallback(ctx, proxy.Direct, network, addr) +} diff --git a/lib/discover/global.go b/lib/discover/global.go index aa072e8e4..a95c8af92 100644 --- a/lib/discover/global.go +++ b/lib/discover/global.go @@ -92,8 +92,8 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo var announceClient httpClient = &http.Client{ Timeout: requestTimeout, Transport: &http.Transport{ - Dial: dialer.Dial, - Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ InsecureSkipVerify: opts.insecure, Certificates: []tls.Certificate{cert}, @@ -109,8 +109,8 @@ func NewGlobal(server string, cert tls.Certificate, addrList AddressLister, evLo var queryClient httpClient = &http.Client{ Timeout: requestTimeout, Transport: &http.Transport{ - Dial: dialer.Dial, - Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ InsecureSkipVerify: opts.insecure, }, diff --git a/lib/nat/registry.go b/lib/nat/registry.go index 889884759..14687c677 100644 --- a/lib/nat/registry.go +++ b/lib/nat/registry.go @@ -12,7 +12,7 @@ import ( "time" ) -type DiscoverFunc func(renewal, timeout time.Duration) []Device +type DiscoverFunc func(ctx context.Context, renewal, timeout time.Duration) []Device var providers []DiscoverFunc @@ -30,7 +30,7 @@ func discoverAll(ctx context.Context, renewal, timeout time.Duration) map[string for _, discoverFunc := range providers { go func(f DiscoverFunc) { defer wg.Done() - for _, dev := range f(renewal, timeout) { + for _, dev := range f(ctx, renewal, timeout) { select { case c <- dev: case <-ctx.Done(): diff --git a/lib/osutil/ping.go b/lib/osutil/ping.go index dc385da65..1ecd3080f 100644 --- a/lib/osutil/ping.go +++ b/lib/osutil/ping.go @@ -7,6 +7,7 @@ package osutil import ( + "context" "net/url" "time" @@ -16,9 +17,11 @@ import ( // TCPPing returns the duration required to establish a TCP connection // to the given host. ICMP packets require root privileges, hence why we use // tcp. -func TCPPing(address string) (time.Duration, error) { +func TCPPing(ctx context.Context, address string) (time.Duration, error) { start := time.Now() - conn, err := dialer.DialTimeout("tcp", address, time.Second) + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + conn, err := dialer.DialContext(ctx, "tcp", address) if conn != nil { conn.Close() } @@ -27,11 +30,11 @@ func TCPPing(address string) (time.Duration, error) { // GetLatencyForURL parses the given URL, tries opening a TCP connection to it // and returns the time it took to establish the connection. -func GetLatencyForURL(addr string) (time.Duration, error) { +func GetLatencyForURL(ctx context.Context, addr string) (time.Duration, error) { uri, err := url.Parse(addr) if err != nil { return 0, err } - return TCPPing(uri.Host) + return TCPPing(ctx, uri.Host) } diff --git a/lib/pmp/pmp.go b/lib/pmp/pmp.go index 1968a0158..2d584a128 100644 --- a/lib/pmp/pmp.go +++ b/lib/pmp/pmp.go @@ -7,6 +7,7 @@ package pmp import ( + "context" "fmt" "net" "strings" @@ -21,7 +22,7 @@ func init() { nat.Register(Discover) } -func Discover(renewal, timeout time.Duration) []nat.Device { +func Discover(ctx context.Context, renewal, timeout time.Duration) []nat.Device { ip, err := gateway.DiscoverGateway() if err != nil { l.Debugln("Failed to discover gateway", err) @@ -44,7 +45,9 @@ func Discover(renewal, timeout time.Duration) []nat.Device { var localIP net.IP // Port comes from the natpmp package - conn, err := net.DialTimeout("udp", net.JoinHostPort(ip.String(), "5351"), timeout) + timeoutCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + conn, err := (&net.Dialer{}).DialContext(timeoutCtx, "udp", net.JoinHostPort(ip.String(), "5351")) if err == nil { conn.Close() localIPAddress, _, err := net.SplitHostPort(conn.LocalAddr().String()) diff --git a/lib/rc/rc.go b/lib/rc/rc.go index 78e07b4d2..39d1a5526 100644 --- a/lib/rc/rc.go +++ b/lib/rc/rc.go @@ -166,7 +166,7 @@ func (p *Process) Get(path string) ([]byte, error) { client := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ - Dial: dialer.Dial, + DialContext: dialer.DialContext, Proxy: http.ProxyFromEnvironment, DisableKeepAlives: true, }, diff --git a/lib/relay/client/dynamic.go b/lib/relay/client/dynamic.go index 8e831ac19..337de16fd 100644 --- a/lib/relay/client/dynamic.go +++ b/lib/relay/client/dynamic.go @@ -153,7 +153,7 @@ func relayAddressesOrder(ctx context.Context, input []string) []string { buckets := make(map[int][]string) for _, relay := range input { - latency, err := osutil.GetLatencyForURL(relay) + latency, err := osutil.GetLatencyForURL(ctx, relay) if err != nil { latency = time.Hour } diff --git a/lib/relay/client/methods.go b/lib/relay/client/methods.go index 6e4f1213e..972fcfec7 100644 --- a/lib/relay/client/methods.go +++ b/lib/relay/client/methods.go @@ -3,6 +3,7 @@ package client import ( + "context" "crypto/tls" "fmt" "net" @@ -16,12 +17,14 @@ import ( "github.com/syncthing/syncthing/lib/relay/protocol" ) -func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate, timeout time.Duration) (protocol.SessionInvitation, error) { +func GetInvitationFromRelay(ctx context.Context, uri *url.URL, id syncthingprotocol.DeviceID, certs []tls.Certificate, timeout time.Duration) (protocol.SessionInvitation, error) { if uri.Scheme != "relay" { return protocol.SessionInvitation{}, fmt.Errorf("Unsupported relay scheme: %v", uri.Scheme) } - rconn, err := dialer.DialTimeout("tcp", uri.Host, timeout) + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + rconn, err := dialer.DialContext(ctx, "tcp", uri.Host) if err != nil { return protocol.SessionInvitation{}, err } @@ -63,10 +66,12 @@ func GetInvitationFromRelay(uri *url.URL, id syncthingprotocol.DeviceID, certs [ } } -func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) { +func JoinSession(ctx context.Context, invitation protocol.SessionInvitation) (net.Conn, error) { addr := net.JoinHostPort(net.IP(invitation.Address).String(), strconv.Itoa(int(invitation.Port))) - conn, err := dialer.Dial("tcp", addr) + ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + defer cancel() + conn, err := dialer.DialContext(ctx, "tcp", addr) if err != nil { return nil, err } @@ -99,7 +104,7 @@ func JoinSession(invitation protocol.SessionInvitation) (net.Conn, error) { } } -func TestRelay(uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) bool { +func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) bool { id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0]) invs := make(chan protocol.SessionInvitation, 1) c, err := NewClient(uri, certs, invs, timeout) @@ -114,7 +119,7 @@ func TestRelay(uri *url.URL, certs []tls.Certificate, sleep, timeout time.Durati }() for i := 0; i < times; i++ { - _, err := GetInvitationFromRelay(uri, id, certs, timeout) + _, err := GetInvitationFromRelay(ctx, uri, id, certs, timeout) if err == nil { return true } diff --git a/lib/relay/client/static.go b/lib/relay/client/static.go index b7227ff1d..f977564ba 100644 --- a/lib/relay/client/static.go +++ b/lib/relay/client/static.go @@ -47,7 +47,7 @@ func newStaticClient(uri *url.URL, certs []tls.Certificate, invitations chan pro } func (c *staticClient) serve(ctx context.Context) error { - if err := c.connect(); err != nil { + if err := c.connect(ctx); err != nil { l.Infof("Could not connect to relay %s: %s", c.uri, err) return err } @@ -146,13 +146,15 @@ func (c *staticClient) URI() *url.URL { return c.uri } -func (c *staticClient) connect() error { +func (c *staticClient) connect(ctx context.Context) error { if c.uri.Scheme != "relay" { return fmt.Errorf("unsupported relay scheme: %v", c.uri.Scheme) } t0 := time.Now() - tcpConn, err := dialer.DialTimeout("tcp", c.uri.Host, c.connectTimeout) + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + tcpConn, err := dialer.DialContext(timeoutCtx, "tcp", c.uri.Host) if err != nil { return err } diff --git a/lib/upgrade/upgrade_supported.go b/lib/upgrade/upgrade_supported.go index 0fadbb3d1..ec8dc7e1d 100644 --- a/lib/upgrade/upgrade_supported.go +++ b/lib/upgrade/upgrade_supported.go @@ -66,8 +66,8 @@ const ( var insecureHTTP = &http.Client{ Timeout: readTimeout, Transport: &http.Transport{ - Dial: dialer.Dial, - Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ InsecureSkipVerify: true, }, diff --git a/lib/upnp/upnp.go b/lib/upnp/upnp.go index 8c7f3873b..98cca31d0 100644 --- a/lib/upnp/upnp.go +++ b/lib/upnp/upnp.go @@ -35,8 +35,8 @@ package upnp import ( "bufio" "bytes" + "context" "encoding/xml" - "errors" "fmt" "io/ioutil" "net" @@ -47,6 +47,8 @@ import ( "sync" "time" + "github.com/pkg/errors" + "github.com/syncthing/syncthing/lib/dialer" "github.com/syncthing/syncthing/lib/nat" ) @@ -83,7 +85,7 @@ func (e UnsupportedDeviceTypeError) Error() string { // Discover discovers UPnP InternetGatewayDevices. // The order in which the devices appear in the results list is not deterministic. -func Discover(renewal, timeout time.Duration) []nat.Device { +func Discover(ctx context.Context, renewal, timeout time.Duration) []nat.Device { var results []nat.Device interfaces, err := net.Interfaces() @@ -105,7 +107,7 @@ func Discover(renewal, timeout time.Duration) []nat.Device { for _, deviceType := range []string{"urn:schemas-upnp-org:device:InternetGatewayDevice:1", "urn:schemas-upnp-org:device:InternetGatewayDevice:2"} { wg.Add(1) go func(intf net.Interface, deviceType string) { - discover(&intf, deviceType, timeout, resultChan) + discover(ctx, &intf, deviceType, timeout, resultChan) wg.Done() }(intf, deviceType) } @@ -135,7 +137,7 @@ nextResult: // Search for UPnP InternetGatewayDevices for seconds. // The order in which the devices appear in the result list is not deterministic -func discover(intf *net.Interface, deviceType string, timeout time.Duration, results chan<- nat.Device) { +func discover(ctx context.Context, intf *net.Interface, deviceType string, timeout time.Duration, results chan<- nat.Device) { ssdp := &net.UDPAddr{IP: []byte{239, 255, 255, 250}, Port: 1900} tpl := `M-SEARCH * HTTP/1.1 @@ -187,13 +189,15 @@ USER-AGENT: syncthing/1.0 } break } - igds, err := parseResponse(deviceType, resp[:n]) + igds, err := parseResponse(ctx, deviceType, resp[:n]) if err != nil { switch err.(type) { case *UnsupportedDeviceTypeError: l.Debugln(err.Error()) default: - l.Infoln("UPnP parse:", err) + if errors.Cause(err) != context.Canceled { + l.Infoln("UPnP parse:", err) + } } continue } @@ -205,7 +209,7 @@ USER-AGENT: syncthing/1.0 l.Debugln("Discovery for device type", deviceType, "on", intf.Name, "finished.") } -func parseResponse(deviceType string, resp []byte) ([]IGDService, error) { +func parseResponse(ctx context.Context, deviceType string, resp []byte) ([]IGDService, error) { l.Debugln("Handling UPnP response:\n\n" + string(resp)) reader := bufio.NewReader(bytes.NewBuffer(resp)) @@ -257,7 +261,7 @@ func parseResponse(deviceType string, resp []byte) ([]IGDService, error) { // We do this in a fairly roundabout way by connecting to the IGD and // checking the address of the local end of the socket. I'm open to // suggestions on a better way to do this... - localIPAddress, err := localIP(deviceDescriptionURL) + localIPAddress, err := localIP(ctx, deviceDescriptionURL) if err != nil { return nil, err } @@ -270,8 +274,10 @@ func parseResponse(deviceType string, resp []byte) ([]IGDService, error) { return services, nil } -func localIP(url *url.URL) (net.IP, error) { - conn, err := dialer.DialTimeout("tcp", url.Host, time.Second) +func localIP(ctx context.Context, url *url.URL) (net.IP, error) { + timeoutCtx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + conn, err := dialer.DialContext(timeoutCtx, "tcp", url.Host) if err != nil { return nil, err } diff --git a/lib/ur/usage_report.go b/lib/ur/usage_report.go index a8d619541..9f3cff3b0 100644 --- a/lib/ur/usage_report.go +++ b/lib/ur/usage_report.go @@ -373,8 +373,8 @@ func (s *Service) sendUsageReport() error { client := &http.Client{ Transport: &http.Transport{ - Dial: dialer.Dial, - Proxy: http.ProxyFromEnvironment, + DialContext: dialer.DialContext, + Proxy: http.ProxyFromEnvironment, TLSClientConfig: &tls.Config{ InsecureSkipVerify: s.cfg.Options().URPostInsecurely, }, diff --git a/lib/util/utils.go b/lib/util/utils.go index 010418f71..7224c6ad5 100644 --- a/lib/util/utils.go +++ b/lib/util/utils.go @@ -236,6 +236,9 @@ func (s *service) Serve() { var err error defer func() { + if err == context.Canceled { + err = nil + } s.mut.Lock() s.err = err close(s.stopped)