From 51f8b47071fd3c37412bbdf5c5b35723309b5bbe Mon Sep 17 00:00:00 2001 From: Nick Thompson Date: Sat, 9 Dec 2023 13:16:57 -0800 Subject: [PATCH] Add fma to math.h --- src/utility/include/mp-units/math.h | 24 ++++++++++++++++++++++++ test/unit_test/runtime/math_test.cpp | 6 ++++++ test/unit_test/static/math_test.cpp | 5 +++++ 3 files changed, 35 insertions(+) diff --git a/src/utility/include/mp-units/math.h b/src/utility/include/mp-units/math.h index 1a6aaf25b..c613fea39 100644 --- a/src/utility/include/mp-units/math.h +++ b/src/utility/include/mp-units/math.h @@ -136,6 +136,30 @@ template return {static_cast(abs(q.numerical_value_ref_in(q.unit))), R}; } +/** + * @brief Computes the fma of 3 quantities + * + * @param a: Multiplicand + * @param x: Multiplicand + * @param b: Addend + * @return Quantity: The nearest floating point representable to ax+b + */ +template +[[nodiscard]] constexpr quantity fma(const quantity& a, const quantity& x, + const quantity& b) noexcept + requires requires { + fma(a.numerical_value_ref_in(a.unit), x.numerical_value_ref_in(x.unit), b.numerical_value_ref_in(b.unit)); + } || requires { + std::fma(a.numerical_value_ref_in(a.unit), x.numerical_value_ref_in(x.unit), b.numerical_value_ref_in(b.unit)); + } +{ + using std::fma; + return {static_cast( + fma(a.numerical_value_ref_in(a.unit), x.numerical_value_ref_in(x.unit), b.numerical_value_ref_in(b.unit))), + R}; +} + + /** * @brief Returns the epsilon of the quantity * diff --git a/test/unit_test/runtime/math_test.cpp b/test/unit_test/runtime/math_test.cpp index e468225f4..d7518de71 100644 --- a/test/unit_test/runtime/math_test.cpp +++ b/test/unit_test/runtime/math_test.cpp @@ -62,6 +62,12 @@ TEST_CASE("'cbrt()' on quantity changes the value and the dimension accordingly" REQUIRE(cbrt(8 * isq::volume[m3]) == 2 * isq::length[m]); } +TEST_CASE("'fma()' on quantity changes the value and the dimension accordingly", "[math][cbrt]") +{ + REQUIRE(fma(1.0 * isq::length[m], 2.0, 2.0 * isq::length[m]) == 4.0 * isq::length[m]); +} + + TEST_CASE("'pow()' on quantity changes the value and the dimension accordingly", "[math][pow]") { REQUIRE(pow<1, 4>(16 * isq::area[m2]) == sqrt(4 * isq::length[m])); diff --git a/test/unit_test/static/math_test.cpp b/test/unit_test/static/math_test.cpp index 980f68e08..7d8779271 100644 --- a/test/unit_test/static/math_test.cpp +++ b/test/unit_test/static/math_test.cpp @@ -40,6 +40,11 @@ template #if __cpp_lib_constexpr_cmath || MP_UNITS_COMP_GCC +static_assert(compare(fma(2 * m, 3 * m, 1 * m2), 7 * m2)); +static_assert(compare(fma(2.0 * s, 3.0 * Hz, 1.0), 7.0)); +static_assert(compare(fma(2 * s, 3 * Hz, 1), 7)); +static_assert(compare(fma(2.0, 3.0*m, 1.0*m), 7*m); +static_assert(compare(fma(2.0*m, 3.0, 1.0*m), 7*m)); static_assert(compare(pow<0>(2 * m), 1 * one)); static_assert(compare(pow<1>(2 * m), 2 * m)); static_assert(compare(pow<2>(2 * m), 4 * pow<2>(m), 4 * m2));