Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor of the withdrawal credentials balance aggregation proof circuits #345

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use itertools::Itertools;
use plonky2::{
field::extension::Extendable,
hash::hash_types::RichField,
Expand Down Expand Up @@ -37,7 +38,7 @@ impl<F: RichField + Extendable<D>, const D: usize> CircuitBuilderExtensions
let a_targets = a.to_targets();
let b_targets = b.to_targets();

let pairs = a_targets.iter().zip(b_targets.iter());
let pairs = a_targets.iter().zip_eq(b_targets.iter());

let targets = pairs.fold(vec![], |mut acc, (&a_target, &b_target)| {
acc.push(self._if(selector, a_target, b_target));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,17 @@ async fn async_main() -> Result<()> {
&proof,
hex::encode(bits_to_bytes(circuit_input.block_root.as_slice())),
withdrawal_credentials,
balance_verification_pis.range_total_value.to_u64().unwrap(),
balance_verification_pis.number_of_non_activated_validators,
balance_verification_pis.number_of_active_validators,
balance_verification_pis.number_of_exited_validators,
balance_verification_pis.number_of_slashed_validators,
balance_verification_pis
.accumulated_data
.balance
.to_u64()
.unwrap(),
balance_verification_pis
.accumulated_data
.non_activated_count,
balance_verification_pis.accumulated_data.active_count,
balance_verification_pis.accumulated_data.exited_count,
balance_verification_pis.accumulated_data.slashed_count,
)
.await?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ impl<const WITHDRAWAL_CREDENTIALS_COUNT: usize> Circuit
assert_slot_is_in_epoch(builder, &input.slot, &balance_verification_pi.current_epoch);

let accumulated_balance_bits =
biguint_to_bits_target(builder, &balance_verification_pi.range_total_value);
biguint_to_bits_target(builder, &balance_verification_pi.accumulated_data.balance);

let flattened_withdrawal_credentials = balance_verification_pi
.withdrawal_credentials
Expand All @@ -134,15 +134,19 @@ impl<const WITHDRAWAL_CREDENTIALS_COUNT: usize> Circuit

let number_of_non_activated_validators_bits = target_to_le_bits(
builder,
balance_verification_pi.number_of_non_activated_validators,
balance_verification_pi.accumulated_data.non_activated_count,
);
let number_of_active_validators_bits = target_to_le_bits(
builder,
balance_verification_pi.accumulated_data.active_count,
);
let number_of_exited_validators_bits = target_to_le_bits(
builder,
balance_verification_pi.accumulated_data.exited_count,
);
let number_of_active_validators_bits =
target_to_le_bits(builder, balance_verification_pi.number_of_active_validators);
let number_of_exited_validators_bits =
target_to_le_bits(builder, balance_verification_pi.number_of_exited_validators);
let number_of_slashed_validators_bits = target_to_le_bits(
builder,
balance_verification_pi.number_of_slashed_validators,
balance_verification_pi.accumulated_data.slashed_count,
);

let mut public_inputs_hash = sha256(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
use crate::{
common_targets::ValidatorTarget,
common_targets::{SSZTarget, ValidatorTarget},
serializers::{serde_bool_array_to_hex_string, serde_bool_array_to_hex_string_nested},
utils::circuit::{
bool_arrays_are_equal,
hashing::merkle::{
poseidon::{hash_tree_root_poseidon, hash_validator_poseidon_or_zeroes},
sha256::hash_tree_root_sha256,
ssz::ssz_num_from_bits,
},
select_biguint,
validator_status::get_validator_status,
},
};
use circuit::Circuit;
use circuit_derive::{CircuitTarget, SerdeCircuitTarget};
use itertools::Itertools;
use circuit::{circuit_builder_extensions::CircuitBuilderExtensions, Circuit};
use circuit_derive::{CircuitTarget, PublicInputsReadable, SerdeCircuitTarget, TargetPrimitive};
use itertools::{izip, Itertools};

use plonky2::{
field::goldilocks_field::GoldilocksField,
hash::hash_types::HashOutTarget,
field::{extension::Extendable, goldilocks_field::GoldilocksField},
hash::hash_types::{HashOutTarget, RichField},
iop::target::{BoolTarget, Target},
plonk::{
circuit_builder::CircuitBuilder, circuit_data::CircuitConfig,
Expand All @@ -32,6 +30,16 @@ use crate::{
serializers::{biguint_to_str, parse_biguint},
};

#[derive(PublicInputsReadable, TargetPrimitive, SerdeCircuitTarget)]
pub struct AccumulatedValidatorsData {
#[serde(serialize_with = "biguint_to_str", deserialize_with = "parse_biguint")]
pub balance: BigUintTarget,
pub non_activated_count: Target,
pub active_count: Target,
pub exited_count: Target,
pub slashed_count: Target,
}

#[derive(CircuitTarget, SerdeCircuitTarget)]
#[serde(rename_all = "camelCase")]
pub struct ValidatorBalanceVerificationTargets<
Expand All @@ -58,10 +66,6 @@ pub struct ValidatorBalanceVerificationTargets<
#[serde(serialize_with = "biguint_to_str", deserialize_with = "parse_biguint")]
pub current_epoch: BigUintTarget,

#[target(out)]
#[serde(serialize_with = "biguint_to_str", deserialize_with = "parse_biguint")]
pub range_total_value: BigUintTarget,

#[target(out)]
#[serde(with = "serde_bool_array_to_hex_string")]
pub range_balances_root: Sha256Target,
Expand All @@ -70,16 +74,7 @@ pub struct ValidatorBalanceVerificationTargets<
pub range_validator_commitment: HashOutTarget,

#[target(out)]
pub number_of_non_activated_validators: Target,

#[target(out)]
pub number_of_active_validators: Target,

#[target(out)]
pub number_of_exited_validators: Target,

#[target(out)]
pub number_of_slashed_validators: Target,
pub accumulated_data: AccumulatedValidatorsData,
}

pub struct WithdrawalCredentialsBalanceAggregatorFirstLevel<
Expand Down Expand Up @@ -119,7 +114,7 @@ where
let validators_leaves = input
.validators
.iter()
.zip(input.non_zero_validator_leaves_mask)
.zip_eq(input.non_zero_validator_leaves_mask)
.map(|(validator, is_not_zero)| {
hash_validator_poseidon_or_zeroes(builder, &validator, is_not_zero)
})
Expand All @@ -128,82 +123,83 @@ where
let validators_hash_tree_root_poseidon =
hash_tree_root_poseidon(builder, &validators_leaves);

let mut range_total_value = builder.zero_biguint();
let mut number_of_non_activated_validators = builder.zero();
let mut number_of_active_validators = builder.zero();
let mut number_of_exited_validators = builder.zero();
let mut number_of_slashed_validators = builder.zero();
let accumulated_data = accumulate_data(
builder,
&input.validators,
&input.balances_leaves,
&input.withdrawal_credentials,
&input.current_epoch,
);

for i in 0..VALIDATORS_COUNT {
let mut validator_is_considered = builder._false();

for j in 0..WITHDRAWAL_CREDENTIALS_COUNT {
let is_equal_inner = bool_arrays_are_equal(
builder,
&input.validators[i].withdrawal_credentials,
&input.withdrawal_credentials[j],
);

validator_is_considered = builder.or(is_equal_inner, validator_is_considered);
}
Self::Target {
validators: input.validators,
non_zero_validator_leaves_mask: input.non_zero_validator_leaves_mask,
withdrawal_credentials: input.withdrawal_credentials,
balances_leaves: input.balances_leaves,
current_epoch: input.current_epoch,
range_balances_root,
range_validator_commitment: validators_hash_tree_root_poseidon,
accumulated_data,
}
}
}

let balance = ssz_num_from_bits(
fn accumulate_data<F: RichField + Extendable<D>, const D: usize>(
builder: &mut CircuitBuilder<F, D>,
validators: &[ValidatorTarget],
balance_leaves: &[SSZTarget],
withdrawal_credentials: &[Sha256Target],
current_epoch: &BigUintTarget,
) -> AccumulatedValidatorsData {
let considered_validators_mask = validators
.iter()
.map(|validator| {
withdrawal_credentials
.iter()
.fold(builder._false(), |acc, credentials| {
let credentials_match =
builder.targets_are_equal(&validator.withdrawal_credentials, credentials);
builder.or(acc, credentials_match)
})
})
.collect_vec();

let balances = balance_leaves
.into_iter()
.flatten()
.copied()
.collect_vec()
.chunks(64)
.into_iter()
.map(|balance_bits| ssz_num_from_bits(builder, balance_bits))
.collect_vec();

let zero_accumulated_data: AccumulatedValidatorsData = builder.zero_init();

izip!(validators, &balances, considered_validators_mask).fold(
zero_accumulated_data,
|acc, (validator, balance, is_considered)| {
let (is_non_activated, is_active, is_exited) = get_validator_status(
builder,
&input.balances_leaves[i / 4][((i % 4) * 64)..(((i % 4) * 64) + 64)],
&validator.activation_epoch,
&current_epoch,
&validator.exit_epoch,
);

let zero = builder.zero_biguint();

let (is_non_activated_validator, is_valid_validator, is_exited_validator) =
get_validator_status(
builder,
&input.validators[i].activation_epoch,
&input.current_epoch,
&input.validators[i].exit_epoch,
);

let will_be_counted = builder.and(validator_is_considered, is_valid_validator);

let current = select_biguint(builder, will_be_counted, &balance, &zero);

range_total_value = builder.add_biguint(&range_total_value, &current);

number_of_active_validators =
builder.add(number_of_active_validators, will_be_counted.target);
let should_sum_balance = builder.and(is_considered, is_active);

let will_be_counted = builder.and(validator_is_considered, is_non_activated_validator);
let mut summed_balance = builder.add_biguint(&acc.balance, balance);
summed_balance.limbs.pop().unwrap();

number_of_non_activated_validators =
builder.add(number_of_non_activated_validators, will_be_counted.target);

let will_be_counted = builder.and(validator_is_considered, is_exited_validator);

number_of_exited_validators =
builder.add(number_of_exited_validators, will_be_counted.target);

let validator_is_considered_and_is_slashed =
builder.and(validator_is_considered, input.validators[i].slashed);
number_of_slashed_validators = builder.add(
number_of_slashed_validators,
validator_is_considered_and_is_slashed.target,
);

range_total_value.limbs.pop();
}
let new_accumulated_data = AccumulatedValidatorsData {
balance: builder.select_target(should_sum_balance, &summed_balance, &acc.balance),
non_activated_count: builder.add(acc.non_activated_count, is_non_activated.target),
active_count: builder.add(acc.active_count, is_active.target),
exited_count: builder.add(acc.exited_count, is_exited.target),
slashed_count: builder.add(acc.slashed_count, validator.slashed.target),
};

Self::Target {
non_zero_validator_leaves_mask: input.non_zero_validator_leaves_mask,
range_total_value,
range_balances_root,
range_validator_commitment: validators_hash_tree_root_poseidon,
validators: input.validators,
balances_leaves: input.balances_leaves,
withdrawal_credentials: input.withdrawal_credentials,
current_epoch: input.current_epoch,
number_of_non_activated_validators,
number_of_active_validators,
number_of_exited_validators,
number_of_slashed_validators,
}
}
builder.select_target(is_considered, &new_accumulated_data, &acc)
},
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ use plonky2::{
};
use plonky2_crypto::biguint::CircuitBuilderBiguint;

use super::first_level::AccumulatedValidatorsData;

pub struct WithdrawalCredentialsBalanceAggregatorInnerLevel<
const VALIDATORS_COUNT: usize,
const WITHDRAWAL_CREDENTIALS_COUNT: usize,
Expand Down Expand Up @@ -72,31 +74,31 @@ where
&r_input.range_balances_root,
);

let number_of_non_activated_validators = builder.add(
l_input.number_of_non_activated_validators,
r_input.number_of_non_activated_validators,
);

let number_of_active_validators = builder.add(
l_input.number_of_active_validators,
r_input.number_of_active_validators,
);

let number_of_exited_validators = builder.add(
l_input.number_of_exited_validators,
r_input.number_of_exited_validators,
let mut accumulated_balance = builder.add_biguint(
&l_input.accumulated_data.balance,
&r_input.accumulated_data.balance,
);

let number_of_slashed_validators = builder.add(
l_input.number_of_slashed_validators,
r_input.number_of_slashed_validators,
);

let mut range_total_value =
builder.add_biguint(&l_input.range_total_value, &r_input.range_total_value);

// pop carry
range_total_value.limbs.pop();
accumulated_balance.limbs.pop();

let accumulated_data = AccumulatedValidatorsData {
balance: accumulated_balance,
non_activated_count: builder.add(
l_input.accumulated_data.non_activated_count,
r_input.accumulated_data.non_activated_count,
),
active_count: builder.add(
l_input.accumulated_data.active_count,
r_input.accumulated_data.active_count,
),
exited_count: builder.add(
l_input.accumulated_data.exited_count,
r_input.accumulated_data.exited_count,
),
slashed_count: builder.add(
l_input.accumulated_data.slashed_count,
r_input.accumulated_data.slashed_count,
),
};

for i in 0..WITHDRAWAL_CREDENTIALS_COUNT {
connect_bool_arrays(
Expand All @@ -115,14 +117,10 @@ where
>,
> {
current_epoch: l_input.current_epoch,
range_total_value,
range_balances_root,
withdrawal_credentials: l_input.withdrawal_credentials,
range_validator_commitment,
number_of_non_activated_validators,
number_of_active_validators,
number_of_exited_validators,
number_of_slashed_validators,
accumulated_data,
};

output_target.register_public_inputs(builder);
Expand Down
Loading