Skip to content

Commit

Permalink
Merge pull request #6248 from huoyaoyuan/net7.0/generic-math
Browse files Browse the repository at this point in the history
Use generic math for bindable numbers
  • Loading branch information
smoogipoo authored Apr 22, 2024
2 parents 2e85145 + 1a30d36 commit 69a26ce
Show file tree
Hide file tree
Showing 8 changed files with 70 additions and 222 deletions.
51 changes: 26 additions & 25 deletions osu.Framework.Tests/Visual/Bindables/TestSceneBindableNumbers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using System;
using System.Globalization;
using System.Numerics;
using osu.Framework.Bindables;
using osu.Framework.Graphics;
using osu.Framework.Graphics.Containers;
Expand Down Expand Up @@ -167,45 +168,45 @@ private void testFractionalPrecision()
private bool checkExact(decimal value) => checkExact(value, value);

private bool checkExact(decimal floatValue, decimal intValue)
=> bindableInt.Value == Convert.ToInt32(intValue)
&& bindableLong.Value == Convert.ToInt64(intValue)
&& bindableFloat.Value == Convert.ToSingle(floatValue)
&& bindableDouble.Value == Convert.ToDouble(floatValue);
=> bindableInt.Value == (int)intValue
&& bindableLong.Value == (long)intValue
&& bindableFloat.Value == (float)floatValue
&& bindableDouble.Value == (double)floatValue;

private void setMin<T>(T value)
private void setMin<T>(T value) where T : INumber<T>
{
bindableInt.MinValue = Convert.ToInt32(value);
bindableLong.MinValue = Convert.ToInt64(value);
bindableFloat.MinValue = Convert.ToSingle(value);
bindableDouble.MinValue = Convert.ToDouble(value);
bindableInt.MinValue = int.CreateTruncating(value);
bindableLong.MinValue = long.CreateTruncating(value);
bindableFloat.MinValue = float.CreateTruncating(value);
bindableDouble.MinValue = double.CreateTruncating(value);
}

private void setMax<T>(T value)
private void setMax<T>(T value) where T : INumber<T>
{
bindableInt.MaxValue = Convert.ToInt32(value);
bindableLong.MaxValue = Convert.ToInt64(value);
bindableFloat.MaxValue = Convert.ToSingle(value);
bindableDouble.MaxValue = Convert.ToDouble(value);
bindableInt.MaxValue = int.CreateTruncating(value);
bindableLong.MaxValue = long.CreateTruncating(value);
bindableFloat.MaxValue = float.CreateTruncating(value);
bindableDouble.MaxValue = double.CreateTruncating(value);
}

private void setValue<T>(T value)
private void setValue<T>(T value) where T : INumber<T>
{
bindableInt.Value = Convert.ToInt32(value);
bindableLong.Value = Convert.ToInt64(value);
bindableFloat.Value = Convert.ToSingle(value);
bindableDouble.Value = Convert.ToDouble(value);
bindableInt.Value = int.CreateTruncating(value);
bindableLong.Value = long.CreateTruncating(value);
bindableFloat.Value = float.CreateTruncating(value);
bindableDouble.Value = double.CreateTruncating(value);
}

private void setPrecision<T>(T precision)
private void setPrecision<T>(T precision) where T : INumber<T>
{
bindableInt.Precision = Convert.ToInt32(precision);
bindableLong.Precision = Convert.ToInt64(precision);
bindableFloat.Precision = Convert.ToSingle(precision);
bindableDouble.Precision = Convert.ToDouble(precision);
bindableInt.Precision = int.CreateTruncating(precision);
bindableLong.Precision = long.CreateTruncating(precision);
bindableFloat.Precision = float.CreateTruncating(precision);
bindableDouble.Precision = double.CreateTruncating(precision);
}

private partial class BindableDisplayContainer<T> : CompositeDrawable
where T : struct, IComparable<T>, IConvertible, IEquatable<T>
where T : struct, INumber<T>, IMinMaxValue<T>, IConvertible
{
public BindableDisplayContainer(BindableNumber<T> bindable)
{
Expand Down
204 changes: 23 additions & 181 deletions osu.Framework/Bindables/BindableNumber.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@
#nullable disable

using System;
using System.Diagnostics;
using System.Globalization;
using System.Numerics;
using JetBrains.Annotations;
using osu.Framework.Extensions.TypeExtensions;
using osu.Framework.Utils;

namespace osu.Framework.Bindables
{
public class BindableNumber<T> : RangeConstrainedBindable<T>, IBindableNumber<T>
where T : struct, IComparable<T>, IConvertible, IEquatable<T>
where T : struct, INumber<T>, IMinMaxValue<T>
{
[CanBeNull]
public event Action<T> PrecisionChanged;
Expand All @@ -40,10 +38,10 @@ public T Precision
get => precision;
set
{
if (precision.Equals(value))
if (precision == value)
return;

if (value.CompareTo(default) <= 0)
if (value <= T.Zero)
throw new ArgumentOutOfRangeException(nameof(Precision), value, "Must be greater than 0.");

SetPrecision(value, true, this);
Expand Down Expand Up @@ -76,102 +74,22 @@ public override T Value

private void setValue(T value)
{
if (Precision.CompareTo(DefaultPrecision) > 0)
if (Precision > DefaultPrecision)
{
// this rounding is purposefully performed on `decimal` to ensure that the resulting value is the closest possible floating-point
// number to actual real-world base-10 decimals, as that is the most common usage of precision.
decimal accurateResult = ClampValue(value, MinValue, MaxValue).ToDecimal(NumberFormatInfo.InvariantInfo);
accurateResult = Math.Round(accurateResult / Precision.ToDecimal(NumberFormatInfo.InvariantInfo)) * Precision.ToDecimal(NumberFormatInfo.InvariantInfo);
decimal accurateResult = decimal.CreateTruncating(T.Clamp(value, MinValue, MaxValue));
accurateResult = Math.Round(accurateResult / decimal.CreateTruncating(Precision)) * decimal.CreateTruncating(Precision);

base.Value = convertFromDecimal(accurateResult);
base.Value = T.CreateTruncating(accurateResult);
}
else
base.Value = value;
}

private T convertFromDecimal(decimal value)
{
if (typeof(T) == typeof(sbyte))
return (T)(object)Convert.ToSByte(value);
if (typeof(T) == typeof(byte))
return (T)(object)Convert.ToByte(value);
if (typeof(T) == typeof(short))
return (T)(object)Convert.ToInt16(value);
if (typeof(T) == typeof(ushort))
return (T)(object)Convert.ToUInt16(value);
if (typeof(T) == typeof(int))
return (T)(object)Convert.ToInt32(value);
if (typeof(T) == typeof(uint))
return (T)(object)Convert.ToUInt32(value);
if (typeof(T) == typeof(long))
return (T)(object)Convert.ToInt64(value);
if (typeof(T) == typeof(ulong))
return (T)(object)Convert.ToUInt64(value);
if (typeof(T) == typeof(float))
return (T)(object)Convert.ToSingle(value);
if (typeof(T) == typeof(double))
return (T)(object)Convert.ToDouble(value);

throw new InvalidCastException($"Cannot convert from decimal to {typeof(T).ReadableName()}");
}
protected override T DefaultMinValue => T.MinValue;

protected override T DefaultMinValue
{
get
{
Debug.Assert(Validation.IsSupportedBindableNumberType<T>());

if (typeof(T) == typeof(sbyte))
return (T)(object)sbyte.MinValue;
if (typeof(T) == typeof(byte))
return (T)(object)byte.MinValue;
if (typeof(T) == typeof(short))
return (T)(object)short.MinValue;
if (typeof(T) == typeof(ushort))
return (T)(object)ushort.MinValue;
if (typeof(T) == typeof(int))
return (T)(object)int.MinValue;
if (typeof(T) == typeof(uint))
return (T)(object)uint.MinValue;
if (typeof(T) == typeof(long))
return (T)(object)long.MinValue;
if (typeof(T) == typeof(ulong))
return (T)(object)ulong.MinValue;
if (typeof(T) == typeof(float))
return (T)(object)float.MinValue;

return (T)(object)double.MinValue;
}
}

protected override T DefaultMaxValue
{
get
{
Debug.Assert(Validation.IsSupportedBindableNumberType<T>());

if (typeof(T) == typeof(sbyte))
return (T)(object)sbyte.MaxValue;
if (typeof(T) == typeof(byte))
return (T)(object)byte.MaxValue;
if (typeof(T) == typeof(short))
return (T)(object)short.MaxValue;
if (typeof(T) == typeof(ushort))
return (T)(object)ushort.MaxValue;
if (typeof(T) == typeof(int))
return (T)(object)int.MaxValue;
if (typeof(T) == typeof(uint))
return (T)(object)uint.MaxValue;
if (typeof(T) == typeof(long))
return (T)(object)long.MaxValue;
if (typeof(T) == typeof(ulong))
return (T)(object)ulong.MaxValue;
if (typeof(T) == typeof(float))
return (T)(object)float.MaxValue;

return (T)(object)double.MaxValue;
}
}
protected override T DefaultMaxValue => T.MaxValue;

/// <summary>
/// The default <see cref="Precision"/>.
Expand All @@ -180,26 +98,12 @@ protected virtual T DefaultPrecision
{
get
{
if (typeof(T) == typeof(sbyte))
return (T)(object)(sbyte)1;
if (typeof(T) == typeof(byte))
return (T)(object)(byte)1;
if (typeof(T) == typeof(short))
return (T)(object)(short)1;
if (typeof(T) == typeof(ushort))
return (T)(object)(ushort)1;
if (typeof(T) == typeof(int))
return (T)(object)1;
if (typeof(T) == typeof(uint))
return (T)(object)1U;
if (typeof(T) == typeof(long))
return (T)(object)1L;
if (typeof(T) == typeof(ulong))
return (T)(object)1UL;
if (typeof(T) == typeof(float))
return (T)(object)float.Epsilon;
if (typeof(T) == typeof(double))
return (T)(object)double.Epsilon;

return (T)(object)double.Epsilon;
return T.One;
}
}

Expand Down Expand Up @@ -249,63 +153,11 @@ public override void UnbindEvents()
typeof(T) != typeof(float) &&
typeof(T) != typeof(double); // Will be **constant** after JIT.

public void Set<TNewValue>(TNewValue val) where TNewValue : struct,
IFormattable, IConvertible, IComparable<TNewValue>, IEquatable<TNewValue>
{
Debug.Assert(Validation.IsSupportedBindableNumberType<T>());

// Comparison between typeof(T) and type literals are treated as **constant** on value types.
// Code paths for other types will be eliminated.
if (typeof(T) == typeof(byte))
((BindableNumber<byte>)(object)this).Value = val.ToByte(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(sbyte))
((BindableNumber<sbyte>)(object)this).Value = val.ToSByte(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(ushort))
((BindableNumber<ushort>)(object)this).Value = val.ToUInt16(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(short))
((BindableNumber<short>)(object)this).Value = val.ToInt16(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(uint))
((BindableNumber<uint>)(object)this).Value = val.ToUInt32(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(int))
((BindableNumber<int>)(object)this).Value = val.ToInt32(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(ulong))
((BindableNumber<ulong>)(object)this).Value = val.ToUInt64(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(long))
((BindableNumber<long>)(object)this).Value = val.ToInt64(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(float))
((BindableNumber<float>)(object)this).Value = val.ToSingle(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(double))
((BindableNumber<double>)(object)this).Value = val.ToDouble(NumberFormatInfo.InvariantInfo);
}
public void Set<TNewValue>(TNewValue val) where TNewValue : struct, INumber<TNewValue>
=> Value = T.CreateTruncating(val);

public void Add<TNewValue>(TNewValue val) where TNewValue : struct,
IFormattable, IConvertible, IComparable<TNewValue>, IEquatable<TNewValue>
{
Debug.Assert(Validation.IsSupportedBindableNumberType<T>());

// Comparison between typeof(T) and type literals are treated as **constant** on value types.
// Code pathes for other types will be eliminated.
if (typeof(T) == typeof(byte))
((BindableNumber<byte>)(object)this).Value += val.ToByte(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(sbyte))
((BindableNumber<sbyte>)(object)this).Value += val.ToSByte(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(ushort))
((BindableNumber<ushort>)(object)this).Value += val.ToUInt16(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(short))
((BindableNumber<short>)(object)this).Value += val.ToInt16(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(uint))
((BindableNumber<uint>)(object)this).Value += val.ToUInt32(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(int))
((BindableNumber<int>)(object)this).Value += val.ToInt32(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(ulong))
((BindableNumber<ulong>)(object)this).Value += val.ToUInt64(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(long))
((BindableNumber<long>)(object)this).Value += val.ToInt64(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(float))
((BindableNumber<float>)(object)this).Value += val.ToSingle(NumberFormatInfo.InvariantInfo);
else if (typeof(T) == typeof(double))
((BindableNumber<double>)(object)this).Value += val.ToDouble(NumberFormatInfo.InvariantInfo);
}
public void Add<TNewValue>(TNewValue val) where TNewValue : struct, INumber<TNewValue>
=> Value += T.CreateTruncating(val);

/// <summary>
/// Sets the value of the bindable to Min + (Max - Min) * amt
Expand All @@ -314,8 +166,10 @@ public void Add<TNewValue>(TNewValue val) where TNewValue : struct,
/// </summary>
public void SetProportional(float amt, float snap = 0)
{
double min = MinValue.ToDouble(NumberFormatInfo.InvariantInfo);
double max = MaxValue.ToDouble(NumberFormatInfo.InvariantInfo);
// TODO: Use IFloatingPointIeee754<T>.Lerp when applicable

double min = double.CreateTruncating(MinValue);
double max = double.CreateTruncating(MaxValue);
double value = min + (max - min) * amt;
if (snap > 0)
value = Math.Round(value / snap) * snap;
Expand Down Expand Up @@ -350,20 +204,8 @@ public override bool IsDefault

protected override Bindable<T> CreateInstance() => new BindableNumber<T>();

protected sealed override T ClampValue(T value, T minValue, T maxValue) => max(minValue, min(maxValue, value));

protected sealed override bool IsValidRange(T min, T max) => min.CompareTo(max) <= 0;
protected sealed override T ClampValue(T value, T minValue, T maxValue) => T.Clamp(value, minValue, maxValue);

private static T max(T value1, T value2)
{
int comparison = value1.CompareTo(value2);
return comparison > 0 ? value1 : value2;
}

private static T min(T value1, T value2)
{
int comparison = value1.CompareTo(value2);
return comparison > 0 ? value2 : value1;
}
protected sealed override bool IsValidRange(T min, T max) => min <= max;
}
}
3 changes: 2 additions & 1 deletion osu.Framework/Bindables/BindableNumberWithCurrent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#nullable disable

using System;
using System.Numerics;

namespace osu.Framework.Bindables
{
Expand All @@ -12,7 +13,7 @@ namespace osu.Framework.Bindables
/// </summary>
/// <typeparam name="T">The type of our stored <see cref="Bindable{T}.Value"/>.</typeparam>
public class BindableNumberWithCurrent<T> : BindableNumber<T>, IBindableWithCurrent<T>
where T : struct, IComparable<T>, IConvertible, IEquatable<T>
where T : struct, INumber<T>, IMinMaxValue<T>
{
private BindableNumber<T> currentBound;

Expand Down
6 changes: 4 additions & 2 deletions osu.Framework/Bindables/RangeConstrainedBindable.cs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,10 @@ public override void CopyTo(Bindable<T> them)
// as Value assignment (in the base call below) automatically clamps to [MinValue, MaxValue].
if (them is RangeConstrainedBindable<T> other)
{
other.MinValue = MinValue;
other.MaxValue = MaxValue;
// copy the bounds over without updating the current value, to avoid clamping on invalid ranges.
// there is no need to clamp `Value` after that directly - the `base.CopyTo()` call will change `Value` anyway.
other.SetMinValue(MinValue, false, this);
other.SetMaxValue(MaxValue, false, this);
}

base.CopyTo(them);
Expand Down
Loading

0 comments on commit 69a26ce

Please sign in to comment.