all: Refactor relay invitations (#7646)

This commit is contained in:
Simon Frei 2021-05-10 22:25:43 +02:00 committed by GitHub
parent 6e662dc9fc
commit 713527facf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 38 additions and 44 deletions

View File

@ -57,7 +57,7 @@ func main() {
if join {
log.Println("Creating client")
relay, err := client.NewClient(uri, []tls.Certificate{cert}, nil, 10*time.Second)
relay, err := client.NewClient(uri, []tls.Certificate{cert}, 10*time.Second)
if err != nil {
log.Fatal(err)
}

View File

@ -44,38 +44,39 @@ type relayListener struct {
}
func (t *relayListener) serve(ctx context.Context) error {
clnt, err := client.NewClient(t.uri, t.tlsCfg.Certificates, nil, 10*time.Second)
clnt, err := client.NewClient(t.uri, t.tlsCfg.Certificates, 10*time.Second)
if err != nil {
l.Infoln("Listen (BEP/relay):", err)
return err
}
invitations := clnt.Invitations()
t.mut.Lock()
t.client = clnt
go clnt.Serve(ctx)
t.mut.Unlock()
// Start with nil, so that we send a addresses changed notification as soon as we connect somewhere.
var oldURI *url.URL
l.Infof("Relay listener (%v) starting", t)
defer l.Infof("Relay listener (%v) shutting down", t)
defer t.clearAddresses(t)
invitationCtx, cancel := context.WithCancel(ctx)
defer cancel()
go t.handleInvitations(invitationCtx, clnt)
return clnt.Serve(ctx)
}
func (t *relayListener) handleInvitations(ctx context.Context, clnt client.RelayClient) {
invitations := clnt.Invitations()
// Start with nil, so that we send a addresses changed notification as soon as we connect somewhere.
var oldURI *url.URL
for {
select {
case inv, ok := <-invitations:
if !ok {
if err := clnt.Error(); err != nil {
l.Infoln("Listen (BEP/relay):", err)
}
return err
}
case inv := <-invitations:
conn, err := client.JoinSession(ctx, inv)
if err != nil {
if errors.Cause(err) != context.Canceled {
if !errors.Is(err, context.Canceled) {
l.Infoln("Listen (BEP/relay): joining session:", err)
}
continue
@ -119,7 +120,7 @@ func (t *relayListener) serve(ctx context.Context) error {
}
case <-ctx.Done():
return ctx.Err()
return
}
}
}

View File

@ -35,21 +35,21 @@ type RelayClient interface {
URI() *url.URL
}
func NewClient(uri *url.URL, certs []tls.Certificate, invitations chan protocol.SessionInvitation, timeout time.Duration) (RelayClient, error) {
func NewClient(uri *url.URL, certs []tls.Certificate, timeout time.Duration) (RelayClient, error) {
factory, ok := supportedSchemes[uri.Scheme]
if !ok {
return nil, fmt.Errorf("unsupported scheme: %s", uri.Scheme)
}
invitations := make(chan protocol.SessionInvitation)
return factory(uri, certs, invitations, timeout), nil
}
type commonClient struct {
svcutil.ServiceWithError
invitations chan protocol.SessionInvitation
closeInvitationsOnFinish bool
mut sync.RWMutex
invitations chan protocol.SessionInvitation
mut sync.RWMutex
}
func newCommonClient(invitations chan protocol.SessionInvitation, serve func(context.Context) error, creator string) commonClient {
@ -57,26 +57,10 @@ func newCommonClient(invitations chan protocol.SessionInvitation, serve func(con
invitations: invitations,
mut: sync.NewRWMutex(),
}
newServe := func(ctx context.Context) error {
defer c.cleanup()
return serve(ctx)
}
c.ServiceWithError = svcutil.AsService(newServe, creator)
if c.invitations == nil {
c.closeInvitationsOnFinish = true
c.invitations = make(chan protocol.SessionInvitation)
}
c.ServiceWithError = svcutil.AsService(serve, creator)
return c
}
func (c *commonClient) cleanup() {
c.mut.Lock()
if c.closeInvitationsOnFinish {
close(c.invitations)
}
c.mut.Unlock()
}
func (c *commonClient) Invitations() chan protocol.SessionInvitation {
c.mut.RLock()
defer c.mut.RUnlock()

View File

@ -114,16 +114,20 @@ func JoinSession(ctx context.Context, invitation protocol.SessionInvitation) (ne
func TestRelay(ctx context.Context, uri *url.URL, certs []tls.Certificate, sleep, timeout time.Duration, times int) error {
id := syncthingprotocol.NewDeviceID(certs[0].Certificate[0])
invs := make(chan protocol.SessionInvitation, 1)
c, err := NewClient(uri, certs, invs, timeout)
c, err := NewClient(uri, certs, timeout)
if err != nil {
close(invs)
return fmt.Errorf("creating client: %w", err)
}
ctx, cancel := context.WithCancel(context.Background())
go c.Serve(ctx)
go func() {
c.Serve(ctx)
close(invs)
for {
select {
case <-c.Invitations():
case <-ctx.Done():
return
}
}
}()
defer cancel()

View File

@ -98,7 +98,12 @@ func (c *staticClient) serve(ctx context.Context) error {
if len(ip) == 0 || ip.IsUnspecified() {
msg.Address = remoteIPBytes(c.conn)
}
c.invitations <- msg
select {
case c.invitations <- msg:
case <-ctx.Done():
l.Debugln(c, "stopping")
return ctx.Err()
}
case protocol.RelayFull:
l.Infof("Disconnected from relay %s due to it becoming full.", c.uri)