lib/connections: Dial devices in parallel (#7783)

This commit is contained in:
Simon Frei 2021-06-25 11:38:04 +02:00 committed by GitHub
parent 993a3ebe73
commit c78fa42f31
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 258 additions and 188 deletions

View File

@ -66,6 +66,8 @@ const (
worstDialerPriority = math.MaxInt32
recentlySeenCutoff = 7 * 24 * time.Hour
shortLivedConnectionThreshold = 5 * time.Second
dialMaxParallel = 64
dialMaxParallelPerDevice = 8
)
// From go/src/crypto/tls/cipher_suites.go
@ -490,14 +492,40 @@ func (s *service) dialDevices(ctx context.Context, now time.Time, cfg config.Con
// Perform dials according to the queue, stopping when we've reached the
// allowed additional number of connections (if limited).
numConns := 0
for _, entry := range queue {
if conn, ok := s.dialParallel(ctx, entry.id, entry.targets); ok {
s.conns <- conn
numConns++
if allowAdditional > 0 && numConns >= allowAdditional {
break
}
var numConnsMut stdsync.Mutex
dialSemaphore := util.NewSemaphore(dialMaxParallel)
dialWG := new(stdsync.WaitGroup)
dialCtx, dialCancel := context.WithCancel(ctx)
defer func() {
dialWG.Wait()
dialCancel()
}()
for i := range queue {
select {
case <-dialCtx.Done():
return
default:
}
dialWG.Add(1)
go func(entry dialQueueEntry) {
defer dialWG.Done()
conn, ok := s.dialParallel(dialCtx, entry.id, entry.targets, dialSemaphore)
if !ok {
return
}
numConnsMut.Lock()
if allowAdditional == 0 || numConns < allowAdditional {
select {
case s.conns <- conn:
numConns++
if allowAdditional > 0 && numConns >= allowAdditional {
dialCancel()
}
case <-dialCtx.Done():
}
}
numConnsMut.Unlock()
}(queue[i])
}
}
@ -959,7 +987,7 @@ func IsAllowedNetwork(host string, allowed []string) bool {
return false
}
func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget) (internalConn, bool) {
func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID, dialTargets []dialTarget, parentSema *util.Semaphore) (internalConn, bool) {
// Group targets into buckets by priority
dialTargetBuckets := make(map[int][]dialTarget, len(dialTargets))
for _, tgt := range dialTargets {
@ -975,13 +1003,19 @@ func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID,
// Sort the priorities so that we dial lowest first (which means highest...)
sort.Ints(priorities)
sema := util.MultiSemaphore{util.NewSemaphore(dialMaxParallelPerDevice), parentSema}
for _, prio := range priorities {
tgts := dialTargetBuckets[prio]
res := make(chan internalConn, len(tgts))
wg := stdsync.WaitGroup{}
for _, tgt := range tgts {
sema.Take(1)
wg.Add(1)
go func(tgt dialTarget) {
defer func() {
wg.Done()
sema.Give(1)
}()
conn, err := tgt.Dial(ctx)
if err == nil {
// Closes the connection on error
@ -994,7 +1028,6 @@ func (s *service) dialParallel(ctx context.Context, deviceID protocol.DeviceID,
l.Debugln("dialing", deviceID, tgt.uri, "success:", conn)
res <- conn
}
wg.Done()
}(tgt)
}

View File

@ -1,109 +0,0 @@
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package model
import (
"context"
"sync"
)
type byteSemaphore struct {
max int
available int
mut sync.Mutex
cond *sync.Cond
}
func newByteSemaphore(max int) *byteSemaphore {
if max < 0 {
max = 0
}
s := byteSemaphore{
max: max,
available: max,
}
s.cond = sync.NewCond(&s.mut)
return &s
}
func (s *byteSemaphore) takeWithContext(ctx context.Context, bytes int) error {
done := make(chan struct{})
var err error
go func() {
err = s.takeInner(ctx, bytes)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
s.cond.Broadcast()
<-done
}
return err
}
func (s *byteSemaphore) take(bytes int) {
_ = s.takeInner(context.Background(), bytes)
}
func (s *byteSemaphore) takeInner(ctx context.Context, bytes int) error {
// Checking context for bytes <= s.available is required for testing and doesn't do any harm.
select {
case <-ctx.Done():
return ctx.Err()
default:
}
s.mut.Lock()
defer s.mut.Unlock()
if bytes > s.max {
bytes = s.max
}
for bytes > s.available {
s.cond.Wait()
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if bytes > s.max {
bytes = s.max
}
}
s.available -= bytes
return nil
}
func (s *byteSemaphore) give(bytes int) {
s.mut.Lock()
if bytes > s.max {
bytes = s.max
}
if s.available+bytes > s.max {
s.available = s.max
} else {
s.available += bytes
}
s.cond.Broadcast()
s.mut.Unlock()
}
func (s *byteSemaphore) setCapacity(cap int) {
if cap < 0 {
cap = 0
}
s.mut.Lock()
diff := cap - s.max
s.max = cap
s.available += diff
if s.available < 0 {
s.available = 0
} else if s.available > s.max {
s.available = s.max
}
s.cond.Broadcast()
s.mut.Unlock()
}

View File

@ -38,7 +38,7 @@ type folder struct {
stateTracker
config.FolderConfiguration
*stats.FolderStatisticsReference
ioLimiter *byteSemaphore
ioLimiter *util.Semaphore
localFlags uint32
@ -91,7 +91,7 @@ type puller interface {
pull() (bool, error) // true when successful and should not be retried
}
func newFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, evLogger events.Logger, ioLimiter *byteSemaphore, ver versioner.Versioner) folder {
func newFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, evLogger events.Logger, ioLimiter *util.Semaphore, ver versioner.Versioner) folder {
f := folder{
stateTracker: newStateTracker(cfg.ID, evLogger),
FolderConfiguration: cfg,
@ -375,10 +375,10 @@ func (f *folder) pull() (success bool, err error) {
if f.Type != config.FolderTypeSendOnly {
f.setState(FolderSyncWaiting)
if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil {
if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil {
return true, err
}
defer f.ioLimiter.give(1)
defer f.ioLimiter.Give(1)
}
startTime := time.Now()
@ -439,10 +439,10 @@ func (f *folder) scanSubdirs(subDirs []string) error {
f.setState(FolderScanWaiting)
defer f.setState(FolderIdle)
if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil {
if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil {
return err
}
defer f.ioLimiter.give(1)
defer f.ioLimiter.Give(1)
for i := range subDirs {
sub := osutil.NativeFilename(subDirs[i])
@ -870,10 +870,10 @@ func (f *folder) versionCleanupTimerFired() {
f.setState(FolderCleanWaiting)
defer f.setState(FolderIdle)
if err := f.ioLimiter.takeWithContext(f.ctx, 1); err != nil {
if err := f.ioLimiter.TakeWithContext(f.ctx, 1); err != nil {
return
}
defer f.ioLimiter.give(1)
defer f.ioLimiter.Give(1)
f.setState(FolderCleaning)

View File

@ -16,6 +16,7 @@ import (
"github.com/syncthing/syncthing/lib/fs"
"github.com/syncthing/syncthing/lib/ignore"
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/util"
"github.com/syncthing/syncthing/lib/versioner"
)
@ -27,7 +28,7 @@ type receiveEncryptedFolder struct {
*sendReceiveFolder
}
func newReceiveEncryptedFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service {
func newReceiveEncryptedFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service {
return &receiveEncryptedFolder{newSendReceiveFolder(model, fset, ignores, cfg, ver, evLogger, ioLimiter).(*sendReceiveFolder)}
}

View File

@ -15,6 +15,7 @@ import (
"github.com/syncthing/syncthing/lib/events"
"github.com/syncthing/syncthing/lib/ignore"
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/util"
"github.com/syncthing/syncthing/lib/versioner"
)
@ -56,7 +57,7 @@ type receiveOnlyFolder struct {
*sendReceiveFolder
}
func newReceiveOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service {
func newReceiveOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service {
sr := newSendReceiveFolder(model, fset, ignores, cfg, ver, evLogger, ioLimiter).(*sendReceiveFolder)
sr.localFlags = protocol.FlagLocalReceiveOnly // gets propagated to the scanner, and set on locally changed files
return &receiveOnlyFolder{sr}

View File

@ -12,6 +12,7 @@ import (
"github.com/syncthing/syncthing/lib/events"
"github.com/syncthing/syncthing/lib/ignore"
"github.com/syncthing/syncthing/lib/protocol"
"github.com/syncthing/syncthing/lib/util"
"github.com/syncthing/syncthing/lib/versioner"
)
@ -23,7 +24,7 @@ type sendOnlyFolder struct {
folder
}
func newSendOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, _ versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service {
func newSendOnlyFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, _ versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service {
f := &sendOnlyFolder{
folder: newFolder(model, fset, ignores, cfg, evLogger, ioLimiter, nil),
}

View File

@ -28,6 +28,7 @@ import (
"github.com/syncthing/syncthing/lib/scanner"
"github.com/syncthing/syncthing/lib/sha256"
"github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/util"
"github.com/syncthing/syncthing/lib/versioner"
"github.com/syncthing/syncthing/lib/weakhash"
)
@ -123,17 +124,17 @@ type sendReceiveFolder struct {
queue *jobQueue
blockPullReorderer blockPullReorderer
writeLimiter *byteSemaphore
writeLimiter *util.Semaphore
tempPullErrors map[string]string // pull errors that might be just transient
}
func newSendReceiveFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *byteSemaphore) service {
func newSendReceiveFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg config.FolderConfiguration, ver versioner.Versioner, evLogger events.Logger, ioLimiter *util.Semaphore) service {
f := &sendReceiveFolder{
folder: newFolder(model, fset, ignores, cfg, evLogger, ioLimiter, ver),
queue: newJobQueue(),
blockPullReorderer: newBlockPullReorderer(cfg.BlockPullOrder, model.id, cfg.DeviceIDs()),
writeLimiter: newByteSemaphore(cfg.MaxConcurrentWrites),
writeLimiter: util.NewSemaphore(cfg.MaxConcurrentWrites),
}
f.folder.puller = f
@ -1435,7 +1436,7 @@ func (f *sendReceiveFolder) verifyBuffer(buf []byte, block protocol.BlockInfo) e
}
func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlockState, out chan<- *sharedPullerState) {
requestLimiter := newByteSemaphore(f.PullerMaxPendingKiB * 1024)
requestLimiter := util.NewSemaphore(f.PullerMaxPendingKiB * 1024)
wg := sync.NewWaitGroup()
for state := range in {
@ -1453,7 +1454,7 @@ func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlock
state := state
bytes := int(state.block.Size)
if err := requestLimiter.takeWithContext(f.ctx, bytes); err != nil {
if err := requestLimiter.TakeWithContext(f.ctx, bytes); err != nil {
state.fail(err)
out <- state.sharedPullerState
continue
@ -1463,7 +1464,7 @@ func (f *sendReceiveFolder) pullerRoutine(snap *db.Snapshot, in <-chan pullBlock
go func() {
defer wg.Done()
defer requestLimiter.give(bytes)
defer requestLimiter.Give(bytes)
f.pullBlock(state, snap, out)
}()
@ -2085,10 +2086,10 @@ func (f *sendReceiveFolder) limitedWriteAt(fd io.WriterAt, data []byte, offset i
}
func (f *sendReceiveFolder) withLimiter(fn func() error) error {
if err := f.writeLimiter.takeWithContext(f.ctx, 1); err != nil {
if err := f.writeLimiter.TakeWithContext(f.ctx, 1); err != nil {
return err
}
defer f.writeLimiter.give(1)
defer f.writeLimiter.Give(1)
return fn()
}

View File

@ -40,6 +40,7 @@ import (
"github.com/syncthing/syncthing/lib/svcutil"
"github.com/syncthing/syncthing/lib/sync"
"github.com/syncthing/syncthing/lib/ur/contract"
"github.com/syncthing/syncthing/lib/util"
"github.com/syncthing/syncthing/lib/versioner"
)
@ -132,10 +133,10 @@ type model struct {
shortID protocol.ShortID
// globalRequestLimiter limits the amount of data in concurrent incoming
// requests
globalRequestLimiter *byteSemaphore
globalRequestLimiter *util.Semaphore
// folderIOLimiter limits the number of concurrent I/O heavy operations,
// such as scans and pulls.
folderIOLimiter *byteSemaphore
folderIOLimiter *util.Semaphore
fatalChan chan error
started chan struct{}
@ -155,7 +156,7 @@ type model struct {
// fields protected by pmut
pmut sync.RWMutex
conn map[protocol.DeviceID]protocol.Connection
connRequestLimiters map[protocol.DeviceID]*byteSemaphore
connRequestLimiters map[protocol.DeviceID]*util.Semaphore
closed map[protocol.DeviceID]chan struct{}
helloMessages map[protocol.DeviceID]protocol.Hello
deviceDownloads map[protocol.DeviceID]*deviceDownloadState
@ -166,7 +167,7 @@ type model struct {
foldersRunning int32
}
type folderFactory func(*model, *db.FileSet, *ignore.Matcher, config.FolderConfiguration, versioner.Versioner, events.Logger, *byteSemaphore) service
type folderFactory func(*model, *db.FileSet, *ignore.Matcher, config.FolderConfiguration, versioner.Versioner, events.Logger, *util.Semaphore) service
var (
folderFactories = make(map[config.FolderType]folderFactory)
@ -220,8 +221,8 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio
finder: db.NewBlockFinder(ldb),
progressEmitter: NewProgressEmitter(cfg, evLogger),
shortID: id.Short(),
globalRequestLimiter: newByteSemaphore(1024 * cfg.Options().MaxConcurrentIncomingRequestKiB()),
folderIOLimiter: newByteSemaphore(cfg.Options().MaxFolderConcurrency()),
globalRequestLimiter: util.NewSemaphore(1024 * cfg.Options().MaxConcurrentIncomingRequestKiB()),
folderIOLimiter: util.NewSemaphore(cfg.Options().MaxFolderConcurrency()),
fatalChan: make(chan error),
started: make(chan struct{}),
@ -240,7 +241,7 @@ func NewModel(cfg config.Wrapper, id protocol.DeviceID, clientName, clientVersio
// fields protected by pmut
pmut: sync.NewRWMutex(),
conn: make(map[protocol.DeviceID]protocol.Connection),
connRequestLimiters: make(map[protocol.DeviceID]*byteSemaphore),
connRequestLimiters: make(map[protocol.DeviceID]*util.Semaphore),
closed: make(map[protocol.DeviceID]chan struct{}),
helloMessages: make(map[protocol.DeviceID]protocol.Hello),
deviceDownloads: make(map[protocol.DeviceID]*deviceDownloadState),
@ -1906,23 +1907,15 @@ func (m *model) Request(deviceID protocol.DeviceID, folder, name string, blockNo
// skipping nil limiters, then returns a requestResponse of the given size.
// When the requestResponse is closed the limiters are given back the bytes,
// in reverse order.
func newLimitedRequestResponse(size int, limiters ...*byteSemaphore) *requestResponse {
for _, limiter := range limiters {
if limiter != nil {
limiter.take(size)
}
}
func newLimitedRequestResponse(size int, limiters ...*util.Semaphore) *requestResponse {
multi := util.MultiSemaphore(limiters)
multi.Take(size)
res := newRequestResponse(size)
go func() {
res.Wait()
for i := range limiters {
limiter := limiters[len(limiters)-1-i]
if limiter != nil {
limiter.give(size)
}
}
multi.Give(size)
}()
return res
@ -2230,9 +2223,9 @@ func (m *model) AddConnection(conn protocol.Connection, hello protocol.Hello) {
// 0: default, <0: no limiting
switch {
case device.MaxRequestKiB > 0:
m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * device.MaxRequestKiB)
m.connRequestLimiters[deviceID] = util.NewSemaphore(1024 * device.MaxRequestKiB)
case device.MaxRequestKiB == 0:
m.connRequestLimiters[deviceID] = newByteSemaphore(1024 * defaultPullerPendingKiB)
m.connRequestLimiters[deviceID] = util.NewSemaphore(1024 * defaultPullerPendingKiB)
}
m.helloMessages[deviceID] = hello
@ -2927,8 +2920,8 @@ func (m *model) CommitConfiguration(from, to config.Configuration) bool {
ignoredDevices := observedDeviceSet(to.IgnoredDevices)
m.cleanPending(toDevices, toFolders, ignoredDevices, removedFolders)
m.globalRequestLimiter.setCapacity(1024 * to.Options.MaxConcurrentIncomingRequestKiB())
m.folderIOLimiter.setCapacity(to.Options.MaxFolderConcurrency())
m.globalRequestLimiter.SetCapacity(1024 * to.Options.MaxConcurrentIncomingRequestKiB())
m.folderIOLimiter.SetCapacity(to.Options.MaxFolderConcurrency())
// Some options don't require restart as those components handle it fine
// by themselves. Compare the options structs containing only the

View File

@ -38,6 +38,7 @@ import (
protocolmocks "github.com/syncthing/syncthing/lib/protocol/mocks"
srand "github.com/syncthing/syncthing/lib/rand"
"github.com/syncthing/syncthing/lib/testutils"
"github.com/syncthing/syncthing/lib/util"
"github.com/syncthing/syncthing/lib/versioner"
)
@ -3319,14 +3320,14 @@ func TestDeviceWasSeen(t *testing.T) {
}
func TestNewLimitedRequestResponse(t *testing.T) {
l0 := newByteSemaphore(0)
l1 := newByteSemaphore(1024)
l2 := (*byteSemaphore)(nil)
l0 := util.NewSemaphore(0)
l1 := util.NewSemaphore(1024)
l2 := (*util.Semaphore)(nil)
// Should take 500 bytes from any non-unlimited non-nil limiters.
res := newLimitedRequestResponse(500, l0, l1, l2)
if l1.available != 1024-500 {
if l1.Available() != 1024-500 {
t.Error("should have taken bytes from limited limiter")
}
@ -3336,7 +3337,7 @@ func TestNewLimitedRequestResponse(t *testing.T) {
// Try to take 1024 bytes to make sure the bytes were returned.
done := make(chan struct{})
go func() {
l1.take(1024)
l1.Take(1024)
close(done)
}()
select {

148
lib/util/semaphore.go Normal file
View File

@ -0,0 +1,148 @@
// Copyright (C) 2018 The Syncthing Authors.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package util
import (
"context"
"sync"
)
type Semaphore struct {
max int
available int
mut sync.Mutex
cond *sync.Cond
}
func NewSemaphore(max int) *Semaphore {
if max < 0 {
max = 0
}
s := Semaphore{
max: max,
available: max,
}
s.cond = sync.NewCond(&s.mut)
return &s
}
func (s *Semaphore) TakeWithContext(ctx context.Context, size int) error {
done := make(chan struct{})
var err error
go func() {
err = s.takeInner(ctx, size)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
s.cond.Broadcast()
<-done
}
return err
}
func (s *Semaphore) Take(size int) {
_ = s.takeInner(context.Background(), size)
}
func (s *Semaphore) takeInner(ctx context.Context, size int) error {
// Checking context for size <= s.available is required for testing and doesn't do any harm.
select {
case <-ctx.Done():
return ctx.Err()
default:
}
s.mut.Lock()
defer s.mut.Unlock()
if size > s.max {
size = s.max
}
for size > s.available {
s.cond.Wait()
select {
case <-ctx.Done():
return ctx.Err()
default:
}
if size > s.max {
size = s.max
}
}
s.available -= size
return nil
}
func (s *Semaphore) Give(size int) {
s.mut.Lock()
if size > s.max {
size = s.max
}
if s.available+size > s.max {
s.available = s.max
} else {
s.available += size
}
s.cond.Broadcast()
s.mut.Unlock()
}
func (s *Semaphore) SetCapacity(capacity int) {
if capacity < 0 {
capacity = 0
}
s.mut.Lock()
diff := capacity - s.max
s.max = capacity
s.available += diff
if s.available < 0 {
s.available = 0
} else if s.available > s.max {
s.available = s.max
}
s.cond.Broadcast()
s.mut.Unlock()
}
func (s *Semaphore) Available() int {
s.mut.Lock()
defer s.mut.Unlock()
return s.available
}
// MultiSemaphore combines semaphores, making sure to always take and give in
// the same order (reversed for give). A semaphore may be nil, in which case it
// is skipped.
type MultiSemaphore []*Semaphore
func (s MultiSemaphore) TakeWithContext(ctx context.Context, size int) error {
for _, limiter := range s {
if limiter != nil {
if err := limiter.TakeWithContext(ctx, size); err != nil {
return err
}
}
}
return nil
}
func (s MultiSemaphore) Take(size int) {
for _, limiter := range s {
if limiter != nil {
limiter.Take(size)
}
}
}
func (s MultiSemaphore) Give(size int) {
for i := range s {
limiter := s[len(s)-1-i]
if limiter != nil {
limiter.Give(size)
}
}
}

View File

@ -4,38 +4,38 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at https://mozilla.org/MPL/2.0/.
package model
package util
import "testing"
func TestZeroByteSempahore(t *testing.T) {
// A semaphore with zero capacity is just a no-op.
s := newByteSemaphore(0)
s := NewSemaphore(0)
// None of these should block or panic
s.take(123)
s.take(456)
s.give(1 << 30)
s.Take(123)
s.Take(456)
s.Give(1 << 30)
}
func TestByteSempahoreCapChangeUp(t *testing.T) {
// Waiting takes should unblock when the capacity increases
s := newByteSemaphore(100)
s := NewSemaphore(100)
s.take(75)
s.Take(75)
if s.available != 25 {
t.Error("bad state after take")
}
gotit := make(chan struct{})
go func() {
s.take(75)
s.Take(75)
close(gotit)
}()
s.setCapacity(155)
s.SetCapacity(155)
<-gotit
if s.available != 5 {
t.Error("bad state after both takes")
@ -45,19 +45,19 @@ func TestByteSempahoreCapChangeUp(t *testing.T) {
func TestByteSempahoreCapChangeDown1(t *testing.T) {
// Things should make sense when capacity is adjusted down
s := newByteSemaphore(100)
s := NewSemaphore(100)
s.take(75)
s.Take(75)
if s.available != 25 {
t.Error("bad state after take")
}
s.setCapacity(90)
s.SetCapacity(90)
if s.available != 15 {
t.Error("bad state after adjust")
}
s.give(75)
s.Give(75)
if s.available != 90 {
t.Error("bad state after give")
}
@ -66,19 +66,19 @@ func TestByteSempahoreCapChangeDown1(t *testing.T) {
func TestByteSempahoreCapChangeDown2(t *testing.T) {
// Things should make sense when capacity is adjusted down, different case
s := newByteSemaphore(100)
s := NewSemaphore(100)
s.take(75)
s.Take(75)
if s.available != 25 {
t.Error("bad state after take")
}
s.setCapacity(10)
s.SetCapacity(10)
if s.available != 0 {
t.Error("bad state after adjust")
}
s.give(75)
s.Give(75)
if s.available != 10 {
t.Error("bad state after give")
}
@ -87,26 +87,26 @@ func TestByteSempahoreCapChangeDown2(t *testing.T) {
func TestByteSempahoreGiveMore(t *testing.T) {
// We shouldn't end up with more available than we have capacity...
s := newByteSemaphore(100)
s := NewSemaphore(100)
s.take(150)
s.Take(150)
if s.available != 0 {
t.Errorf("bad state after large take")
}
s.give(150)
s.Give(150)
if s.available != 100 {
t.Errorf("bad state after large take + give")
}
s.take(150)
s.setCapacity(125)
s.Take(150)
s.SetCapacity(125)
// available was zero before, we're increasing capacity by 25
if s.available != 25 {
t.Errorf("bad state after setcap")
}
s.give(150)
s.Give(150)
if s.available != 125 {
t.Errorf("bad state after large take + give with adjustment")
}