Skip to content

Commit

Permalink
Merge pull request #553 from kroma-network/refac/refac-two-adic-fri
Browse files Browse the repository at this point in the history
refac: refactor `TwoAdicFri` verify
  • Loading branch information
chokobole authored Oct 28, 2024
2 parents dad8fb5 + bb6d337 commit 87116f9
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 22 deletions.
27 changes: 14 additions & 13 deletions tachyon/crypto/commitments/fri/fri_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ struct FRIConfig {
size_t Blowup() const { return size_t{1} << log_blowup; }
};

// NOTE(ashjeong): |kLogArity| is subject to change in the future
template <typename ExtF, typename Derived>
std::vector<ExtF> FoldMatrix(const ExtF& beta,
const Eigen::MatrixBase<Derived>& mat) {
using F = typename math::ExtensionFieldTraits<ExtF>::BaseField;
const size_t kLogArity = 1;
// We use the fact that
// pₑ(x²) = (p(x) + p(-x)) / 2
// pₒ(x²) = (p(x) - p(-x)) / (2x)
Expand All @@ -42,8 +44,8 @@ std::vector<ExtF> FoldMatrix(const ExtF& beta,
// + (1/2 - β/2gᵢₙᵥⁱ)p(gⁿᐟ²⁺ⁱ)
size_t rows = static_cast<size_t>(mat.rows());
F w;
CHECK(
F::GetRootOfUnity(size_t{1} << (base::bits::CheckedLog2(rows) + 1), &w));
CHECK(F::GetRootOfUnity(
size_t{1} << (base::bits::CheckedLog2(rows) + kLogArity), &w));
ExtF w_inv = ExtF(unwrap(w.Inverse()));
ExtF half_beta = beta * ExtF::TwoInv();

Expand All @@ -61,26 +63,25 @@ std::vector<ExtF> FoldMatrix(const ExtF& beta,
return ret;
}

// NOTE(ashjeong): |arity| is subject to change in the future
// NOTE(ashjeong): |kLogArity| is subject to change in the future
template <typename ExtF>
ExtF FoldRow(size_t index, uint32_t log_num_rows, const ExtF& beta,
const std::vector<ExtF>& evals) {
using F = typename math::ExtensionFieldTraits<ExtF>::BaseField;
const size_t kArity = 2;
const size_t kLogArity = 1;
const ExtF& e0 = evals[0];
const ExtF& e1 = evals[1];

F w;
CHECK(F::GetRootOfUnity(size_t{1} << (log_num_rows + kLogArity), &w));
ExtF subgroup_start =
ExtF(w.Pow(base::bits::ReverseBitsLen(index, log_num_rows)));
ExtF w_inv = ExtF(unwrap(w.Inverse()));
ExtF half_beta = beta * ExtF::TwoInv();
ExtF power =
ExtF(w_inv.Pow(base::bits::ReverseBitsLen(index, log_num_rows))) *
half_beta;

CHECK(F::GetRootOfUnity(size_t{1} << kLogArity, &w));
std::vector<ExtF> xs = ExtF::GetBitRevIndexSuccessivePowersSerial(
kArity, ExtF(w), subgroup_start);
// interpolate and evaluate at beta
return e0 + unwrap((beta - xs[0]) * (e1 - e0) / (xs[1] - xs[0]));
// result(g²ⁱ) = (1/2 + β/2gᵢₙᵥⁱ)p(gⁱ) + (1/2 - β/2gᵢₙᵥⁱ)p(gⁿᐟ²⁺ⁱ)
const ExtF& lo = evals[0];
const ExtF& hi = evals[1];
return (ExtF::TwoInv() + power) * lo + (ExtF::TwoInv() - power) * hi;
}

} // namespace tachyon::crypto
Expand Down
7 changes: 3 additions & 4 deletions tachyon/crypto/commitments/fri/simple_fri.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,14 +153,13 @@ class SimpleFRI final
// Pᵢ₊₁(X) = Pᵢ_even(X) + β * Pᵢ_odd(X)
//
// If the domain of Pᵢ(X) is Dᵢ = {ω⁰, ω¹, ..., ωⁿ⁻¹},
// then the domain of Pᵢ₊₁(X) is Dᵢ₊₁ = {ω⁰, ω¹, ..., ωᵏ⁻¹},
// where k = n / 2.
// then the domain of Pᵢ₊₁(X) is Dᵢ₊₁ = {ω⁰, ω², ..., ωⁿ⁻²}.
//
// As per the definition:
// Pᵢ₊₁(ωʲ) = Pᵢ_even(ωʲ) + β * Pᵢ_odd(ωʲ)
// Pᵢ₊₁(ω²ʲ) = Pᵢ_even(ω²ʲ) + β * Pᵢ_odd(ω²ʲ)
//
// Substituting Pᵢ_even and Pᵢ_odd:
// Pᵢ₊₁(ωʲ) = (Pᵢ(ωʲ) + Pᵢ(-ωʲ)) / 2 + β * (Pᵢ(ωʲ) - Pᵢ(-ωʲ)) / (2 * ωʲ)
// Pᵢ₊₁(ω²ʲ) = (Pᵢ(ωʲ) + Pᵢ(-ωʲ)) / 2 + β * (Pᵢ(ωʲ) - Pᵢ(-ωʲ)) / (2 * ωʲ)
// = ((1 + β * ω⁻ʲ) * Pᵢ(ωʲ) + (1 - β * ω⁻ʲ) * Pᵢ(-ωʲ)) / 2
size_t leaf_index = index % domain_size;
if (i == 0) {
Expand Down
2 changes: 1 addition & 1 deletion tachyon/math/base/semigroups.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class MultiplicativeSemigroup {
size_t size, const G& generator, const G& c = G::One()) {
std::vector<MulResult> ret(size);
uint32_t log_size = base::bits::CheckedLog2(size);
MulResult pow = c.IsOne() ? G::One() : c;
MulResult pow = c;
for (size_t idx = 0; idx < size - 1; ++idx) {
ret[base::bits::ReverseBitsLen(idx, log_size)] = pow;
pow *= generator;
Expand Down
11 changes: 7 additions & 4 deletions tachyon/math/finite_fields/extension_field_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ class ExtensionFieldBase {
using PackedField = typename PackedFieldTraits<BaseField>::PackedField;
constexpr uint32_t kDegree = T::ExtensionDegree();

// if |PackedField::N| = 8:
// |first_n_powers[i]| = {1, aᵢ, ..., a⁷ᵢ}
// if |PackedField::N| = 8,
// |first_n_powers| = [ 1, b, b², b³,..., b⁶, b⁷ ], where b is an extension
// field of |base|.
ExtendedPackedField first_n_powers;
T pow = T::One();
for (size_t i = 0; i < PackedField::N; ++i) {
Expand All @@ -80,15 +81,17 @@ class ExtensionFieldBase {
pow *= base;
}

// |multiplier[j]| = {a⁸ⱼ, a⁸ⱼ, ..., a⁸ⱼ, a⁸ⱼ}
// |multiplier| = [ b⁸, b⁸, ..., b⁸, b⁸ ], where b is an extension field of
// |base|. #|multiplier| = 8
ExtendedPackedField multiplier;
for (size_t i = 0; i < PackedField::N; ++i) {
for (uint32_t j = 0; j < kDegree; ++j) {
multiplier[j][i] = pow[j];
}
}

// |ret[i]| = {(a⁸ᵢ)ⁱ, aᵢ * (a⁸ᵢ)ⁱ, ..., a⁷ᵢ * (a⁸ᵢ)ⁱ}
// |ret[i]| = [ b⁸ⁱ, b⁸ⁱ⁺¹, ..., b⁸ⁱ⁺⁶, b⁸ⁱ⁺⁷ ], where b is an extension
// field of |base|.
std::vector<ExtendedPackedField> ret;
ret.reserve(size);
ret.emplace_back(first_n_powers);
Expand Down

0 comments on commit 87116f9

Please sign in to comment.