From 38f2b34d29704db2ce8ca064e35937fd84792f91 Mon Sep 17 00:00:00 2001 From: greatroar <61184462+greatroar@users.noreply.github.com> Date: Tue, 7 Feb 2023 12:07:34 +0100 Subject: [PATCH] all: Use new Go 1.19 atomic types (#8772) --- cmd/strelaysrv/listener.go | 10 ++++----- cmd/strelaysrv/main.go | 8 +++---- cmd/strelaysrv/session.go | 10 ++++----- cmd/strelaysrv/status.go | 12 +++++------ cmd/stvanity/main.go | 10 ++++----- lib/config/wrapper.go | 12 +++-------- lib/connections/limiter.go | 22 ++++--------------- lib/connections/limiter_test.go | 9 ++++---- lib/connections/quic_listen.go | 8 +++---- lib/model/folder.go | 5 ++--- lib/model/model.go | 3 ++- lib/model/model_test.go | 5 ++--- lib/protocol/bufferpool.go | 20 ++++++++--------- lib/protocol/bufferpool_test.go | 8 +++---- lib/protocol/counting.go | 38 +++++++++++++++------------------ lib/protocol/protocol_test.go | 4 ++-- lib/scanner/walk.go | 8 +++---- lib/stun/stun.go | 15 ++++++------- lib/sync/sync.go | 8 +++---- 19 files changed, 94 insertions(+), 121 deletions(-) diff --git a/cmd/strelaysrv/listener.go b/cmd/strelaysrv/listener.go index d31f9b84a..dc3724c0f 100644 --- a/cmd/strelaysrv/listener.go +++ b/cmd/strelaysrv/listener.go @@ -20,7 +20,7 @@ import ( var ( outboxesMut = sync.RWMutex{} outboxes = make(map[syncthingprotocol.DeviceID]chan interface{}) - numConnections int64 + numConnections atomic.Int64 ) func listener(_, addr string, config *tls.Config, token string) { @@ -128,7 +128,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config, token strin continue } - if atomic.LoadInt32(&overLimit) > 0 { + if overLimit.Load() { protocol.WriteMessage(conn, protocol.RelayFull{}) if debug { log.Println("Refusing join request from", id, "due to being over limits") @@ -267,7 +267,7 @@ func protocolConnectionHandler(tcpConn net.Conn, config *tls.Config, token strin conn.Close() } - if atomic.LoadInt32(&overLimit) > 0 && !hasSessions(id) { + if overLimit.Load() && !hasSessions(id) { if debug { log.Println("Dropping", id, "as it has no sessions and we are over our limits") } @@ -360,8 +360,8 @@ func sessionConnectionHandler(conn net.Conn) { } func messageReader(conn net.Conn, messages chan<- interface{}, errors chan<- error) { - atomic.AddInt64(&numConnections, 1) - defer atomic.AddInt64(&numConnections, -1) + numConnections.Add(1) + defer numConnections.Add(-1) for { msg, err := protocol.ReadMessage(conn) diff --git a/cmd/strelaysrv/main.go b/cmd/strelaysrv/main.go index bdf317875..be682a565 100644 --- a/cmd/strelaysrv/main.go +++ b/cmd/strelaysrv/main.go @@ -49,7 +49,7 @@ var ( sessionLimitBps int globalLimitBps int - overLimit int32 + overLimit atomic.Bool descriptorLimit int64 sessionLimiter *rate.Limiter globalLimiter *rate.Limiter @@ -308,10 +308,10 @@ func main() { func monitorLimits() { limitCheckTimer = time.NewTimer(time.Minute) for range limitCheckTimer.C { - if atomic.LoadInt64(&numConnections)+atomic.LoadInt64(&numProxies) > descriptorLimit { - atomic.StoreInt32(&overLimit, 1) + if numConnections.Load()+numProxies.Load() > descriptorLimit { + overLimit.Store(true) log.Println("Gone past our connection limits. Starting to refuse new/drop idle connections.") - } else if atomic.CompareAndSwapInt32(&overLimit, 1, 0) { + } else if overLimit.CompareAndSwap(true, false) { log.Println("Dropped below our connection limits. Accepting new connections.") } limitCheckTimer.Reset(time.Minute) diff --git a/cmd/strelaysrv/session.go b/cmd/strelaysrv/session.go index 0221434bd..52cc39441 100644 --- a/cmd/strelaysrv/session.go +++ b/cmd/strelaysrv/session.go @@ -23,8 +23,8 @@ var ( sessionMut = sync.RWMutex{} activeSessions = make([]*session, 0) pendingSessions = make(map[string]*session) - numProxies int64 - bytesProxied int64 + numProxies atomic.Int64 + bytesProxied atomic.Int64 ) func newSession(serverid, clientid syncthingprotocol.DeviceID, sessionRateLimit, globalRateLimit *rate.Limiter) *session { @@ -251,8 +251,8 @@ func (s *session) proxy(c1, c2 net.Conn) error { log.Println("Proxy", c1.RemoteAddr(), "->", c2.RemoteAddr()) } - atomic.AddInt64(&numProxies, 1) - defer atomic.AddInt64(&numProxies, -1) + numProxies.Add(1) + defer numProxies.Add(-1) buf := make([]byte, networkBufferSize) for { @@ -262,7 +262,7 @@ func (s *session) proxy(c1, c2 net.Conn) error { return err } - atomic.AddInt64(&bytesProxied, int64(n)) + bytesProxied.Add(int64(n)) if debug { log.Printf("%d bytes from %s to %s", n, c1.RemoteAddr(), c2.RemoteAddr()) diff --git a/cmd/strelaysrv/status.go b/cmd/strelaysrv/status.go index 838d653b0..b5437ceb1 100644 --- a/cmd/strelaysrv/status.go +++ b/cmd/strelaysrv/status.go @@ -51,9 +51,9 @@ func getStatus(w http.ResponseWriter, _ *http.Request) { status["numPendingSessionKeys"] = len(pendingSessions) status["numActiveSessions"] = len(activeSessions) sessionMut.Unlock() - status["numConnections"] = atomic.LoadInt64(&numConnections) - status["numProxies"] = atomic.LoadInt64(&numProxies) - status["bytesProxied"] = atomic.LoadInt64(&bytesProxied) + status["numConnections"] = numConnections.Load() + status["numProxies"] = numProxies.Load() + status["bytesProxied"] = bytesProxied.Load() status["goVersion"] = runtime.Version() status["goOS"] = runtime.GOOS status["goArch"] = runtime.GOARCH @@ -88,13 +88,13 @@ func getStatus(w http.ResponseWriter, _ *http.Request) { } type rateCalculator struct { - counter *int64 // atomic, must remain 64-bit aligned + counter *atomic.Int64 rates []int64 prev int64 startTime time.Time } -func newRateCalculator(keepIntervals int, interval time.Duration, counter *int64) *rateCalculator { +func newRateCalculator(keepIntervals int, interval time.Duration, counter *atomic.Int64) *rateCalculator { r := &rateCalculator{ rates: make([]int64, keepIntervals), counter: counter, @@ -112,7 +112,7 @@ func (r *rateCalculator) updateRates(interval time.Duration) { next := now.Truncate(interval).Add(interval) time.Sleep(next.Sub(now)) - cur := atomic.LoadInt64(r.counter) + cur := r.counter.Load() rate := int64(float64(cur-r.prev) / interval.Seconds()) copy(r.rates[1:], r.rates) r.rates[0] = rate diff --git a/cmd/stvanity/main.go b/cmd/stvanity/main.go index 128f35300..8917163b4 100644 --- a/cmd/stvanity/main.go +++ b/cmd/stvanity/main.go @@ -44,7 +44,7 @@ func main() { found := make(chan result) stop := make(chan struct{}) - var count int64 + var count atomic.Int64 // Print periodic progress reports. go printProgress(prefix, &count) @@ -72,7 +72,7 @@ func main() { // Try certificates until one is found that has the prefix at the start of // the resulting device ID. Increments count atomically, sends the result to // found, returns when stop is closed. -func generatePrefixed(prefix string, count *int64, found chan<- result, stop <-chan struct{}) { +func generatePrefixed(prefix string, count *atomic.Int64, found chan<- result, stop <-chan struct{}) { notBefore := time.Now() notAfter := time.Date(2049, 12, 31, 23, 59, 59, 0, time.UTC) @@ -109,7 +109,7 @@ func generatePrefixed(prefix string, count *int64, found chan<- result, stop <-c } id := protocol.NewDeviceID(derBytes) - atomic.AddInt64(count, 1) + count.Add(1) if strings.HasPrefix(id.String(), prefix) { select { @@ -121,7 +121,7 @@ func generatePrefixed(prefix string, count *int64, found chan<- result, stop <-c } } -func printProgress(prefix string, count *int64) { +func printProgress(prefix string, count *atomic.Int64) { started := time.Now() wantBits := 5 * len(prefix) if wantBits > 63 { @@ -132,7 +132,7 @@ func printProgress(prefix string, count *int64) { fmt.Printf("Want %d bits for prefix %q, about %.2g certs to test (statistically speaking)\n", wantBits, prefix, expectedIterations) for range time.NewTicker(15 * time.Second).C { - tried := atomic.LoadInt64(count) + tried := count.Load() elapsed := time.Since(started) rate := float64(tried) / elapsed.Seconds() expected := timeStr(expectedIterations / rate) diff --git a/lib/config/wrapper.go b/lib/config/wrapper.go index fb3ce6f76..6dbde86ae 100644 --- a/lib/config/wrapper.go +++ b/lib/config/wrapper.go @@ -134,7 +134,7 @@ type wrapper struct { subs []Committer mut sync.Mutex - requiresRestart uint32 // an atomic bool + requiresRestart atomic.Bool } // Wrap wraps an existing Configuration structure and ties it to a file on @@ -340,7 +340,7 @@ func (w *wrapper) notifyListener(sub Committer, from, to Configuration) { l.Debugln(sub, "committing configuration") if !sub.CommitConfiguration(from, to) { l.Debugln(sub, "requires restart") - w.setRequiresRestart() + w.requiresRestart.Store(true) } } @@ -525,13 +525,7 @@ func (w *wrapper) Save() error { return nil } -func (w *wrapper) RequiresRestart() bool { - return atomic.LoadUint32(&w.requiresRestart) != 0 -} - -func (w *wrapper) setRequiresRestart() { - atomic.StoreUint32(&w.requiresRestart, 1) -} +func (w *wrapper) RequiresRestart() bool { return w.requiresRestart.Load() } type modifyEntry struct { modifyFunc ModifyFunction diff --git a/lib/connections/limiter.go b/lib/connections/limiter.go index 9f279ba6e..4fa4578dc 100644 --- a/lib/connections/limiter.go +++ b/lib/connections/limiter.go @@ -25,7 +25,7 @@ type limiter struct { mu sync.Mutex write *rate.Limiter read *rate.Limiter - limitsLAN atomicBool + limitsLAN atomic.Bool deviceReadLimiters map[protocol.DeviceID]*rate.Limiter deviceWriteLimiters map[protocol.DeviceID]*rate.Limiter } @@ -157,7 +157,7 @@ func (lim *limiter) CommitConfiguration(from, to config.Configuration) bool { limited = true } - lim.limitsLAN.set(to.Options.LimitBandwidthInLan) + lim.limitsLAN.Store(to.Options.LimitBandwidthInLan) l.Infof("Overall send rate %s, receive rate %s", sendLimitStr, recvLimitStr) @@ -282,13 +282,13 @@ func (w *limitedWriter) Write(buf []byte) (int, error) { // waiter, valid for both writers and readers type waiterHolder struct { waiter waiter - limitsLAN *atomicBool + limitsLAN *atomic.Bool isLAN bool } // unlimited returns true if the waiter is not limiting the rate func (w waiterHolder) unlimited() bool { - if w.isLAN && !w.limitsLAN.get() { + if w.isLAN && !w.limitsLAN.Load() { return true } return w.waiter.Limit() == rate.Inf @@ -322,20 +322,6 @@ func (w waiterHolder) take(tokens int) { } } -type atomicBool int32 - -func (b *atomicBool) set(v bool) { - if v { - atomic.StoreInt32((*int32)(b), 1) - } else { - atomic.StoreInt32((*int32)(b), 0) - } -} - -func (b *atomicBool) get() bool { - return atomic.LoadInt32((*int32)(b)) != 0 -} - // totalWaiter waits for all of the waiters type totalWaiter []waiter diff --git a/lib/connections/limiter_test.go b/lib/connections/limiter_test.go index 2379ec65c..fd7e3547e 100644 --- a/lib/connections/limiter_test.go +++ b/lib/connections/limiter_test.go @@ -12,6 +12,7 @@ import ( crand "crypto/rand" "io" "math/rand" + "sync/atomic" "testing" "github.com/syncthing/syncthing/lib/config" @@ -234,7 +235,7 @@ func TestLimitedWriterWrite(t *testing.T) { writer: cw, waiterHolder: waiterHolder{ waiter: rate.NewLimiter(rate.Limit(42), limiterBurstSize), - limitsLAN: new(atomicBool), + limitsLAN: new(atomic.Bool), isLAN: false, // enables limiting }, } @@ -263,7 +264,7 @@ func TestLimitedWriterWrite(t *testing.T) { writer: cw, waiterHolder: waiterHolder{ waiter: rate.NewLimiter(rate.Limit(42), limiterBurstSize), - limitsLAN: new(atomicBool), + limitsLAN: new(atomic.Bool), isLAN: true, // disables limiting }, } @@ -287,7 +288,7 @@ func TestLimitedWriterWrite(t *testing.T) { writer: cw, waiterHolder: waiterHolder{ waiter: totalWaiter{rate.NewLimiter(rate.Inf, limiterBurstSize), rate.NewLimiter(rate.Inf, limiterBurstSize)}, - limitsLAN: new(atomicBool), + limitsLAN: new(atomic.Bool), isLAN: false, // enables limiting }, } @@ -315,7 +316,7 @@ func TestLimitedWriterWrite(t *testing.T) { rate.NewLimiter(rate.Limit(42), limiterBurstSize), rate.NewLimiter(rate.Inf, limiterBurstSize), }, - limitsLAN: new(atomicBool), + limitsLAN: new(atomic.Bool), isLAN: false, // enables limiting }, } diff --git a/lib/connections/quic_listen.go b/lib/connections/quic_listen.go index 74a06725e..4d10d7e7a 100644 --- a/lib/connections/quic_listen.go +++ b/lib/connections/quic_listen.go @@ -36,7 +36,7 @@ func init() { type quicListener struct { svcutil.ServiceWithError - nat atomic.Value + nat atomic.Uint64 // Holds a stun.NATType. onAddressesChangedNotifier @@ -56,7 +56,7 @@ func (t *quicListener) OnNATTypeChanged(natType stun.NATType) { if natType != stun.NATUnknown { l.Infof("%s detected NAT type: %s", t.uri, natType) } - t.nat.Store(natType) + t.nat.Store(uint64(natType)) } func (t *quicListener) OnExternalAddressChanged(address *stun.Host, via string) { @@ -205,7 +205,7 @@ func (t *quicListener) Factory() listenerFactory { } func (t *quicListener) NATType() string { - v := t.nat.Load().(stun.NATType) + v := stun.NATType(t.nat.Load()) if v == stun.NATUnknown || v == stun.NATError { return "unknown" } @@ -228,7 +228,7 @@ func (f *quicListenerFactory) New(uri *url.URL, cfg config.Wrapper, tlsCfg *tls. registry: registry, } l.ServiceWithError = svcutil.AsService(l.serve, l.String()) - l.nat.Store(stun.NATUnknown) + l.nat.Store(uint64(stun.NATUnknown)) return l } diff --git a/lib/model/folder.go b/lib/model/folder.go index 52638a2df..216776689 100644 --- a/lib/model/folder.go +++ b/lib/model/folder.go @@ -13,7 +13,6 @@ import ( "math/rand" "path/filepath" "sort" - "sync/atomic" "time" "github.com/syncthing/syncthing/lib/config" @@ -142,8 +141,8 @@ func newFolder(model *model, fset *db.FileSet, ignores *ignore.Matcher, cfg conf } func (f *folder) Serve(ctx context.Context) error { - atomic.AddInt32(&f.model.foldersRunning, 1) - defer atomic.AddInt32(&f.model.foldersRunning, -1) + f.model.foldersRunning.Add(1) + defer f.model.foldersRunning.Add(-1) f.ctx = ctx diff --git a/lib/model/model.go b/lib/model/model.go index 47fdc35e9..c72821224 100644 --- a/lib/model/model.go +++ b/lib/model/model.go @@ -23,6 +23,7 @@ import ( "runtime" "strings" stdsync "sync" + "sync/atomic" "time" "github.com/thejerf/suture/v4" @@ -166,7 +167,7 @@ type model struct { indexHandlers map[protocol.DeviceID]*indexHandlerRegistry // for testing only - foldersRunning int32 + foldersRunning atomic.Int32 } var _ config.Verifier = &model{} diff --git a/lib/model/model_test.go b/lib/model/model_test.go index 725409d6e..ec27cd619 100644 --- a/lib/model/model_test.go +++ b/lib/model/model_test.go @@ -21,7 +21,6 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "testing" "time" @@ -3097,7 +3096,7 @@ func TestFolderRestartZombies(t *testing.T) { m.ScanFolder("default") // Check how many running folders we have running before the test. - if r := atomic.LoadInt32(&m.foldersRunning); r != 1 { + if r := m.foldersRunning.Load(); r != 1 { t.Error("Expected one running folder, not", r) } @@ -3122,7 +3121,7 @@ func TestFolderRestartZombies(t *testing.T) { wg.Wait() // Make sure the folder is up and running, because we want to count it. m.ScanFolder("default") - if r := atomic.LoadInt32(&m.foldersRunning); r != 1 { + if r := m.foldersRunning.Load(); r != 1 { t.Error("Expected one running folder, not", r) } } diff --git a/lib/protocol/bufferpool.go b/lib/protocol/bufferpool.go index 20c330777..cde957fdb 100644 --- a/lib/protocol/bufferpool.go +++ b/lib/protocol/bufferpool.go @@ -13,24 +13,24 @@ import ( var BufferPool bufferPool type bufferPool struct { - puts int64 - skips int64 - misses int64 + puts atomic.Int64 + skips atomic.Int64 + misses atomic.Int64 pools []sync.Pool - hits []int64 // start of slice allocation is always aligned + hits []atomic.Int64 } func newBufferPool() bufferPool { return bufferPool{ pools: make([]sync.Pool, len(BlockSizes)), - hits: make([]int64, len(BlockSizes)), + hits: make([]atomic.Int64, len(BlockSizes)), } } func (p *bufferPool) Get(size int) []byte { // Too big, isn't pooled if size > MaxBlockSize { - atomic.AddInt64(&p.skips, 1) + p.skips.Add(1) return make([]byte, size) } @@ -38,13 +38,13 @@ func (p *bufferPool) Get(size int) []byte { bkt := getBucketForLen(size) for j := bkt; j < len(BlockSizes); j++ { if intf := p.pools[j].Get(); intf != nil { - atomic.AddInt64(&p.hits[j], 1) + p.hits[j].Add(1) bs := *intf.(*[]byte) return bs[:size] } } - atomic.AddInt64(&p.misses, 1) + p.misses.Add(1) // All pools are empty, must allocate. For very small slices where we // didn't have a block to reuse, just allocate a small slice instead of @@ -60,11 +60,11 @@ func (p *bufferPool) Get(size int) []byte { func (p *bufferPool) Put(bs []byte) { // Don't buffer slices outside of our pool range if cap(bs) > MaxBlockSize || cap(bs) < MinBlockSize { - atomic.AddInt64(&p.skips, 1) + p.skips.Add(1) return } - atomic.AddInt64(&p.puts, 1) + p.puts.Add(1) bkt := putBucketForCap(cap(bs)) p.pools[bkt].Put(&bs) } diff --git a/lib/protocol/bufferpool_test.go b/lib/protocol/bufferpool_test.go index ab889ba88..a4d08c791 100644 --- a/lib/protocol/bufferpool_test.go +++ b/lib/protocol/bufferpool_test.go @@ -108,13 +108,13 @@ func TestStressBufferPool(t *testing.T) { default: } - t.Log(bp.puts, bp.skips, bp.misses, bp.hits) - if bp.puts == 0 || bp.skips == 0 || bp.misses == 0 { + t.Log(bp.puts.Load(), bp.skips.Load(), bp.misses.Load(), bp.hits) + if bp.puts.Load() == 0 || bp.skips.Load() == 0 || bp.misses.Load() == 0 { t.Error("didn't exercise some paths") } var hits int64 - for _, h := range bp.hits { - hits += h + for i := range bp.hits { + hits += bp.hits[i].Load() } if hits == 0 { t.Error("didn't exercise some paths") diff --git a/lib/protocol/counting.go b/lib/protocol/counting.go index 2310e91f8..e54c593ad 100644 --- a/lib/protocol/counting.go +++ b/lib/protocol/counting.go @@ -10,53 +10,49 @@ import ( type countingReader struct { io.Reader - tot int64 // bytes (atomic, must remain 64-bit aligned) - last int64 // unix nanos (atomic, must remain 64-bit aligned) + tot atomic.Int64 // bytes + last atomic.Int64 // unix nanos } var ( - totalIncoming int64 - totalOutgoing int64 + totalIncoming atomic.Int64 + totalOutgoing atomic.Int64 ) func (c *countingReader) Read(bs []byte) (int, error) { n, err := c.Reader.Read(bs) - atomic.AddInt64(&c.tot, int64(n)) - atomic.AddInt64(&totalIncoming, int64(n)) - atomic.StoreInt64(&c.last, time.Now().UnixNano()) + c.tot.Add(int64(n)) + totalIncoming.Add(int64(n)) + c.last.Store(time.Now().UnixNano()) return n, err } -func (c *countingReader) Tot() int64 { - return atomic.LoadInt64(&c.tot) -} +func (c *countingReader) Tot() int64 { return c.tot.Load() } func (c *countingReader) Last() time.Time { - return time.Unix(0, atomic.LoadInt64(&c.last)) + return time.Unix(0, c.last.Load()) } type countingWriter struct { io.Writer - tot int64 // bytes (atomic, must remain 64-bit aligned) - last int64 // unix nanos (atomic, must remain 64-bit aligned) + tot atomic.Int64 // bytes + last atomic.Int64 // unix nanos } func (c *countingWriter) Write(bs []byte) (int, error) { n, err := c.Writer.Write(bs) - atomic.AddInt64(&c.tot, int64(n)) - atomic.AddInt64(&totalOutgoing, int64(n)) - atomic.StoreInt64(&c.last, time.Now().UnixNano()) + c.tot.Add(int64(n)) + totalOutgoing.Add(int64(n)) + c.last.Store(time.Now().UnixNano()) return n, err } -func (c *countingWriter) Tot() int64 { - return atomic.LoadInt64(&c.tot) -} +func (c *countingWriter) Tot() int64 { return c.tot.Load() } func (c *countingWriter) Last() time.Time { - return time.Unix(0, atomic.LoadInt64(&c.last)) + return time.Unix(0, c.last.Load()) } func TotalInOut() (int64, int64) { - return atomic.LoadInt64(&totalIncoming), atomic.LoadInt64(&totalOutgoing) + return totalIncoming.Load(), totalOutgoing.Load() } diff --git a/lib/protocol/protocol_test.go b/lib/protocol/protocol_test.go index 2ce8c0025..aac5e36d8 100644 --- a/lib/protocol/protocol_test.go +++ b/lib/protocol/protocol_test.go @@ -468,9 +468,9 @@ func TestWriteCompressed(t *testing.T) { hdr := Header{Type: typeOf(msg)} size := int64(2 + hdr.ProtoSize() + 4 + msg.ProtoSize()) - if c.cr.tot > size { + if c.cr.Tot() > size { t.Errorf("compression enlarged message from %d to %d", - size, c.cr.tot) + size, c.cr.Tot()) } } } diff --git a/lib/scanner/walk.go b/lib/scanner/walk.go index b155925ca..2d8dde469 100644 --- a/lib/scanner/walk.go +++ b/lib/scanner/walk.go @@ -628,7 +628,7 @@ func (w *walker) String() string { // A byteCounter gets bytes added to it via Update() and then provides the // Total() and one minute moving average Rate() in bytes per second. type byteCounter struct { - total int64 // atomic, must remain 64-bit aligned + total atomic.Int64 metrics.EWMA stop chan struct{} } @@ -658,13 +658,11 @@ func (c *byteCounter) ticker() { } func (c *byteCounter) Update(bytes int64) { - atomic.AddInt64(&c.total, bytes) + c.total.Add(bytes) c.EWMA.Update(bytes) } -func (c *byteCounter) Total() int64 { - return atomic.LoadInt64(&c.total) -} +func (c *byteCounter) Total() int64 { return c.total.Load() } func (c *byteCounter) Close() { close(c.stop) diff --git a/lib/stun/stun.go b/lib/stun/stun.go index 4ac476319..db96502cd 100644 --- a/lib/stun/stun.go +++ b/lib/stun/stun.go @@ -39,36 +39,35 @@ const ( ) type writeTrackingUdpConn struct { - lastWrite int64 // atomic, must remain 64-bit aligned // Needs to be UDPConn not PacketConn, as pfilter checks for WriteMsgUDP/ReadMsgUDP // and even if we embed UDPConn here, in place of a PacketConn, seems the interface // check fails. *net.UDPConn + lastWrite atomic.Int64 } func (c *writeTrackingUdpConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { - atomic.StoreInt64(&c.lastWrite, time.Now().Unix()) + c.lastWrite.Store(time.Now().Unix()) return c.UDPConn.WriteTo(p, addr) } func (c *writeTrackingUdpConn) WriteMsgUDP(b, oob []byte, addr *net.UDPAddr) (n, oobn int, err error) { - atomic.StoreInt64(&c.lastWrite, time.Now().Unix()) + c.lastWrite.Store(time.Now().Unix()) return c.UDPConn.WriteMsgUDP(b, oob, addr) } func (c *writeTrackingUdpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { - atomic.StoreInt64(&c.lastWrite, time.Now().Unix()) + c.lastWrite.Store(time.Now().Unix()) return c.UDPConn.WriteToUDP(b, addr) } func (c *writeTrackingUdpConn) Write(b []byte) (int, error) { - atomic.StoreInt64(&c.lastWrite, time.Now().Unix()) + c.lastWrite.Store(time.Now().Unix()) return c.UDPConn.Write(b) } func (c *writeTrackingUdpConn) getLastWrite() time.Time { - unix := atomic.LoadInt64(&c.lastWrite) - return time.Unix(unix, 0) + return time.Unix(c.lastWrite.Load(), 0) } type Subscriber interface { @@ -91,7 +90,7 @@ type Service struct { func New(cfg config.Wrapper, subscriber Subscriber, conn *net.UDPConn) (*Service, net.PacketConn) { // Wrap the original connection to track writes on it - writeTrackingUdpConn := &writeTrackingUdpConn{lastWrite: 0, UDPConn: conn} + writeTrackingUdpConn := &writeTrackingUdpConn{UDPConn: conn} // Wrap it in a filter and split it up, so that stun packets arrive on stun conn, others arrive on the data conn filterConn := pfilter.NewPacketFilter(writeTrackingUdpConn) diff --git a/lib/sync/sync.go b/lib/sync/sync.go index 9d538c037..9d7661384 100644 --- a/lib/sync/sync.go +++ b/lib/sync/sync.go @@ -116,16 +116,16 @@ type loggedRWMutex struct { readHolders map[int][]holder readHoldersMut sync.Mutex - logUnlockers int32 + logUnlockers atomic.Bool unlockers chan holder } func (m *loggedRWMutex) Lock() { start := timeNow() - atomic.StoreInt32(&m.logUnlockers, 1) + m.logUnlockers.Store(true) m.RWMutex.Lock() - atomic.StoreInt32(&m.logUnlockers, 0) + m.logUnlockers.Store(false) holder := getHolder() m.holder.Store(holder) @@ -173,7 +173,7 @@ func (m *loggedRWMutex) RUnlock() { m.readHolders[id] = current[:len(current)-1] } m.readHoldersMut.Unlock() - if atomic.LoadInt32(&m.logUnlockers) == 1 { + if m.logUnlockers.Load() { holder := getHolder() select { case m.unlockers <- holder: