Skip to content

Commit

Permalink
Merge pull request #634 from chiphogg/chiphogg/mod#509
Browse files Browse the repository at this point in the history
Add helpers for modular arithmetic
  • Loading branch information
mpusz authored Nov 13, 2024
2 parents 9dd59e8 + b99faf0 commit 5cd07bc
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
96 changes: 96 additions & 0 deletions src/core/include/mp-units/ext/prime.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <mp-units/ext/algorithm.h>

#ifndef MP_UNITS_IN_MODULE_INTERFACE
#include <mp-units/ext/contracts.h>
#ifdef MP_UNITS_IMPORT_STD
import std;
#else
Expand All @@ -42,6 +43,101 @@ import std;

namespace mp_units::detail {

// (a + b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval std::uint64_t add_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n)
{
MP_UNITS_EXPECTS_DEBUG(a < n);
MP_UNITS_EXPECTS_DEBUG(b < n);
MP_UNITS_EXPECTS_DEBUG(n > 0u);

if (a >= n - b) {
return a - (n - b);
} else {
return a + b;
}
}

// (a - b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval std::uint64_t sub_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n)
{
MP_UNITS_EXPECTS_DEBUG(a < n);
MP_UNITS_EXPECTS_DEBUG(b < n);
MP_UNITS_EXPECTS_DEBUG(n > 0u);

if (a >= b) {
return a - b;
} else {
return n - (b - a);
}
}

// (a * b) % n.
//
// Precondition: (a < n).
// Precondition: (b < n).
// Precondition: (n > 0).
[[nodiscard]] consteval std::uint64_t mul_mod(std::uint64_t a, std::uint64_t b, std::uint64_t n)
{
MP_UNITS_EXPECTS_DEBUG(a < n);
MP_UNITS_EXPECTS_DEBUG(b < n);
MP_UNITS_EXPECTS_DEBUG(n > 0u);

if (b == 0u || a < std::numeric_limits<std::uint64_t>::max() / b) {
return (a * b) % n;
}

const std::uint64_t batch_size = n / a;
const std::uint64_t num_batches = b / batch_size;

return add_mod(
// Transform into "negative space" to make the first parameter as small as possible;
// then, transform back.
n - mul_mod(n % a, num_batches, n),

// Handle the leftover product (which is guaranteed to fit in the integer type).
(a * (b % batch_size)) % n,

n);
}

// (a / 2) % n.
//
// Precondition: (a < n).
// Precondition: (n % 2 == 1).
[[nodiscard]] consteval std::uint64_t half_mod_odd(std::uint64_t a, std::uint64_t n)
{
MP_UNITS_EXPECTS_DEBUG(a < n);
MP_UNITS_EXPECTS_DEBUG(n % 2 == 1);

return (a / 2u) + ((a % 2u == 0u) ? 0u : (n / 2u + 1u));
}

// (base ^ exp) % n.
[[nodiscard]] consteval std::uint64_t pow_mod(std::uint64_t base, std::uint64_t exp, std::uint64_t n)
{
std::uint64_t result = 1u;
base %= n;

while (exp > 0u) {
if (exp % 2u == 1u) {
result = mul_mod(result, base, n);
}

exp /= 2u;
base = mul_mod(base, base, n);
}

return result;
}

[[nodiscard]] consteval bool is_prime_by_trial_division(std::uintmax_t n)
{
for (std::uintmax_t f = 2; f * f <= n; f += 1 + (f % 2)) {
Expand Down
26 changes: 26 additions & 0 deletions test/static/prime_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ using namespace mp_units::detail;

namespace {

inline constexpr auto MAX_U64 = std::numeric_limits<std::uint64_t>::max();

template<std::size_t BasisSize, std::size_t... Is>
constexpr bool check_primes(std::index_sequence<Is...>)
{
Expand Down Expand Up @@ -78,4 +80,28 @@ static_assert(!wheel_factorizer<3>::is_prime(0));
static_assert(!wheel_factorizer<3>::is_prime(1));
static_assert(wheel_factorizer<3>::is_prime(2));

// Modular arithmetic.
static_assert(add_mod(1u, 2u, 5u) == 3u);
static_assert(add_mod(4u, 4u, 5u) == 3u);
static_assert(add_mod(MAX_U64 - 1u, MAX_U64 - 2u, MAX_U64) == MAX_U64 - 3u);

static_assert(sub_mod(2u, 1u, 5u) == 1u);
static_assert(sub_mod(1u, 2u, 5u) == 4u);
static_assert(sub_mod(MAX_U64 - 2u, MAX_U64 - 1u, MAX_U64) == MAX_U64 - 1u);
static_assert(sub_mod(1u, MAX_U64 - 1u, MAX_U64) == 2u);

static_assert(mul_mod(6u, 7u, 10u) == 2u);
static_assert(mul_mod(13u, 11u, 50u) == 43u);
static_assert(mul_mod(MAX_U64 / 2u, 10u, MAX_U64) == MAX_U64 - 5u);

static_assert(half_mod_odd(0u, 11u) == 0u);
static_assert(half_mod_odd(10u, 11u) == 5u);
static_assert(half_mod_odd(1u, 11u) == 6u);
static_assert(half_mod_odd(9u, 11u) == 10u);
static_assert(half_mod_odd(MAX_U64 - 1u, MAX_U64) == (MAX_U64 - 1u) / 2u);
static_assert(half_mod_odd(MAX_U64 - 2u, MAX_U64) == MAX_U64 - 1u);

static_assert(pow_mod(5u, 8u, 9u) == ((5u * 5u * 5u * 5u) * (5u * 5u * 5u * 5u)) % 9u);
static_assert(pow_mod(2u, 64u, MAX_U64) == 1u);

} // namespace

0 comments on commit 5cd07bc

Please sign in to comment.