Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(crypto): optimize TwoAdicFri #555

Merged
merged 3 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions tachyon/crypto/commitments/fri/two_adic_fri.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,13 @@ class TwoAdicFRI {
reduced_openings[log_num_rows];
CHECK_EQ(reduced_opening_for_log_num_rows.size(), num_rows);

math::RowMajorMatrix<F> block =
mat.topRows(num_rows >> config_.log_blowup);
ReverseMatrixIndexBits(block);
math::RowMajorMatrix<F> block;
block.resize(num_rows >> config_.log_blowup, mat.cols());
OMP_PARALLEL_FOR(size_t row = 0; row < num_rows >> config_.log_blowup;
++row) {
block.row(row) = mat.row(base::bits::ReverseBitsLen(
row, log_num_rows - config_.log_blowup));
}
std::vector<ExtF> reduced_rows = DotExtPowers(mat, alpha);

// TODO(ashjeong): Determine if using a matrix is a better fit.
Expand Down
45 changes: 34 additions & 11 deletions tachyon/math/base/semigroups.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,24 @@
std::declval<T>().Name##InPlace()))> \
: std::true_type {}

#define SUPPORTS_DEDICATED_EXP_OPERATOR(Pow) \
template <typename T, typename = void> \
struct SupportsExp##Pow : std::false_type {}; \
\
template <typename T> \
struct SupportsExp##Pow<T, decltype(void(std::declval<T>().Exp##Pow()))> \
: std::true_type {};

namespace tachyon::math {
namespace internal {

SUPPORTS_BINARY_OPERATOR(Mul);
SUPPORTS_UNARY_OPERATOR(SquareImpl);
SUPPORTS_BINARY_OPERATOR(Add);
SUPPORTS_UNARY_OPERATOR(DoubleImpl);
SUPPORTS_DEDICATED_EXP_OPERATOR(3);
SUPPORTS_DEDICATED_EXP_OPERATOR(5);
SUPPORTS_DEDICATED_EXP_OPERATOR(7);

template <typename T, typename = void>
struct SupportsSize : std::false_type {};
Expand Down Expand Up @@ -157,24 +168,36 @@ class MultiplicativeSemigroup {
return g;
else if constexpr (Power == 2)
return Square();
else if constexpr (Power == 3)
return Square() * g;
else if constexpr (Power == 4)
else if constexpr (Power == 3) {
if constexpr (internal::SupportsExp3<G>::value) {
return g.Exp3();
} else {
return Square() * g;
}
} else if constexpr (Power == 4) {
return Square().Square();
else if constexpr (Power == 5) {
MulResult g4 = Square();
g4.SquareInPlace();
return g4 * g;
} else if constexpr (Power == 5) {
if constexpr (internal::SupportsExp5<G>::value) {
return g.Exp5();
} else {
MulResult g4 = Square();
g4.SquareInPlace();
return g4 * g;
}
} else if constexpr (Power == 6) {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2;
} else if constexpr (Power == 7) {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2 * g;
if constexpr (internal::SupportsExp7<G>::value) {
return g.Exp7();
} else {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2 * g;
}
} else {
return DoPow(BigInt<1>(Power));
}
Expand Down
5 changes: 4 additions & 1 deletion tachyon/math/finite_fields/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ tachyon_cc_library(
tachyon_cc_library(
name = "packed_prime_field32_avx2",
hdrs = ["packed_prime_field32_avx2.h"],
deps = ["//tachyon/base:compiler_specific"],
deps = [
"//tachyon/base:compiler_specific",
"//tachyon/base/functional:callback",
],
)

tachyon_cc_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,16 @@ PackedBabyBearAVX2 PackedBabyBearAVX2::Mul(
return FromVector(math::Mul(ToVector(*this), ToVector(other)));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp3() const {
return FromVector(math::Exp3(ToVector(*this), kP, kInv));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp5() const {
return FromVector(math::Exp5(ToVector(*this), kP, kInv));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp7() const {
return FromVector(math::Exp7(ToVector(*this), kP, kInv));
}

} // namespace tachyon::math
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class TACHYON_EXPORT PackedBabyBearAVX2 final

// MultiplicativeSemigroup methods
PackedBabyBearAVX2 Mul(const PackedBabyBearAVX2& other) const;

PackedBabyBearAVX2 Exp3() const;
PackedBabyBearAVX2 Exp5() const;
PackedBabyBearAVX2 Exp7() const;
};

} // namespace tachyon::math
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,16 @@ PackedKoalaBearAVX2 PackedKoalaBearAVX2::Mul(
return FromVector(math::Mul(ToVector(*this), ToVector(other)));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp3() const {
return FromVector(math::Exp3(ToVector(*this), kP, kInv));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp5() const {
return FromVector(math::Exp5(ToVector(*this), kP, kInv));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp7() const {
return FromVector(math::Exp7(ToVector(*this), kP, kInv));
}

} // namespace tachyon::math
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class TACHYON_EXPORT PackedKoalaBearAVX2 final

// MultiplicativeSemigroup methods
PackedKoalaBearAVX2 Mul(const PackedKoalaBearAVX2& other) const;

PackedKoalaBearAVX2 Exp3() const;
PackedKoalaBearAVX2 Exp5() const;
PackedKoalaBearAVX2 Exp7() const;
};

} // namespace tachyon::math
Expand Down
143 changes: 137 additions & 6 deletions tachyon/math/finite_fields/packed_prime_field32_avx2.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
// Copyright (c) 2022 The Plonky3 Authors
// Use of this source code is governed by a MIT/Apache-2.0 style license that
// can be found in the LICENSE-MIT.plonky3 and the LICENCE-APACHE.plonky3
// file.

#ifndef TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX2_H_
#define TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX2_H_

#include <immintrin.h>

#include "tachyon/base/compiler_specific.h"
#include "tachyon/base/functional/callback.h"

namespace tachyon::math {

Expand Down Expand Up @@ -166,11 +172,59 @@ ALWAYS_INLINE __m256i NegateMod32(__m256i val, __m256i p) {
//
// [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann,
// Cambridge University Press, 2010, algorithm 2.7.
ALWAYS_INLINE __m256i MontyD(__m256i lhs, __m256i rhs, __m256i p, __m256i inv) {
__m256i prod = _mm256_mul_epu32(lhs, rhs);
__m256i q = _mm256_mul_epu32(prod, inv);

// We provide 2 variants of Montgomery reduction depending on if the inputs are
// unsigned or signed. The unsigned variant follows steps 1 and 2 in the above
// protocol to produce D in (-P, ..., P). For the signed variant we assume -PB/2
// < C < PB/2 and let Q := μ C mod B be the unique representative in [-B/2, ...,
// B/2 - 1]. The division in step 2 is clearly still exact and |C - Q P| <= |C|
// + |Q||P| < PB so D still lies in (-P, ..., P).

// Perform a partial Montgomery reduction on each 64 bit element.
// Input must lie in {0, ..., 2³²P}.
// The output will lie in {-P, ..., P} and be stored in the upper 32 bits.
ALWAYS_INLINE __m256i PartialMontyRedUnsignedToSigned(__m256i input, __m256i p,
__m256i inv) {
__m256i q = _mm256_mul_epu32(input, inv);
__m256i q_p = _mm256_mul_epu32(q, p);
return _mm256_sub_epi64(prod, q_p);
// By construction, the bottom 32 bits of input and q_p are equal.
// Thus |_mm256_sub_epi32| and |_mm256_sub_epi64| should act identically.
// However for some reason, the compiler gets confused if we use
// |_mm256_sub_epi64| and outputs a load of nonsense, see:
// https://godbolt.org/z/3W8M7Tv84.
return _mm256_sub_epi32(input, q_p);
}
// Perform a partial Montgomery reduction on each 64 bit element.
// Input must lie in {-2³¹P, ..., 2³¹P}.
// The output will lie in {-P, ..., P} and be stored in the upper 32 bits.
ALWAYS_INLINE __m256i PartialMontyRedSignedToSigned(__m256i input, __m256i p,
__m256i inv) {
__m256i q = _mm256_mul_epi32(input, inv);
__m256i q_p = _mm256_mul_epi32(q, p);
// Unlike the previous case the compiler output is essentially identical
// between |_mm256_sub_epi32| and |_mm256_sub_epi64|. We use
// |_mm256_sub_epi32| again just for consistency.
return _mm256_sub_epi32(input, q_p);
}

// Multiply the field elements in the even index entries.
// |lhs[2i]|, |rhs[2i]| must be unsigned 32-bit integers such that
// |lhs[2i]| * |rhs[2i]| lies in {0, ..., 2³²P}.
// The output will lie in {-P, ..., P} and be stored in |output[2i + 1]|.
ALWAYS_INLINE __m256i MontyMul(__m256i lhs, __m256i rhs, __m256i p,
__m256i inv) {
__m256i prod = _mm256_mul_epu32(lhs, rhs);
return PartialMontyRedSignedToSigned(prod, p, inv);
}

// Multiply the field elements in the even index entries.
// |lhs[2i]|, |rhs[2i]| must be signed 32-bit integers such that
// |lhs[2i]| * |rhs[2i]| lies in {-2³¹P, ..., 2³¹P}.
// The output will lie in {-P, ..., P} stored in |output[2i + 1]|.
ALWAYS_INLINE __m256i MontyMulSigned(__m256i lhs, __m256i rhs, __m256i p,
__m256i inv) {
__m256i prod = _mm256_mul_epi32(lhs, rhs);
return PartialMontyRedSignedToSigned(prod, p, inv);
}

ALWAYS_INLINE __m256i movehdup_epi32(__m256i x) {
Expand Down Expand Up @@ -205,8 +259,8 @@ ALWAYS_INLINE __m256i MontMulMod32(__m256i lhs, __m256i rhs, __m256i p,
__m256i lhs_odd = movehdup_epi32(lhs);
__m256i rhs_odd = movehdup_epi32(rhs);

__m256i d_evn = MontyD(lhs_evn, rhs_evn, p, inv);
__m256i d_odd = MontyD(lhs_odd, rhs_odd, p, inv);
__m256i d_evn = MontyMul(lhs_evn, rhs_evn, p, inv);
__m256i d_odd = MontyMul(lhs_odd, rhs_odd, p, inv);

__m256i d_evn_hi = movehdup_epi32(d_evn);
__m256i t = _mm256_blend_epi32(d_evn_hi, d_odd, 0b10101010);
Expand All @@ -215,6 +269,83 @@ ALWAYS_INLINE __m256i MontMulMod32(__m256i lhs, __m256i rhs, __m256i p,
return _mm256_min_epu32(t, u);
}

// Square the field elements in the even index entries.
// Inputs must be signed 32-bit integers.
// Outputs will be a signed integer in (-P, ..., P) copied into both the even
// and odd indices.
ALWAYS_INLINE __m256i ShiftedSquare(__m256i input, __m256i p, __m256i inv) {
// Note that we do not need a restriction on the size of |input[i]²| as
// 2³⁰ < P and |i32| <= 2³¹ and so => |input[i]²| <= 2⁶² < 2³²P.
__m256i square = _mm256_mul_epi32(input, input);
__m256i square_red = PartialMontyRedSignedToSigned(square, p, inv);
return movehdup_epi32(square_red);
}

// Apply callback to the even and odd indices of the input vector.
// callback should only depend in the 32 bit entries in the even indices.
// The output of callback must lie in (-P, ..., P) and be stored in the odd
// indices. The even indices of the output of callback will not be read. The
// input should conform to the requirements of |callback|.
// NOTE(chokobole): This is to suppress the error below.
// clang-format off
// error: ignoring attributes on template argument '__m256i(__m256i, __m256i, __m256i)' [-Werror=ignored-attributes]
// clang-format on
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
ALWAYS_INLINE __m256i ApplyFuncToEvenOdd(
__m256i input, __m256i p, __m256i inv,
base::RepeatingCallback<__m256i(__m256i, __m256i, __m256i)> callback) {
__m256i input_evn = input;
__m256i input_odd = movehdup_epi32(input);
__m256i d_evn = callback.Run(input_evn, p, inv);
__m256i d_odd = callback.Run(input_odd, p, inv);
__m256i d_evn_hi = movehdup_epi32(d_evn);
__m256i t = _mm256_blend_epi32(d_evn_hi, d_odd, 0b10101010);
__m256i u = _mm256_add_epi32(t, p);
return _mm256_min_epu32(t, u);
}
#pragma GCC diagnostic pop

// Cube the field elements in the even index entries.
// Inputs must be signed 32-bit integers in [-P, ..., P].
// Outputs will be a signed integer in (-P, ..., P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp3(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
return MontyMulSigned(square, input, p, inv);
}

ALWAYS_INLINE __m256i Exp3(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp3);
}

// Take the fifth power of the field elements in the even index
// entries. Inputs must be signed 32-bit integers in [-P, ..., P]. Outputs will
// be a signed integer in (-P, ..., P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp5(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
__m256i quad = ShiftedSquare(square, p, inv);
return MontyMulSigned(quad, input, p, inv);
}

ALWAYS_INLINE __m256i Exp5(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp5);
}

/// Take the seventh power of the field elements in the even index
/// entries. Inputs must lie in [-P, ..., P]. Outputs will also lie in (-P, ...,
/// P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp7(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
__m256i cube = MontyMulSigned(square, input, p, inv);
__m256i cube_shifted = movehdup_epi32(cube);
__m256i quad = ShiftedSquare(square, p, inv);
return MontyMulSigned(quad, cube_shifted, p, inv);
}

ALWAYS_INLINE __m256i Exp7(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp7);
}

} // namespace tachyon::math

#endif // TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX2_H_
5 changes: 5 additions & 0 deletions tachyon/math/finite_fields/packed_prime_field32_avx512.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
// Copyright (c) 2022 The Plonky3 Authors
// Use of this source code is governed by a MIT/Apache-2.0 style license that
// can be found in the LICENSE-MIT.plonky3 and the LICENCE-APACHE.plonky3
// file.

#ifndef TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX512_H_
#define TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX512_H_

Expand Down
5 changes: 5 additions & 0 deletions tachyon/math/finite_fields/packed_prime_field32_neon.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
// Copyright (c) 2022 The Plonky3 Authors
// Use of this source code is governed by a MIT/Apache-2.0 style license that
// can be found in the LICENSE-MIT.plonky3 and the LICENCE-APACHE.plonky3
// file.

#ifndef TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_NEON_H_
#define TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_NEON_H_

Expand Down