Skip to content

Commit

Permalink
Merge pull request #84 from nextmv-io/feature/eng-5676-fix-bug-in-nex…
Browse files Browse the repository at this point in the history
…troute-plateau-detection

Fixes bug in plateau detection
  • Loading branch information
merschformann authored Jan 21, 2025
2 parents d22e079 + 38ceffa commit f5358eb
Show file tree
Hide file tree
Showing 2 changed files with 156 additions and 6 deletions.
18 changes: 12 additions & 6 deletions solve_terminate.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,18 @@ func (t *plateauTracker) ShouldTerminate(iterations int, elapsed time.Duration)
if t.durationIndex >= len(t.progression) {
return true
}
// Compare the current value to the value at the duration index.
// Compare the current value to the first value within the cutoff. It
// needs to be (significantly) better to avoid termination.
// Note: The cutoff value will always be larger than (or equal to) the
// current value, as we are minimizing.
cutoffValue := t.progression[t.durationIndex].Value
if t.options.AbsoluteThreshold >= 0 &&
currentValue-cutoffValue < t.options.AbsoluteThreshold {
cutoffValue-currentValue < t.options.AbsoluteThreshold {
return true
}
if t.options.RelativeThreshold >= 0 &&
currentValue > 0 && // Relative threshold is only supported for positive values.
(currentValue-cutoffValue)/currentValue < t.options.RelativeThreshold {
(cutoffValue-currentValue)/currentValue < t.options.RelativeThreshold {
return true
}
}
Expand All @@ -102,15 +105,18 @@ func (t *plateauTracker) ShouldTerminate(iterations int, elapsed time.Duration)
if t.iterationsIndex >= len(t.progression) {
return true
}
// Compare the current value to the value at the iterations index.
// Compare the current value to the first value within the cutoff. It
// needs to be (significantly) better to avoid termination.
// Note: The cutoff value will always be larger than (or equal to) the
// current value, as we are minimizing.
cutoffValue := t.progression[t.iterationsIndex].Value
if t.options.AbsoluteThreshold >= 0 &&
currentValue-cutoffValue < t.options.AbsoluteThreshold {
cutoffValue-currentValue < t.options.AbsoluteThreshold {
return true
}
if t.options.RelativeThreshold >= 0 &&
currentValue > 0 && // Relative threshold is only supported for positive values.
(currentValue-cutoffValue)/currentValue < t.options.RelativeThreshold {
(cutoffValue-currentValue)/currentValue < t.options.RelativeThreshold {
return true
}
}
Expand Down
144 changes: 144 additions & 0 deletions solve_terminate_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
// © 2019-present nextmv.io inc

package nextroute

import (
"testing"
"time"
)

func TestShouldTerminate(t *testing.T) {
// Define the test cases
basicProgression := []ProgressionEntry{
{ElapsedSeconds: 2, Iterations: 100, Value: 150},
{ElapsedSeconds: 5, Iterations: 200, Value: 100},
{ElapsedSeconds: 8, Iterations: 300, Value: 70},
{ElapsedSeconds: 10, Iterations: 400, Value: 50},
}
cases := []struct {
opts PlateauOptions
progression []ProgressionEntry
iterations int
elapsed float64
shouldTerminate bool
}{
{ // no plateau tracking
opts: PlateauOptions{},
progression: basicProgression,
iterations: 500,
elapsed: 15,
shouldTerminate: false,
},
{ // 5 seconds plateau (stagnation case, relative threshold)
opts: PlateauOptions{
Duration: time.Duration(5) * time.Second,
RelativeThreshold: 0.1,
},
progression: basicProgression,
iterations: 500,
elapsed: 15,
shouldTerminate: true,
},
{ // 5 seconds plateau (stagnation case, absolute threshold)
opts: PlateauOptions{
Duration: time.Duration(5) * time.Second,
AbsoluteThreshold: 10,
},
progression: basicProgression,
iterations: 500,
elapsed: 15,
shouldTerminate: true,
},
{ // 5 iterations plateau (stagnation case, relative threshold)
opts: PlateauOptions{
Iterations: 5,
RelativeThreshold: 0.1,
},
progression: basicProgression,
iterations: 500,
elapsed: 15,
shouldTerminate: true,
},
{ // 5 iterations plateau (stagnation case, absolute threshold)
opts: PlateauOptions{
Iterations: 5,
AbsoluteThreshold: 10,
},
progression: basicProgression,
iterations: 500,
elapsed: 15,
shouldTerminate: true,
},
{ // 5 seconds plateau (no-stagnation case, relative threshold)
opts: PlateauOptions{
Duration: time.Duration(5) * time.Second,
RelativeThreshold: 0.1,
},
progression: basicProgression,
iterations: 500,
elapsed: 12, // catches two last entries
shouldTerminate: false,
},
{ // 5 seconds plateau (no-stagnation case, absolute threshold)
opts: PlateauOptions{
Duration: time.Duration(5) * time.Second,
AbsoluteThreshold: 10,
},
progression: basicProgression,
iterations: 500,
elapsed: 12, // catches two last entries
shouldTerminate: false,
},
{ // 5 iterations plateau (no-stagnation case, relative threshold)
opts: PlateauOptions{
Iterations: 200,
RelativeThreshold: 0.1,
},
progression: basicProgression,
iterations: 450,
elapsed: 15,
shouldTerminate: false,
},
{ // 5 iterations plateau (no-stagnation case, absolute threshold)
opts: PlateauOptions{
Iterations: 200,
AbsoluteThreshold: 10,
},
progression: basicProgression,
iterations: 450,
elapsed: 15,
shouldTerminate: false,
},
{ // 5 seconds plateau (non-significant improvement case, relative threshold)
opts: PlateauOptions{
Duration: time.Duration(5) * time.Second,
RelativeThreshold: 0.3,
},
progression: append(basicProgression, ProgressionEntry{ElapsedSeconds: 12, Iterations: 500, Value: 45}),
iterations: 500,
elapsed: 14, // catches two last entries
shouldTerminate: true,
},
{ // 5 seconds plateau (non-significant improvement case, absolute threshold)
opts: PlateauOptions{
Duration: time.Duration(5) * time.Second,
AbsoluteThreshold: 10,
},
progression: append(basicProgression, ProgressionEntry{ElapsedSeconds: 12, Iterations: 500, Value: 45}),
iterations: 500,
elapsed: 14, // catches two last entries
shouldTerminate: true,
},
}

// Run the tests
for _, c := range cases {
tracker := newPlateauTracker(c.opts)
for _, p := range c.progression {
tracker.onImprovement(p.ElapsedSeconds, p.Iterations, p.Value)
}
if got := tracker.ShouldTerminate(c.iterations, time.Duration(c.elapsed)*time.Second); got != c.shouldTerminate {
t.Errorf("ShouldTerminate() = %v, want %v", got, c.shouldTerminate)
}
}
}

0 comments on commit f5358eb

Please sign in to comment.