Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance of UR calculations #30868

Merged
merged 8 commits into from
Nov 26, 2024
38 changes: 32 additions & 6 deletions osu.Game.Benchmarks/BenchmarkUnstableRate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,54 @@
using System.Collections.Generic;
using BenchmarkDotNet.Attributes;
using osu.Framework.Utils;
using osu.Game.Rulesets.Objects;
using osu.Game.Beatmaps;
using osu.Game.Beatmaps.ControlPoints;
using osu.Game.Rulesets.Osu.Objects;
using osu.Game.Rulesets.Scoring;

namespace osu.Game.Benchmarks
{
public class BenchmarkUnstableRate : BenchmarkTest
{
private List<HitEvent> events = null!;
private readonly List<List<HitEvent>> incrementalEventLists = new List<List<HitEvent>>();

public override void SetUp()
{
base.SetUp();
events = new List<HitEvent>();

for (int i = 0; i < 1000; i++)
events.Add(new HitEvent(RNG.NextDouble(-200.0, 200.0), RNG.NextDouble(1.0, 2.0), HitResult.Great, new HitObject(), null, null));
var events = new List<HitEvent>();

for (int i = 0; i < 2048; i++)
{
// Ensure the object has hit windows populated.
var hitObject = new HitCircle();
hitObject.ApplyDefaults(new ControlPointInfo(), new BeatmapDifficulty());
events.Add(new HitEvent(RNG.NextDouble(-200.0, 200.0), RNG.NextDouble(1.0, 2.0), HitResult.Great, hitObject, null, null));

incrementalEventLists.Add(new List<HitEvent>(events));
}
}

[Benchmark]
public void CalculateUnstableRate()
{
_ = events.CalculateUnstableRate();
for (int i = 0; i < 2048; i++)
{
var events = incrementalEventLists[i];
_ = events.CalculateUnstableRate();
}
}

[Benchmark]
public void CalculateUnstableRateUsingIncrementalCalculation()
{
HitEventExtensions.UnstableRateCalculationResult? last = null;

for (int i = 0; i < 2048; i++)
{
var events = incrementalEventLists[i];
last = events.CalculateUnstableRate(last);
}
}
}
}
45 changes: 43 additions & 2 deletions osu.Game.Tests/NonVisual/Ranking/UnstableRateTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,53 @@ public class UnstableRateTest
public void TestDistributedHits()
{
var events = Enumerable.Range(-5, 11)
.Select(t => new HitEvent(t - 5, 1.0, HitResult.Great, new HitObject(), null, null));
.Select(t => new HitEvent(t - 5, 1.0, HitResult.Great, new HitObject(), null, null))
.ToList();

var unstableRate = new UnstableRate(events);

Assert.IsNotNull(unstableRate.Value);
Assert.IsTrue(Precision.AlmostEquals(unstableRate.Value.Value, 10 * Math.Sqrt(10)));
Assert.AreEqual(unstableRate.Value.Value, 10 * Math.Sqrt(10), Precision.DOUBLE_EPSILON);
}

[Test]
public void TestDistributedHitsIncrementalRewind()
{
var events = Enumerable.Range(-5, 11)
.Select(t => new HitEvent(t - 5, 1.0, HitResult.Great, new HitObject(), null, null))
.ToList();

HitEventExtensions.UnstableRateCalculationResult result = null;

for (int i = 0; i < events.Count; i++)
{
result = events.GetRange(0, i + 1)
.CalculateUnstableRate(result);
}

result = events.GetRange(0, 2).CalculateUnstableRate(result);

Assert.IsNotNull(result!.Result);
Assert.AreEqual(5, result.Result, Precision.DOUBLE_EPSILON);
}

[Test]
public void TestDistributedHitsIncremental()
{
var events = Enumerable.Range(-5, 11)
.Select(t => new HitEvent(t - 5, 1.0, HitResult.Great, new HitObject(), null, null))
.ToList();

HitEventExtensions.UnstableRateCalculationResult result = null;

for (int i = 0; i < events.Count; i++)
{
result = events.GetRange(0, i + 1)
.CalculateUnstableRate(result);
}

Assert.IsNotNull(result!.Result);
Assert.AreEqual(10 * Math.Sqrt(10), result.Result, Precision.DOUBLE_EPSILON);
}

[Test]
Expand Down
2 changes: 1 addition & 1 deletion osu.Game/Rulesets/Mods/ModAdaptiveSpeed.cs
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ private IEnumerable<HitObject> getAllApplicableHitObjects(IEnumerable<HitObject>
{
foreach (var hitObject in hitObjects)
{
if (!(hitObject.HitWindows is HitWindows.EmptyHitWindows))
if (hitObject.HitWindows != HitWindows.Empty)
yield return hitObject;

foreach (HitObject nested in getAllApplicableHitObjects(hitObject.NestedHitObjects))
Expand Down
62 changes: 50 additions & 12 deletions osu.Game/Rulesets/Scoring/HitEventExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using osu.Game.Rulesets.Objects;

namespace osu.Game.Rulesets.Scoring
{
Expand All @@ -20,32 +21,36 @@ public static class HitEventExtensions
/// A non-null <see langword="double"/> value if unstable rate could be calculated,
/// and <see langword="null"/> if unstable rate cannot be calculated due to <paramref name="hitEvents"/> being empty.
/// </returns>
public static double? CalculateUnstableRate(this IEnumerable<HitEvent> hitEvents)
public static UnstableRateCalculationResult? CalculateUnstableRate(this IReadOnlyList<HitEvent> hitEvents, UnstableRateCalculationResult? result = null)
{
Debug.Assert(hitEvents.All(ev => ev.GameplayRate != null));

int count = 0;
double mean = 0;
double sumOfSquares = 0;
result ??= new UnstableRateCalculationResult();

foreach (var e in hitEvents)
// Handle rewinding in the simplest way possible.
if (hitEvents.Count < result.EventCount + 1)
result = new UnstableRateCalculationResult();

for (int i = result.EventCount; i < hitEvents.Count; i++)
{
HitEvent e = hitEvents[i];

if (!AffectsUnstableRate(e))
continue;

count++;
result.EventCount++;

// Division by gameplay rate is to account for TimeOffset scaling with gameplay rate.
double currentValue = e.TimeOffset / e.GameplayRate!.Value;
double nextMean = mean + (currentValue - mean) / count;
sumOfSquares += (currentValue - mean) * (currentValue - nextMean);
mean = nextMean;
double nextMean = result.Mean + (currentValue - result.Mean) / result.EventCount;
result.SumOfSquares += (currentValue - result.Mean) * (currentValue - nextMean);
result.Mean = nextMean;
}

if (count == 0)
if (result.EventCount == 0)
return null;

return 10.0 * Math.Sqrt(sumOfSquares / count);
return result;
}

/// <summary>
Expand All @@ -65,6 +70,39 @@ public static class HitEventExtensions
return timeOffsets.Average();
}

public static bool AffectsUnstableRate(HitEvent e) => !(e.HitObject.HitWindows is HitWindows.EmptyHitWindows) && e.Result.IsHit();
public static bool AffectsUnstableRate(HitEvent e) => AffectsUnstableRate(e.HitObject, e.Result);
public static bool AffectsUnstableRate(HitObject hitObject, HitResult result) => hitObject.HitWindows != HitWindows.Empty && result.IsHit();

/// <summary>
/// Data type returned by <see cref="HitEventExtensions.CalculateUnstableRate"/> which allows efficient incremental processing.
/// </summary>
/// <remarks>
/// This should be passed back into future <see cref="HitEventExtensions.CalculateUnstableRate"/> calls as a parameter.
///
/// The optimisations used here rely on hit events being a consecutive sequence from a single gameplay session.
/// When a new gameplay session is started, any existing results should be disposed.
/// </remarks>
public class UnstableRateCalculationResult
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sure people are going to want changes to this, but I'm not sure how strict review will be so I've just left it without any safeties initially.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rename NextProcessableIndex to EventCount probably (with the necessary off-by-one adjustments) but otherwise I'm not sure I have much of a problem with this...?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually doesn't require off-by-one adjustments since it already kinda is the event count (at least post-loop-execution).

Copy link
Contributor

@smoogipoo smoogipoo Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My only query is why this isn't a struct with readonly fields. Perhaps a readonly record struct even.

Minor nitpick, though.

Copy link
Member Author

@peppy peppy Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can try a struct, but as a record the alloc overhead (cpu, not memory) outweighted the benefits, seemingly (see ea68d4b).

But maybe this was just local to the benchmark and not relevant in real-world scenarios.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it has a significant overhead, I'd say leave it then.

{
/// <summary>
/// Total events processed. For internal incremental calculation use.
/// </summary>
public int EventCount;

/// <summary>
/// Last sum-of-squares value. For internal incremental calculation use.
/// </summary>
public double SumOfSquares;

/// <summary>
/// Last mean value. For internal incremental calculation use.
/// </summary>
public double Mean;

/// <summary>
/// The unstable rate.
/// </summary>
public double Result => EventCount == 0 ? 0 : 10.0 * Math.Sqrt(SumOfSquares / EventCount);
}
}
}
4 changes: 2 additions & 2 deletions osu.Game/Rulesets/Scoring/HitWindows.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public class HitWindows
/// An empty <see cref="HitWindows"/> with only <see cref="HitResult.Miss"/> and <see cref="HitResult.Perfect"/>.
/// No time values are provided (meaning instantaneous hit or miss).
/// </summary>
public static HitWindows Empty => new EmptyHitWindows();
public static HitWindows Empty { get; } = new EmptyHitWindows();

public HitWindows()
{
Expand Down Expand Up @@ -182,7 +182,7 @@ public double WindowFor(HitResult result)
/// </summary>
protected virtual DifficultyRange[] GetRanges() => base_ranges;

public class EmptyHitWindows : HitWindows
private class EmptyHitWindows : HitWindows
{
private static readonly DifficultyRange[] ranges =
{
Expand Down
16 changes: 11 additions & 5 deletions osu.Game/Screens/Play/HUD/UnstableRateCounter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public partial class UnstableRateCounter : RollingCounter<int>, ISerialisableDra
private const float alpha_when_invalid = 0.3f;
private readonly Bindable<bool> valid = new Bindable<bool>();

private HitEventExtensions.UnstableRateCalculationResult? unstableRateResult;

[Resolved]
private ScoreProcessor scoreProcessor { get; set; } = null!;

Expand All @@ -44,9 +46,6 @@ private void load(OsuColour colours)
DrawableCount.FadeTo(e.NewValue ? 1 : alpha_when_invalid, 1000, Easing.OutQuint));
}

private bool changesUnstableRate(JudgementResult judgement)
=> !(judgement.HitObject.HitWindows is HitWindows.EmptyHitWindows) && judgement.IsHit;

protected override void LoadComplete()
{
base.LoadComplete();
Expand All @@ -56,13 +55,20 @@ protected override void LoadComplete()
updateDisplay();
}

private void updateDisplay(JudgementResult _) => Scheduler.AddOnce(updateDisplay);
private void updateDisplay(JudgementResult result)
{
if (HitEventExtensions.AffectsUnstableRate(result.HitObject, result.Type))
Scheduler.AddOnce(updateDisplay);
}

private void updateDisplay()
{
double? unstableRate = scoreProcessor.HitEvents.CalculateUnstableRate();
unstableRateResult = scoreProcessor.HitEvents.CalculateUnstableRate(unstableRateResult);

double? unstableRate = unstableRateResult?.Result;

valid.Value = unstableRate != null;

if (unstableRate != null)
Current.Value = (int)Math.Round(unstableRate.Value);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ public partial class HitEventTimingDistributionGraph : CompositeDrawable
/// <param name="hitEvents">The <see cref="HitEvent"/>s to display the timing distribution of.</param>
public HitEventTimingDistributionGraph(IReadOnlyList<HitEvent> hitEvents)
{
this.hitEvents = hitEvents.Where(e => !(e.HitObject.HitWindows is HitWindows.EmptyHitWindows) && e.Result.IsBasic() && e.Result.IsHit()).ToList();
this.hitEvents = hitEvents.Where(e => e.HitObject.HitWindows != HitWindows.Empty && e.Result.IsBasic() && e.Result.IsHit()).ToList();
bins = Enumerable.Range(0, total_timing_distribution_bins).Select(_ => new Dictionary<HitResult, int>()).ToArray<IDictionary<HitResult, int>>();
}

Expand Down
4 changes: 2 additions & 2 deletions osu.Game/Screens/Ranking/Statistics/UnstableRate.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ public partial class UnstableRate : SimpleStatisticItem<double?>
/// Creates and computes an <see cref="UnstableRate"/> statistic.
/// </summary>
/// <param name="hitEvents">Sequence of <see cref="HitEvent"/>s to calculate the unstable rate based on.</param>
public UnstableRate(IEnumerable<HitEvent> hitEvents)
public UnstableRate(IReadOnlyList<HitEvent> hitEvents)
: base("Unstable Rate")
{
Value = hitEvents.CalculateUnstableRate();
Value = hitEvents.CalculateUnstableRate()?.Result;
}

protected override string DisplayValue(double? value) => value == null ? "(not available)" : value.Value.ToString(@"N2");
Expand Down
Loading