From 7c292cc812320c01675697f50f0f419f7fc93bb7 Mon Sep 17 00:00:00 2001 From: greatroar <61184462+greatroar@users.noreply.github.com> Date: Wed, 6 Oct 2021 10:52:51 +0200 Subject: [PATCH] lib/connections: Fix and optimize registry (#7996) Registry.Get used a full sort to get the minimum of a list, and the sort was broken because util.AddressUnspecifiedLess assumed it could find out whether an address is IPv4 or IPv6 from its Network method. However, net.(TCP|UDP)Addr.Network always returns "tcp"/"udp". --- lib/connections/quic_dial.go | 3 +- lib/connections/quic_misc.go | 10 +++-- lib/connections/registry/registry.go | 44 +++++++++++-------- lib/connections/registry/registry_test.go | 51 +++++++++++++++++------ lib/dialer/internal.go | 6 --- lib/dialer/public.go | 4 +- lib/util/utils.go | 20 --------- lib/util/utils_test.go | 44 ------------------- 8 files changed, 76 insertions(+), 106 deletions(-) diff --git a/lib/connections/quic_dial.go b/lib/connections/quic_dial.go index 58eb1ff60..d47d3b84d 100644 --- a/lib/connections/quic_dial.go +++ b/lib/connections/quic_dial.go @@ -58,7 +58,8 @@ func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL // Given we always pass the connection to quic, it assumes it's a remote connection it never closes it, // So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection. var createdConn net.PacketConn - if listenConn := registry.Get(uri.Scheme, packetConnLess); listenConn != nil { + listenConn := registry.Get(uri.Scheme, packetConnUnspecified) + if listenConn != nil { conn = listenConn.(net.PacketConn) } else { if packetConn, err := net.ListenPacket("udp", ":0"); err != nil { diff --git a/lib/connections/quic_misc.go b/lib/connections/quic_misc.go index 5014dd259..fb8e7044f 100644 --- a/lib/connections/quic_misc.go +++ b/lib/connections/quic_misc.go @@ -15,7 +15,6 @@ import ( "net/url" "github.com/lucas-clemente/quic-go" - "github.com/syncthing/syncthing/lib/util" ) var ( @@ -63,7 +62,10 @@ func (q *quicTlsConn) ConnectionState() tls.ConnectionState { return q.Session.ConnectionState().TLS.ConnectionState } -// Sort available packet connections by ip address, preferring unspecified local address. -func packetConnLess(i interface{}, j interface{}) bool { - return util.AddressUnspecifiedLess(i.(net.PacketConn).LocalAddr(), j.(net.PacketConn).LocalAddr()) +func packetConnUnspecified(conn interface{}) bool { + // Since QUIC connections are wrapped, we can't do a simple typecheck + // on *net.UDPAddr here. + addr := conn.(net.PacketConn).LocalAddr() + host, _, err := net.SplitHostPort(addr.String()) + return err == nil && net.ParseIP(host).IsUnspecified() } diff --git a/lib/connections/registry/registry.go b/lib/connections/registry/registry.go index 42dcf699b..4d9be9510 100644 --- a/lib/connections/registry/registry.go +++ b/lib/connections/registry/registry.go @@ -10,7 +10,6 @@ package registry import ( - "sort" "strings" "github.com/syncthing/syncthing/lib/sync" @@ -46,7 +45,7 @@ func (r *Registry) Unregister(scheme string, item interface{}) { candidates := r.available[scheme] for i, existingItem := range candidates { if existingItem == item { - copy(candidates[i:], candidates[i+1:]) + candidates[i] = candidates[len(candidates)-1] candidates[len(candidates)-1] = nil r.available[scheme] = candidates[:len(candidates)-1] break @@ -54,26 +53,37 @@ func (r *Registry) Unregister(scheme string, item interface{}) { } } -func (r *Registry) Get(scheme string, less func(i, j interface{}) bool) interface{} { +// Get returns an item for a schema compatible with the given scheme. +// If any item satisfies preferred, that has precedence over other items. +func (r *Registry) Get(scheme string, preferred func(interface{}) bool) interface{} { r.mut.Lock() defer r.mut.Unlock() - candidates := make([]interface{}, 0) + var ( + best interface{} + bestPref bool + bestScheme string + ) for availableScheme, items := range r.available { // quic:// should be considered ok for both quic4:// and quic6:// - if strings.HasPrefix(scheme, availableScheme) { - candidates = append(candidates, items...) + if !strings.HasPrefix(scheme, availableScheme) { + continue + } + for _, item := range items { + better := best == nil + pref := preferred(item) + if !better { + // In case of a tie, prefer "quic" to "quic[46]" etc. + better = pref && + (!bestPref || len(availableScheme) < len(bestScheme)) + } + if !better { + continue + } + best, bestPref, bestScheme = item, pref, availableScheme } } - - if len(candidates) == 0 { - return nil - } - - sort.Slice(candidates, func(i, j int) bool { - return less(candidates[i], candidates[j]) - }) - return candidates[0] + return best } func Register(scheme string, item interface{}) { @@ -84,6 +94,6 @@ func Unregister(scheme string, item interface{}) { Default.Unregister(scheme, item) } -func Get(scheme string, less func(i, j interface{}) bool) interface{} { - return Default.Get(scheme, less) +func Get(scheme string, preferred func(interface{}) bool) interface{} { + return Default.Get(scheme, preferred) } diff --git a/lib/connections/registry/registry_test.go b/lib/connections/registry/registry_test.go index f62a5f5c7..bf7a2e24a 100644 --- a/lib/connections/registry/registry_test.go +++ b/lib/connections/registry/registry_test.go @@ -7,13 +7,18 @@ package registry import ( + "net" "testing" ) func TestRegistry(t *testing.T) { r := New() - if res := r.Get("int", intLess); res != nil { + want := func(i int) func(interface{}) bool { + return func(x interface{}) bool { return x.(int) == i } + } + + if res := r.Get("int", want(1)); res != nil { t.Error("unexpected") } @@ -24,30 +29,28 @@ func TestRegistry(t *testing.T) { r.Register("int6", 6) r.Register("int6", 66) - if res := r.Get("int", intLess).(int); res != 1 { + if res := r.Get("int", want(1)).(int); res != 1 { t.Error("unexpected", res) } // int is prefix of int4, so returns 1 - if res := r.Get("int4", intLess).(int); res != 1 { + if res := r.Get("int4", want(1)).(int); res != 1 { t.Error("unexpected", res) } r.Unregister("int", 1) - // Check that falls through to 11 - if res := r.Get("int", intLess).(int); res != 11 { + if res := r.Get("int", want(1)).(int); res == 1 { t.Error("unexpected", res) } - // 6 is smaller than 11 available in int. - if res := r.Get("int6", intLess).(int); res != 6 { + if res := r.Get("int6", want(6)).(int); res != 6 { t.Error("unexpected", res) } // Unregister 11, int should be impossible to find r.Unregister("int", 11) - if res := r.Get("int", intLess); res != nil { + if res := r.Get("int", want(11)); res != nil { t.Error("unexpected") } @@ -59,13 +62,35 @@ func TestRegistry(t *testing.T) { r.Register("int", 1) r.Unregister("int", 1) - if res := r.Get("int4", intLess).(int); res != 1 { + if res := r.Get("int4", want(1)).(int); res != 1 { t.Error("unexpected", res) } } -func intLess(i, j interface{}) bool { - iInt := i.(int) - jInt := j.(int) - return iInt < jInt +func TestShortSchemeFirst(t *testing.T) { + r := New() + r.Register("foo", 0) + r.Register("foobar", 1) + + // If we don't care about the value, we should get the one with "foo". + res := r.Get("foo", func(interface{}) bool { return false }) + if res != 0 { + t.Error("unexpected", res) + } +} + +func BenchmarkGet(b *testing.B) { + r := New() + for _, addr := range []string{"192.168.1.1", "172.1.1.1", "10.1.1.1"} { + r.Register("tcp", &net.TCPAddr{IP: net.ParseIP(addr)}) + } + + b.ReportAllocs() + b.ResetTimer() + + for i := 0; i < b.N; i++ { + r.Get("tcp", func(x interface{}) bool { + return x.(*net.TCPAddr).IP.IsUnspecified() + }) + } } diff --git a/lib/dialer/internal.go b/lib/dialer/internal.go index ac3d10f9a..1eb44ca94 100644 --- a/lib/dialer/internal.go +++ b/lib/dialer/internal.go @@ -13,7 +13,6 @@ import ( "os" "time" - "github.com/syncthing/syncthing/lib/util" "golang.org/x/net/proxy" ) @@ -61,11 +60,6 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error) return proxy.SOCKS5("tcp", u.Host, auth, forward) } -// Sort available addresses, preferring unspecified address. -func tcpAddrLess(i interface{}, j interface{}) bool { - return util.AddressUnspecifiedLess(i.(*net.TCPAddr), j.(*net.TCPAddr)) -} - // dialerConn is needed because proxy dialed connections have RemoteAddr() pointing at the proxy, // which then screws up various things such as IsLAN checks, and "let's populate the relay invitation address from // existing connection" shenanigans. diff --git a/lib/dialer/public.go b/lib/dialer/public.go index 8326e5626..cfd632ced 100644 --- a/lib/dialer/public.go +++ b/lib/dialer/public.go @@ -110,7 +110,9 @@ func DialContextReusePort(ctx context.Context, network, addr string) (net.Conn, return DialContext(ctx, network, addr) } - localAddrInterface := registry.Get(network, tcpAddrLess) + localAddrInterface := registry.Get(network, func(addr interface{}) bool { + return addr.(*net.TCPAddr).IP.IsUnspecified() + }) if localAddrInterface == nil { // Nothing listening, nothing to reuse. return DialContext(ctx, network, addr) diff --git a/lib/util/utils.go b/lib/util/utils.go index 57b317d28..43fafde78 100644 --- a/lib/util/utils.go +++ b/lib/util/utils.go @@ -9,7 +9,6 @@ package util import ( "context" "fmt" - "net" "net/url" "reflect" "strconv" @@ -231,25 +230,6 @@ func Address(network, host string) string { return u.String() } -// AddressUnspecifiedLess is a comparator function preferring least specific network address (most widely listening, -// namely preferring 0.0.0.0 over some IP), if both IPs are equal, it prefers the less restrictive network (prefers tcp -// over tcp4) -func AddressUnspecifiedLess(a, b net.Addr) bool { - aIsUnspecified := false - bIsUnspecified := false - if host, _, err := net.SplitHostPort(a.String()); err == nil { - aIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified() - } - if host, _, err := net.SplitHostPort(b.String()); err == nil { - bIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified() - } - - if aIsUnspecified == bIsUnspecified { - return len(a.Network()) < len(b.Network()) - } - return aIsUnspecified -} - func CallWithContext(ctx context.Context, fn func() error) error { var err error done := make(chan struct{}) diff --git a/lib/util/utils_test.go b/lib/util/utils_test.go index 40099c376..d16ef2f48 100644 --- a/lib/util/utils_test.go +++ b/lib/util/utils_test.go @@ -225,50 +225,6 @@ func TestCopyMatching(t *testing.T) { } } -type mockedAddr struct { - network string - addr string -} - -func (a mockedAddr) Network() string { - return a.network -} - -func (a mockedAddr) String() string { - return a.addr -} - -func TestInspecifiedAddressLess(t *testing.T) { - cases := []struct { - netA string - addrA string - netB string - addrB string - }{ - // B is assumed the winner. - {"tcp", "127.0.0.1:1234", "tcp", ":1235"}, - {"tcp", "127.0.0.1:1234", "tcp", "0.0.0.0:1235"}, - {"tcp4", "0.0.0.0:1234", "tcp", "0.0.0.0:1235"}, // tcp4 on the first one - } - - for i, testCase := range cases { - addrs := []mockedAddr{ - {testCase.netA, testCase.addrA}, - {testCase.netB, testCase.addrB}, - } - - if AddressUnspecifiedLess(addrs[0], addrs[1]) { - t.Error(i, "unexpected") - } - if !AddressUnspecifiedLess(addrs[1], addrs[0]) { - t.Error(i, "unexpected") - } - if AddressUnspecifiedLess(addrs[0], addrs[0]) || AddressUnspecifiedLess(addrs[1], addrs[1]) { - t.Error(i, "unexpected") - } - } -} - func TestFillNil(t *testing.T) { type A struct { Slice []int