diff --git a/cmd/strelaysrv/testutil/main.go b/cmd/strelaysrv/testutil/main.go index fdbb971f9..15969ef4c 100644 --- a/cmd/strelaysrv/testutil/main.go +++ b/cmd/strelaysrv/testutil/main.go @@ -57,7 +57,7 @@ func main() { if join { log.Println("Creating client") - relay, err := client.NewClient(uri, []tls.Certificate{cert}, nil, 10*time.Second) + relay, err := client.NewClient(uri, []tls.Certificate{cert}, 10*time.Second) if err != nil { log.Fatal(err) } diff --git a/lib/connections/relay_listen.go b/lib/connections/relay_listen.go index efce199d8..4b93d2c92 100644 --- a/lib/connections/relay_listen.go +++ b/lib/connections/relay_listen.go @@ -44,38 +44,39 @@ type relayListener struct { } func (t *relayListener) serve(ctx context.Context) error { - clnt, err := client.NewClient(t.uri, t.tlsCfg.Certificates, nil, 10*time.Second) + clnt, err := client.NewClient(t.uri, t.tlsCfg.Certificates, 10*time.Second) if err != nil { l.Infoln("Listen (BEP/relay):", err) return err } - invitations := clnt.Invitations() t.mut.Lock() t.client = clnt - go clnt.Serve(ctx) t.mut.Unlock() - // Start with nil, so that we send a addresses changed notification as soon as we connect somewhere. - var oldURI *url.URL - l.Infof("Relay listener (%v) starting", t) defer l.Infof("Relay listener (%v) shutting down", t) defer t.clearAddresses(t) + invitationCtx, cancel := context.WithCancel(ctx) + defer cancel() + go t.handleInvitations(invitationCtx, clnt) + + return clnt.Serve(ctx) +} + +func (t *relayListener) handleInvitations(ctx context.Context, clnt client.RelayClient) { + invitations := clnt.Invitations() + + // Start with nil, so that we send a addresses changed notification as soon as we connect somewhere. + var oldURI *url.URL + for { select { - case inv, ok := <-invitations: - if !ok { - if err := clnt.Error(); err != nil { - l.Infoln("Listen (BEP/relay):", err) - } - return err - } - + case inv := <-invitations: conn, err := client.JoinSession(ctx, inv) if err != nil { - if errors.Cause(err) != context.Canceled { + if !errors.Is(err, context.Canceled) { l.Infoln("Listen (BEP/relay): joining session:", err) } continue @@ -119,7 +120,7 @@ func (t *relayListener) serve(ctx context.Context) error { } case <-ctx.Done(): - return ctx.Err() + return } } } diff --git a/lib/relay/client/client.go b/lib/relay/client/client.go index 50d38a962..f92a4cd95 100644 --- a/lib/relay/client/client.go +++ b/lib/relay/client/client.go @@ -35,21 +35,21 @@ type RelayClient interface { URI() *url.URL } -func NewClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation, timeout time.Duration) (RelayClient, error) { +func NewClient(uri *url.URL, certs []tls.Certificate, timeout time.Duration) (RelayClient, error) { factory, ok := supportedSchemes[uri.Scheme] if !ok { return nil, fmt.Errorf("unsupported scheme: %s", uri.Scheme) } + invitations := make(chan protocol.SessionInvitation) return factory(uri, certs, invitations, timeout), nil } type commonClient struct { svcutil.ServiceWithError - invitations chan protocol.SessionInvitation - closeInvitationsOnFinish bool - mut sync.RWMutex + invitations chan protocol.SessionInvitation + mut sync.RWMutex } func newCommonClient(invitations chan protocol.SessionInvitation, serve func(context.Context) error, creator string) commonClient { @@ -57,26 +57,10 @@ func newCommonClient(invitations chan protocol.SessionInvitation, serve func(con invitations: invitations, mut: sync.NewRWMutex(), } - newServe := func(ctx context.Context) error { - defer c.cleanup() - return serve(ctx) - } - c.ServiceWithError = svcutil.AsService(newServe, creator) - if c.invitations == nil { - c.closeInvitationsOnFinish = true - c.invitations = make(chan protocol.SessionInvitation) - } + c.ServiceWithError = svcutil.AsService(serve, creator) return c } -func (c *commonClient) cleanup() { - c.mut.Lock() - if c.closeInvitationsOnFinish { - close(c.invitations) - } - c.mut.Unlock() -} - func (c *commonClient) Invitations() chan protocol.SessionInvitation { c.mut.RLock() defer c.mut.RUnlock() diff --git a/lib/relay/client/methods.go b/lib/relay/client/methods.go index 297e5e221..cbc8b0f32 100644 --- a/lib/relay/client/methods.go +++ b/lib/relay/client/methods.go @@ -114,16 +114,20 @@ func JoinSession(ctx context.Context, invitation protocol.SessionInvitation) (ne func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) error { id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0]) - invs := make(chan protocol.SessionInvitation, 1) - c, err := NewClient(uri, certs, invs, timeout) + c, err := NewClient(uri, certs, timeout) if err != nil { - close(invs) return fmt.Errorf("creating client: %w", err) } ctx, cancel := context.WithCancel(context.Background()) + go c.Serve(ctx) go func() { - c.Serve(ctx) - close(invs) + for { + select { + case <-c.Invitations(): + case <-ctx.Done(): + return + } + } }() defer cancel() diff --git a/lib/relay/client/static.go b/lib/relay/client/static.go index a26eb03b4..4e6ec94a8 100644 --- a/lib/relay/client/static.go +++ b/lib/relay/client/static.go @@ -98,7 +98,12 @@ func (c *staticClient) serve(ctx context.Context) error { if len(ip) == 0 || ip.IsUnspecified() { msg.Address = remoteIPBytes(c.conn) } - c.invitations <- msg + select { + case c.invitations <- msg: + case <-ctx.Done(): + l.Debugln(c, "stopping") + return ctx.Err() + } case protocol.RelayFull: l.Infof("Disconnected from relay %s due to it becoming full.", c.uri)