diff --git a/au/magnitude.hh b/au/magnitude.hh index e7cd880e..c0a585dc 100644 --- a/au/magnitude.hh +++ b/au/magnitude.hh @@ -270,7 +270,7 @@ namespace detail { enum class MagRepresentationOutcome { OK, ERR_NON_INTEGER_IN_INTEGER_TYPE, - ERR_RATIONAL_POWERS, + ERR_INVALID_ROOT, ERR_CANNOT_FIT, }; @@ -316,10 +316,101 @@ constexpr MagRepresentationOrError checked_int_pow(T base, std::uintmax_t exp return result; } -template +template +constexpr MagRepresentationOrError root(T x, std::uintmax_t n) { + // The "zeroth root" would be mathematically undefined. + if (n == 0) { + return {MagRepresentationOutcome::ERR_INVALID_ROOT}; + } + + // The "first root" is trivial. + if (n == 1) { + return {MagRepresentationOutcome::OK, x}; + } + + // We only support nontrivial roots of floating point types. + if (!std::is_floating_point::value) { + return {MagRepresentationOutcome::ERR_NON_INTEGER_IN_INTEGER_TYPE}; + } + + // Handle negative numbers: only odd roots are allowed. + if (x < 0) { + if (n % 2 == 0) { + return {MagRepresentationOutcome::ERR_INVALID_ROOT}; + } else { + const auto negative_result = root(-x, n); + if (negative_result.outcome != MagRepresentationOutcome::OK) { + return {negative_result.outcome}; + } + return {MagRepresentationOutcome::OK, static_cast(-negative_result.value)}; + } + } + + // Handle special cases of zero and one. + if (x == 0 || x == 1) { + return {MagRepresentationOutcome::OK, x}; + } + + // Handle numbers bewtween 0 and 1. + if (x < 1) { + const auto inverse_result = root(T{1} / x, n); + if (inverse_result.outcome != MagRepresentationOutcome::OK) { + return {inverse_result.outcome}; + } + return {MagRepresentationOutcome::OK, static_cast(T{1} / inverse_result.value)}; + } + + // + // At this point, error conditions are finished, and we can proceed with the "core" algorithm. + // + + // Always use `long double` for intermediate computations. We don't ever expect people to be + // calling this at runtime, so we want maximum accuracy. + long double lo = 1.0; + long double hi = x; + + // Do a binary search to find the closest value such that `checked_int_pow` recovers the input. + // + // Because we know `n > 1`, and `x > 1`, and x^n is monotonically increasing, we know that + // `checked_int_pow(lo, n) < x < checked_int_pow(hi, n)`. We will preserve this as an + // invariant. + while (lo < hi) { + long double mid = lo + (hi - lo) / 2; + + auto result = checked_int_pow(mid, n); + + if (result.outcome != MagRepresentationOutcome::OK) { + return {result.outcome}; + } + + // Early return if we get lucky with an exact answer. + if (result.value == x) { + return {MagRepresentationOutcome::OK, static_cast(mid)}; + } + + // Check for stagnation. + if (mid == lo || mid == hi) { + break; + } + + // Preserve the invariant that `checked_int_pow(lo, n) < x < checked_int_pow(hi, n)`. + if (result.value < x) { + lo = mid; + } else { + hi = mid; + } + } + + // Pick whichever one gets closer to the target. + const auto lo_diff = x - checked_int_pow(lo, n).value; + const auto hi_diff = checked_int_pow(hi, n).value - x; + return {MagRepresentationOutcome::OK, static_cast(lo_diff < hi_diff ? lo : hi)}; +} + +template constexpr MagRepresentationOrError> base_power_value(B base) { if (N < 0) { - const auto inverse_result = base_power_value(base); + const auto inverse_result = base_power_value(base); if (inverse_result.outcome != MagRepresentationOutcome::OK) { return inverse_result; } @@ -329,7 +420,12 @@ constexpr MagRepresentationOrError> base_power_value(B base) { }; } - return checked_int_pow(static_cast>(base), static_cast(N)); + const auto power_result = + checked_int_pow(static_cast>(base), static_cast(N)); + if (power_result.outcome != MagRepresentationOutcome::OK) { + return {power_result.outcome}; + } + return (D > 1) ? root(power_result.value, D) : power_result; } template @@ -393,15 +489,10 @@ constexpr MagRepresentationOrError get_value_result(Magnitude) { return {MagRepresentationOutcome::ERR_NON_INTEGER_IN_INTEGER_TYPE}; } - // Computing values for rational base powers is something we would _like_ to support, but we - // need a `constexpr` implementation of `powl()` first. - if (!all({(ExpT::den == 1)...})) { - return {MagRepresentationOutcome::ERR_RATIONAL_POWERS}; - } - // Force the expression to be evaluated in a constexpr context. constexpr auto widened_result = - product({base_power_value::num / ExpT::den)>(BaseT::value())...}); + product({base_power_value::num, static_cast(ExpT::den)>( + BaseT::value())...}); if ((widened_result.outcome != MagRepresentationOutcome::OK) || !safe_to_cast_to(widened_result.value)) { @@ -433,8 +524,8 @@ constexpr T get_value(Magnitude m) { static_assert(result.outcome != MagRepresentationOutcome::ERR_NON_INTEGER_IN_INTEGER_TYPE, "Cannot represent non-integer in integral destination type"); - static_assert(result.outcome != MagRepresentationOutcome::ERR_RATIONAL_POWERS, - "Computing values for rational powers not yet supported"); + static_assert(result.outcome != MagRepresentationOutcome::ERR_INVALID_ROOT, + "Could not compute root for rational power of base"); static_assert(result.outcome != MagRepresentationOutcome::ERR_CANNOT_FIT, "Value outside range of destination type"); diff --git a/au/magnitude_test.cc b/au/magnitude_test.cc index 5e72a101..1f9005fb 100644 --- a/au/magnitude_test.cc +++ b/au/magnitude_test.cc @@ -20,6 +20,8 @@ #include "gtest/gtest.h" using ::testing::DoubleEq; +using ::testing::Eq; +using ::testing::FloatEq; using ::testing::StaticAssertTypeEq; namespace au { @@ -208,6 +210,11 @@ TEST(GetValue, ImpossibleRequestsArePreventedAtCompileTime) { // get_value(sqrt_2); } +TEST(GetValue, HandlesRoots) { + constexpr auto sqrt_2 = get_value(root<2>(mag<2>())); + EXPECT_DOUBLE_EQ(sqrt_2 * sqrt_2, 2.0); +} + TEST(GetValue, WorksForEmptyPack) { constexpr auto one = Magnitude<>{}; EXPECT_THAT(get_value(one), SameTypeAndValue(1)); @@ -270,6 +277,15 @@ MATCHER(CannotFit, "") { return (arg.outcome == MagRepresentationOutcome::ERR_CANNOT_FIT) && (arg.value == 0); } +MATCHER(NonIntegerInIntegerType, "") { + return (arg.outcome == MagRepresentationOutcome::ERR_NON_INTEGER_IN_INTEGER_TYPE) && + (arg.value == 0); +} + +MATCHER(InvalidRoot, "") { + return (arg.outcome == MagRepresentationOutcome::ERR_INVALID_ROOT) && (arg.value == 0); +} + template auto FitsAndMatchesValue(ValueMatcher &&matcher) { return ::testing::AllOf( @@ -298,6 +314,106 @@ TEST(CheckedIntPow, FindsAppropriateLimits) { EXPECT_THAT(checked_int_pow(10.0, 309), CannotFit()); } +TEST(Root, ReturnsErrorForIntegralType) { + EXPECT_THAT(root(4, 2), NonIntegerInIntegerType()); + EXPECT_THAT(root(uint8_t{125}, 3), NonIntegerInIntegerType()); +} + +TEST(Root, ReturnsErrorForZerothRoot) { + EXPECT_THAT(root(4.0, 0), InvalidRoot()); + EXPECT_THAT(root(125.0, 0), InvalidRoot()); +} + +TEST(Root, NegativeRootsWorkForOddPowersOnly) { + EXPECT_THAT(root(-4.0, 2), InvalidRoot()); + EXPECT_THAT(root(-125.0, 3), FitsAndProducesValue(-5.0)); + EXPECT_THAT(root(-10000.0, 4), InvalidRoot()); +} + +TEST(Root, AnyRootOfOneIsOne) { + for (const std::uintmax_t r : {1u, 2u, 3u, 4u, 5u, 6u, 7u, 8u, 9u}) { + EXPECT_THAT(root(1.0, r), FitsAndProducesValue(1.0)); + } +} + +TEST(Root, AnyRootOfZeroIsZero) { + for (const std::uintmax_t r : {1u, 2u, 3u, 4u, 5u, 6u, 7u, 8u, 9u}) { + EXPECT_THAT(root(0.0, r), FitsAndProducesValue(0.0)); + } +} + +TEST(Root, OddRootOfNegativeOneIsItself) { + EXPECT_THAT(root(-1.0, 1), FitsAndProducesValue(-1.0)); + EXPECT_THAT(root(-1.0, 2), InvalidRoot()); + EXPECT_THAT(root(-1.0, 3), FitsAndProducesValue(-1.0)); + EXPECT_THAT(root(-1.0, 4), InvalidRoot()); + EXPECT_THAT(root(-1.0, 5), FitsAndProducesValue(-1.0)); +} + +TEST(Root, RecoversExactValueWherePossible) { + { + const auto sqrt_4f = root(4.0f, 2); + EXPECT_THAT(sqrt_4f.outcome, Eq(MagRepresentationOutcome::OK)); + EXPECT_THAT(sqrt_4f.value, SameTypeAndValue(2.0f)); + } + + { + const auto cbrt_125L = root(125.0L, 3); + EXPECT_THAT(cbrt_125L.outcome, Eq(MagRepresentationOutcome::OK)); + EXPECT_THAT(cbrt_125L.value, SameTypeAndValue(5.0L)); + } +} + +TEST(Root, HandlesArgumentsBetweenOneAndZero) { + EXPECT_THAT(root(0.25, 2), FitsAndProducesValue(0.5)); + EXPECT_THAT(root(0.0001, 4), FitsAndMatchesValue(DoubleEq(0.1))); +} + +TEST(Root, ResultIsVeryCloseToStdPowForPureRoots) { + for (const double x : {55.5, 123.456, 789.012, 3456.789, 12345.6789, 5.67e25}) { + for (const auto r : {2u, 3u, 4u, 5u, 6u, 7u, 8u, 9u}) { + const auto double_result = root(x, r); + EXPECT_THAT(double_result.outcome, Eq(MagRepresentationOutcome::OK)); + EXPECT_THAT(double_result.value, DoubleEq(static_cast(std::pow(x, 1.0L / r)))); + + const auto float_result = root(static_cast(x), r); + EXPECT_THAT(float_result.outcome, Eq(MagRepresentationOutcome::OK)); + EXPECT_THAT(float_result.value, FloatEq(static_cast(std::pow(x, 1.0L / r)))); + } + } +} + +TEST(Root, ResultAtLeastAsGoodAsStdPowForRationalPowers) { + struct RationalPower { + std::uintmax_t num; + std::uintmax_t den; + }; + + auto result_via_root = [](double x, RationalPower power) { + return static_cast( + root(checked_int_pow(static_cast(x), power.num).value, power.den).value); + }; + + auto result_via_std_pow = [](double x, RationalPower power) { + return static_cast( + std::pow(static_cast(x), + static_cast(power.num) / static_cast(power.den))); + }; + + auto round_trip_error = [](double x, RationalPower power, auto func) { + const auto round_trip_result = func(func(x, power), {power.den, power.num}); + return std::abs(round_trip_result - x); + }; + + for (const auto base : {2.0, 3.1415, 98.6, 1.2e-10, 5.5e15}) { + for (const auto power : std::vector{{5, 2}, {2, 3}, {7, 4}}) { + const auto error_from_root = round_trip_error(base, power, result_via_root); + const auto error_from_std_pow = round_trip_error(base, power, result_via_std_pow); + EXPECT_LE(error_from_root, error_from_std_pow); + } + } +} + TEST(GetValueResult, HandlesNumbersTooBigForUintmax) { EXPECT_THAT(get_value_result(pow<64>(mag<2>())), CannotFit()); } diff --git a/au/quantity_test.cc b/au/quantity_test.cc index 553ce3eb..8e582333 100644 --- a/au/quantity_test.cc +++ b/au/quantity_test.cc @@ -263,18 +263,20 @@ TEST(Quantity, HandlesBaseDimensionsWithFractionalExponents) { } TEST(Quantity, HandlesMagnitudesWithFractionalExponents) { - using RootKiloFeet = decltype(root<2>(Kilo{})); - constexpr auto x = make_quantity(3); + constexpr auto x = sqrt(kilo(feet))(3.0); // We can retrieve the value in the same unit (regardless of the scale's fractional powers). - EXPECT_EQ(x.in(RootKiloFeet{}), 3); + EXPECT_EQ(x.in(sqrt(kilo(feet))), 3.0); // We can retrieve the value in a *different* unit, which *also* has fractional powers, as long // as their *ratio* has no fractional powers. - EXPECT_EQ(x.in(root<2>(Milli{})), 3'000); + EXPECT_EQ(x.in(sqrt(milli(feet))), 3'000.0); + + // We can also retrieve the value in a different unit whose ratio *does* have fractional powers. + EXPECT_NEAR(x.in(sqrt(feet)), 94.86833, 1e-5); // Squaring the fractional base power gives us an exact non-fractional dimension and scale. - EXPECT_EQ(x * x, kilo(feet)(9)); + EXPECT_EQ(x * x, kilo(feet)(9.0)); } // A custom "Quantity-equivalent" type, whose interop with Quantity we'll provide below.