diff --git a/src/core/include/mp-units/framework/expression_template.h b/src/core/include/mp-units/framework/expression_template.h index c0ec574b1..72a2db0e2 100644 --- a/src/core/include/mp-units/framework/expression_template.h +++ b/src/core/include/mp-units/framework/expression_template.h @@ -334,7 +334,7 @@ struct expr_fractions : decltype(expr_fractions_impl>( // expr_make_spec template typename To> -[[nodiscard]] consteval auto expr_make_spec_impl() +[[nodiscard]] MP_UNITS_CONSTEVAL auto expr_make_spec_impl() { constexpr std::size_t num = type_list_size; constexpr std::size_t den = type_list_size; @@ -359,7 +359,7 @@ template typename Pred, template typename To> -[[nodiscard]] consteval auto get_optimized_expression() +[[nodiscard]] MP_UNITS_CONSTEVAL auto get_optimized_expression() { using num_list = expr_consolidate; using den_list = expr_consolidate; @@ -380,7 +380,7 @@ template typename To, typename OneType, template typename Pred, typename Lhs, typename Rhs> -[[nodiscard]] consteval auto expr_multiply(Lhs, Rhs) +[[nodiscard]] MP_UNITS_CONSTEVAL auto expr_multiply(Lhs, Rhs) { if constexpr (is_same_v) { return Rhs{}; diff --git a/src/core/include/mp-units/framework/magnitude.h b/src/core/include/mp-units/framework/magnitude.h index 8fb22eb81..d16325283 100644 --- a/src/core/include/mp-units/framework/magnitude.h +++ b/src/core/include/mp-units/framework/magnitude.h @@ -307,6 +307,119 @@ template return checked_square(int_power(base, exp / 2)); } +template +[[nodiscard]] consteval std::optional checked_int_pow(T base, std::uintmax_t exp) +{ + T result = T{1}; + while (exp > 0u) { + if (exp % 2u == 1u) { + if (base > std::numeric_limits::max() / result) { + return std::nullopt; + } + result *= base; + } + + exp /= 2u; + + if (base > std::numeric_limits::max() / base) { + return (exp == 0u) ? std::make_optional(result) : std::nullopt; + } + base *= base; + } + return result; +} + +template +[[nodiscard]] consteval std::optional root(T x, std::uintmax_t n) +{ + // The "zeroth root" would be mathematically undefined. + if (n == 0) { + return std::nullopt; + } + + // The "first root" is trivial. + if (n == 1) { + return x; + } + + // We only support nontrivial roots of floating point types. + if (!std::is_floating_point::value) { + return std::nullopt; + } + + // Handle negative numbers: only odd roots are allowed. + if (x < 0) { + if (n % 2 == 0) { + return std::nullopt; + } else { + const auto negative_result = root(-x, n); + if (!negative_result.has_value()) { + return std::nullopt; + } + return static_cast(-negative_result.value()); + } + } + + // Handle special cases of zero and one. + if (x == 0 || x == 1) { + return x; + } + + // Handle numbers bewtween 0 and 1. + if (x < 1) { + const auto inverse_result = root(T{1} / x, n); + if (!inverse_result.has_value()) { + return std::nullopt; + } + return static_cast(T{1} / inverse_result.value()); + } + + // + // At this point, error conditions are finished, and we can proceed with the "core" algorithm. + // + + // Always use `long double` for intermediate computations. We don't ever expect people to be + // calling this at runtime, so we want maximum accuracy. + long double lo = 1.0; + long double hi = static_cast(x); + + // Do a binary search to find the closest value such that `checked_int_pow` recovers the input. + // + // Because we know `n > 1`, and `x > 1`, and x^n is monotonically increasing, we know that + // `checked_int_pow(lo, n) < x < checked_int_pow(hi, n)`. We will preserve this as an + // invariant. + while (lo < hi) { + long double mid = lo + (hi - lo) / 2; + + auto result = checked_int_pow(mid, n); + + if (!result.has_value()) { + return std::nullopt; + } + + // Early return if we get lucky with an exact answer. + if (result.value() == x) { + return static_cast(mid); + } + + // Check for stagnation. + if (mid == lo || mid == hi) { + break; + } + + // Preserve the invariant that `checked_int_pow(lo, n) < x < checked_int_pow(hi, n)`. + if (result.value() < x) { + lo = mid; + } else { + hi = mid; + } + } + + // Pick whichever one gets closer to the target. + const auto lo_diff = x - checked_int_pow(lo, n).value(); + const auto hi_diff = checked_int_pow(hi, n).value() - x; + return static_cast(lo_diff < hi_diff ? lo : hi); +} template [[nodiscard]] consteval widen_t compute_base_power(MagnitudeSpec auto el) @@ -317,9 +430,6 @@ template // Note that since this function should only be called at compile time, the point of these // terminations is to act as "static_assert substitutes", not to actually terminate at runtime. const auto exp = get_exponent(el); - if (exp.den != 1) { - std::abort(); // Rational powers not yet supported - } if (exp.num < 0) { if constexpr (std::is_integral_v) { @@ -329,8 +439,19 @@ template } } - auto power = exp.num; - return int_power(static_cast>(get_base_value(el)), power); + const auto pow_result = + checked_int_pow(static_cast>(get_base_value(el)), static_cast(exp.num)); + if (pow_result.has_value()) { + const auto final_result = + (exp.den > 1) ? root(pow_result.value(), static_cast(exp.den)) : pow_result; + if (final_result.has_value()) { + return final_result.value(); + } else { + std::abort(); // Root computation failed. + } + } else { + std::abort(); // Power computation failed. + } } // A converter for the value member variable of magnitude (below). diff --git a/test/runtime/CMakeLists.txt b/test/runtime/CMakeLists.txt index 0be597b38..d79ed54b3 100644 --- a/test/runtime/CMakeLists.txt +++ b/test/runtime/CMakeLists.txt @@ -23,8 +23,14 @@ find_package(Catch2 3 REQUIRED) add_executable( - unit_tests_runtime distribution_test.cpp fixed_string_test.cpp fmt_test.cpp math_test.cpp atomic_test.cpp - truncation_test.cpp + unit_tests_runtime + distribution_test.cpp + fixed_string_test.cpp + fmt_test.cpp + math_test.cpp + atomic_test.cpp + truncation_test.cpp + quantity_test.cpp ) if(${projectPrefix}BUILD_CXX_MODULES) target_compile_definitions(unit_tests_runtime PUBLIC ${projectPrefix}MODULES) diff --git a/test/runtime/quantity_test.cpp b/test/runtime/quantity_test.cpp new file mode 100644 index 000000000..151845a7a --- /dev/null +++ b/test/runtime/quantity_test.cpp @@ -0,0 +1,75 @@ +// The MIT License (MIT) +// +// Copyright (c) 2024 Chip Hogg +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include +#ifdef MP_UNITS_IMPORT_STD +import std; +#else +#include +#include +#endif +#ifdef MP_UNITS_MODULES +import mp_units; +#else +#include +#include +#endif + +using namespace mp_units; +using namespace mp_units::si::unit_symbols; + +namespace { + +template +constexpr bool within_4_ulps(T a, T b) +{ + static_assert(std::is_floating_point_v); + auto walk_ulps = [](T x, int n) { + while (n > 0) { + x = std::nextafter(x, std::numeric_limits::infinity()); + --n; + } + while (n < 0) { + x = std::nextafter(x, -std::numeric_limits::infinity()); + ++n; + } + return x; + }; + + return (walk_ulps(a, -4) <= b) && (b <= walk_ulps(a, 4)); +} + +} // namespace + +// conversion requiring radical magnitudes +TEST_CASE("unit conversions support radical magnitudes", "[conversion][radical]") +{ + REQUIRE(within_4_ulps(sqrt((1.0 * m) * (1.0 * km)).numerical_value_in(m), sqrt(1000.0))); +} + +// Reproducing issue #474 exactly: +TEST_CASE("Issue 474 is fixed", "[conversion][radical]") +{ + constexpr auto val_issue_474 = 8.0 * si::si2019::boltzmann_constant * 1000.0 * K / (std::numbers::pi * 10 * Da); + REQUIRE(within_4_ulps(sqrt(val_issue_474).numerical_value_in(m / s), + sqrt(val_issue_474.numerical_value_in(m * m / s / s)))); +} diff --git a/test/static/quantity_test.cpp b/test/static/quantity_test.cpp index 9b3abbb0a..d8ddfe609 100644 --- a/test/static/quantity_test.cpp +++ b/test/static/quantity_test.cpp @@ -199,7 +199,6 @@ static_assert(std::convertible_to, quantity, quantity>); static_assert(std::convertible_to, quantity>); - /////////////////////// // obtaining a number ///////////////////////