lib/connections: Fix and optimize registry (#7996)

Registry.Get used a full sort to get the minimum of a list, and the sort
was broken because util.AddressUnspecifiedLess assumed it could find out
whether an address is IPv4 or IPv6 from its Network method. However,
net.(TCP|UDP)Addr.Network always returns "tcp"/"udp".
This commit is contained in:
greatroar 2021-10-06 10:52:51 +02:00 committed by GitHub
parent c94b797f00
commit 7c292cc812
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 76 additions and 106 deletions

View File

@ -58,7 +58,8 @@ func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL
// Given we always pass the connection to quic, it assumes it's a remote connection it never closes it,
// So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection.
var createdConn net.PacketConn
if listenConn := registry.Get(uri.Scheme, packetConnLess); listenConn != nil {
listenConn := registry.Get(uri.Scheme, packetConnUnspecified)
if listenConn != nil {
conn = listenConn.(net.PacketConn)
} else {
if packetConn, err := net.ListenPacket("udp", ":0"); err != nil {

View File

@ -15,7 +15,6 @@ import (
"net/url"
"github.com/lucas-clemente/quic-go"
"github.com/syncthing/syncthing/lib/util"
)
var (
@ -63,7 +62,10 @@ func (q *quicTlsConn) ConnectionState() tls.ConnectionState {
return q.Session.ConnectionState().TLS.ConnectionState
}
// Sort available packet connections by ip address, preferring unspecified local address.
func packetConnLess(i interface{}, j interface{}) bool {
return util.AddressUnspecifiedLess(i.(net.PacketConn).LocalAddr(), j.(net.PacketConn).LocalAddr())
func packetConnUnspecified(conn interface{}) bool {
// Since QUIC connections are wrapped, we can't do a simple typecheck
// on *net.UDPAddr here.
addr := conn.(net.PacketConn).LocalAddr()
host, _, err := net.SplitHostPort(addr.String())
return err == nil && net.ParseIP(host).IsUnspecified()
}

View File

@ -10,7 +10,6 @@
package registry
import (
"sort"
"strings"
"github.com/syncthing/syncthing/lib/sync"
@ -46,7 +45,7 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
candidates := r.available[scheme]
for i, existingItem := range candidates {
if existingItem == item {
copy(candidates[i:], candidates[i+1:])
candidates[i] = candidates[len(candidates)-1]
candidates[len(candidates)-1] = nil
r.available[scheme] = candidates[:len(candidates)-1]
break
@ -54,26 +53,37 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
}
}
func (r *Registry) Get(scheme string, less func(i, j interface{}) bool) interface{} {
// Get returns an item for a schema compatible with the given scheme.
// If any item satisfies preferred, that has precedence over other items.
func (r *Registry) Get(scheme string, preferred func(interface{}) bool) interface{} {
r.mut.Lock()
defer r.mut.Unlock()
candidates := make([]interface{}, 0)
var (
best interface{}
bestPref bool
bestScheme string
)
for availableScheme, items := range r.available {
// quic:// should be considered ok for both quic4:// and quic6://
if strings.HasPrefix(scheme, availableScheme) {
candidates = append(candidates, items...)
if !strings.HasPrefix(scheme, availableScheme) {
continue
}
for _, item := range items {
better := best == nil
pref := preferred(item)
if !better {
// In case of a tie, prefer "quic" to "quic[46]" etc.
better = pref &&
(!bestPref || len(availableScheme) < len(bestScheme))
}
if !better {
continue
}
best, bestPref, bestScheme = item, pref, availableScheme
}
}
if len(candidates) == 0 {
return nil
}
sort.Slice(candidates, func(i, j int) bool {
return less(candidates[i], candidates[j])
})
return candidates[0]
return best
}
func Register(scheme string, item interface{}) {
@ -84,6 +94,6 @@ func Unregister(scheme string, item interface{}) {
Default.Unregister(scheme, item)
}
func Get(scheme string, less func(i, j interface{}) bool) interface{} {
return Default.Get(scheme, less)
func Get(scheme string, preferred func(interface{}) bool) interface{} {
return Default.Get(scheme, preferred)
}

View File

@ -7,13 +7,18 @@
package registry
import (
"net"
"testing"
)
func TestRegistry(t *testing.T) {
r := New()
if res := r.Get("int", intLess); res != nil {
want := func(i int) func(interface{}) bool {
return func(x interface{}) bool { return x.(int) == i }
}
if res := r.Get("int", want(1)); res != nil {
t.Error("unexpected")
}
@ -24,30 +29,28 @@ func TestRegistry(t *testing.T) {
r.Register("int6", 6)
r.Register("int6", 66)
if res := r.Get("int", intLess).(int); res != 1 {
if res := r.Get("int", want(1)).(int); res != 1 {
t.Error("unexpected", res)
}
// int is prefix of int4, so returns 1
if res := r.Get("int4", intLess).(int); res != 1 {
if res := r.Get("int4", want(1)).(int); res != 1 {
t.Error("unexpected", res)
}
r.Unregister("int", 1)
// Check that falls through to 11
if res := r.Get("int", intLess).(int); res != 11 {
if res := r.Get("int", want(1)).(int); res == 1 {
t.Error("unexpected", res)
}
// 6 is smaller than 11 available in int.
if res := r.Get("int6", intLess).(int); res != 6 {
if res := r.Get("int6", want(6)).(int); res != 6 {
t.Error("unexpected", res)
}
// Unregister 11, int should be impossible to find
r.Unregister("int", 11)
if res := r.Get("int", intLess); res != nil {
if res := r.Get("int", want(11)); res != nil {
t.Error("unexpected")
}
@ -59,13 +62,35 @@ func TestRegistry(t *testing.T) {
r.Register("int", 1)
r.Unregister("int", 1)
if res := r.Get("int4", intLess).(int); res != 1 {
if res := r.Get("int4", want(1)).(int); res != 1 {
t.Error("unexpected", res)
}
}
func intLess(i, j interface{}) bool {
iInt := i.(int)
jInt := j.(int)
return iInt < jInt
func TestShortSchemeFirst(t *testing.T) {
r := New()
r.Register("foo", 0)
r.Register("foobar", 1)
// If we don't care about the value, we should get the one with "foo".
res := r.Get("foo", func(interface{}) bool { return false })
if res != 0 {
t.Error("unexpected", res)
}
}
func BenchmarkGet(b *testing.B) {
r := New()
for _, addr := range []string{"192.168.1.1", "172.1.1.1", "10.1.1.1"} {
r.Register("tcp", &net.TCPAddr{IP: net.ParseIP(addr)})
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
r.Get("tcp", func(x interface{}) bool {
return x.(*net.TCPAddr).IP.IsUnspecified()
})
}
}

View File

@ -13,7 +13,6 @@ import (
"os"
"time"
"github.com/syncthing/syncthing/lib/util"
"golang.org/x/net/proxy"
)
@ -61,11 +60,6 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error)
return proxy.SOCKS5("tcp", u.Host, auth, forward)
}
// Sort available addresses, preferring unspecified address.
func tcpAddrLess(i interface{}, j interface{}) bool {
return util.AddressUnspecifiedLess(i.(*net.TCPAddr), j.(*net.TCPAddr))
}
// dialerConn is needed because proxy dialed connections have RemoteAddr() pointing at the proxy,
// which then screws up various things such as IsLAN checks, and "let's populate the relay invitation address from
// existing connection" shenanigans.

View File

@ -110,7 +110,9 @@ func DialContextReusePort(ctx context.Context, network, addr string) (net.Conn,
return DialContext(ctx, network, addr)
}
localAddrInterface := registry.Get(network, tcpAddrLess)
localAddrInterface := registry.Get(network, func(addr interface{}) bool {
return addr.(*net.TCPAddr).IP.IsUnspecified()
})
if localAddrInterface == nil {
// Nothing listening, nothing to reuse.
return DialContext(ctx, network, addr)

View File

@ -9,7 +9,6 @@ package util
import (
"context"
"fmt"
"net"
"net/url"
"reflect"
"strconv"
@ -231,25 +230,6 @@ func Address(network, host string) string {
return u.String()
}
// AddressUnspecifiedLess is a comparator function preferring least specific network address (most widely listening,
// namely preferring 0.0.0.0 over some IP), if both IPs are equal, it prefers the less restrictive network (prefers tcp
// over tcp4)
func AddressUnspecifiedLess(a, b net.Addr) bool {
aIsUnspecified := false
bIsUnspecified := false
if host, _, err := net.SplitHostPort(a.String()); err == nil {
aIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
}
if host, _, err := net.SplitHostPort(b.String()); err == nil {
bIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
}
if aIsUnspecified == bIsUnspecified {
return len(a.Network()) < len(b.Network())
}
return aIsUnspecified
}
func CallWithContext(ctx context.Context, fn func() error) error {
var err error
done := make(chan struct{})

View File

@ -225,50 +225,6 @@ func TestCopyMatching(t *testing.T) {
}
}
type mockedAddr struct {
network string
addr string
}
func (a mockedAddr) Network() string {
return a.network
}
func (a mockedAddr) String() string {
return a.addr
}
func TestInspecifiedAddressLess(t *testing.T) {
cases := []struct {
netA string
addrA string
netB string
addrB string
}{
// B is assumed the winner.
{"tcp", "127.0.0.1:1234", "tcp", ":1235"},
{"tcp", "127.0.0.1:1234", "tcp", "0.0.0.0:1235"},
{"tcp4", "0.0.0.0:1234", "tcp", "0.0.0.0:1235"}, // tcp4 on the first one
}
for i, testCase := range cases {
addrs := []mockedAddr{
{testCase.netA, testCase.addrA},
{testCase.netB, testCase.addrB},
}
if AddressUnspecifiedLess(addrs[0], addrs[1]) {
t.Error(i, "unexpected")
}
if !AddressUnspecifiedLess(addrs[1], addrs[0]) {
t.Error(i, "unexpected")
}
if AddressUnspecifiedLess(addrs[0], addrs[0]) || AddressUnspecifiedLess(addrs[1], addrs[1]) {
t.Error(i, "unexpected")
}
}
}
func TestFillNil(t *testing.T) {
type A struct {
Slice []int