Skip to content

Commit

Permalink
refactor again, use thread_pool instead of ctxs
Browse files Browse the repository at this point in the history
  • Loading branch information
nulltea committed Sep 7, 2023
1 parent 3635b11 commit 0d5c09f
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 229 deletions.
4 changes: 2 additions & 2 deletions lightclient-circuits/src/committee_update_circuit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

// use crate::{
// gadget::crypto::{
// Fp2Point, FpPoint, G1Chip, G1Point, G2Chip, G2Point, HashChip, HashToCurveCache,
// Fp2Point, FpPoint, G1Chip, G1Point, G2Chip, G2Point, HashInstructions, HashToCurveCache,
// HashToCurveChip, Sha256ChipWide, SpreadConfig,
// },
// poseidon::{g1_array_poseidon, poseidon_sponge},
Expand Down Expand Up @@ -187,7 +187,7 @@
// fn sync_committee_root_ssz<'a, I: IntoIterator<Item = Vec<AssignedValue<F>>>>(
// ctx: &mut Context<F>,
// region: &mut Region<'_, F>,
// hasher: &'a impl HashChip<F>,
// hasher: &'a impl HashInstructions<F>,
// compressed_encodings: I,
// ) -> Result<Vec<AssignedValue<F>>, Error> {
// let mut pubkeys_hashes = compressed_encodings
Expand Down
63 changes: 28 additions & 35 deletions lightclient-circuits/src/gadget/crypto/hash2curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use itertools::Itertools;
use num_bigint::{BigInt, BigUint};
use pasta_curves::arithmetic::SqrtRatio;

use super::ShaContexts;
use super::{ShaContexts, ShaThreadBuilder};
use super::{
sha256::HashInstructions,
util::{fp2_sgn0, i2osp, strxor},
Expand Down Expand Up @@ -63,17 +63,16 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,

pub fn hash_to_curve<C: HashCurveExt>(
&self,
ctx_base: &mut Context<F>,
ctx_sha: &mut ShaContexts<F>,
thread_pool: &mut ShaThreadBuilder<F>,
fp_chip: &FpChip<F, C::Fp>,
msg: HashInput<QuantumCell<F>>,
cache: &mut HashToCurveCache<F>,
) -> Result<G2Point<F>, Error>
where
C::Fq: FieldExtConstructor<C::Fp, 2>,
{
let u = self.hash_to_field::<C>(ctx_base, ctx_sha, fp_chip, msg, cache)?;
let p = self.map_to_curve::<C>(ctx_base, fp_chip, u, cache)?;
let u = self.hash_to_field::<C>(thread_pool, fp_chip, msg, cache)?;
let p = self.map_to_curve::<C>(thread_pool.main(), fp_chip, u, cache)?;
Ok(p)
}

Expand All @@ -87,8 +86,7 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,
/// - https://github.com/succinctlabs/telepathy-circuits/blob/d5c7771/circuits/hash_to_field.circom#L11
fn hash_to_field<C: HashCurveExt>(
&self,
ctx_base: &mut Context<F>,
ctx_sha: &mut ShaContexts<F>,
thread_pool: &mut ShaThreadBuilder<F>,
fp_chip: &FpChip<F, C::Fp>,
msg: HashInput<QuantumCell<F>>,
cache: &mut HashToCurveCache<F>,
Expand All @@ -99,15 +97,14 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,
let safe_types = SafeTypeChip::new(range);

// constants
let zero = ctx_base.load_zero();
let one = ctx_base.load_constant(F::one());
let zero = thread_pool.main().load_zero();
let one = thread_pool.main().load_constant(F::one());

let assigned_msg = msg.into_assigned(ctx_base).to_vec();
let assigned_msg = msg.into_assigned(thread_pool.main()).to_vec();

let len_in_bytes = 2 * G2_EXT_DEGREE * L;
let extended_msg = Self::expand_message_xmd(
ctx_base,
ctx_sha,
thread_pool,
self.hash_chip,
assigned_msg,
len_in_bytes,
Expand All @@ -117,12 +114,12 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,
let limb_bases = cache.binary_bases.get_or_insert_with(|| {
C::limb_bytes_bases()
.into_iter()
.map(|base| ctx_base.load_constant(base))
.map(|base| thread_pool.main().load_constant(base))
.collect()
});

// 2^256
let two_pow_256 = fp_chip.load_constant_uint(ctx_base, BigUint::from(2u8).pow(256));
let two_pow_256 = fp_chip.load_constant_uint(thread_pool.main(), BigUint::from(2u8).pow(256));
let fq_bytes = C::BYTES_COMPRESSED / 2;

let mut fst = true;
Expand All @@ -141,20 +138,20 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,
buf.to_vec(),
&fp_chip.limb_bases,
gate,
ctx_base,
thread_pool.main(),
);

buf[rem..].copy_from_slice(&tv[32..]);
let hi = decode_into_field_be::<F, C, _>(
buf.to_vec(),
&fp_chip.limb_bases,
gate,
ctx_base,
thread_pool.main(),
);

let lo_2_256 = fp_chip.mul_no_carry(ctx_base, lo, two_pow_256.clone());
let lo_2_356_hi = fp_chip.add_no_carry(ctx_base, lo_2_256, hi);
fp_chip.carry_mod(ctx_base, lo_2_356_hi)
let lo_2_256 = fp_chip.mul_no_carry(thread_pool.main(), lo, two_pow_256.clone());
let lo_2_356_hi = fp_chip.add_no_carry(thread_pool.main(), lo_2_256, hi);
fp_chip.carry_mod(thread_pool.main(), lo_2_356_hi)
})
.collect_vec(),
)
Expand Down Expand Up @@ -200,8 +197,7 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,
/// - https://github.com/paulmillr/noble-curves/blob/bf70ba9/src/abstract/hash-to-curve.ts#L63
/// - https://github.com/succinctlabs/telepathy-circuits/blob/d5c7771/circuits/hash_to_field.circom#L139
fn expand_message_xmd(
ctx_base: &mut Context<F>,
ctx_sha: &mut ShaContexts<F>,
thread_pool: &mut ShaThreadBuilder<F>,
hash_chip: &HC,
msg: Vec<AssignedValue<F>>,
len_in_bytes: usize,
Expand All @@ -212,25 +208,25 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,

// constants
// const MAX_INPUT_SIZE: usize = 192;
let zero = ctx_base.load_zero();
let one = ctx_base.load_constant(F::one());
let zero = thread_pool.main().load_zero();
let one = thread_pool.main().load_constant(F::one());

// assign DST bytes & cache them
let dst_len = ctx_base.load_constant(F::from(S::DST.len() as u64));
let dst_len = thread_pool.main().load_constant(F::from(S::DST.len() as u64));
let dst_prime = cache
.dst_with_len
.get_or_insert_with(|| {
S::DST
.iter()
.map(|&b| ctx_base.load_constant(F::from(b as u64)))
.map(|&b| thread_pool.main().load_constant(F::from(b as u64)))
.chain(iter::once(dst_len))
.collect()
})
.clone();

// padding and length strings
let z_pad = i2osp(0, HC::BLOCK_SIZE, |b| zero); // TODO: cache these
let l_i_b_str = i2osp(len_in_bytes as u128, 2, |b| ctx_base.load_constant(b));
let l_i_b_str = i2osp(len_in_bytes as u128, 2, |b| thread_pool.main().load_constant(b));

// compute blocks
let ell = len_in_bytes.div_ceil(HC::DIGEST_SIZE);
Expand All @@ -243,15 +239,14 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,
.chain(dst_prime.clone());

let b_0 = hash_chip
.digest::<143>(ctx_base, ctx_sha, msg_prime.into(), false)?
.digest::<143>(thread_pool, msg_prime.into(), false)?
.output_bytes;

b_vals.insert(
0,
hash_chip
.digest::<77>(
ctx_base,
ctx_sha,
thread_pool,
b_0.into_iter()
.chain(iter::once(one))
.chain(dst_prime.clone())
Expand All @@ -262,16 +257,16 @@ impl<'a, S: Spec, F: Field, HC: HashInstructions<F> + 'a> HashToCurveChip<'a, S,
);

for i in 1..ell {
let preimg = strxor(b_0, b_vals[i - 1], gate, ctx_base)
let preimg = strxor(b_0, b_vals[i - 1], gate, thread_pool.main())
.into_iter()
.chain(iter::once(ctx_base.load_constant(F::from(i as u64 + 1))))
.chain(iter::once(thread_pool.main().load_constant(F::from(i as u64 + 1))))
.chain(dst_prime.clone())
.into();

b_vals.insert(
i,
hash_chip
.digest::<77>(ctx_base, ctx_sha, preimg, false)?
.digest::<77>(thread_pool, preimg, false)?
.output_bytes,
);
}
Expand Down Expand Up @@ -658,16 +653,14 @@ mod test {
let spread = SpreadChip::new(&range);

let sha256 = Sha256Chip::new(&range, spread);
let (ctx_base, mut ctx_sha) = builder.sha_contexts_pair();

let h2c_chip = HashToCurveChip::<Test, F, _>::new(&sha256);
let fp_chip = halo2_ecc::bls12_381::FpChip::<F>::new(&range, G2::LIMB_BITS, G2::NUM_LIMBS);

for input in input_vector {
let mut cache = HashToCurveCache::<F>::default();
let hp = h2c_chip.hash_to_curve::<G2>(
ctx_base,
&mut ctx_sha,
&mut builder,
&fp_chip,
input.clone().into_witness(),
&mut cache,
Expand Down
46 changes: 21 additions & 25 deletions lightclient-circuits/src/gadget/crypto/sha256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ pub trait HashInstructions<F: Field> {
/// `strict` flag indicates whether to perform range check on input bytes.
fn digest<const MAX_INPUT_SIZE: usize>(
&self,
ctx_base: &mut Context<F>,
ctx_sha: &mut ShaContexts<F>,
thread_pool: &mut ShaThreadBuilder<F>,
input: HashInput<QuantumCell<F>>,
strict: bool,
) -> Result<AssignedHashResult<F>, Error>;
Expand Down Expand Up @@ -77,8 +76,7 @@ impl<'a, F: Field> HashInstructions<F> for Sha256Chip<'a, F> {

fn digest<const MAX_INPUT_SIZE: usize>(
&self,
ctx_base: &mut Context<F>,
ctx_sha: &mut ShaContexts<F>,
thread_pool: &mut ShaThreadBuilder<F>,
input: HashInput<QuantumCell<F>>,
strict: bool,
) -> Result<AssignedHashResult<F>, Error> {
Expand All @@ -91,7 +89,7 @@ impl<'a, F: Field> HashInstructions<F> for Sha256Chip<'a, F> {
max_bytes
};

let assigned_input = input.into_assigned(ctx_base);
let assigned_input = input.into_assigned(thread_pool.main());

let mut assigned_input_bytes = assigned_input.to_vec();
let input_byte_size = assigned_input_bytes.len();
Expand All @@ -115,7 +113,7 @@ impl<'a, F: Field> HashInstructions<F> for Sha256Chip<'a, F> {
// remaining_byte_size,
// one_round_size * (max_round - num_round)
// );
let mut assign_byte = |byte: u8| ctx_base.load_witness(F::from(byte as u64));
let mut assign_byte = |byte: u8| thread_pool.main().load_witness(F::from(byte as u64));

assigned_input_bytes.push(assign_byte(0x80));

Expand All @@ -137,26 +135,26 @@ impl<'a, F: Field> HashInstructions<F> for Sha256Chip<'a, F> {

if strict {
for &assigned in assigned_input_bytes.iter() {
range.range_check(ctx_base, assigned, 8);
range.range_check(thread_pool.main(), assigned, 8);
}
}

let assigned_num_round = ctx_base.load_witness(F::from(num_round as u64));
let assigned_num_round = thread_pool.main().load_witness(F::from(num_round as u64));

// compute an initial state from the precomputed_input.
let mut last_state = INIT_STATE;

let mut assigned_last_state_vec = vec![last_state
.iter()
.map(|state| ctx_base.load_witness(F::from(*state as u64)))
.map(|state| thread_pool.main().load_witness(F::from(*state as u64)))
.collect_vec()];

let mut num_processed_input = 0;
while num_processed_input < max_processed_bytes {
let assigned_input_word_at_round =
&assigned_input_bytes[num_processed_input..(num_processed_input + one_round_size)];
let new_assigned_hs_out = sha256_compression(
ctx_base, ctx_sha,
thread_pool,
&self.spread,
assigned_input_word_at_round,
assigned_last_state_vec.last().unwrap(),
Expand All @@ -166,17 +164,17 @@ impl<'a, F: Field> HashInstructions<F> for Sha256Chip<'a, F> {
num_processed_input += one_round_size;
}

let zero = ctx_base.load_zero();
let zero = thread_pool.main().load_zero();
let mut output_h_out = vec![zero; 8];
for (n_round, assigned_state) in assigned_last_state_vec.into_iter().enumerate() {
let selector = gate.is_equal(
ctx_base,
thread_pool.main(),
QuantumCell::Constant(F::from(n_round as u64)),
assigned_num_round,
);
for i in 0..8 {
output_h_out[i] =
gate.select(ctx_base, assigned_state[i], output_h_out[i], selector)
gate.select(thread_pool.main(), assigned_state[i], output_h_out[i], selector)
}
}
let output_digest_bytes = output_h_out
Expand All @@ -185,21 +183,21 @@ impl<'a, F: Field> HashInstructions<F> for Sha256Chip<'a, F> {
let be_bytes = assigned_word.value().get_lower_32().to_be_bytes().to_vec();
let assigned_bytes = (0..4)
.map(|idx| {
let assigned = ctx_base.load_witness(F::from(be_bytes[idx] as u64));
range.range_check(ctx_base, assigned, 8);
let assigned = thread_pool.main().load_witness(F::from(be_bytes[idx] as u64));
range.range_check(thread_pool.main(), assigned, 8);
assigned
})
.collect_vec();
let mut sum = ctx_base.load_zero();
let mut sum = thread_pool.main().load_zero();
for (idx, assigned_byte) in assigned_bytes.iter().copied().enumerate() {
sum = gate.mul_add(
ctx_base,
thread_pool.main(),
assigned_byte,
QuantumCell::Constant(F::from(1u64 << (24 - 8 * idx))),
sum,
);
}
ctx_base.constrain_equal(&assigned_word, &sum);
thread_pool.main().constrain_equal(&assigned_word, &sum);
assigned_bytes
})
.collect_vec()
Expand Down Expand Up @@ -268,15 +266,13 @@ mod test {
let spread = SpreadChip::new(&range);

let sha256 = Sha256Chip::new(&range, spread);
let (ctx_base, mut ctx_sha) = builder.sha_contexts_pair();

for input in input_vector {
let _ = sha256
.digest::<64>(
ctx_base, &mut ctx_sha,
input.as_slice().into_witness(),
false,
)?;
let _ = sha256.digest::<64>(
&mut builder,
input.as_slice().into_witness(),
false,
)?;
}

builder.config(k, None);
Expand Down
Loading

0 comments on commit 0d5c09f

Please sign in to comment.