Skip to content

Commit

Permalink
Merge pull request #517 from kroma-network/perf/optimize-fft-batch
Browse files Browse the repository at this point in the history
perf: optimize fft batch
  • Loading branch information
chokobole authored Aug 13, 2024
2 parents 7dcfd00 + ff014d7 commit 9a6ffb5
Show file tree
Hide file tree
Showing 19 changed files with 98 additions and 175 deletions.
2 changes: 1 addition & 1 deletion tachyon/crypto/commitments/fri/two_adic_fri_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ std::vector<ExtF> FoldMatrix(const ExtF& beta,

// NOTE(ashjeong): |arity| is subject to change in the future
template <typename ExtF>
ExtF FoldRow(size_t index, size_t log_num_rows, const ExtF& beta,
ExtF FoldRow(size_t index, uint32_t log_num_rows, const ExtF& beta,
const std::vector<ExtF>& evals) {
using F = typename math::ExtensionFieldTraits<ExtF>::BaseField;
const size_t kArity = 2;
Expand Down
8 changes: 4 additions & 4 deletions tachyon/crypto/commitments/fri/two_adic_fri_pcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ class TwoAdicFriPCS {
// Batch combination challenge
const ExtF alpha = challenger.template SampleExtElement<ExtF>();
VLOG(2) << "FRI(alpha): " << alpha.ToHexString(true);
size_t log_global_max_num_rows =
uint32_t log_global_max_num_rows =
proof.commit_phase_commits.size() + fri_.log_blowup;
return TwoAdicFriPCSVerify(
fri_, proof, challenger,
Expand Down Expand Up @@ -287,7 +287,7 @@ class TwoAdicFriPCS {
absl::flat_hash_map<ExtF, std::vector<ExtF>> ComputeInverseDenominators(
const std::vector<absl::Span<const math::RowMajorMatrix<F>>>&
matrices_by_round,
const std::vector<Points>& points_by_round, const F& coset_shift) {
const std::vector<Points>& points_by_round, F coset_shift) {
size_t num_rounds = matrices_by_round.size();

absl::flat_hash_map<ExtF, uint32_t> max_log_num_rows_for_point;
Expand Down Expand Up @@ -327,7 +327,7 @@ class TwoAdicFriPCS {
uint32_t log_num_rows = it->second;
std::vector<ExtF> temp = base::Map(
absl::MakeSpan(subgroup.data(), (size_t{1} << log_num_rows)),
[&point](const F& x) { return ExtF(x) - point; });
[&point](F x) { return ExtF(x) - point; });
CHECK(ExtF::BatchInverseInPlace(temp));
ret[point] = std::move(temp);
}
Expand All @@ -338,7 +338,7 @@ class TwoAdicFriPCS {
// https://hackmd.io/@vbuterin/barycentric_evaluation
template <typename Derived>
static std::vector<ExtF> InterpolateCoset(
const Eigen::MatrixBase<Derived>& coset_evals, const F& shift,
const Eigen::MatrixBase<Derived>& coset_evals, F shift,
const ExtF& point) {
size_t num_rows = static_cast<size_t>(coset_evals.rows());
size_t num_cols = static_cast<size_t>(coset_evals.cols());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ class FieldMerkleTreeMMCS final
std::vector<std::vector<F>>* openings,
Proof* proof) const {
size_t max_row_size = this->GetMaxRowSize(prover_data);
size_t log_max_row_size = base::bits::Log2Ceiling(max_row_size);
uint32_t log_max_row_size = base::bits::Log2Ceiling(max_row_size);

// TODO(chokobole): Is it able to be parallelized?
*openings = base::Map(
prover_data.leaves(),
[log_max_row_size, index](const math::RowMajorMatrix<F>& matrix) {
size_t log_row_size =
uint32_t log_row_size =
base::bits::Log2Ceiling(static_cast<size_t>(matrix.rows()));
size_t bits_reduced = log_max_row_size - log_row_size;
uint32_t bits_reduced = log_max_row_size - log_row_size;
size_t reduced_index = index >> bits_reduced;
return base::CreateVector(matrix.cols(),
[reduced_index, &matrix](size_t col) {
Expand Down
4 changes: 2 additions & 2 deletions tachyon/math/base/semigroups.h
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ class MultiplicativeSemigroup {
constexpr static std::vector<MulResult> GetBitRevIndexSuccessivePowers(
size_t size, const G& generator, const G& c = G::One()) {
std::vector<MulResult> ret(size);
size_t log_size = base::bits::CheckedLog2(size);
uint32_t log_size = base::bits::CheckedLog2(size);
base::Parallelize(
ret, [log_size, &generator, &c, &ret](
absl::Span<G> chunk, size_t chunk_offset, size_t chunk_size) {
Expand All @@ -248,7 +248,7 @@ class MultiplicativeSemigroup {
constexpr static std::vector<MulResult> GetBitRevIndexSuccessivePowersSerial(
size_t size, const G& generator, const G& c = G::One()) {
std::vector<MulResult> ret(size);
size_t log_size = base::bits::CheckedLog2(size);
uint32_t log_size = base::bits::CheckedLog2(size);
MulResult pow = c.IsOne() ? G::One() : c;
for (size_t idx = 0; idx < size - 1; ++idx) {
ret[base::bits::ReverseBitsLen(idx, log_size)] = pow;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ size_t DetermineMsmDivisionsForMemory(size_t scalar_t_mem_size,
size_t free_memory =
device::gpu::GpuMemLimitInfo(device::gpu::MemoryUsage::kHigh);
size_t shift = 0;
size_t log_msm_size = base::bits::Log2Ceiling(msm_size);
uint32_t log_msm_size = base::bits::Log2Ceiling(msm_size);

for (size_t number_of_divisions = 0; number_of_divisions < log_msm_size;
++number_of_divisions) {
Expand Down
4 changes: 2 additions & 2 deletions tachyon/math/finite_fields/prime_field_fallback.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,11 @@ class PrimeField<_Config, std::enable_if_t<!_Config::kUseAsm &&
constexpr const uint64_t& operator[](size_t i) const { return value_[i]; }

constexpr bool operator==(const PrimeField& other) const {
return ToBigInt() == other.ToBigInt();
return value_ == other.value_;
}

constexpr bool operator!=(const PrimeField& other) const {
return ToBigInt() != other.ToBigInt();
return value_ != other.value_;
}

constexpr bool operator<(const PrimeField& other) const {
Expand Down
4 changes: 2 additions & 2 deletions tachyon/math/finite_fields/prime_field_gpu_debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,11 @@ class PrimeFieldGpuDebug final
constexpr const uint64_t& operator[](size_t i) const { return value_[i]; }

constexpr bool operator==(const PrimeFieldGpuDebug& other) const {
return ToBigInt() == other.ToBigInt();
return value_ == other.value_;
}

constexpr bool operator!=(const PrimeFieldGpuDebug& other) const {
return ToBigInt() != other.ToBigInt();
return value_ != other.value_;
}

constexpr bool operator<(const PrimeFieldGpuDebug& other) const {
Expand Down
25 changes: 0 additions & 25 deletions tachyon/math/matrix/matrix_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,31 +52,6 @@ MakeCirculant(const Eigen::MatrixBase<ArgType>& arg) {
CirculantFunctor<ArgType>(arg.derived()));
}

// Packs a given row of a matrix. Results in a vector of packed fields and a
// vector of remaining values if the number of cols is not a factor of the
// packed field size.
//
// NOTE(ashjeong): |PackRowHorizontally| currently only
// supports row-major matrices.
template <typename PackedField, typename PrimeField, typename Expr,
int BlockRows, int BlockCols, bool InnerPanel>
std::vector<PackedField*> PackRowHorizontally(
Eigen::Block<Expr, BlockRows, BlockCols, InnerPanel>& matrix_row,
std::vector<PrimeField*>& remaining_values) {
size_t num_packed = matrix_row.cols() / PackedField::N;
size_t remaining_start_idx = num_packed * PackedField::N;
remaining_values =
base::CreateVector(matrix_row.cols() - remaining_start_idx,
[remaining_start_idx, &matrix_row](size_t col) {
return reinterpret_cast<PrimeField*>(
matrix_row.data() + remaining_start_idx + col);
});
return base::CreateVector(num_packed, [&matrix_row](size_t col) {
return reinterpret_cast<PackedField*>(matrix_row.data() +
PackedField::N * col);
});
}

// Creates a vector of packed fields for a given matrix row. If the length
// of the row is not a multiple of |PackedField::N|, the last |PackedField|
// element populates leftover values with |F::Zero()|.
Expand Down
39 changes: 0 additions & 39 deletions tachyon/math/matrix/matrix_utils_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,45 +21,6 @@ TEST_F(MatrixUtilsTest, Circulant) {

class MatrixPackingTest : public FiniteFieldTest<PackedBabyBear> {};

TEST_F(MatrixPackingTest, PackRowHorizontally) {
constexpr size_t N = PackedBabyBear::N;
constexpr size_t R = 3;

{
RowMajorMatrix<BabyBear> matrix =
RowMajorMatrix<BabyBear>::Random(2 * N, 2 * N);
auto mat_row = matrix.row(R);
std::vector<BabyBear*> remaining_values;
std::vector<PackedBabyBear*> packed_values =
PackRowHorizontally<PackedBabyBear>(mat_row, remaining_values);
ASSERT_TRUE(remaining_values.empty());
ASSERT_EQ(packed_values.size(), 2);
for (size_t i = 0; i < packed_values.size(); ++i) {
for (size_t j = 0; j < N; ++j) {
EXPECT_EQ((*packed_values[i])[j], matrix(R, i * N + j));
}
}
}
{
RowMajorMatrix<BabyBear> matrix =
RowMajorMatrix<BabyBear>::Random(2 * N - 1, 2 * N - 1);
auto mat_row = matrix.row(R);
std::vector<BabyBear*> remaining_values;
std::vector<PackedBabyBear*> packed_values =
PackRowHorizontally<PackedBabyBear>(mat_row, remaining_values);
ASSERT_EQ(remaining_values.size(), N - 1);
ASSERT_EQ(packed_values.size(), 1);
for (size_t i = 0; i < remaining_values.size(); ++i) {
EXPECT_EQ(*remaining_values[i], matrix(R, packed_values.size() * N + i));
}
for (size_t i = 0; i < packed_values.size(); ++i) {
for (size_t j = 0; j < N; ++j) {
EXPECT_EQ((*packed_values[i])[j], matrix(R, i * N + j));
}
}
}
}

TEST_F(MatrixPackingTest, PackRowVerticallyWithPrimeField) {
constexpr size_t N = PackedBabyBear::N;
constexpr size_t R = 3;
Expand Down
2 changes: 1 addition & 1 deletion tachyon/math/polynomials/univariate/evaluations_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ std::vector<F> SwapBitRevElements(const std::vector<F>& vals) {
// element at index 4(100) are swapped.
template <typename Container>
void SwapBitRevElementsInPlace(Container& container, size_t size,
size_t log_len) {
uint32_t log_len) {
TRACE_EVENT("Utils", "SwapBitRevElementsInPlace");
if (size <= 1) return;
OMP_PARALLEL_FOR(size_t idx = 1; idx < size; ++idx) {
Expand Down
Loading

0 comments on commit 9a6ffb5

Please sign in to comment.