From c78fa42f3199b4320085bc53e08b06ec8a0c9ea5 Mon Sep 17 00:00:00 2001 From: Simon Frei Date: Fri, 25 Jun 2021 11:38:04 +0200 Subject: [PATCH] lib/connections: Dial devices in parallel (#7783) --- lib/connections/service.go | 51 ++++-- lib/model/bytesemaphore.go | 109 ------------- lib/model/folder.go | 16 +- lib/model/folder_recvenc.go | 3 +- lib/model/folder_recvonly.go | 3 +- lib/model/folder_sendonly.go | 3 +- lib/model/folder_sendrecv.go | 17 +- lib/model/model.go | 39 ++--- lib/model/model_test.go | 11 +- lib/util/semaphore.go | 148 ++++++++++++++++++ .../semaphore_test.go} | 46 +++--- 11 files changed, 258 insertions(+), 188 deletions(-) delete mode 100644 lib/model/bytesemaphore.go create mode 100644 lib/util/semaphore.go rename lib/{model/bytesemaphore_test.go => util/semaphore_test.go} (82%) diff --git a/lib/connections/service.go b/lib/connections/service.go index 4734e598a..8613626ce 100644 --- a/lib/connections/service.go +++ b/lib/connections/service.go @@ -66,6 +66,8 @@ const ( worstDialerPriority = math.MaxInt32 recentlySeenCutoff = 7 * 24 * time.Hour shortLivedConnectionThreshold = 5 * time.Second + dialMaxParallel = 64 + dialMaxParallelPerDevice = 8 ) // From go/src/crypto/tls/cipher_suites.go @@ -490,14 +492,40 @@ func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Con // Perform dials according to the queue, stopping when we've reached the // allowed additional number of connections (if limited). numConns := 0 - for _, entry := range queue { - if conn, ok := s.dialParallel(ctx, entry.id, entry.targets); ok { - s.conns <- conn - numConns++ - if allowAdditional > 0 && numConns >= allowAdditional { - break - } + var numConnsMut stdsync.Mutex + dialSemaphore := util.NewSemaphore(dialMaxParallel) + dialWG := new(stdsync.WaitGroup) + dialCtx, dialCancel := context.WithCancel(ctx) + defer func() { + dialWG.Wait() + dialCancel() + }() + for i := range queue { + select { + case <-dialCtx.Done(): + return + default: } + dialWG.Add(1) + go func(entry dialQueueEntry) { + defer dialWG.Done() + conn, ok := s.dialParallel(dialCtx, entry.id, entry.targets, dialSemaphore) + if !ok { + return + } + numConnsMut.Lock() + if allowAdditional == 0 || numConns < allowAdditional { + select { + case s.conns <- conn: + numConns++ + if allowAdditional > 0 && numConns >= allowAdditional { + dialCancel() + } + case <-dialCtx.Done(): + } + } + numConnsMut.Unlock() + }(queue[i]) } } @@ -959,7 +987,7 @@ func IsAllowedNetwork(host string, allowed []string) bool { return false } -func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) { +func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget, parentSema *util.Semaphore) (internalConn, bool) { // Group targets into buckets by priority dialTargetBuckets := make(map[int][]dialTarget, len(dialTargets)) for _, tgt := range dialTargets { @@ -975,13 +1003,19 @@ func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, // Sort the priorities so that we dial lowest first (which means highest...) sort.Ints(priorities) + sema := util.MultiSemaphore{util.NewSemaphore(dialMaxParallelPerDevice), parentSema} for _, prio := range priorities { tgts := dialTargetBuckets[prio] res := make(chan internalConn, len(tgts)) wg := stdsync.WaitGroup{} for _, tgt := range tgts { + sema.Take(1) wg.Add(1) go func(tgt dialTarget) { + defer func() { + wg.Done() + sema.Give(1) + }() conn, err := tgt.Dial(ctx) if err == nil { // Closes the connection on error @@ -994,7 +1028,6 @@ func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, l.Debugln("dialing", deviceID, tgt.uri, "success:", conn) res <- conn } - wg.Done() }(tgt) } diff --git a/lib/model/bytesemaphore.go b/lib/model/bytesemaphore.go deleted file mode 100644 index 92feecf9e..000000000 --- a/lib/model/bytesemaphore.go +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright (C) 2018 The Syncthing Authors. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at https://mozilla.org/MPL/2.0/. - -package model - -import ( - "context" - "sync" -) - -type byteSemaphore struct { - max int - available int - mut sync.Mutex - cond *sync.Cond -} - -func newByteSemaphore(max int) *byteSemaphore { - if max < 0 { - max = 0 - } - s := byteSemaphore{ - max: max, - available: max, - } - s.cond = sync.NewCond(&s.mut) - return &s -} - -func (s *byteSemaphore) takeWithContext(ctx context.Context, bytes int) error { - done := make(chan struct{}) - var err error - go func() { - err = s.takeInner(ctx, bytes) - close(done) - }() - select { - case <-done: - case <-ctx.Done(): - s.cond.Broadcast() - <-done - } - return err -} - -func (s *byteSemaphore) take(bytes int) { - _ = s.takeInner(context.Background(), bytes) -} - -func (s *byteSemaphore) takeInner(ctx context.Context, bytes int) error { - // Checking context for bytes <= s.available is required for testing and doesn't do any harm. - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - s.mut.Lock() - defer s.mut.Unlock() - if bytes > s.max { - bytes = s.max - } - for bytes > s.available { - s.cond.Wait() - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - if bytes > s.max { - bytes = s.max - } - } - s.available -= bytes - return nil -} - -func (s *byteSemaphore) give(bytes int) { - s.mut.Lock() - if bytes > s.max { - bytes = s.max - } - if s.available+bytes > s.max { - s.available = s.max - } else { - s.available += bytes - } - s.cond.Broadcast() - s.mut.Unlock() -} - -func (s *byteSemaphore) setCapacity(cap int) { - if cap < 0 { - cap = 0 - } - s.mut.Lock() - diff := cap - s.max - s.max = cap - s.available += diff - if s.available < 0 { - s.available = 0 - } else if s.available > s.max { - s.available = s.max - } - s.cond.Broadcast() - s.mut.Unlock() -} diff --git a/lib/model/folder.go b/lib/model/folder.go index 71e833026..159f92f67 100644 --- a/lib/model/folder.go +++ b/lib/model/folder.go @@ -38,7 +38,7 @@ type folder struct { stateTracker config.FolderConfiguration *stats.FolderStatisticsReference - ioLimiter *byteSemaphore + ioLimiter *util.Semaphore localFlags uint32 @@ -91,7 +91,7 @@ type puller interface { pull() (bool, error) // true when successful and should not be retried } -func newFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, evLogger events.Logger, ioLimiter *byteSemaphore, ver versioner.Versioner) folder { +func newFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, evLogger events.Logger, ioLimiter *util.Semaphore, ver versioner.Versioner) folder { f := folder{ stateTracker: newStateTracker(cfg.ID, evLogger), FolderConfiguration: cfg, @@ -375,10 +375,10 @@ func (f *folder) pull() (success bool, err error) { if f.Type != config.FolderTypeSendOnly { f.setState(FolderSyncWaiting) - if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil { + if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil { return true, err } - defer f.ioLimiter.give(1) + defer f.ioLimiter.Give(1) } startTime := time.Now() @@ -439,10 +439,10 @@ func (f *folder) scanSubdirs(subDirs []string) error { f.setState(FolderScanWaiting) defer f.setState(FolderIdle) - if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil { + if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil { return err } - defer f.ioLimiter.give(1) + defer f.ioLimiter.Give(1) for i := range subDirs { sub := osutil.NativeFilename(subDirs[i]) @@ -870,10 +870,10 @@ func (f *folder) versionCleanupTimerFired() { f.setState(FolderCleanWaiting) defer f.setState(FolderIdle) - if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil { + if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil { return } - defer f.ioLimiter.give(1) + defer f.ioLimiter.Give(1) f.setState(FolderCleaning) diff --git a/lib/model/folder_recvenc.go b/lib/model/folder_recvenc.go index 60031001b..2ddb6dde7 100644 --- a/lib/model/folder_recvenc.go +++ b/lib/model/folder_recvenc.go @@ -16,6 +16,7 @@ import ( "github.com/syncthing/syncthing/lib/fs" "github.com/syncthing/syncthing/lib/ignore" "github.com/syncthing/syncthing/lib/protocol" + "github.com/syncthing/syncthing/lib/util" "github.com/syncthing/syncthing/lib/versioner" ) @@ -27,7 +28,7 @@ type receiveEncryptedFolder struct { *sendReceiveFolder } -func newReceiveEncryptedFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service { +func newReceiveEncryptedFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service { return &receiveEncryptedFolder{newSendReceiveFolder(model, fset, ignores, cfg, ver, evLogger, ioLimiter).(*sendReceiveFolder)} } diff --git a/lib/model/folder_recvonly.go b/lib/model/folder_recvonly.go index 210308c5f..6351409bf 100644 --- a/lib/model/folder_recvonly.go +++ b/lib/model/folder_recvonly.go @@ -15,6 +15,7 @@ import ( "github.com/syncthing/syncthing/lib/events" "github.com/syncthing/syncthing/lib/ignore" "github.com/syncthing/syncthing/lib/protocol" + "github.com/syncthing/syncthing/lib/util" "github.com/syncthing/syncthing/lib/versioner" ) @@ -56,7 +57,7 @@ type receiveOnlyFolder struct { *sendReceiveFolder } -func newReceiveOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service { +func newReceiveOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service { sr := newSendReceiveFolder(model, fset, ignores, cfg, ver, evLogger, ioLimiter).(*sendReceiveFolder) sr.localFlags = protocol.FlagLocalReceiveOnly // gets propagated to the scanner, and set on locally changed files return &receiveOnlyFolder{sr} diff --git a/lib/model/folder_sendonly.go b/lib/model/folder_sendonly.go index 3b201885d..9b6104fff 100644 --- a/lib/model/folder_sendonly.go +++ b/lib/model/folder_sendonly.go @@ -12,6 +12,7 @@ import ( "github.com/syncthing/syncthing/lib/events" "github.com/syncthing/syncthing/lib/ignore" "github.com/syncthing/syncthing/lib/protocol" + "github.com/syncthing/syncthing/lib/util" "github.com/syncthing/syncthing/lib/versioner" ) @@ -23,7 +24,7 @@ type sendOnlyFolder struct { folder } -func newSendOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, _ versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service { +func newSendOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, _ versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service { f := &sendOnlyFolder{ folder: newFolder(model, fset, ignores, cfg, evLogger, ioLimiter, nil), } diff --git a/lib/model/folder_sendrecv.go b/lib/model/folder_sendrecv.go index 98de040a4..7d28790db 100644 --- a/lib/model/folder_sendrecv.go +++ b/lib/model/folder_sendrecv.go @@ -28,6 +28,7 @@ import ( "github.com/syncthing/syncthing/lib/scanner" "github.com/syncthing/syncthing/lib/sha256" "github.com/syncthing/syncthing/lib/sync" + "github.com/syncthing/syncthing/lib/util" "github.com/syncthing/syncthing/lib/versioner" "github.com/syncthing/syncthing/lib/weakhash" ) @@ -123,17 +124,17 @@ type sendReceiveFolder struct { queue *jobQueue blockPullReorderer blockPullReorderer - writeLimiter *byteSemaphore + writeLimiter *util.Semaphore tempPullErrors map[string]string // pull errors that might be just transient } -func newSendReceiveFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service { +func newSendReceiveFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service { f := &sendReceiveFolder{ folder: newFolder(model, fset, ignores, cfg, evLogger, ioLimiter, ver), queue: newJobQueue(), blockPullReorderer: newBlockPullReorderer(cfg.BlockPullOrder, model.id, cfg.DeviceIDs()), - writeLimiter: newByteSemaphore(cfg.MaxConcurrentWrites), + writeLimiter: util.NewSemaphore(cfg.MaxConcurrentWrites), } f.folder.puller = f @@ -1435,7 +1436,7 @@ func (f *sendReceiveFolder) verifyBuffer(buf []byte, block protocol.BlockInfo) e } func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlockState, out chan<- *sharedPullerState) { - requestLimiter := newByteSemaphore(f.PullerMaxPendingKiB * 1024) + requestLimiter := util.NewSemaphore(f.PullerMaxPendingKiB * 1024) wg := sync.NewWaitGroup() for state := range in { @@ -1453,7 +1454,7 @@ func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlock state := state bytes := int(state.block.Size) - if err := requestLimiter.takeWithContext(f.ctx, bytes); err != nil { + if err := requestLimiter.TakeWithContext(f.ctx, bytes); err != nil { state.fail(err) out <- state.sharedPullerState continue @@ -1463,7 +1464,7 @@ func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlock go func() { defer wg.Done() - defer requestLimiter.give(bytes) + defer requestLimiter.Give(bytes) f.pullBlock(state, snap, out) }() @@ -2085,10 +2086,10 @@ func (f *sendReceiveFolder) limitedWriteAt(fd io.WriterAt, data []byte, offset i } func (f *sendReceiveFolder) withLimiter(fn func() error) error { - if err := f.writeLimiter.takeWithContext(f.ctx, 1); err != nil { + if err := f.writeLimiter.TakeWithContext(f.ctx, 1); err != nil { return err } - defer f.writeLimiter.give(1) + defer f.writeLimiter.Give(1) return fn() } diff --git a/lib/model/model.go b/lib/model/model.go index 07bd6f22a..376a69ac0 100644 --- a/lib/model/model.go +++ b/lib/model/model.go @@ -40,6 +40,7 @@ import ( "github.com/syncthing/syncthing/lib/svcutil" "github.com/syncthing/syncthing/lib/sync" "github.com/syncthing/syncthing/lib/ur/contract" + "github.com/syncthing/syncthing/lib/util" "github.com/syncthing/syncthing/lib/versioner" ) @@ -132,10 +133,10 @@ type model struct { shortID protocol.ShortID // globalRequestLimiter limits the amount of data in concurrent incoming // requests - globalRequestLimiter *byteSemaphore + globalRequestLimiter *util.Semaphore // folderIOLimiter limits the number of concurrent I/O heavy operations, // such as scans and pulls. - folderIOLimiter *byteSemaphore + folderIOLimiter *util.Semaphore fatalChan chan error started chan struct{} @@ -155,7 +156,7 @@ type model struct { // fields protected by pmut pmut sync.RWMutex conn map[protocol.DeviceID]protocol.Connection - connRequestLimiters map[protocol.DeviceID]*byteSemaphore + connRequestLimiters map[protocol.DeviceID]*util.Semaphore closed map[protocol.DeviceID]chan struct{} helloMessages map[protocol.DeviceID]protocol.Hello deviceDownloads map[protocol.DeviceID]*deviceDownloadState @@ -166,7 +167,7 @@ type model struct { foldersRunning int32 } -type folderFactory func(*model, *db.FileSet, *ignore.Matcher, config.FolderConfiguration, versioner.Versioner, events.Logger, *byteSemaphore) service +type folderFactory func(*model, *db.FileSet, *ignore.Matcher, config.FolderConfiguration, versioner.Versioner, events.Logger, *util.Semaphore) service var ( folderFactories = make(map[config.FolderType]folderFactory) @@ -220,8 +221,8 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio finder: db.NewBlockFinder(ldb), progressEmitter: NewProgressEmitter(cfg, evLogger), shortID: id.Short(), - globalRequestLimiter: newByteSemaphore(1024 * cfg.Options().MaxConcurrentIncomingRequestKiB()), - folderIOLimiter: newByteSemaphore(cfg.Options().MaxFolderConcurrency()), + globalRequestLimiter: util.NewSemaphore(1024 * cfg.Options().MaxConcurrentIncomingRequestKiB()), + folderIOLimiter: util.NewSemaphore(cfg.Options().MaxFolderConcurrency()), fatalChan: make(chan error), started: make(chan struct{}), @@ -240,7 +241,7 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio // fields protected by pmut pmut: sync.NewRWMutex(), conn: make(map[protocol.DeviceID]protocol.Connection), - connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore), + connRequestLimiters: make(map[protocol.DeviceID]*util.Semaphore), closed: make(map[protocol.DeviceID]chan struct{}), helloMessages: make(map[protocol.DeviceID]protocol.Hello), deviceDownloads: make(map[protocol.DeviceID]*deviceDownloadState), @@ -1906,23 +1907,15 @@ func (m *model) Request(deviceID protocol.DeviceID, folder, name string, blockNo // skipping nil limiters, then returns a requestResponse of the given size. // When the requestResponse is closed the limiters are given back the bytes, // in reverse order. -func newLimitedRequestResponse(size int, limiters ...*byteSemaphore) *requestResponse { - for _, limiter := range limiters { - if limiter != nil { - limiter.take(size) - } - } +func newLimitedRequestResponse(size int, limiters ...*util.Semaphore) *requestResponse { + multi := util.MultiSemaphore(limiters) + multi.Take(size) res := newRequestResponse(size) go func() { res.Wait() - for i := range limiters { - limiter := limiters[len(limiters)-1-i] - if limiter != nil { - limiter.give(size) - } - } + multi.Give(size) }() return res @@ -2230,9 +2223,9 @@ func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) { // 0: default, <0: no limiting switch { case device.MaxRequestKiB > 0: - m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * device.MaxRequestKiB) + m.connRequestLimiters[deviceID] = util.NewSemaphore(1024 * device.MaxRequestKiB) case device.MaxRequestKiB == 0: - m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * defaultPullerPendingKiB) + m.connRequestLimiters[deviceID] = util.NewSemaphore(1024 * defaultPullerPendingKiB) } m.helloMessages[deviceID] = hello @@ -2927,8 +2920,8 @@ func (m *model) CommitConfiguration(from, to config.Configuration) bool { ignoredDevices := observedDeviceSet(to.IgnoredDevices) m.cleanPending(toDevices, toFolders, ignoredDevices, removedFolders) - m.globalRequestLimiter.setCapacity(1024 * to.Options.MaxConcurrentIncomingRequestKiB()) - m.folderIOLimiter.setCapacity(to.Options.MaxFolderConcurrency()) + m.globalRequestLimiter.SetCapacity(1024 * to.Options.MaxConcurrentIncomingRequestKiB()) + m.folderIOLimiter.SetCapacity(to.Options.MaxFolderConcurrency()) // Some options don't require restart as those components handle it fine // by themselves. Compare the options structs containing only the diff --git a/lib/model/model_test.go b/lib/model/model_test.go index db288396d..b665abc01 100644 --- a/lib/model/model_test.go +++ b/lib/model/model_test.go @@ -38,6 +38,7 @@ import ( protocolmocks "github.com/syncthing/syncthing/lib/protocol/mocks" srand "github.com/syncthing/syncthing/lib/rand" "github.com/syncthing/syncthing/lib/testutils" + "github.com/syncthing/syncthing/lib/util" "github.com/syncthing/syncthing/lib/versioner" ) @@ -3319,14 +3320,14 @@ func TestDeviceWasSeen(t *testing.T) { } func TestNewLimitedRequestResponse(t *testing.T) { - l0 := newByteSemaphore(0) - l1 := newByteSemaphore(1024) - l2 := (*byteSemaphore)(nil) + l0 := util.NewSemaphore(0) + l1 := util.NewSemaphore(1024) + l2 := (*util.Semaphore)(nil) // Should take 500 bytes from any non-unlimited non-nil limiters. res := newLimitedRequestResponse(500, l0, l1, l2) - if l1.available != 1024-500 { + if l1.Available() != 1024-500 { t.Error("should have taken bytes from limited limiter") } @@ -3336,7 +3337,7 @@ func TestNewLimitedRequestResponse(t *testing.T) { // Try to take 1024 bytes to make sure the bytes were returned. done := make(chan struct{}) go func() { - l1.take(1024) + l1.Take(1024) close(done) }() select { diff --git a/lib/util/semaphore.go b/lib/util/semaphore.go new file mode 100644 index 000000000..276234ed7 --- /dev/null +++ b/lib/util/semaphore.go @@ -0,0 +1,148 @@ +// Copyright (C) 2018 The Syncthing Authors. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at https://mozilla.org/MPL/2.0/. + +package util + +import ( + "context" + "sync" +) + +type Semaphore struct { + max int + available int + mut sync.Mutex + cond *sync.Cond +} + +func NewSemaphore(max int) *Semaphore { + if max < 0 { + max = 0 + } + s := Semaphore{ + max: max, + available: max, + } + s.cond = sync.NewCond(&s.mut) + return &s +} + +func (s *Semaphore) TakeWithContext(ctx context.Context, size int) error { + done := make(chan struct{}) + var err error + go func() { + err = s.takeInner(ctx, size) + close(done) + }() + select { + case <-done: + case <-ctx.Done(): + s.cond.Broadcast() + <-done + } + return err +} + +func (s *Semaphore) Take(size int) { + _ = s.takeInner(context.Background(), size) +} + +func (s *Semaphore) takeInner(ctx context.Context, size int) error { + // Checking context for size <= s.available is required for testing and doesn't do any harm. + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + s.mut.Lock() + defer s.mut.Unlock() + if size > s.max { + size = s.max + } + for size > s.available { + s.cond.Wait() + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + if size > s.max { + size = s.max + } + } + s.available -= size + return nil +} + +func (s *Semaphore) Give(size int) { + s.mut.Lock() + if size > s.max { + size = s.max + } + if s.available+size > s.max { + s.available = s.max + } else { + s.available += size + } + s.cond.Broadcast() + s.mut.Unlock() +} + +func (s *Semaphore) SetCapacity(capacity int) { + if capacity < 0 { + capacity = 0 + } + s.mut.Lock() + diff := capacity - s.max + s.max = capacity + s.available += diff + if s.available < 0 { + s.available = 0 + } else if s.available > s.max { + s.available = s.max + } + s.cond.Broadcast() + s.mut.Unlock() +} + +func (s *Semaphore) Available() int { + s.mut.Lock() + defer s.mut.Unlock() + return s.available +} + +// MultiSemaphore combines semaphores, making sure to always take and give in +// the same order (reversed for give). A semaphore may be nil, in which case it +// is skipped. +type MultiSemaphore []*Semaphore + +func (s MultiSemaphore) TakeWithContext(ctx context.Context, size int) error { + for _, limiter := range s { + if limiter != nil { + if err := limiter.TakeWithContext(ctx, size); err != nil { + return err + } + } + } + return nil +} + +func (s MultiSemaphore) Take(size int) { + for _, limiter := range s { + if limiter != nil { + limiter.Take(size) + } + } +} + +func (s MultiSemaphore) Give(size int) { + for i := range s { + limiter := s[len(s)-1-i] + if limiter != nil { + limiter.Give(size) + } + } +} diff --git a/lib/model/bytesemaphore_test.go b/lib/util/semaphore_test.go similarity index 82% rename from lib/model/bytesemaphore_test.go rename to lib/util/semaphore_test.go index 1efaa0069..99b583887 100644 --- a/lib/model/bytesemaphore_test.go +++ b/lib/util/semaphore_test.go @@ -4,38 +4,38 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at https://mozilla.org/MPL/2.0/. -package model +package util import "testing" func TestZeroByteSempahore(t *testing.T) { // A semaphore with zero capacity is just a no-op. - s := newByteSemaphore(0) + s := NewSemaphore(0) // None of these should block or panic - s.take(123) - s.take(456) - s.give(1 << 30) + s.Take(123) + s.Take(456) + s.Give(1 << 30) } func TestByteSempahoreCapChangeUp(t *testing.T) { // Waiting takes should unblock when the capacity increases - s := newByteSemaphore(100) + s := NewSemaphore(100) - s.take(75) + s.Take(75) if s.available != 25 { t.Error("bad state after take") } gotit := make(chan struct{}) go func() { - s.take(75) + s.Take(75) close(gotit) }() - s.setCapacity(155) + s.SetCapacity(155) <-gotit if s.available != 5 { t.Error("bad state after both takes") @@ -45,19 +45,19 @@ func TestByteSempahoreCapChangeUp(t *testing.T) { func TestByteSempahoreCapChangeDown1(t *testing.T) { // Things should make sense when capacity is adjusted down - s := newByteSemaphore(100) + s := NewSemaphore(100) - s.take(75) + s.Take(75) if s.available != 25 { t.Error("bad state after take") } - s.setCapacity(90) + s.SetCapacity(90) if s.available != 15 { t.Error("bad state after adjust") } - s.give(75) + s.Give(75) if s.available != 90 { t.Error("bad state after give") } @@ -66,19 +66,19 @@ func TestByteSempahoreCapChangeDown1(t *testing.T) { func TestByteSempahoreCapChangeDown2(t *testing.T) { // Things should make sense when capacity is adjusted down, different case - s := newByteSemaphore(100) + s := NewSemaphore(100) - s.take(75) + s.Take(75) if s.available != 25 { t.Error("bad state after take") } - s.setCapacity(10) + s.SetCapacity(10) if s.available != 0 { t.Error("bad state after adjust") } - s.give(75) + s.Give(75) if s.available != 10 { t.Error("bad state after give") } @@ -87,26 +87,26 @@ func TestByteSempahoreCapChangeDown2(t *testing.T) { func TestByteSempahoreGiveMore(t *testing.T) { // We shouldn't end up with more available than we have capacity... - s := newByteSemaphore(100) + s := NewSemaphore(100) - s.take(150) + s.Take(150) if s.available != 0 { t.Errorf("bad state after large take") } - s.give(150) + s.Give(150) if s.available != 100 { t.Errorf("bad state after large take + give") } - s.take(150) - s.setCapacity(125) + s.Take(150) + s.SetCapacity(125) // available was zero before, we're increasing capacity by 25 if s.available != 25 { t.Errorf("bad state after setcap") } - s.give(150) + s.Give(150) if s.available != 125 { t.Errorf("bad state after large take + give with adjustment") }