Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: optimize proof generation #509

Merged
merged 11 commits into from
Aug 9, 2024
4 changes: 2 additions & 2 deletions tachyon/base/parallelize.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ void Parallelize(Container& container, Callable callback,
template <typename Callable>
void Parallelize(size_t size, Callable callback,
std::optional<size_t> threshold = std::nullopt) {
size_t num_elements_per_thread = GetNumElementsPerThread(size, threshold);
ParallelizeByChunkSize(size, num_elements_per_thread, std::move(callback));
size_t size_per_thread = GetSizePerThread(size, threshold);
ParallelizeByChunkSize(size, size_per_thread, std::move(callback));
}

// Splits the |container| by |chunk_size| and maps each chunk using the provided
Expand Down
7 changes: 4 additions & 3 deletions tachyon/crypto/commitments/fri/two_adic_fri_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ std::vector<ExtF> FoldMatrix(const ExtF& beta,
// + (1/2 - β/2gᵢₙᵥⁱ)p(gⁿᐟ²⁺ⁱ)
size_t rows = static_cast<size_t>(mat.rows());
F w;
CHECK(F::GetRootOfUnity(1 << (base::bits::CheckedLog2(rows) + 1), &w));
CHECK(
F::GetRootOfUnity(size_t{1} << (base::bits::CheckedLog2(rows) + 1), &w));
ExtF w_inv = ExtF(unwrap(w.Inverse()));
// TODO(ashjeong): implement a field function |TwoInv()| as the inverse of 2
// is computed often.
Expand Down Expand Up @@ -73,11 +74,11 @@ ExtF FoldRow(size_t index, size_t log_num_rows, const ExtF& beta,
const ExtF& e1 = evals[1];

F w;
CHECK(F::GetRootOfUnity(1 << (log_num_rows + kLogArity), &w));
CHECK(F::GetRootOfUnity(size_t{1} << (log_num_rows + kLogArity), &w));
ExtF subgroup_start =
ExtF(w.Pow(base::bits::ReverseBitsLen(index, log_num_rows)));

CHECK(F::GetRootOfUnity(1 << kLogArity, &w));
CHECK(F::GetRootOfUnity(size_t{1} << kLogArity, &w));
std::vector<ExtF> xs = ExtF::GetBitRevIndexSuccessivePowersSerial(
kArity, ExtF(w), subgroup_start);
// interpolate and evaluate at beta
Expand Down
12 changes: 6 additions & 6 deletions tachyon/crypto/commitments/fri/two_adic_fri_pcs.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ class TwoAdicFriPCS {
uint32_t rev_reduced_index = base::bits::ReverseBitsLen(
index >> bits_reduced, log_num_rows);
F w;
CHECK(F::GetRootOfUnity(1 << log_num_rows, &w));
CHECK(F::GetRootOfUnity(size_t{1} << log_num_rows, &w));
F x = F::FromMontgomery(F::Config::kSubgroupGenerator) *
w.Pow(rev_reduced_index);

Expand Down Expand Up @@ -316,18 +316,18 @@ class TwoAdicFriPCS {

// Compute the largest subgroup we will use, in bitrev order.
F w;
CHECK(F::GetRootOfUnity(1 << max_log_num_rows, &w));
CHECK(F::GetRootOfUnity(size_t{1} << max_log_num_rows, &w));
std::vector<F> subgroup = F::GetBitRevIndexSuccessivePowers(
1 << max_log_num_rows, w, coset_shift);
size_t{1} << max_log_num_rows, w, coset_shift);

absl::flat_hash_map<ExtF, std::vector<ExtF>> ret;
for (auto it = max_log_num_rows_for_point.begin();
it != max_log_num_rows_for_point.end(); ++it) {
const ExtF& point = it->first;
uint32_t log_num_rows = it->second;
std::vector<ExtF> temp =
base::Map(absl::MakeSpan(subgroup.data(), (1 << log_num_rows)),
[&point](const F& x) { return ExtF(x) - point; });
std::vector<ExtF> temp = base::Map(
absl::MakeSpan(subgroup.data(), (size_t{1} << log_num_rows)),
[&point](const F& x) { return ExtF(x) - point; });
CHECK(ExtF::BatchInverseInPlace(temp));
ret[point] = std::move(temp);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class TwoAdicFriPCSTest : public testing::Test {
std::vector<Domain> inner_domains(log_degrees.size());
std::vector<math::RowMajorMatrix<F>> inner_polys(log_degrees.size());
for (size_t i = 0; i < log_degrees.size(); ++i) {
size_t d = 1 << log_degrees[i];
size_t d = size_t{1} << log_degrees[i];
// TODO(ashjeong): make the latter number randomized from 0-10
size_t cols = 5;
inner_domains[i] = pcs_.GetNaturalDomainForDegree(d);
Expand Down
5 changes: 3 additions & 2 deletions tachyon/crypto/commitments/fri/two_adic_fri_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ F VerifyQuery(uint32_t index, uint32_t log_max_num_rows,
std::vector<std::vector<F>> evals = {{folded_eval, folded_eval}};
evals[0][index_sibling % 2] = steps[step_idx].opening.sibling_value;
CHECK(config.mmcs.VerifyOpeningProof(
steps[step_idx].commit, {math::Dimensions(2, 1 << log_folded_num_rows)},
index_pair, evals, steps[step_idx].opening.opening_proof));
steps[step_idx].commit,
{math::Dimensions(2, size_t{1} << log_folded_num_rows)}, index_pair,
evals, steps[step_idx].opening.opening_proof));
folded_eval = FoldRow(index_pair, log_folded_num_rows, steps[step_idx].beta,
evals[0]);
index = index_pair;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ class BinaryMerkleTree final
}
size_t i = base::bits::Log2Floor(leaves_size) -
base::bits::Log2Floor(leaves_size_for_parallelization_);
BuildTreeFromLeaves(
base::Range<size_t>((1 << i) - 1, (1 << (i + 1)) - 1));
BuildTreeFromLeaves(base::Range<size_t>((size_t{1} << i) - 1,
(size_t{1} << (i + 1)) - 1));
} else {
BuildTreeFromLeaves(
base::Range<size_t>(leaves_size - 1, (leaves_size << 1) - 1));
Expand Down
3 changes: 2 additions & 1 deletion tachyon/crypto/hashes/sponge/poseidon2/poseidon2_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ struct Poseidon2Config : public PoseidonConfigBase<F> {
ret.internal_diagonal_minus_one = math::Vector<F>(N + 1);
ret.internal_diagonal_minus_one[0] = F(PrimeField::Config::kModulus - 2);
for (size_t i = 1; i < N + 1; ++i) {
ret.internal_diagonal_minus_one[i] = F(1 << internal_shifts[i - 1]);
ret.internal_diagonal_minus_one[i] =
F(uint32_t{1} << internal_shifts[i - 1]);
}
} else {
ret.internal_shifts = math::Vector<uint8_t>(N);
Expand Down
1 change: 1 addition & 0 deletions tachyon/crypto/sumcheck/multilinear/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ tachyon_cc_library(
":sumcheck_prover_msg",
":sumcheck_proving_key",
":sumcheck_verifier_msg",
"//tachyon/base:parallelize",
"//tachyon/math/polynomials/multivariate:linear_combination",
"//tachyon/math/polynomials/univariate:univariate_evaluations",
],
Expand Down
47 changes: 21 additions & 26 deletions tachyon/crypto/sumcheck/multilinear/sumcheck_prover.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <utility>
#include <vector>

#include "tachyon/base/parallelize.h"
#include "tachyon/crypto/sumcheck/multilinear/sumcheck_prover_msg.h"
#include "tachyon/crypto/sumcheck/multilinear/sumcheck_proving_key.h"
#include "tachyon/crypto/sumcheck/multilinear/sumcheck_verifier_msg.h"
Expand Down Expand Up @@ -91,36 +92,30 @@ class SumcheckProver {
size_t thread_nums = 1;
#endif
size_t size = size_t{1} << (num_variables_ - round_);
thread_nums = (thread_nums * kParallelFactor) <= size ? thread_nums : 1;

size_t chunk_size = (size + thread_nums - 1) / thread_nums;
size_t num_chunks = (size + chunk_size - 1) / chunk_size;

std::vector<std::vector<F>> finished_evaluations(
num_chunks, std::vector<F>(max_evaluations_ + 1, F::Zero()));

OMP_PARALLEL_FOR(size_t i = 0; i < num_chunks; ++i) {
size_t begin = i * chunk_size;
size_t len = (i == num_chunks - 1) ? size - begin : chunk_size;
std::vector<F> intermediate_evaluations(max_evaluations_ + 1, F::Zero());
for (size_t j = begin; j < begin + len; ++j) {
for (const Term& term : terms_) {
std::fill(intermediate_evaluations.begin(),
intermediate_evaluations.end(), term.coefficient);
EvaluateTermPerVariable(j, intermediate_evaluations, term);
for (size_t k = 0; k < max_evaluations_ + 1; ++k) {
finished_evaluations[i][k] += intermediate_evaluations[k];
std::vector<std::vector<F>> evals_vec = base::ParallelizeMap(
size,
[this](size_t len, size_t chunk_offset, size_t chunk_size) {
size_t begin = chunk_offset * chunk_size;
std::vector<F> ret(max_evaluations_ + 1, F::Zero());
std::vector<F> tmp(max_evaluations_ + 1);
for (size_t i = begin; i < begin + len; ++i) {
for (const Term& term : terms_) {
std::fill(tmp.begin(), tmp.end(), term.coefficient);
EvaluateTermPerVariable(i, tmp, term);
for (size_t j = 0; j < ret.size(); ++j) {
ret[j] += tmp[j];
}
}
}
}
}
}
for (size_t i = 1; i < num_chunks; ++i) {
return ret;
},
kParallelFactor * thread_nums);
for (size_t i = 1; i < evals_vec.size(); ++i) {
for (size_t j = 0; j < max_evaluations_ + 1; ++j) {
finished_evaluations[0][j] += finished_evaluations[i][j];
evals_vec[0][j] += evals_vec[i][j];
}
}
return {math::UnivariateEvaluations<F, MaxDegree>(
std::move(finished_evaluations[0]))};
return {math::UnivariateEvaluations<F, MaxDegree>(std::move(evals_vec[0]))};
}

// Receive message from verifier and run a prover round.
Expand Down
74 changes: 46 additions & 28 deletions tachyon/crypto/sumcheck/multilinear/sumcheck_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,38 +181,54 @@ F InterpolateUniPoly(const std::vector<F>& poly, const F& evaluation_point) {
#else
size_t thread_nums = 1;
#endif
thread_nums =
((thread_nums * kParallelFactor) <= poly_size) ? thread_nums : 1;

size_t chunk_size = (poly_size + thread_nums - 1) / thread_nums;
size_t num_chunks = (poly_size + chunk_size - 1) / chunk_size;
struct Result {
F product;
F denom_up;
std::vector<F> evals;
};

std::vector<F> products(num_chunks, F::One());
std::vector<F> denom_ups(num_chunks, F::One());
std::vector<std::vector<F>> list_of_evals(num_chunks);
OMP_PARALLEL_FOR(size_t i = 0; i < num_chunks; ++i) {
size_t begin = i * chunk_size;
size_t len = (i == num_chunks - 1) ? poly_size - begin : chunk_size;
list_of_evals[i].reserve(len);
F check = F(begin);
for (size_t j = begin; j < begin + len; ++j) {
const F difference = evaluation_point - check;
list_of_evals[i].push_back(difference);
products[i] *= difference;
if (j > 1) {
denom_ups[i] *= check;
}
check += F::One();
F product;
F denom_up;
std::vector<F> evals;
{
std::vector<Result> results = base::ParallelizeMap(
poly_size,
[&evaluation_point](size_t len, size_t chunk_offset,
size_t chunk_size) {
size_t begin = chunk_offset * chunk_size;
Result result;
result.product = F::One();
result.denom_up = F::One();
result.evals.reserve(len);
F check = F(begin);
for (size_t i = begin; i < begin + len; ++i) {
F difference = evaluation_point - check;
result.product *= difference;
result.evals.push_back(std::move(difference));
if (i > 1) {
TomTaehoonKim marked this conversation as resolved.
Show resolved Hide resolved
result.denom_up *= check;
}
check += F::One();
}
return result;
},
kParallelFactor * thread_nums);
product = std::move(results[0].product);
denom_up = std::move(results[0].denom_up);
size_t size = std::accumulate(results.begin(), results.end(), 0,
[](size_t acc, const Result& result) {
return acc + result.evals.size();
});
evals = std::move(results[0].evals);
batzor marked this conversation as resolved.
Show resolved Hide resolved
evals.reserve(size);
for (size_t i = 1; i < results.size(); ++i) {
evals.insert(evals.end(), results[i].evals.begin(),
results[i].evals.end());
product *= results[i].product;
denom_up *= results[i].denom_up;
}
}
F product = products[0];
F denom_up = denom_ups[0];
std::vector<F> evals = list_of_evals[0];
for (size_t i = 1; i < num_chunks; ++i) {
evals.insert(evals.end(), list_of_evals[i].begin(), list_of_evals[i].end());
product *= products[i];
denom_up *= denom_ups[i];
}

// Computing denom[i] = ∏ⱼ≠ᵢ(i - j) for a given i:
//
Expand All @@ -234,6 +250,8 @@ F InterpolateUniPoly(const std::vector<F>& poly, const F& evaluation_point) {
// 1,2,...,|poly_size| - i - 1)
//
// denom is stored as a fraction number to reduce field divisions.
// TODO(chokobole): This should be parallelized depending on |poly_size| like
// above.
F res = F::Zero();
F offset_up = F::One();

Expand Down
1 change: 1 addition & 0 deletions tachyon/math/base/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ tachyon_cc_library(
hdrs = ["groups.h"],
deps = [
":semigroups",
"//tachyon/base:parallelize",
"//tachyon/base/containers:container_util",
"//tachyon/base/types:always_false",
],
Expand Down
12 changes: 5 additions & 7 deletions tachyon/math/base/big_int.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,19 +60,17 @@ struct BigInt {
constexpr static size_t kBitNums = kByteNums * 8;

constexpr BigInt() = default;
constexpr explicit BigInt(int64_t value)
: BigInt(static_cast<uint64_t>(value)) {
DCHECK_GE(value, int64_t{0});
}
constexpr explicit BigInt(uint64_t value) { limbs[kSmallestLimbIdx] = value; }
constexpr explicit BigInt(int value) : BigInt(static_cast<uint64_t>(value)) {
template <typename T, std::enable_if_t<std::is_signed_v<T>>* = nullptr>
constexpr explicit BigInt(T value) {
DCHECK_GE(value, 0);
limbs[kSmallestLimbIdx] = value;
}
template <typename T, std::enable_if_t<std::is_unsigned_v<T>>* = nullptr>
constexpr explicit BigInt(T value) {
limbs[kSmallestLimbIdx] = value;
}
constexpr explicit BigInt(std::initializer_list<int> values) {
template <typename T, std::enable_if_t<std::is_signed_v<T>>* = nullptr>
constexpr explicit BigInt(std::initializer_list<T> values) {
DCHECK_LE(values.size(), N);
auto it = values.begin();
for (size_t i = 0; i < values.size(); ++i, ++it) {
Expand Down
19 changes: 8 additions & 11 deletions tachyon/math/base/groups.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include <vector>

#include "tachyon/base/containers/container_util.h"
#include "tachyon/base/openmp_util.h"
#include "tachyon/base/parallelize.h"
#include "tachyon/base/types/always_false.h"
#include "tachyon/math/base/semigroups.h"

Expand Down Expand Up @@ -88,20 +88,17 @@ class MultiplicativeGroup : public MultiplicativeSemigroup<G> {
size_t thread_nums = static_cast<size_t>(omp_get_max_threads());
if (size >=
size_t{1} << (thread_nums / kParallelBatchInverseDivisorThreshold)) {
size_t chunk_size = base::GetNumElementsPerThread(groups);
size_t num_chunks = (size + chunk_size - 1) / chunk_size;
std::atomic<bool> check_valid(true);
OMP_PARALLEL_FOR(size_t i = 0; i < num_chunks; ++i) {
size_t len = i == num_chunks - 1 ? size - i * chunk_size : chunk_size;
absl::Span<const G> groups_chunk(std::data(groups) + i * chunk_size,
len);
absl::Span<G> inverses_chunk(std::data(*inverses) + i * chunk_size,
len);
base::Parallelize(size, [&groups, inverses, &coeff, &check_valid](
size_t len, size_t chunk_offset,
size_t chunk_size) {
size_t start = chunk_offset * chunk_size;
absl::Span<const G> groups_chunk(&groups[start], len);
absl::Span<G> inverses_chunk(&(*inverses)[start], len);
if (UNLIKELY(!DoBatchInverse(groups_chunk, inverses_chunk, coeff))) {
check_valid.store(false, std::memory_order_relaxed);
continue;
}
}
});
if (UNLIKELY(!check_valid.load(std::memory_order_relaxed))) {
LOG(ERROR) << "Inverse of zero attempted";
return false;
Expand Down
2 changes: 1 addition & 1 deletion tachyon/math/base/semigroups.h
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ class AdditiveSemigroup {
const G& g = static_cast<const G&>(*this);
switch (scalar) {
case 0:
return AddResult::One();
return AddResult::Zero();
case 1:
return g;
case 2:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ tachyon_cc_library(
hdrs = ["pippenger_adapter.h"],
deps = [
":pippenger",
"//tachyon/base:parallelize",
"//tachyon/base:profiler",
],
)
Expand Down
Loading