diff --git a/meter.go b/meter.go index aa5e587..b70593e 100644 --- a/meter.go +++ b/meter.go @@ -33,6 +33,9 @@ func (s Snapshot) String() string { type Meter struct { accumulator uint64 + // managed by the sweeper loop. + registered bool + // Take lock. snapshot Snapshot } @@ -40,8 +43,9 @@ type Meter struct { // Mark updates the total. func (m *Meter) Mark(count uint64) { if count > 0 && atomic.AddUint64(&m.accumulator, count) == count { - // I'm the first one to bump this above 0. - // Register it. + // The accumulator is 0 so we probably need to register. We may + // already _be_ registered however, if we are, the registration + // loop will notice that `m.registered` is set and ignore us. globalSweeper.Register(m) } } @@ -53,6 +57,15 @@ func (m *Meter) Snapshot() Snapshot { return m.snapshot } +// Reset sets accumulator, total and rate to zero. +func (m *Meter) Reset() { + globalSweeper.snapshotMu.Lock() + atomic.StoreUint64(&m.accumulator, 0) + m.snapshot.Rate = 0 + m.snapshot.Total = 0 + globalSweeper.snapshotMu.Unlock() +} + func (m *Meter) String() string { return m.Snapshot().String() } diff --git a/meter_test.go b/meter_test.go index f0bd810..0f29a8e 100644 --- a/meter_test.go +++ b/meter_test.go @@ -3,6 +3,8 @@ package flow import ( "fmt" "math" + "sync" + "testing" "time" ) @@ -29,6 +31,43 @@ func ExampleMeter() { // Output: 3000 (300/s) } +func TestResetMeter(t *testing.T) { + meter := new(Meter) + + meter.Mark(30) + + time.Sleep(2 * time.Second) + + if total := meter.Snapshot().Total; total != 30 { + t.Errorf("total = %d; want 30", total) + } + + meter.Reset() + + if total := meter.Snapshot().Total; total != 0 { + t.Errorf("total = %d; want 0", total) + } +} + +func TestMarkResetMeterMulti(t *testing.T) { + var wg sync.WaitGroup + wg.Add(2) + + meter := new(Meter) + go func(meter *Meter) { + meter.Mark(30) + meter.Mark(30) + wg.Done() + }(meter) + + go func(meter *Meter) { + meter.Reset() + wg.Done() + }(meter) + + wg.Wait() +} + func roundTens(x float64) int64 { return int64(math.Floor(x/10+0.5)) * 10 } diff --git a/registry.go b/registry.go index 81c8d85..d3e47be 100644 --- a/registry.go +++ b/registry.go @@ -72,3 +72,11 @@ func (r *MeterRegistry) ForEach(iterFunc func(string, *Meter)) { return true }) } + +// Clear removes all meters from the registry. +func (r *MeterRegistry) Clear() { + r.meters.Range(func(k, v interface{}) bool { + r.meters.Delete(k) + return true + }) +} diff --git a/registry_test.go b/registry_test.go index 5795029..f7976c9 100644 --- a/registry_test.go +++ b/registry_test.go @@ -15,7 +15,7 @@ func TestRegistry(t *testing.T) { m1.Mark(10) m2.Mark(30) - time.Sleep(2 * time.Second) + time.Sleep(2*time.Second + time.Millisecond) if total := r.Get("first").Snapshot().Total; total != 10 { t.Errorf("expected first total to be 10, got %d", total) @@ -98,3 +98,27 @@ func TestRegistry(t *testing.T) { t.Error("expected to trim 2 idle timers") } } + +func TestClearRegistry(t *testing.T) { + r := new(MeterRegistry) + m1 := r.Get("first") + m2 := r.Get("second") + + m1.Mark(10) + m2.Mark(30) + + time.Sleep(2 * time.Second) + + r.Clear() + + r.ForEach(func(n string, _m *Meter) { + t.Errorf("expected no meters at all, found a meter %s", n) + }) + + if total := r.Get("first").Snapshot().Total; total != 0 { + t.Errorf("expected first total to be 0, got %d", total) + } + if total := r.Get("second").Snapshot().Total; total != 0 { + t.Errorf("expected second total to be 0, got %d", total) + } +} diff --git a/sweeper.go b/sweeper.go index 48e301c..e4294ed 100644 --- a/sweeper.go +++ b/sweeper.go @@ -23,8 +23,9 @@ var globalSweeper sweeper type sweeper struct { sweepOnce sync.Once - snapshotMu sync.RWMutex - meters []*Meter + snapshotMu sync.RWMutex + meters []*Meter + activeMeters int lastUpdateTime time.Time registerChannel chan *Meter @@ -43,9 +44,11 @@ func (sw *sweeper) run() { } func (sw *sweeper) register(m *Meter) { - // Add back the snapshot total. If we unregistered this - // one, we set it to zero. - atomic.AddUint64(&m.accumulator, m.snapshot.Total) + if m.registered { + // registered twice, move on. + return + } + m.registered = true sw.meters = append(sw.meters, m) } @@ -85,9 +88,9 @@ func (sw *sweeper) update() { sw.lastUpdateTime = now timeMultiplier := float64(time.Second) / float64(tdiff) + // Calculate the bandwidth for all active meters. newLen := len(sw.meters) - - for i, m := range sw.meters { + for i, m := range sw.meters[:sw.activeMeters] { total := atomic.LoadUint64(&m.accumulator) diff := total - m.snapshot.Total instant := timeMultiplier * float64(diff) @@ -142,16 +145,31 @@ func (sw *sweeper) update() { } // Reset the rate, keep the total. + m.registered = false m.snapshot.Rate = 0 newLen-- sw.meters[i] = sw.meters[newLen] } + // Re-add the total to all the newly active accumulators and set the snapshot to the total. + // 1. We don't do this on register to avoid having to take the snapshot lock. + // 2. We skip calculating the bandwidth for this round so we get an _accurate_ bandwidth calculation. + for _, m := range sw.meters[sw.activeMeters:] { + total := atomic.AddUint64(&m.accumulator, m.snapshot.Total) + if total > m.snapshot.Total { + m.snapshot.LastUpdate = now + } + m.snapshot.Total = total + } + // trim the meter list for i := newLen; i < len(sw.meters); i++ { sw.meters[i] = nil } sw.meters = sw.meters[:newLen] + + // Finally, mark all meters still in the list as "active". + sw.activeMeters = len(sw.meters) } func (sw *sweeper) Register(m *Meter) {