lib: Ensure timely service termination (fixes #5860) (#5863)

This commit is contained in:
Simon Frei 2019-07-19 19:40:40 +02:00 committed by Jakob Borg
parent 1cb55904bc
commit 4d3432af3e
8 changed files with 174 additions and 150 deletions

View File

@ -8,7 +8,6 @@ package beacon
import (
"net"
stdsync "sync"
"github.com/thejerf/suture"
)
@ -24,21 +23,3 @@ type Interface interface {
Recv() ([]byte, net.Addr)
Error() error
}
type errorHolder struct {
err error
mut stdsync.Mutex // uses stdlib sync as I want this to be trivially embeddable, and there is no risk of blocking
}
func (e *errorHolder) setError(err error) {
e.mut.Lock()
e.err = err
e.mut.Unlock()
}
func (e *errorHolder) Error() error {
e.mut.Lock()
err := e.err
e.mut.Unlock()
return err
}

View File

@ -11,8 +11,9 @@ import (
"net"
"time"
"github.com/syncthing/syncthing/lib/sync"
"github.com/thejerf/suture"
"github.com/syncthing/syncthing/lib/util"
)
type Broadcast struct {
@ -44,16 +45,16 @@ func NewBroadcast(port int) *Broadcast {
}
b.br = &broadcastReader{
port: port,
outbox: b.outbox,
connMut: sync.NewMutex(),
port: port,
outbox: b.outbox,
}
b.br.ServiceWithError = util.AsServiceWithError(b.br.serve)
b.Add(b.br)
b.bw = &broadcastWriter{
port: port,
inbox: b.inbox,
connMut: sync.NewMutex(),
port: port,
inbox: b.inbox,
}
b.bw.ServiceWithError = util.AsServiceWithError(b.bw.serve)
b.Add(b.bw)
return b
@ -76,34 +77,42 @@ func (b *Broadcast) Error() error {
}
type broadcastWriter struct {
port int
inbox chan []byte
conn *net.UDPConn
connMut sync.Mutex
errorHolder
util.ServiceWithError
port int
inbox chan []byte
}
func (w *broadcastWriter) Serve() {
func (w *broadcastWriter) serve(stop chan struct{}) error {
l.Debugln(w, "starting")
defer l.Debugln(w, "stopping")
conn, err := net.ListenUDP("udp4", nil)
if err != nil {
l.Debugln(err)
w.setError(err)
return
return err
}
defer conn.Close()
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-stop:
case <-done:
}
conn.Close()
}()
w.connMut.Lock()
w.conn = conn
w.connMut.Unlock()
for {
var bs []byte
select {
case bs = <-w.inbox:
case <-stop:
return nil
}
for bs := range w.inbox {
addrs, err := net.InterfaceAddrs()
if err != nil {
l.Debugln(err)
w.setError(err)
w.SetError(err)
continue
}
@ -134,14 +143,13 @@ func (w *broadcastWriter) Serve() {
// Write timeouts should not happen. We treat it as a fatal
// error on the socket.
l.Debugln(err)
w.setError(err)
return
return err
}
if err != nil {
// Some other error that we don't expect. Debug and continue.
l.Debugln(err)
w.setError(err)
w.SetError(err)
continue
}
@ -150,57 +158,49 @@ func (w *broadcastWriter) Serve() {
}
if success > 0 {
w.setError(nil)
w.SetError(nil)
}
}
}
func (w *broadcastWriter) Stop() {
w.connMut.Lock()
if w.conn != nil {
w.conn.Close()
}
w.connMut.Unlock()
}
func (w *broadcastWriter) String() string {
return fmt.Sprintf("broadcastWriter@%p", w)
}
type broadcastReader struct {
port int
outbox chan recv
conn *net.UDPConn
connMut sync.Mutex
errorHolder
util.ServiceWithError
port int
outbox chan recv
}
func (r *broadcastReader) Serve() {
func (r *broadcastReader) serve(stop chan struct{}) error {
l.Debugln(r, "starting")
defer l.Debugln(r, "stopping")
conn, err := net.ListenUDP("udp4", &net.UDPAddr{Port: r.port})
if err != nil {
l.Debugln(err)
r.setError(err)
return
return err
}
defer conn.Close()
r.connMut.Lock()
r.conn = conn
r.connMut.Unlock()
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-stop:
case <-done:
}
conn.Close()
}()
bs := make([]byte, 65536)
for {
n, addr, err := conn.ReadFrom(bs)
if err != nil {
l.Debugln(err)
r.setError(err)
return
return err
}
r.setError(nil)
r.SetError(nil)
l.Debugf("recv %d bytes from %s", n, addr)
@ -208,19 +208,12 @@ func (r *broadcastReader) Serve() {
copy(c, bs)
select {
case r.outbox <- recv{c, addr}:
case <-stop:
return nil
default:
l.Debugln("dropping message")
}
}
}
func (r *broadcastReader) Stop() {
r.connMut.Lock()
if r.conn != nil {
r.conn.Close()
}
r.connMut.Unlock()
}
func (r *broadcastReader) String() string {

View File

@ -48,14 +48,14 @@ func NewMulticast(addr string) *Multicast {
addr: addr,
outbox: m.outbox,
}
m.mr.Service = util.AsService(m.mr.serve)
m.mr.ServiceWithError = util.AsServiceWithError(m.mr.serve)
m.Add(m.mr)
m.mw = &multicastWriter{
addr: addr,
inbox: m.inbox,
}
m.mw.Service = util.AsService(m.mw.serve)
m.mw.ServiceWithError = util.AsServiceWithError(m.mw.serve)
m.Add(m.mw)
return m
@ -78,29 +78,35 @@ func (m *Multicast) Error() error {
}
type multicastWriter struct {
suture.Service
util.ServiceWithError
addr string
inbox <-chan []byte
errorHolder
}
func (w *multicastWriter) serve(stop chan struct{}) {
func (w *multicastWriter) serve(stop chan struct{}) error {
l.Debugln(w, "starting")
defer l.Debugln(w, "stopping")
gaddr, err := net.ResolveUDPAddr("udp6", w.addr)
if err != nil {
l.Debugln(err)
w.setError(err)
return
return err
}
conn, err := net.ListenPacket("udp6", ":0")
if err != nil {
l.Debugln(err)
w.setError(err)
return
return err
}
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-stop:
case <-done:
}
conn.Close()
}()
pconn := ipv6.NewPacketConn(conn)
@ -113,14 +119,13 @@ func (w *multicastWriter) serve(stop chan struct{}) {
select {
case bs = <-w.inbox:
case <-stop:
return
return nil
}
intfs, err := net.Interfaces()
if err != nil {
l.Debugln(err)
w.setError(err)
return
return err
}
success := 0
@ -132,7 +137,7 @@ func (w *multicastWriter) serve(stop chan struct{}) {
if err != nil {
l.Debugln(err, "on write to", gaddr, intf.Name)
w.setError(err)
w.SetError(err)
continue
}
@ -142,16 +147,13 @@ func (w *multicastWriter) serve(stop chan struct{}) {
select {
case <-stop:
return
return nil
default:
}
}
if success > 0 {
w.setError(nil)
} else {
l.Debugln(err)
w.setError(err)
w.SetError(nil)
}
}
}
@ -161,35 +163,40 @@ func (w *multicastWriter) String() string {
}
type multicastReader struct {
suture.Service
util.ServiceWithError
addr string
outbox chan<- recv
errorHolder
}
func (r *multicastReader) serve(stop chan struct{}) {
func (r *multicastReader) serve(stop chan struct{}) error {
l.Debugln(r, "starting")
defer l.Debugln(r, "stopping")
gaddr, err := net.ResolveUDPAddr("udp6", r.addr)
if err != nil {
l.Debugln(err)
r.setError(err)
return
return err
}
conn, err := net.ListenPacket("udp6", r.addr)
if err != nil {
l.Debugln(err)
r.setError(err)
return
return err
}
done := make(chan struct{})
defer close(done)
go func() {
select {
case <-stop:
case <-done:
}
conn.Close()
}()
intfs, err := net.Interfaces()
if err != nil {
l.Debugln(err)
r.setError(err)
return
return err
}
pconn := ipv6.NewPacketConn(conn)
@ -206,16 +213,20 @@ func (r *multicastReader) serve(stop chan struct{}) {
if joined == 0 {
l.Debugln("no multicast interfaces available")
r.setError(errors.New("no multicast interfaces available"))
return
return errors.New("no multicast interfaces available")
}
bs := make([]byte, 65536)
for {
select {
case <-stop:
return nil
default:
}
n, _, addr, err := pconn.ReadFrom(bs)
if err != nil {
l.Debugln(err)
r.setError(err)
r.SetError(err)
continue
}
l.Debugf("recv %d bytes from %s", n, addr)
@ -224,8 +235,6 @@ func (r *multicastReader) serve(stop chan struct{}) {
copy(c, bs)
select {
case r.outbox <- recv{c, addr}:
case <-stop:
return
default:
l.Debugln("dropping message")
}

View File

@ -19,7 +19,7 @@ func Register(provider DiscoverFunc) {
providers = append(providers, provider)
}
func discoverAll(renewal, timeout time.Duration) map[string]Device {
func discoverAll(renewal, timeout time.Duration, stop chan struct{}) map[string]Device {
wg := &sync.WaitGroup{}
wg.Add(len(providers))
@ -28,20 +28,32 @@ func discoverAll(renewal, timeout time.Duration) map[string]Device {
for _, discoverFunc := range providers {
go func(f DiscoverFunc) {
defer wg.Done()
for _, dev := range f(renewal, timeout) {
c <- dev
select {
case c <- dev:
case <-stop:
return
}
}
wg.Done()
}(discoverFunc)
}
nats := make(map[string]Device)
go func() {
for dev := range c {
nats[dev.ID()] = dev
defer close(done)
for {
select {
case dev, ok := <-c:
if !ok {
return
}
nats[dev.ID()] = dev
case <-stop:
return
}
}
close(done)
}()
wg.Wait()

View File

@ -14,17 +14,21 @@ import (
stdsync "sync"
"time"
"github.com/thejerf/suture"
"github.com/syncthing/syncthing/lib/config"
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/util"
)
// Service runs a loop for discovery of IGDs (Internet Gateway Devices) and
// setup/renewal of a port mapping.
type Service struct {
id protocol.DeviceID
cfg config.Wrapper
stop chan struct{}
suture.Service
id protocol.DeviceID
cfg config.Wrapper
mappings []*Mapping
timer *time.Timer
@ -32,27 +36,28 @@ type Service struct {
}
func NewService(id protocol.DeviceID, cfg config.Wrapper) *Service {
return &Service{
s := &Service{
id: id,
cfg: cfg,
timer: time.NewTimer(0),
mut: sync.NewRWMutex(),
}
s.Service = util.AsService(s.serve)
return s
}
func (s *Service) Serve() {
func (s *Service) serve(stop chan struct{}) {
announce := stdsync.Once{}
s.mut.Lock()
s.timer.Reset(0)
s.stop = make(chan struct{})
s.mut.Unlock()
for {
select {
case <-s.timer.C:
if found := s.process(); found != -1 {
if found := s.process(stop); found != -1 {
announce.Do(func() {
suffix := "s"
if found == 1 {
@ -61,7 +66,7 @@ func (s *Service) Serve() {
l.Infoln("Detected", found, "NAT service"+suffix)
})
}
case <-s.stop:
case <-stop:
s.timer.Stop()
s.mut.RLock()
for _, mapping := range s.mappings {
@ -73,7 +78,7 @@ func (s *Service) Serve() {
}
}
func (s *Service) process() int {
func (s *Service) process(stop chan struct{}) int {
// toRenew are mappings which are due for renewal
// toUpdate are the remaining mappings, which will only be updated if one of
// the old IGDs has gone away, or a new IGD has appeared, but only if we
@ -115,25 +120,19 @@ func (s *Service) process() int {
return -1
}
nats := discoverAll(time.Duration(s.cfg.Options().NATRenewalM)*time.Minute, time.Duration(s.cfg.Options().NATTimeoutS)*time.Second)
nats := discoverAll(time.Duration(s.cfg.Options().NATRenewalM)*time.Minute, time.Duration(s.cfg.Options().NATTimeoutS)*time.Second, stop)
for _, mapping := range toRenew {
s.updateMapping(mapping, nats, true)
s.updateMapping(mapping, nats, true, stop)
}
for _, mapping := range toUpdate {
s.updateMapping(mapping, nats, false)
s.updateMapping(mapping, nats, false, stop)
}
return len(nats)
}
func (s *Service) Stop() {
s.mut.RLock()
close(s.stop)
s.mut.RUnlock()
}
func (s *Service) NewMapping(protocol Protocol, ip net.IP, port int) *Mapping {
mapping := &Mapping{
protocol: protocol,
@ -178,17 +177,17 @@ func (s *Service) RemoveMapping(mapping *Mapping) {
// acquire mappings for natds which the mapping was unaware of before.
// Optionally takes renew flag which indicates whether or not we should renew
// mappings with existing natds
func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew bool) {
func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew bool, stop chan struct{}) {
var added, removed []Address
renewalTime := time.Duration(s.cfg.Options().NATRenewalM) * time.Minute
mapping.expires = time.Now().Add(renewalTime)
newAdded, newRemoved := s.verifyExistingMappings(mapping, nats, renew)
newAdded, newRemoved := s.verifyExistingMappings(mapping, nats, renew, stop)
added = append(added, newAdded...)
removed = append(removed, newRemoved...)
newAdded, newRemoved = s.acquireNewMappings(mapping, nats)
newAdded, newRemoved = s.acquireNewMappings(mapping, nats, stop)
added = append(added, newAdded...)
removed = append(removed, newRemoved...)
@ -197,12 +196,18 @@ func (s *Service) updateMapping(mapping *Mapping, nats map[string]Device, renew
}
}
func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Device, renew bool) ([]Address, []Address) {
func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Device, renew bool, stop chan struct{}) ([]Address, []Address) {
var added, removed []Address
leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute
for id, address := range mapping.addressMap() {
select {
case <-stop:
return nil, nil
default:
}
// Delete addresses for NATDevice's that do not exist anymore
nat, ok := nats[id]
if !ok {
@ -242,13 +247,19 @@ func (s *Service) verifyExistingMappings(mapping *Mapping, nats map[string]Devic
return added, removed
}
func (s *Service) acquireNewMappings(mapping *Mapping, nats map[string]Device) ([]Address, []Address) {
func (s *Service) acquireNewMappings(mapping *Mapping, nats map[string]Device, stop chan struct{}) ([]Address, []Address) {
var added, removed []Address
leaseTime := time.Duration(s.cfg.Options().NATLeaseM) * time.Minute
addrMap := mapping.addressMap()
for id, nat := range nats {
select {
case <-stop:
return nil, nil
default:
}
if _, ok := addrMap[id]; ok {
continue
}

View File

@ -69,15 +69,7 @@ func (c *dynamicClient) serve(stop chan struct{}) error {
addrs = append(addrs, ruri.String())
}
defer func() {
c.mut.RLock()
if c.client != nil {
c.client.Stop()
}
c.mut.RUnlock()
}()
for _, addr := range relayAddressesOrder(addrs) {
for _, addr := range relayAddressesOrder(addrs, stop) {
select {
case <-stop:
l.Debugln(c, "stopping")
@ -104,6 +96,15 @@ func (c *dynamicClient) serve(stop chan struct{}) error {
return fmt.Errorf("could not find a connectable relay")
}
func (c *dynamicClient) Stop() {
c.mut.RLock()
if c.client != nil {
c.client.Stop()
}
c.mut.RUnlock()
c.commonClient.Stop()
}
func (c *dynamicClient) Error() error {
c.mut.RLock()
defer c.mut.RUnlock()
@ -147,7 +148,7 @@ type dynamicAnnouncement struct {
// the closest 50ms, and puts them in buckets of 50ms latency ranges. Then
// shuffles each bucket, and returns all addresses starting with the ones from
// the lowest latency bucket, ending with the highest latency buceket.
func relayAddressesOrder(input []string) []string {
func relayAddressesOrder(input []string, stop chan struct{}) []string {
buckets := make(map[int][]string)
for _, relay := range input {
@ -159,6 +160,12 @@ func relayAddressesOrder(input []string) []string {
id := int(latency/time.Millisecond) / 50
buckets[id] = append(buckets[id], relay)
select {
case <-stop:
return nil
default:
}
}
var ids []int

View File

@ -109,8 +109,8 @@ func New(cfg config.Wrapper, subscriber Subscriber, conn net.PacketConn) (*Servi
}
func (s *Service) Stop() {
s.Service.Stop()
_ = s.stunConn.Close()
s.Service.Stop()
}
func (s *Service) serve(stop chan struct{}) {
@ -163,7 +163,11 @@ func (s *Service) serve(stop chan struct{}) {
// We failed to contact all provided stun servers or the nat is not punchable.
// Chillout for a while.
time.Sleep(stunRetryInterval)
select {
case <-time.After(stunRetryInterval):
case <-stop:
return
}
}
}

View File

@ -187,6 +187,7 @@ func AsService(fn func(stop chan struct{})) suture.Service {
type ServiceWithError interface {
suture.Service
Error() error
SetError(error)
}
// AsServiceWithError does the same as AsService, except that it keeps track
@ -244,3 +245,9 @@ func (s *service) Error() error {
defer s.mut.Unlock()
return s.err
}
func (s *service) SetError(err error) {
s.mut.Lock()
s.err = err
s.mut.Unlock()
}