Skip to content

Commit

Permalink
Merge pull request #523 from kroma-network/perf/optimize-packed-prime…
Browse files Browse the repository at this point in the history
…-field-operations

perf: optimize packed prime field operations
  • Loading branch information
chokobole authored Aug 21, 2024
2 parents eaa0faa + c366fb3 commit c83ea94
Show file tree
Hide file tree
Showing 15 changed files with 65 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class Poseidon2Plonky3InternalMatrix {
if constexpr (PrimeField::Config::kUseMontgomery) {
static_assert(PrimeField::Config::kModulusBits <= 32);
for (F& f : v) {
f *= F::Broadcast(PrimeField::FromMontgomery(1));
f *= F::RawOne();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ __m256i kZero;
__m256i kOne;
__m256i kMinusOne;
__m256i kTwoInv;
__m256i kRawOne;

__m256i ToVector(const PackedBabyBearAVX2& packed) {
return _mm256_loadu_si256(
Expand Down Expand Up @@ -58,6 +59,7 @@ void PackedBabyBearAVX2::Init() {
kOne = _mm256_set1_epi32(BabyBear::Config::kOne);
kMinusOne = _mm256_set1_epi32(BabyBear::Config::kMinusOne);
kTwoInv = _mm256_set1_epi32(BabyBear::Config::kTwoInv);
kRawOne = _mm256_set1_epi32(1);
}

// static
Expand All @@ -74,6 +76,9 @@ PackedBabyBearAVX2 PackedBabyBearAVX2::MinusOne() {
// static
PackedBabyBearAVX2 PackedBabyBearAVX2::TwoInv() { return FromVector(kTwoInv); }

// static
PackedBabyBearAVX2 PackedBabyBearAVX2::RawOne() { return FromVector(kRawOne); }

// static
PackedBabyBearAVX2 PackedBabyBearAVX2::Broadcast(const PrimeField& value) {
return FromVector(_mm256_set1_epi32(value.value()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TACHYON_EXPORT PackedBabyBearAVX2 final

static PackedBabyBearAVX2 TwoInv();

static PackedBabyBearAVX2 RawOne();

static PackedBabyBearAVX2 Broadcast(const PrimeField& value);

// AdditiveSemigroup methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ __m512i kZero;
__m512i kOne;
__m512i kMinusOne;
__m512i kTwoInv;
__m512i kRawOne;

__m512i ToVector(const PackedBabyBearAVX512& packed) {
return _mm512_loadu_si512(packed.values().data());
Expand Down Expand Up @@ -58,6 +59,7 @@ void PackedBabyBearAVX512::Init() {
kOne = _mm512_set1_epi32(BabyBear::Config::kOne);
kMinusOne = _mm512_set1_epi32(BabyBear::Config::kMinusOne);
kTwoInv = _mm512_set1_epi32(BabyBear::Config::kTwoInv);
kRawOne = _mm512_set1_epi32(1);
}

// static
Expand All @@ -76,6 +78,11 @@ PackedBabyBearAVX512 PackedBabyBearAVX512::TwoInv() {
return FromVector(kTwoInv);
}

// static
PackedBabyBearAVX512 PackedBabyBearAVX512::RawOne() {
return FromVector(kRawOne);
}

// static
PackedBabyBearAVX512 PackedBabyBearAVX512::Broadcast(const PrimeField& value) {
return FromVector(_mm512_set1_epi32(value.value()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TACHYON_EXPORT PackedBabyBearAVX512 final

static PackedBabyBearAVX512 TwoInv();

static PackedBabyBearAVX512 RawOne();

static PackedBabyBearAVX512 Broadcast(const PrimeField& value);

// AdditiveSemigroup methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ uint32x4_t kZero;
uint32x4_t kOne;
uint32x4_t kMinusOne;
uint32x4_t kTwoInv;
uint32x4_t kRawOne;

uint32x4_t ToVector(const PackedBabyBearNeon& packed) {
return vld1q_u32(reinterpret_cast<const uint32_t*>(packed.values().data()));
Expand Down Expand Up @@ -60,6 +61,7 @@ void PackedBabyBearNeon::Init() {
kOne = vdupq_n_u32(BabyBear::Config::kOne);
kMinusOne = vdupq_n_u32(BabyBear::Config::kMinusOne);
kTwoInv = vdupq_n_u32(BabyBear::Config::kTwoInv);
kRawOne = vdupq_n_u32(1);
}

// static
Expand All @@ -76,6 +78,9 @@ PackedBabyBearNeon PackedBabyBearNeon::MinusOne() {
// static
PackedBabyBearNeon PackedBabyBearNeon::TwoInv() { return FromVector(kTwoInv); }

// static
PackedBabyBearNeon PackedBabyBearNeon::RawOne() { return FromVector(kRawOne); }

// static
PackedBabyBearNeon PackedBabyBearNeon::Broadcast(const PrimeField& value) {
return FromVector(vdupq_n_u32(value.value()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TACHYON_EXPORT PackedBabyBearNeon final

static PackedBabyBearNeon TwoInv();

static PackedBabyBearNeon RawOne();

static PackedBabyBearNeon Broadcast(const PrimeField& value);

// AdditiveSemigroup methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ __m256i kZero;
__m256i kOne;
__m256i kMinusOne;
__m256i kTwoInv;
__m256i kRawOne;

__m256i ToVector(const PackedKoalaBearAVX2& packed) {
return _mm256_loadu_si256(
Expand Down Expand Up @@ -58,6 +59,7 @@ void PackedKoalaBearAVX2::Init() {
kOne = _mm256_set1_epi32(KoalaBear::Config::kOne);
kMinusOne = _mm256_set1_epi32(KoalaBear::Config::kMinusOne);
kTwoInv = _mm256_set1_epi32(KoalaBear::Config::kTwoInv);
kRawOne = _mm256_set1_epi32(1);
}

// static
Expand All @@ -76,6 +78,11 @@ PackedKoalaBearAVX2 PackedKoalaBearAVX2::TwoInv() {
return FromVector(kTwoInv);
}

// static
PackedKoalaBearAVX2 PackedKoalaBearAVX2::RawOne() {
return FromVector(kRawOne);
}

// static
PackedKoalaBearAVX2 PackedKoalaBearAVX2::Broadcast(const PrimeField& value) {
return FromVector(_mm256_set1_epi32(value.value()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TACHYON_EXPORT PackedKoalaBearAVX2 final

static PackedKoalaBearAVX2 TwoInv();

static PackedKoalaBearAVX2 RawOne();

static PackedKoalaBearAVX2 Broadcast(const PrimeField& value);

// AdditiveSemigroup methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ __m512i kZero;
__m512i kOne;
__m512i kMinusOne;
__m512i kTwoInv;
__m512i kRawOne;

__m512i ToVector(const PackedKoalaBearAVX512& packed) {
return _mm512_loadu_si512(packed.values().data());
Expand Down Expand Up @@ -58,6 +59,7 @@ void PackedKoalaBearAVX512::Init() {
kOne = _mm512_set1_epi32(KoalaBear::Config::kOne);
kMinusOne = _mm512_set1_epi32(KoalaBear::Config::kMinusOne);
kTwoInv = _mm512_set1_epi32(KoalaBear::Config::kTwoInv);
kRawOne = _mm512_set1_epi32(1);
}

// static
Expand All @@ -78,6 +80,11 @@ PackedKoalaBearAVX512 PackedKoalaBearAVX512::TwoInv() {
return FromVector(kTwoInv);
}

// static
PackedKoalaBearAVX512 PackedKoalaBearAVX512::RawOne() {
return FromVector(kRawOne);
}

// static
PackedKoalaBearAVX512 PackedKoalaBearAVX512::Broadcast(
const PrimeField& value) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ class TACHYON_EXPORT PackedKoalaBearAVX512 final

static PackedKoalaBearAVX512 TwoInv();

static PackedKoalaBearAVX512 RawOne();

static PackedKoalaBearAVX512 Broadcast(const PrimeField& value);

// AdditiveSemigroup methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ uint32x4_t kZero;
uint32x4_t kOne;
uint32x4_t kMinusOne;
uint32x4_t kTwoInv;
uint32x4_t kRawOne;

uint32x4_t ToVector(const PackedKoalaBearNeon& packed) {
return vld1q_u32(reinterpret_cast<const uint32_t*>(packed.values().data()));
Expand Down Expand Up @@ -60,6 +61,7 @@ void PackedKoalaBearNeon::Init() {
kOne = vdupq_n_u32(KoalaBear::Config::kOne);
kMinusOne = vdupq_n_u32(KoalaBear::Config::kMinusOne);
kTwoInv = vdupq_n_u32(KoalaBear::Config::kTwoInv);
kRawOne = vdupq_n_u32(1);
}

// static
Expand All @@ -78,6 +80,11 @@ PackedKoalaBearNeon PackedKoalaBearNeon::TwoInv() {
return FromVector(kTwoInv);
}

// static
PackedKoalaBearNeon PackedKoalaBearNeon::RawOne() {
return FromVector(kRawOne);
}

// static
PackedKoalaBearNeon PackedKoalaBearNeon::Broadcast(const PrimeField& value) {
return FromVector(vdupq_n_u32(value.value()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class TACHYON_EXPORT PackedKoalaBearNeon final

static PackedKoalaBearNeon TwoInv();

static PackedKoalaBearNeon RawOne();

static PackedKoalaBearNeon Broadcast(const PrimeField& value);

// AdditiveSemigroup methods
Expand Down
4 changes: 2 additions & 2 deletions tachyon/math/finite_fields/packed_prime_field_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ class PackedPrimeFieldBase : public Field<Derived> {
// MultiplicativeGroup methods
std::optional<Derived> Inverse() const {
Derived ret;
CHECK(PrimeField::BatchInverse(values_, &ret.values_));
CHECK(PrimeField::BatchInverseSerial(values_, &ret.values_));
return ret;
}

[[nodiscard]] std::optional<Derived*> InverseInPlace() {
CHECK(PrimeField::BatchInverseInPlace(values_));
CHECK(PrimeField::BatchInverseInPlaceSerial(values_));
return static_cast<Derived*>(this);
}

Expand Down
12 changes: 12 additions & 0 deletions tachyon/math/finite_fields/packed_prime_field_unittest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,18 @@ TYPED_TEST(PackedPrimeFieldTest, TwoInv) {
}
}

TYPED_TEST(PackedPrimeFieldTest, RawOne) {
using PackedPrimeField = TypeParam;
using PrimeField = typename PackedFieldTraits<PackedPrimeField>::Field;

if constexpr (PrimeField::Config::kUseMontgomery) {
EXPECT_EQ(PackedPrimeField::RawOne(),
PackedPrimeField::Broadcast(PrimeField::FromMontgomery(1)));
} else {
GTEST_SKIP() << "RawOne() doesn't exist";
}
}

TYPED_TEST(PackedPrimeFieldTest, Broadcast) {
using PackedPrimeField = TypeParam;
using PrimeField = typename PackedPrimeField::PrimeField;
Expand Down

0 comments on commit c83ea94

Please sign in to comment.