Skip to content

Commit

Permalink
Merge pull request #555 from kroma-network/perf/optimize-create-openi…
Browse files Browse the repository at this point in the history
…ng-proof

perf(crypto): optimize `TwoAdicFri`
  • Loading branch information
chokobole authored Nov 4, 2024
2 parents 87116f9 + 2e0f7f0 commit d0ae6c9
Show file tree
Hide file tree
Showing 10 changed files with 224 additions and 21 deletions.
10 changes: 7 additions & 3 deletions tachyon/crypto/commitments/fri/two_adic_fri.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,13 @@ class TwoAdicFRI {
reduced_openings[log_num_rows];
CHECK_EQ(reduced_opening_for_log_num_rows.size(), num_rows);

math::RowMajorMatrix<F> block =
mat.topRows(num_rows >> config_.log_blowup);
ReverseMatrixIndexBits(block);
math::RowMajorMatrix<F> block;
block.resize(num_rows >> config_.log_blowup, mat.cols());
OMP_PARALLEL_FOR(size_t row = 0; row < num_rows >> config_.log_blowup;
++row) {
block.row(row) = mat.row(base::bits::ReverseBitsLen(
row, log_num_rows - config_.log_blowup));
}
std::vector<ExtF> reduced_rows = DotExtPowers(mat, alpha);

// TODO(ashjeong): Determine if using a matrix is a better fit.
Expand Down
45 changes: 34 additions & 11 deletions tachyon/math/base/semigroups.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,24 @@
std::declval<T>().Name##InPlace()))> \
: std::true_type {}

#define SUPPORTS_DEDICATED_EXP_OPERATOR(Pow) \
template <typename T, typename = void> \
struct SupportsExp##Pow : std::false_type {}; \
\
template <typename T> \
struct SupportsExp##Pow<T, decltype(void(std::declval<T>().Exp##Pow()))> \
: std::true_type {};

namespace tachyon::math {
namespace internal {

SUPPORTS_BINARY_OPERATOR(Mul);
SUPPORTS_UNARY_OPERATOR(SquareImpl);
SUPPORTS_BINARY_OPERATOR(Add);
SUPPORTS_UNARY_OPERATOR(DoubleImpl);
SUPPORTS_DEDICATED_EXP_OPERATOR(3);
SUPPORTS_DEDICATED_EXP_OPERATOR(5);
SUPPORTS_DEDICATED_EXP_OPERATOR(7);

template <typename T, typename = void>
struct SupportsSize : std::false_type {};
Expand Down Expand Up @@ -157,24 +168,36 @@ class MultiplicativeSemigroup {
return g;
else if constexpr (Power == 2)
return Square();
else if constexpr (Power == 3)
return Square() * g;
else if constexpr (Power == 4)
else if constexpr (Power == 3) {
if constexpr (internal::SupportsExp3<G>::value) {
return g.Exp3();
} else {
return Square() * g;
}
} else if constexpr (Power == 4) {
return Square().Square();
else if constexpr (Power == 5) {
MulResult g4 = Square();
g4.SquareInPlace();
return g4 * g;
} else if constexpr (Power == 5) {
if constexpr (internal::SupportsExp5<G>::value) {
return g.Exp5();
} else {
MulResult g4 = Square();
g4.SquareInPlace();
return g4 * g;
}
} else if constexpr (Power == 6) {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2;
} else if constexpr (Power == 7) {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2 * g;
if constexpr (internal::SupportsExp7<G>::value) {
return g.Exp7();
} else {
MulResult g2 = Square();
MulResult g4 = g2;
g4.SquareInPlace();
return g4 * g2 * g;
}
} else {
return DoPow(BigInt<1>(Power));
}
Expand Down
5 changes: 4 additions & 1 deletion tachyon/math/finite_fields/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ tachyon_cc_library(
tachyon_cc_library(
name = "packed_prime_field32_avx2",
hdrs = ["packed_prime_field32_avx2.h"],
deps = ["//tachyon/base:compiler_specific"],
deps = [
"//tachyon/base:compiler_specific",
"//tachyon/base/functional:callback",
],
)

tachyon_cc_library(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,16 @@ PackedBabyBearAVX2 PackedBabyBearAVX2::Mul(
return FromVector(math::Mul(ToVector(*this), ToVector(other)));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp3() const {
return FromVector(math::Exp3(ToVector(*this), kP, kInv));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp5() const {
return FromVector(math::Exp5(ToVector(*this), kP, kInv));
}

PackedBabyBearAVX2 PackedBabyBearAVX2::Exp7() const {
return FromVector(math::Exp7(ToVector(*this), kP, kInv));
}

} // namespace tachyon::math
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class TACHYON_EXPORT PackedBabyBearAVX2 final

// MultiplicativeSemigroup methods
PackedBabyBearAVX2 Mul(const PackedBabyBearAVX2& other) const;

PackedBabyBearAVX2 Exp3() const;
PackedBabyBearAVX2 Exp5() const;
PackedBabyBearAVX2 Exp7() const;
};

} // namespace tachyon::math
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,4 +107,16 @@ PackedKoalaBearAVX2 PackedKoalaBearAVX2::Mul(
return FromVector(math::Mul(ToVector(*this), ToVector(other)));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp3() const {
return FromVector(math::Exp3(ToVector(*this), kP, kInv));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp5() const {
return FromVector(math::Exp5(ToVector(*this), kP, kInv));
}

PackedKoalaBearAVX2 PackedKoalaBearAVX2::Exp7() const {
return FromVector(math::Exp7(ToVector(*this), kP, kInv));
}

} // namespace tachyon::math
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ class TACHYON_EXPORT PackedKoalaBearAVX2 final

// MultiplicativeSemigroup methods
PackedKoalaBearAVX2 Mul(const PackedKoalaBearAVX2& other) const;

PackedKoalaBearAVX2 Exp3() const;
PackedKoalaBearAVX2 Exp5() const;
PackedKoalaBearAVX2 Exp7() const;
};

} // namespace tachyon::math
Expand Down
143 changes: 137 additions & 6 deletions tachyon/math/finite_fields/packed_prime_field32_avx2.h
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
// Copyright (c) 2022 The Plonky3 Authors
// Use of this source code is governed by a MIT/Apache-2.0 style license that
// can be found in the LICENSE-MIT.plonky3 and the LICENCE-APACHE.plonky3
// file.

#ifndef TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX2_H_
#define TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX2_H_

#include <immintrin.h>

#include "tachyon/base/compiler_specific.h"
#include "tachyon/base/functional/callback.h"

namespace tachyon::math {

Expand Down Expand Up @@ -166,11 +172,59 @@ ALWAYS_INLINE __m256i NegateMod32(__m256i val, __m256i p) {
//
// [1] Modern Computer Arithmetic, Richard Brent and Paul Zimmermann,
// Cambridge University Press, 2010, algorithm 2.7.
ALWAYS_INLINE __m256i MontyD(__m256i lhs, __m256i rhs, __m256i p, __m256i inv) {
__m256i prod = _mm256_mul_epu32(lhs, rhs);
__m256i q = _mm256_mul_epu32(prod, inv);

// We provide 2 variants of Montgomery reduction depending on if the inputs are
// unsigned or signed. The unsigned variant follows steps 1 and 2 in the above
// protocol to produce D in (-P, ..., P). For the signed variant we assume -PB/2
// < C < PB/2 and let Q := μ C mod B be the unique representative in [-B/2, ...,
// B/2 - 1]. The division in step 2 is clearly still exact and |C - Q P| <= |C|
// + |Q||P| < PB so D still lies in (-P, ..., P).

// Perform a partial Montgomery reduction on each 64 bit element.
// Input must lie in {0, ..., 2³²P}.
// The output will lie in {-P, ..., P} and be stored in the upper 32 bits.
ALWAYS_INLINE __m256i PartialMontyRedUnsignedToSigned(__m256i input, __m256i p,
__m256i inv) {
__m256i q = _mm256_mul_epu32(input, inv);
__m256i q_p = _mm256_mul_epu32(q, p);
return _mm256_sub_epi64(prod, q_p);
// By construction, the bottom 32 bits of input and q_p are equal.
// Thus |_mm256_sub_epi32| and |_mm256_sub_epi64| should act identically.
// However for some reason, the compiler gets confused if we use
// |_mm256_sub_epi64| and outputs a load of nonsense, see:
// https://godbolt.org/z/3W8M7Tv84.
return _mm256_sub_epi32(input, q_p);
}
// Perform a partial Montgomery reduction on each 64 bit element.
// Input must lie in {-2³¹P, ..., 2³¹P}.
// The output will lie in {-P, ..., P} and be stored in the upper 32 bits.
ALWAYS_INLINE __m256i PartialMontyRedSignedToSigned(__m256i input, __m256i p,
__m256i inv) {
__m256i q = _mm256_mul_epi32(input, inv);
__m256i q_p = _mm256_mul_epi32(q, p);
// Unlike the previous case the compiler output is essentially identical
// between |_mm256_sub_epi32| and |_mm256_sub_epi64|. We use
// |_mm256_sub_epi32| again just for consistency.
return _mm256_sub_epi32(input, q_p);
}

// Multiply the field elements in the even index entries.
// |lhs[2i]|, |rhs[2i]| must be unsigned 32-bit integers such that
// |lhs[2i]| * |rhs[2i]| lies in {0, ..., 2³²P}.
// The output will lie in {-P, ..., P} and be stored in |output[2i + 1]|.
ALWAYS_INLINE __m256i MontyMul(__m256i lhs, __m256i rhs, __m256i p,
__m256i inv) {
__m256i prod = _mm256_mul_epu32(lhs, rhs);
return PartialMontyRedSignedToSigned(prod, p, inv);
}

// Multiply the field elements in the even index entries.
// |lhs[2i]|, |rhs[2i]| must be signed 32-bit integers such that
// |lhs[2i]| * |rhs[2i]| lies in {-2³¹P, ..., 2³¹P}.
// The output will lie in {-P, ..., P} stored in |output[2i + 1]|.
ALWAYS_INLINE __m256i MontyMulSigned(__m256i lhs, __m256i rhs, __m256i p,
__m256i inv) {
__m256i prod = _mm256_mul_epi32(lhs, rhs);
return PartialMontyRedSignedToSigned(prod, p, inv);
}

ALWAYS_INLINE __m256i movehdup_epi32(__m256i x) {
Expand Down Expand Up @@ -205,8 +259,8 @@ ALWAYS_INLINE __m256i MontMulMod32(__m256i lhs, __m256i rhs, __m256i p,
__m256i lhs_odd = movehdup_epi32(lhs);
__m256i rhs_odd = movehdup_epi32(rhs);

__m256i d_evn = MontyD(lhs_evn, rhs_evn, p, inv);
__m256i d_odd = MontyD(lhs_odd, rhs_odd, p, inv);
__m256i d_evn = MontyMul(lhs_evn, rhs_evn, p, inv);
__m256i d_odd = MontyMul(lhs_odd, rhs_odd, p, inv);

__m256i d_evn_hi = movehdup_epi32(d_evn);
__m256i t = _mm256_blend_epi32(d_evn_hi, d_odd, 0b10101010);
Expand All @@ -215,6 +269,83 @@ ALWAYS_INLINE __m256i MontMulMod32(__m256i lhs, __m256i rhs, __m256i p,
return _mm256_min_epu32(t, u);
}

// Square the field elements in the even index entries.
// Inputs must be signed 32-bit integers.
// Outputs will be a signed integer in (-P, ..., P) copied into both the even
// and odd indices.
ALWAYS_INLINE __m256i ShiftedSquare(__m256i input, __m256i p, __m256i inv) {
// Note that we do not need a restriction on the size of |input[i]²| as
// 2³⁰ < P and |i32| <= 2³¹ and so => |input[i]²| <= 2⁶² < 2³²P.
__m256i square = _mm256_mul_epi32(input, input);
__m256i square_red = PartialMontyRedSignedToSigned(square, p, inv);
return movehdup_epi32(square_red);
}

// Apply callback to the even and odd indices of the input vector.
// callback should only depend in the 32 bit entries in the even indices.
// The output of callback must lie in (-P, ..., P) and be stored in the odd
// indices. The even indices of the output of callback will not be read. The
// input should conform to the requirements of |callback|.
// NOTE(chokobole): This is to suppress the error below.
// clang-format off
// error: ignoring attributes on template argument '__m256i(__m256i, __m256i, __m256i)' [-Werror=ignored-attributes]
// clang-format on
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
ALWAYS_INLINE __m256i ApplyFuncToEvenOdd(
__m256i input, __m256i p, __m256i inv,
base::RepeatingCallback<__m256i(__m256i, __m256i, __m256i)> callback) {
__m256i input_evn = input;
__m256i input_odd = movehdup_epi32(input);
__m256i d_evn = callback.Run(input_evn, p, inv);
__m256i d_odd = callback.Run(input_odd, p, inv);
__m256i d_evn_hi = movehdup_epi32(d_evn);
__m256i t = _mm256_blend_epi32(d_evn_hi, d_odd, 0b10101010);
__m256i u = _mm256_add_epi32(t, p);
return _mm256_min_epu32(t, u);
}
#pragma GCC diagnostic pop

// Cube the field elements in the even index entries.
// Inputs must be signed 32-bit integers in [-P, ..., P].
// Outputs will be a signed integer in (-P, ..., P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp3(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
return MontyMulSigned(square, input, p, inv);
}

ALWAYS_INLINE __m256i Exp3(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp3);
}

// Take the fifth power of the field elements in the even index
// entries. Inputs must be signed 32-bit integers in [-P, ..., P]. Outputs will
// be a signed integer in (-P, ..., P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp5(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
__m256i quad = ShiftedSquare(square, p, inv);
return MontyMulSigned(quad, input, p, inv);
}

ALWAYS_INLINE __m256i Exp5(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp5);
}

/// Take the seventh power of the field elements in the even index
/// entries. Inputs must lie in [-P, ..., P]. Outputs will also lie in (-P, ...,
/// P) stored in the odd indices.
ALWAYS_INLINE __m256i DoExp7(__m256i input, __m256i p, __m256i inv) {
__m256i square = ShiftedSquare(input, p, inv);
__m256i cube = MontyMulSigned(square, input, p, inv);
__m256i cube_shifted = movehdup_epi32(cube);
__m256i quad = ShiftedSquare(square, p, inv);
return MontyMulSigned(quad, cube_shifted, p, inv);
}

ALWAYS_INLINE __m256i Exp7(__m256i input, __m256i p, __m256i inv) {
return ApplyFuncToEvenOdd(input, p, inv, &DoExp7);
}

} // namespace tachyon::math

#endif // TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX2_H_
5 changes: 5 additions & 0 deletions tachyon/math/finite_fields/packed_prime_field32_avx512.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
// Copyright (c) 2022 The Plonky3 Authors
// Use of this source code is governed by a MIT/Apache-2.0 style license that
// can be found in the LICENSE-MIT.plonky3 and the LICENCE-APACHE.plonky3
// file.

#ifndef TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX512_H_
#define TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_AVX512_H_

Expand Down
5 changes: 5 additions & 0 deletions tachyon/math/finite_fields/packed_prime_field32_neon.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
// Copyright (c) 2022 The Plonky3 Authors
// Use of this source code is governed by a MIT/Apache-2.0 style license that
// can be found in the LICENSE-MIT.plonky3 and the LICENCE-APACHE.plonky3
// file.

#ifndef TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_NEON_H_
#define TACHYON_MATH_FINITE_FIELDS_PACKED_PRIME_FIELD32_NEON_H_

Expand Down

0 comments on commit d0ae6c9

Please sign in to comment.