Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

devirgo style on phase 1 #83

Merged
merged 3 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions gkr-graph/src/circuit_graph_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ impl<E: ExtensionField> CircuitGraphBuilder<E> {
),
};
let old_num_instances = self.witness.node_witnesses[*id].n_instances();
// TODO find way to avoid expensive clone for wit_in
let new_instances = match pred {
PredType::PredWire(_) => {
let new_size = (old_num_instances * out[0].len()) / num_instances;
Expand Down
93 changes: 68 additions & 25 deletions gkr-graph/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ impl<E: ExtensionField> IOPProverState<E> {
expected_max_thread_id: usize,
) -> Result<IOPProof<E>, GKRGraphError> {
assert_eq!(target_evals.0.len(), circuit.targets.len());
assert_eq!(circuit_witness.node_witnesses.len(), circuit.nodes.len());

let mut output_evals = vec![vec![]; circuit.nodes.len()];
let mut wit_out_evals = circuit
Expand All @@ -36,10 +37,42 @@ impl<E: ExtensionField> IOPProverState<E> {
let gkr_proofs = izip!(&circuit.nodes, &circuit_witness.node_witnesses)
.rev()
.map(|(node, witness)| {
// println!("expected_max_thread_id {:?}", expected_max_thread_id);
let max_thread_id = witness.n_instances().min(expected_max_thread_id);
// println!("max_thread_id {:?}", max_thread_id);
let timer = std::time::Instant::now();

// sanity check for witness poly evaluation
if cfg!(debug_assertions) {

// TODO figure out a way to do sanity check on output_evals
// it doens't work for now because output evaluation
// might only take partial range of output layer witness
// assert!(output_evals[node.id].len() <= 1);
// if !output_evals[node.id].is_empty() {
// debug_assert_eq!(
// witness
// .output_layer_witness_ref()
// .instances
// .as_slice()
// .original_mle()
// .evaluate(&point_and_eval.point),
// point_and_eval.eval,
// "node_id {} output eval failed",
// node.id,
// );
// }

for (witness_id, point_and_eval) in wit_out_evals[node.id].iter().enumerate() {
let mle = witness.witness_out_ref()[witness_id]
.instances
.as_slice()
.original_mle();
debug_assert_eq!(
mle.evaluate(&point_and_eval.point),
point_and_eval.eval,
"node_id {} output eval failed",
node.id,
);
}
}
let (proof, input_claim) = GKRProverState::prove_parallel(
&node.circuit,
witness,
Expand All @@ -48,6 +81,7 @@ impl<E: ExtensionField> IOPProverState<E> {
max_thread_id,
transcript,
);

// println!(
// "Proving node {}, label {}, num_instances:{}, took {}s",
// node.id,
Expand All @@ -56,52 +90,61 @@ impl<E: ExtensionField> IOPProverState<E> {
// timer.elapsed().as_secs_f64()
// );

izip!(&node.preds, input_claim.point_and_evals)
izip!(&node.preds, &input_claim.point_and_evals)
.enumerate()
.for_each(|(wire_id, (pred, point_and_eval))| match pred {
.for_each(|(wire_id, (pred_type, point_and_eval))| match pred_type {
PredType::Source => {
debug_assert_eq!(
witness.witness_in_ref()[wire_id as usize]
// sanity check for input poly evaluation
if cfg!(debug_assertions) {
let input_layer_poly = witness.witness_in_ref()[wire_id]
.instances
.as_slice()
.original_mle()
.evaluate(&point_and_eval.point),
point_and_eval.eval
);
.original_mle();
debug_assert_eq!(
input_layer_poly.evaluate(&point_and_eval.point),
point_and_eval.eval,
"mismatch at node.id {:?} wire_id {:?}, input_claim.point_and_evals.point {:?}, node.preds {:?}",
node.id,
wire_id,
input_claim.point_and_evals[0].point,
node.preds
);
}
}
PredType::PredWire(out) | PredType::PredWireDup(out) => {
let point = match pred {
PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => {
let point = match pred_type {
PredType::PredWire(_) => point_and_eval.point.clone(),
PredType::PredWireDup(out) => {
let node_id = match out {
let pred_node_id = match out {
NodeOutputType::OutputLayer(id) => id,
NodeOutputType::WireOut(id, _) => id,
};
// Suppose the new point is
// [single_instance_slice ||
// new_instance_index_slice]. The old point
// is [single_instance_slices ||
// new_instance_index_slices[(new_instance_num_vars
// - old_instance_num_vars)..]]
let old_instance_num_vars = circuit_witness.node_witnesses
[*node_id]
// new_instance_index_slices[(instance_num_vars
// - pred_instance_num_vars)..]]
let pred_instance_num_vars = circuit_witness.node_witnesses
[*pred_node_id]
.instance_num_vars();
let new_instance_num_vars = witness.instance_num_vars();
let num_vars =
point_and_eval.point.len() - new_instance_num_vars;
let instance_num_vars = witness.instance_num_vars();
let num_vars = point_and_eval.point.len() - instance_num_vars;
[
point_and_eval.point[..num_vars].to_vec(),
point_and_eval.point[num_vars
+ (new_instance_num_vars - old_instance_num_vars)..]
+ (instance_num_vars - pred_instance_num_vars)..]
.to_vec(),
]
.concat()
}
_ => unreachable!(),
};
match out {
NodeOutputType::OutputLayer(id) => output_evals[*id]
.push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)),
match pred_out {
NodeOutputType::OutputLayer(id) => {
output_evals[*id]
.push(PointAndEval::new_from_ref(&point, &point_and_eval.eval))
},
NodeOutputType::WireOut(id, wire_id) => {
let evals = &mut wit_out_evals[*id][*wire_id as usize];
assert!(
Expand Down
90 changes: 47 additions & 43 deletions gkr-graph/src/verifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,52 +50,56 @@ impl<E: ExtensionField> IOPVerifierState<E> {

let new_instance_num_vars = aux_info.instance_num_vars[node.id];

izip!(&node.preds, input_claim.point_and_evals).for_each(|(pred, point_and_eval)| {
match pred {
PredType::Source => {
// TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations` for later PCS open?
}
PredType::PredWire(out) | PredType::PredWireDup(out) => {
let old_point = match pred {
PredType::PredWire(_) => point_and_eval.point.clone(),
PredType::PredWireDup(out) => {
let node_id = match out {
NodeOutputType::OutputLayer(id) => *id,
NodeOutputType::WireOut(id, _) => *id,
};
// Suppose the new point is
// [single_instance_slice ||
// new_instance_index_slice]. The old point
// is [single_instance_slices ||
// new_instance_index_slices[(new_instance_num_vars
// - old_instance_num_vars)..]]
let old_instance_num_vars = aux_info.instance_num_vars[node_id];
let num_vars = point_and_eval.point.len() - new_instance_num_vars;
[
point_and_eval.point[..num_vars].to_vec(),
point_and_eval.point[num_vars
+ (new_instance_num_vars - old_instance_num_vars)..]
.to_vec(),
]
.concat()
}
_ => unreachable!(),
};
match out {
NodeOutputType::OutputLayer(id) => output_evals[*id]
.push(PointAndEval::new_from_ref(&old_point, &point_and_eval.eval)),
NodeOutputType::WireOut(id, wire_id) => {
let evals = &mut wit_out_evals[*id][*wire_id as usize];
assert!(
evals.point.is_empty() && evals.eval.is_zero_vartime(),
"unimplemented",
);
*evals = PointAndEval::new(old_point, point_and_eval.eval);
izip!(&node.preds, input_claim.point_and_evals).for_each(
|(pred_type, point_and_eval)| {
match pred_type {
PredType::Source => {
// TODO: collect `(proof.point.clone(), *eval)` as `TargetEvaluations`
// for later PCS open?
}
PredType::PredWire(pred_out) | PredType::PredWireDup(pred_out) => {
let point = match pred_type {
PredType::PredWire(_) => point_and_eval.point.clone(),
PredType::PredWireDup(out) => {
let node_id = match out {
NodeOutputType::OutputLayer(id) => *id,
NodeOutputType::WireOut(id, _) => *id,
};
// Suppose the new point is
// [single_instance_slice ||
// new_instance_index_slice]. The old point
// is [single_instance_slices ||
// new_instance_index_slices[(new_instance_num_vars
// - old_instance_num_vars)..]]
let old_instance_num_vars = aux_info.instance_num_vars[node_id];
let num_vars =
point_and_eval.point.len() - new_instance_num_vars;
[
point_and_eval.point[..num_vars].to_vec(),
point_and_eval.point[num_vars
+ (new_instance_num_vars - old_instance_num_vars)..]
.to_vec(),
]
.concat()
}
_ => unreachable!(),
};
match pred_out {
NodeOutputType::OutputLayer(id) => output_evals[*id]
.push(PointAndEval::new_from_ref(&point, &point_and_eval.eval)),
NodeOutputType::WireOut(id, wire_id) => {
let evals = &mut wit_out_evals[*id][*wire_id as usize];
assert!(
evals.point.is_empty() && evals.eval.is_zero_vartime(),
"unimplemented",
);
*evals = PointAndEval::new(point, point_and_eval.eval);
}
}
}
}
}
});
},
);
}

Ok(())
Expand Down
4 changes: 1 addition & 3 deletions gkr/benches/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ use gkr::gadgets::keccak256::{keccak256_circuit, prove_keccak256, verify_keccak2
use goldilocks::GoldilocksExt2;
use sumcheck::util::is_power_of_2;

// cargo bench --bench keccak256 --features parallel --features flamegraph --package gkr -- --profile-time <secs>
cfg_if::cfg_if! {
if #[cfg(feature = "flamegraph")] {
criterion_group! {
Expand Down Expand Up @@ -48,8 +47,7 @@ fn bench_keccak256(c: &mut Criterion) {

#[cfg(feature = "non_pow2_rayon_thread")]
{
use sumcheck::local_thread_pool::create_local_pool_once;
use sumcheck::util::ceil_log2;
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
let max_thread_id = 1 << ceil_log2(RAYON_NUM_THREADS);
create_local_pool_once(1 << ceil_log2(RAYON_NUM_THREADS), true);
max_thread_id
Expand Down
3 changes: 1 addition & 2 deletions gkr/examples/keccak256.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ fn main() {

#[cfg(feature = "non_pow2_rayon_thread")]
{
use sumcheck::local_thread_pool::create_local_pool_once;
use sumcheck::util::ceil_log2;
use sumcheck::{local_thread_pool::create_local_pool_once, util::ceil_log2};
max_thread_id = 1 << ceil_log2(max_thread_id);
create_local_pool_once(max_thread_id, true);
}
Expand Down
13 changes: 7 additions & 6 deletions gkr/src/circuit/circuit_layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ impl<E: ExtensionField> Circuit<E> {
});
let segment = (
wire_ids_in_layer[in_cell_ids[0]],
wire_ids_in_layer[in_cell_ids[in_cell_ids.len() - 1]] + 1,
wire_ids_in_layer[in_cell_ids[in_cell_ids.len() - 1]] + 1, /* + 1 for exclusive
* last index */
);
match ty {
InType::Witness(wit_id) => {
Expand Down Expand Up @@ -258,9 +259,10 @@ impl<E: ExtensionField> Circuit<E> {
.push(output_subsets.update_wire_id(old_layer_id, old_wire_id));
}
OutType::AssertConst(constant) => {
let new_wire_id = output_subsets.update_wire_id(old_layer_id, old_wire_id);
output_assert_const.push(GateCIn {
idx_in: [],
idx_out: output_subsets.update_wire_id(old_layer_id, old_wire_id),
idx_out: new_wire_id,
scalar: ConstantType::Field(i64_to_field(constant)),
});
}
Expand Down Expand Up @@ -288,8 +290,7 @@ impl<E: ExtensionField> Circuit<E> {
} else {
let last_layer = &layers[(layer_id - 1) as usize];
if !last_layer.is_linear() || !layer.copy_to.is_empty() {
curr_sc_steps
.extend([SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2]);
curr_sc_steps.extend([SumcheckStepType::Phase1Step1]);
}
}

Expand Down Expand Up @@ -900,7 +901,7 @@ mod tests {
// Single input witness, therefore no input phase 2 steps.
assert_eq!(
circuit.layers[2].sumcheck_steps,
vec![SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2,]
vec![SumcheckStepType::Phase1Step1]
);
// There are only one incoming evals since the last layer is linear, and
// no subset evals. Therefore, there are no phase1 steps.
Expand Down Expand Up @@ -931,7 +932,7 @@ mod tests {
// Single input witness, therefore no input phase 2 steps.
assert_eq!(
circuit.layers[1].sumcheck_steps,
vec![SumcheckStepType::Phase1Step1, SumcheckStepType::Phase1Step2]
vec![SumcheckStepType::Phase1Step1]
);
// Output layer, single output witness, therefore no output phase 1 steps.
assert_eq!(
Expand Down
Loading
Loading