Skip to content

Commit

Permalink
refactor: 💥 op==(U1, U2) now checks for the same type (old behavior…
Browse files Browse the repository at this point in the history
… available as `equivalent(U1, U2)`) + `convertible` now verifies associated `quantity_spec` as well
  • Loading branch information
mpusz committed Oct 9, 2024
1 parent e3ce507 commit 70a18fe
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 72 deletions.
11 changes: 6 additions & 5 deletions src/core/include/mp-units/bits/sudo_cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,13 @@ struct conversion_value_traits {
*/
template<Quantity To, typename FwdFrom, Quantity From = std::remove_cvref_t<FwdFrom>>
requires(castable(From::quantity_spec, To::quantity_spec)) &&
((From::unit == To::unit && std::constructible_from<typename To::rep, typename From::rep>) ||
(From::unit != To::unit)) // && scalable_with_<typename To::rep>))
(((equivalent(From::unit, To::unit)) && std::constructible_from<typename To::rep, typename From::rep>) ||
(!equivalent(From::unit, To::unit))) // && scalable_with_<typename To::rep>))
// TODO how to constrain the second part here?
[[nodiscard]] constexpr To sudo_cast(FwdFrom&& q)
{
constexpr auto q_unit = From::unit;
if constexpr (q_unit == To::unit) {
if constexpr (equivalent(q_unit, To::unit)) {
// no scaling of the number needed
return {static_cast<To::rep>(std::forward<FwdFrom>(q).numerical_value_is_an_implementation_detail_),
To::reference}; // this is the only (and recommended) way to do a truncating conversion on a number, so we
Expand Down Expand Up @@ -149,8 +149,9 @@ template<Quantity To, typename FwdFrom, Quantity From = std::remove_cvref_t<FwdF
template<QuantityPoint ToQP, typename FwdFromQP, QuantityPoint FromQP = std::remove_cvref_t<FwdFromQP>>
requires(castable(FromQP::quantity_spec, ToQP::quantity_spec)) &&
(detail::same_absolute_point_origins(ToQP::point_origin, FromQP::point_origin)) &&
((FromQP::unit == ToQP::unit && std::constructible_from<typename ToQP::rep, typename FromQP::rep>) ||
(FromQP::unit != ToQP::unit))
(((equivalent(FromQP::unit, ToQP::unit)) &&
std::constructible_from<typename ToQP::rep, typename FromQP::rep>) ||
(!equivalent(FromQP::unit, ToQP::unit)))
[[nodiscard]] constexpr QuantityPoint auto sudo_cast(FwdFromQP&& qp)
{
if constexpr (is_same_v<std::remove_const_t<decltype(ToQP::point_origin)>,
Expand Down
8 changes: 4 additions & 4 deletions src/core/include/mp-units/framework/quantity.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ concept CommonlyInvocableQuantities =
InvocableQuantities<Func, Q1, Q2, get_common_quantity_spec(Q1::quantity_spec, Q2::quantity_spec).character>;

template<auto R1, auto R2, typename Rep1, typename Rep2>
concept SameValueAs = SameReference<R1, R2> && std::same_as<Rep1, Rep2>;
concept SameValueAs = (equivalent(get_unit(R1), get_unit(R2))) && std::convertible_to<Rep1, Rep2>;

template<typename T>
using quantity_like_type = quantity<quantity_like_traits<T>::reference, typename quantity_like_traits<T>::rep>;
Expand Down Expand Up @@ -261,21 +261,21 @@ class quantity {

// data access
template<Unit U>
requires(U{} == unit)
requires(equivalent(U{}, unit))
[[nodiscard]] constexpr rep& numerical_value_ref_in(U) & noexcept
{
return numerical_value_is_an_implementation_detail_;
}

template<Unit U>
requires(U{} == unit)
requires(equivalent(U{}, unit))
[[nodiscard]] constexpr const rep& numerical_value_ref_in(U) const& noexcept
{
return numerical_value_is_an_implementation_detail_;
}

template<Unit U>
requires(U{} == unit)
requires(equivalent(U{}, unit))
constexpr const rep&& numerical_value_ref_in(U) const&& noexcept
#if __cpp_deleted_function
= delete("Can't form a reference to a temporary");
Expand Down
26 changes: 19 additions & 7 deletions src/core/include/mp-units/framework/unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ struct unit_less : std::bool_constant<type_name<Lhs>() < type_name<Rhs>()> {};
template<typename T1, typename T2>
using type_list_of_unit_less = expr_less<T1, T2, unit_less>;

template<typename From, typename To>
concept PotentiallyConvertibleTo = Unit<From> && Unit<To> &&
((AssociatedUnit<From> && AssociatedUnit<To> &&
implicitly_convertible(get_quantity_spec(From{}), get_quantity_spec(To{}))) ||
(!AssociatedUnit<From> && !AssociatedUnit<To>));

} // namespace detail

// TODO this should really be in the `details` namespace but is used in `chrono.h` (a part of mp_units.systems)
Expand All @@ -134,9 +140,11 @@ template<Unit From, Unit To>
{
if constexpr (is_same_v<From, To>)
return true;
else
else if constexpr (detail::PotentiallyConvertibleTo<From, To>)
return is_same_v<decltype(get_canonical_unit(from).reference_unit),
decltype(get_canonical_unit(to).reference_unit)>;
else
return false;
}

namespace detail {
Expand Down Expand Up @@ -192,12 +200,16 @@ struct unit_interface {
return expr_divide<derived_unit, struct one, type_list_of_unit_less>(lhs, rhs);
}

[[nodiscard]] friend consteval bool operator==(Unit auto lhs, Unit auto rhs)
template<Unit Lhs, Unit Rhs>
[[nodiscard]] friend consteval bool operator==(Lhs, Rhs)
{
return is_same_v<Lhs, Rhs>;
}

[[nodiscard]] friend consteval bool equivalent(Unit auto lhs, Unit auto rhs)
requires(convertible(lhs, rhs))
{
auto canonical_lhs = get_canonical_unit(lhs);
auto canonical_rhs = get_canonical_unit(rhs);
return convertible(canonical_lhs.reference_unit, canonical_rhs.reference_unit) &&
canonical_lhs.mag == canonical_rhs.mag;
return get_canonical_unit(lhs).mag == get_canonical_unit(rhs).mag;
}
};

Expand Down Expand Up @@ -662,7 +674,7 @@ template<Unit U1, Unit U2>
{
if constexpr (is_same_v<U1, U2>)
return u1;
else if constexpr (U1{} == U2{}) {
else if constexpr (equivalent(U1{}, U2{})) {
if constexpr (std::derived_from<U1, typename U2::_base_type_>)
return u1;
else if constexpr (std::derived_from<U2, typename U1::_base_type_>)
Expand Down
118 changes: 62 additions & 56 deletions test/static/unit_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "test_tools.h"
#include <mp-units/ext/type_traits.h>
#include <mp-units/framework.h>
#include <mp-units/systems/isq.h>
#include <mp-units/systems/si/prefixes.h>
#ifdef MP_UNITS_IMPORT_STD
import std;
Expand All @@ -38,41 +39,29 @@ using namespace mp_units::detail;
using one_ = struct one;
using percent_ = struct percent;

// base dimensions
// clang-format off
inline constexpr struct dim_length_ final : base_dimension<"L"> {} dim_length;
inline constexpr struct dim_mass_ final : base_dimension<"M"> {} dim_mass;
inline constexpr struct dim_time_ final : base_dimension<"T"> {} dim_time;
inline constexpr struct dim_thermodynamic_temperature_ final : base_dimension<symbol_text{u8"Θ", "O"}> {} dim_thermodynamic_temperature;

// quantities specification
QUANTITY_SPEC_(length, dim_length);
QUANTITY_SPEC_(mass, dim_mass);
QUANTITY_SPEC_(time, dim_time);
QUANTITY_SPEC_(thermodynamic_temperature, dim_thermodynamic_temperature);

// prefixes
template<PrefixableUnit U> struct milli_ final : prefixed_unit<"m", mag_power<10, -3>, U{}> {};
template<PrefixableUnit U> struct kilo_ final : prefixed_unit<"k", mag_power<10, 3>, U{}> {};
template<PrefixableUnit auto U> constexpr milli_<MP_UNITS_REMOVE_CONST(decltype(U))> milli;
template<PrefixableUnit auto U> constexpr kilo_<MP_UNITS_REMOVE_CONST(decltype(U))> kilo;

// base units
inline constexpr struct second_ final : named_unit<"s", kind_of<time>> {} second;
inline constexpr struct metre_ final : named_unit<"m", kind_of<length>> {} metre;
inline constexpr struct gram_ final : named_unit<"g", kind_of<mass>> {} gram;
inline constexpr struct second_ final : named_unit<"s", kind_of<isq::time>> {} second;
inline constexpr struct metre_ final : named_unit<"m", kind_of<isq::length>> {} metre;
inline constexpr struct gram_ final : named_unit<"g", kind_of<isq::mass>> {} gram;
inline constexpr auto kilogram = kilo<gram>;
inline constexpr struct kelvin_ final : named_unit<"K", kind_of<thermodynamic_temperature>> {} kelvin;
inline constexpr struct kelvin_ final : named_unit<"K", kind_of<isq::thermodynamic_temperature>> {} kelvin;

// hypothetical natural units for c=1
inline constexpr struct nu_second_ final : named_unit<"s"> {} nu_second;

// derived named units
inline constexpr struct radian_ final : named_unit<"rad", metre / metre> {} radian;
inline constexpr struct radian_ final : named_unit<"rad", metre / metre, kind_of<isq::angular_measure>> {} radian;
inline constexpr struct revolution_ final : named_unit<"rev", mag<2> * mag<pi> * radian> {} revolution;
inline constexpr struct steradian_ final : named_unit<"sr", square(metre) / square(metre)> {} steradian;
inline constexpr struct hertz_ final : named_unit<"Hz", inverse(second)> {} hertz;
inline constexpr struct becquerel_ final : named_unit<"Bq", inverse(second)> {} becquerel;
inline constexpr struct steradian_ final : named_unit<"sr", square(metre) / square(metre), kind_of<isq::solid_angular_measure>> {} steradian;
inline constexpr struct hertz_ final : named_unit<"Hz", inverse(second), kind_of<isq::frequency>> {} hertz;
inline constexpr struct becquerel_ final : named_unit<"Bq", inverse(second), kind_of<isq::activity>> {} becquerel;
inline constexpr struct newton_ final : named_unit<"N", kilogram * metre / square(second)> {} newton;
inline constexpr struct pascal_ final : named_unit<"Pa", newton / square(metre)> {} pascal;
inline constexpr struct joule_ final : named_unit<"J", newton * metre> {} joule;
Expand Down Expand Up @@ -140,7 +129,8 @@ static_assert(is_of_type<degree_Celsius, degree_Celsius_>);
static_assert(is_of_type<get_canonical_unit(degree_Celsius).reference_unit, kelvin_>);
static_assert(get_canonical_unit(degree_Celsius).mag == mag<1>);
static_assert(convertible(degree_Celsius, kelvin));
static_assert(degree_Celsius == kelvin);
static_assert(degree_Celsius != kelvin);
static_assert(equivalent(degree_Celsius, kelvin));

static_assert(is_of_type<radian, radian_>);
static_assert(is_of_type<get_canonical_unit(radian).reference_unit, one_>);
Expand All @@ -155,8 +145,8 @@ static_assert(radian != degree);
static_assert(is_of_type<steradian, steradian_>);
static_assert(is_of_type<get_canonical_unit(steradian).reference_unit, one_>);
static_assert(get_canonical_unit(steradian).mag == mag<1>);
static_assert(convertible(radian, steradian)); // !!!
static_assert(radian == steradian); // !!!
static_assert(!convertible(radian, steradian));
static_assert(radian != steradian);

static_assert(is_of_type<minute, minute_>);
static_assert(is_of_type<get_canonical_unit(minute).reference_unit, second_>);
Expand Down Expand Up @@ -414,50 +404,59 @@ concept invalid_operations = requires {
requires !requires { 2 == s; };
requires !requires { s < 2; };
requires !requires { 2 < s; };
requires !requires { s + time[second]; };
requires !requires { s - time[second]; };
requires !requires { s < time[second]; };
requires !requires { time[second] + s; };
requires !requires { time[second] - s; };
requires !requires { s + 1 * time[second]; };
requires !requires { s - 1 * time[second]; };
requires !requires { s * 1 * time[second]; };
requires !requires { s / 1 * time[second]; };
requires !requires { s == 1 * time[second]; };
requires !requires { s == 1 * time[second]; };
requires !requires { 1 * time[second] + s; };
requires !requires { 1 * time[second] - s; };
requires !requires { 1 * time[second] == s; };
requires !requires { 1 * time[second] < s; };
requires !requires { s + isq::time[second]; };
requires !requires { s - isq::time[second]; };
requires !requires { s < isq::time[second]; };
requires !requires { isq::time[second] + s; };
requires !requires { isq::time[second] - s; };
requires !requires { s + 1 * isq::time[second]; };
requires !requires { s - 1 * isq::time[second]; };
requires !requires { s * 1 * isq::time[second]; };
requires !requires { s / 1 * isq::time[second]; };
requires !requires { s == 1 * isq::time[second]; };
requires !requires { s == 1 * isq::time[second]; };
requires !requires { 1 * isq::time[second] + s; };
requires !requires { 1 * isq::time[second] - s; };
requires !requires { 1 * isq::time[second] == s; };
requires !requires { 1 * isq::time[second] < s; };
};
static_assert(invalid_operations<second>);

// comparisons of the same units
static_assert(second == second);
static_assert(metre / second == metre / second);
static_assert(milli<metre> / milli<second> == si::micro<metre> / si::micro<second>);
static_assert(milli<metre> / si::micro<second> == si::micro<metre> / si::nano<second>);
static_assert(si::micro<metre> / milli<second> == si::nano<metre> / si::micro<second>);
static_assert(milli<metre> * kilo<metre> == si::deci<metre> * si::deca<metre>);
static_assert(kilo<metre> * milli<metre> == si::deca<metre> * si::deci<metre>);
static_assert(milli<metre> / milli<second> != si::micro<metre> / si::micro<second>);
static_assert(equivalent(milli<metre> / milli<second>, si::micro<metre> / si::micro<second>));
static_assert(milli<metre> / si::micro<second> != si::micro<metre> / si::nano<second>);
static_assert(equivalent(milli<metre> / si::micro<second>, si::micro<metre> / si::nano<second>));
static_assert(si::micro<metre> / milli<second> != si::nano<metre> / si::micro<second>);
static_assert(equivalent(si::micro<metre> / milli<second>, si::nano<metre> / si::micro<second>));
static_assert(milli<metre> * kilo<metre> != si::deci<metre> * si::deca<metre>);
static_assert(equivalent(milli<metre> * kilo<metre>, si::deci<metre>* si::deca<metre>));
static_assert(kilo<metre> * milli<metre> != si::deca<metre> * si::deci<metre>);
static_assert(equivalent(kilo<metre> * milli<metre>, si::deca<metre>* si::deci<metre>));

// comparisons of equivalent units (named vs unnamed/derived)
static_assert(one / second == hertz);
static_assert(one / second != hertz);
static_assert(equivalent(one / second, hertz));
static_assert(convertible(one / second, hertz));

// comparisons of equivalent units of different quantities
static_assert(hertz == becquerel);
static_assert(convertible(hertz, becquerel));
static_assert(hertz != becquerel);
static_assert(!convertible(hertz, becquerel));

// comparisons of scaled units
static_assert(kilo<metre> == kilometre);
static_assert(mag<1000> * metre == kilo<metre>);
static_assert(mag<1000> * metre == kilometre);
static_assert(mag<1000> * metre != kilo<metre>);
static_assert(equivalent(mag<1000> * metre, kilo<metre>));
static_assert(mag<1000> * metre != kilometre);
static_assert(equivalent(mag<1000> * metre, kilometre));
static_assert(convertible(kilo<metre>, kilometre));
static_assert(convertible(mag<1000> * metre, kilo<metre>));
static_assert(convertible(mag<1000> * metre, kilometre));

static_assert(mag<60> * metre / second == metre / (mag_ratio<1, 60> * second));
static_assert(mag<60> * metre / second != metre / (mag_ratio<1, 60> * second));
static_assert(equivalent(mag<60> * metre / second, metre / (mag_ratio<1, 60> * second)));

static_assert(metre != kilometre);
static_assert(convertible(metre, kilometre));
Expand All @@ -474,20 +473,27 @@ static_assert(!convertible(metre, metre* metre));
static_assert(is_of_type<metre / metre, one_>);
static_assert(is_of_type<kilo<metre> / metre, derived_unit<kilo_<metre_>, per<metre_>>>);
static_assert(metre / metre == one);
static_assert(hertz * second == one);
static_assert(hertz * second != one);
static_assert(equivalent(hertz * second, one));
static_assert(one * one == one);
static_assert(is_of_type<one * one, one_>);
static_assert(one * percent == percent);
static_assert(percent * one == percent);
static_assert(is_of_type<one * percent, percent_>);
static_assert(is_of_type<percent * one, percent_>);

static_assert(hertz == one / second);
static_assert(newton == kilogram * metre / square(second));
static_assert(joule == kilogram * square(metre) / square(second));
static_assert(joule == newton * metre);
static_assert(watt == joule / second);
static_assert(watt == kilogram * square(metre) / cubic(second));
static_assert(hertz != one / second);
static_assert(equivalent(hertz, one / second));
static_assert(newton != kilogram * metre / square(second));
static_assert(equivalent(newton, kilogram* metre / square(second)));
static_assert(joule != kilogram * square(metre) / square(second));
static_assert(equivalent(joule, kilogram* square(metre) / square(second)));
static_assert(joule != newton * metre);
static_assert(equivalent(joule, newton* metre));
static_assert(watt != joule / second);
static_assert(equivalent(watt, joule / second));
static_assert(watt != kilogram * square(metre) / cubic(second));
static_assert(equivalent(watt, kilogram* square(metre) / cubic(second)));

// power
static_assert(is_same_v<decltype(pow<2>(metre)), decltype(metre * metre)>);
Expand Down

0 comments on commit 70a18fe

Please sign in to comment.