diff --git a/src/frontend/dsl/cb.rs b/src/frontend/dsl/cb.rs index 5b2f577f..d41b041b 100644 --- a/src/frontend/dsl/cb.rs +++ b/src/frontend/dsl/cb.rs @@ -716,4 +716,51 @@ mod tests { matches!(v[1], Expr::Const(c) if c == 40u64.field())) && matches!(v[1], Expr::Const(c) if c == 10u64.field()))); } + + #[test] + fn test_constraint_from_queriable() { + // Create a Queriable instance and convert it to a Constraint + let queriable = Queriable::StepTypeNext(StepTypeHandler::new("test_step".to_owned())); + let constraint: Constraint = Constraint::from(queriable); + + assert_eq!(constraint.annotation, "test_step"); + assert!( + matches!(constraint.expr, Expr::Query(Queriable::StepTypeNext(s)) if + matches!(s, StepTypeHandler {id: _id, annotation: "test_step"})) + ); + assert!(matches!(constraint.typing, Typing::Boolean)); + } + + #[test] + fn test_constraint_from_expr() { + // Create an expression and convert it to a Constraint + let expr = >>::expr(&10) * 20u64.expr(); + let constraint: Constraint = Constraint::from(expr); + + // returns "10 * 20" + assert!(matches!(constraint.expr, Expr::Mul(v) if v.len() == 2 && + matches!(v[0], Expr::Const(c) if c == 10u64.field()) && + matches!(v[1], Expr::Const(c) if c == 20u64.field()))); + assert!(matches!(constraint.typing, Typing::Unknown)); + } + + #[test] + fn test_constraint_from_int() { + // Create an integer and convert it to a Constraint + let constraint: Constraint = Constraint::from(10); + + // returns "10" + assert!(matches!(constraint.expr, Expr::Const(c) if c == 10u64.field())); + assert!(matches!(constraint.typing, Typing::Unknown)); + } + + #[test] + fn test_constraint_from_bool() { + // Create a boolean and convert it to a Constraint + let constraint: Constraint = Constraint::from(true); + + assert_eq!(constraint.annotation, "0x1"); + assert!(matches!(constraint.expr, Expr::Const(c) if c == 1u64.field())); + assert!(matches!(constraint.typing, Typing::Unknown)); + } } diff --git a/src/frontend/dsl/mod.rs b/src/frontend/dsl/mod.rs index 303e66b9..ce93ca46 100644 --- a/src/frontend/dsl/mod.rs +++ b/src/frontend/dsl/mod.rs @@ -154,6 +154,8 @@ impl CircuitContext { self.circuit.last_step = Some(step_type.into().uuid()); } + /// Enforce the number of step instances by adding a constraint to the circuit. Takes a `usize` + /// parameter that represents the total number of steps. pub fn pragma_num_steps(&mut self, num_steps: usize) { self.circuit.num_steps = num_steps; } @@ -231,6 +233,7 @@ impl StepTypeContext { } /// DEPRECATED + // #[deprecated(note = "use step types setup for constraints instead")] pub fn constr>>(&mut self, constraint: C) { println!("DEPRECATED constr: use setup for constraints in step types"); @@ -241,6 +244,7 @@ impl StepTypeContext { } /// DEPRECATED + #[deprecated(note = "use step types setup for constraints instead")] pub fn transition>>(&mut self, constraint: C) { println!("DEPRECATED transition: use setup for constraints in step types"); @@ -430,28 +434,49 @@ pub mod sc; #[cfg(test)] mod tests { + use crate::sbpir::ForwardSignal; + use super::*; + fn setup_circuit_context() -> CircuitContext + where + F: Default, + TraceArgs: Default, + { + CircuitContext { + circuit: SBPIR::default(), + tables: Default::default(), + } + } + #[test] - fn test_disable_q_enable() { + fn test_circuit_default_initialization() { let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; - context.pragma_disable_q_enable(); + // Assert default values + assert!(circuit.step_types.is_empty()); + assert!(circuit.forward_signals.is_empty()); + assert!(circuit.shared_signals.is_empty()); + assert!(circuit.fixed_signals.is_empty()); + assert!(circuit.exposed.is_empty()); + assert!(circuit.annotations.is_empty()); + assert!(circuit.trace.is_none()); + assert!(circuit.first_step.is_none()); + assert!(circuit.last_step.is_none()); + assert!(circuit.num_steps == 0); + assert!(circuit.q_enable); + } + #[test] + fn test_disable_q_enable() { + let mut context = setup_circuit_context::(); + context.pragma_disable_q_enable(); assert!(!context.circuit.q_enable); } #[test] fn test_set_num_steps() { - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); context.pragma_num_steps(3); assert_eq!(context.circuit.num_steps, 3); @@ -460,14 +485,29 @@ mod tests { assert_eq!(context.circuit.num_steps, 0); } + #[test] + fn test_set_first_step() { + let mut context = setup_circuit_context::(); + + let step_type: StepTypeHandler = context.step_type("step_type"); + + context.pragma_first_step(step_type); + assert_eq!(context.circuit.first_step, Some(step_type.uuid())); + } + + #[test] + fn test_set_last_step() { + let mut context = setup_circuit_context::(); + + let step_type: StepTypeHandler = context.step_type("step_type"); + + context.pragma_last_step(step_type); + assert_eq!(context.circuit.last_step, Some(step_type.uuid())); + } + #[test] fn test_forward() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signals let forward_a: Queriable = context.forward("forward_a"); @@ -479,14 +519,21 @@ mod tests { assert_eq!(context.circuit.forward_signals[1].uuid(), forward_b.uuid()); } + #[test] + fn test_adding_duplicate_signal_names() { + let mut context = setup_circuit_context::(); + context.forward("duplicate_name"); + context.forward("duplicate_name"); + // Assert how the system should behave. Does it override the previous signal, throw an + // error, or something else? + // TODO: Should we let the user know that they are adding a duplicate signal name? And let + // the circuit have two signals with the same name? + assert_eq!(context.circuit.forward_signals.len(), 2); + } + #[test] fn test_forward_with_phase() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signals with specified phase context.forward_with_phase("forward_a", 1); @@ -500,12 +547,7 @@ mod tests { #[test] fn test_shared() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set shared signal let shared_a: Queriable = context.shared("shared_a"); @@ -517,12 +559,7 @@ mod tests { #[test] fn test_shared_with_phase() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set shared signal with specified phase context.shared_with_phase("shared_a", 2); @@ -534,12 +571,7 @@ mod tests { #[test] fn test_fixed() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set fixed signal context.fixed("fixed_a"); @@ -550,12 +582,7 @@ mod tests { #[test] fn test_expose() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // set forward signal and step to expose let forward_a: Queriable = context.forward("forward_a"); @@ -572,14 +599,21 @@ mod tests { ); } + #[test] + #[ignore] + #[should_panic(expected = "Signal not found")] + fn test_expose_non_existing_signal() { + let mut context = setup_circuit_context::(); + let non_existing_signal = + Queriable::Forward(ForwardSignal::new_with_phase(0, "".to_owned()), false); // Create a signal not added to the circuit + context.expose(non_existing_signal, ExposeOffset::First); + + todo!("remove the ignore after fixing the check for non existing signals") + } + #[test] fn test_step_type() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type let handler: StepTypeHandler = context.step_type("fibo_first_step"); @@ -593,12 +627,7 @@ mod tests { #[test] fn test_step_type_def() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type including its definition let simple_step = context.step_type_def("simple_step", |context| { @@ -619,12 +648,7 @@ mod tests { #[test] fn test_step_type_def_pass_handler() { - // create circuit context - let circuit: SBPIR = SBPIR::default(); - let mut context = CircuitContext { - circuit, - tables: Default::default(), - }; + let mut context = setup_circuit_context::(); // create a step type handler let handler: StepTypeHandler = context.step_type("simple_step"); @@ -645,4 +669,23 @@ mod tests { context.circuit.step_types[&simple_step.uuid()].uuid() ); } + + #[test] + fn test_trace() { + let mut context = setup_circuit_context::(); + + // set trace function + context.trace(|_, _: i32| {}); + + // assert trace function was set + assert!(context.circuit.trace.is_some()); + } + + #[test] + #[should_panic(expected = "circuit cannot have more than one trace generator")] + fn test_setting_trace_multiple_times() { + let mut context = setup_circuit_context::(); + context.trace(|_, _| {}); + context.trace(|_, _| {}); + } } diff --git a/src/frontend/dsl/sc.rs b/src/frontend/dsl/sc.rs index f9ef60d5..8e3fdcd2 100644 --- a/src/frontend/dsl/sc.rs +++ b/src/frontend/dsl/sc.rs @@ -18,6 +18,7 @@ use crate::{ use super::{lb::LookupTableRegistry, CircuitContext}; +#[derive(Debug)] pub struct SuperCircuitContext { super_circuit: SuperCircuit, sub_circuit_phase1: Vec>, @@ -120,3 +121,205 @@ where ctx.compile() } + +#[cfg(test)] +mod tests { + use halo2curves::{bn256::Fr, ff::PrimeField}; + + use crate::{ + plonkish::compiler::{ + cell_manager::SingleRowCellManager, config, step_selector::SimpleStepSelectorBuilder, + }, + poly::ToField, + }; + + use super::*; + + #[test] + fn test_super_circuit_context_default() { + let ctx = SuperCircuitContext::::default(); + + assert_eq!( + format!("{:#?}", ctx.super_circuit), + format!("{:#?}", SuperCircuit::::default()) + ); + assert_eq!( + format!("{:#?}", ctx.sub_circuit_phase1), + format!("{:#?}", Vec::>::default()) + ); + assert_eq!(ctx.sub_circuit_phase1.len(), 0); + assert_eq!( + format!("{:#?}", ctx.tables), + format!("{:#?}", LookupTableRegistry::::default()) + ); + } + + #[test] + fn test_super_circuit_context_sub_circuit() { + let mut ctx = SuperCircuitContext::::default(); + + fn simple_circuit(ctx: &mut CircuitContext, _: ()) { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }) + } + + // simple circuit to check if the sum of two inputs are 10 + ctx.sub_circuit( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit, + (), + ); + + // ensure phase 1 was done correctly for the sub circuit + assert_eq!(ctx.sub_circuit_phase1.len(), 1); + assert_eq!(ctx.sub_circuit_phase1[0].columns.len(), 4); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!(ctx.sub_circuit_phase1[0].columns[2].annotation, "q_enable"); + assert_eq!( + ctx.sub_circuit_phase1[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + assert_eq!(ctx.sub_circuit_phase1[0].forward_signals.len(), 2); + assert_eq!(ctx.sub_circuit_phase1[0].step_types.len(), 1); + assert_eq!(ctx.sub_circuit_phase1[0].compilation_phase, 1); + } + + #[test] + fn test_super_circuit_compile() { + let mut ctx = SuperCircuitContext::::default(); + + fn simple_circuit(ctx: &mut CircuitContext, _: ()) { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }) + } + + // simple circuit to check if the sum of two inputs are 10 + ctx.sub_circuit( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit, + (), + ); + + let super_circuit = ctx.compile(); + + assert_eq!(super_circuit.get_sub_circuits().len(), 1); + assert_eq!(super_circuit.get_sub_circuits()[0].columns.len(), 4); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[2].annotation, + "q_enable" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + } + + #[test] + fn test_super_circuit_sub_circuit_with_ast() { + use crate::frontend::dsl::circuit; + let mut ctx = SuperCircuitContext::::default(); + + let simple_circuit_with_ast = circuit("simple circuit", |ctx| { + use crate::frontend::dsl::cb::*; + + let x = ctx.forward("x"); + let y = ctx.forward("y"); + + let step_type = ctx.step_type_def("sum should be 10", |ctx| { + ctx.setup(move |ctx| { + ctx.constr(eq(x + y, 10)); + }); + + ctx.wg(move |ctx, (x_value, y_value): (u32, u32)| { + ctx.assign(x, x_value.field()); + ctx.assign(y, y_value.field()); + }) + }); + + ctx.pragma_num_steps(1); + + ctx.trace(move |ctx, ()| { + ctx.add(&step_type, (2, 8)); + }); + }); + + ctx.sub_circuit_with_ast( + config(SingleRowCellManager {}, SimpleStepSelectorBuilder {}), + simple_circuit_with_ast, + ); + + let super_circuit = ctx.compile(); + + assert_eq!(super_circuit.get_sub_circuits().len(), 1); + assert_eq!(super_circuit.get_sub_circuits()[0].columns.len(), 4); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[0].annotation, + "srcm forward x" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[1].annotation, + "srcm forward y" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[2].annotation, + "q_enable" + ); + assert_eq!( + super_circuit.get_sub_circuits()[0].columns[3].annotation, + "'step selector for sum should be 10'" + ); + } +} diff --git a/src/plonkish/compiler/mod.rs b/src/plonkish/compiler/mod.rs index 1e40c710..d1a5c9d1 100644 --- a/src/plonkish/compiler/mod.rs +++ b/src/plonkish/compiler/mod.rs @@ -567,10 +567,69 @@ fn add_halo2_columns(unit: &mut CompilationUnit, ast: &astCircu } #[cfg(test)] -mod tests { - use super::*; +mod test { + use halo2_proofs::plonk::Any; use halo2curves::bn256::Fr; + use super::{cell_manager::SingleRowCellManager, step_selector::SimpleStepSelectorBuilder, *}; + + #[test] + fn test_compiler_config_initialization() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + + let config = config(cell_manager.clone(), step_selector_builder.clone()); + + assert_eq!( + format!("{:#?}", config.cell_manager), + format!("{:#?}", cell_manager) + ); + assert_eq!( + format!("{:#?}", config.step_selector_builder), + format!("{:#?}", step_selector_builder) + ); + } + + #[test] + fn test_compile() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + let config = config(cell_manager, step_selector_builder); + + let mock_ast_circuit = astCircuit::::default(); + + let (circuit, assignment_generator) = compile(config, &mock_ast_circuit); + + assert_eq!(circuit.columns.len(), 1); + assert_eq!(circuit.exposed.len(), 0); + assert_eq!(circuit.polys.len(), 0); + assert_eq!(circuit.lookups.len(), 0); + assert_eq!(circuit.fixed_assignments.len(), 1); + assert_eq!(circuit.ast_id, mock_ast_circuit.id); + + assert!(assignment_generator.is_none()); + } + + #[test] + fn test_compile_phase1() { + let cell_manager = SingleRowCellManager::default(); + let step_selector_builder = SimpleStepSelectorBuilder::default(); + let config = config(cell_manager, step_selector_builder); + + let mock_ast_circuit = astCircuit::::default(); + + let (unit, assignment_generator) = compile_phase1(config, &mock_ast_circuit); + + assert_eq!(unit.columns.len(), 1); + assert_eq!(unit.exposed.len(), 0); + assert_eq!(unit.polys.len(), 0); + assert_eq!(unit.lookups.len(), 0); + assert_eq!(unit.fixed_assignments.len(), 0); + assert_eq!(unit.ast_id, mock_ast_circuit.id); + + assert!(assignment_generator.is_none()); + } + #[test] #[should_panic] fn test_compile_phase2_before_phase1() { @@ -578,4 +637,19 @@ mod tests { compile_phase2(&mut unit); } + + #[test] + fn test_add_default_columns() { + let mock_ast_circuit = astCircuit::::default(); + + let mut unit = CompilationUnit::from(&mock_ast_circuit); + add_default_columns(&mut unit); + + assert_eq!(unit.columns.len(), 1); + assert_eq!(unit.exposed.len(), 0); + assert_eq!(unit.polys.len(), 0); + assert_eq!(unit.lookups.len(), 0); + assert_eq!(unit.fixed_assignments.len(), 0); + assert_eq!(unit.ast_id, mock_ast_circuit.id); + } } diff --git a/src/plonkish/compiler/step_selector.rs b/src/plonkish/compiler/step_selector.rs index b4d88a6f..fbe0eb53 100644 --- a/src/plonkish/compiler/step_selector.rs +++ b/src/plonkish/compiler/step_selector.rs @@ -289,6 +289,99 @@ mod tests { } } + #[test] + fn test_default_step_selector() { + let unit = mock_compilation_unit::(); + assert_eq!(unit.selector.columns.len(), 0); + assert_eq!(unit.selector.selector_expr.len(), 0); + assert_eq!(unit.selector.selector_expr_not.len(), 0); + assert_eq!(unit.selector.selector_assignment.len(), 0); + } + + #[test] + fn test_select_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let constraint = PolyExpr::Const(Fr::ONE); + + let step_uuid = step_type.uuid(); + let selector_expr = selector + .selector_expr + .get(&step_uuid) + .expect("Step not found") + .clone(); + let expected_expr = PolyExpr::Mul(vec![selector_expr, constraint.clone()]); + + assert_eq!( + format!("{:#?}", selector.select(step_uuid, &constraint)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_next_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let step_uuid = step_type.uuid(); + let step_height = 1; + let expected_expr = selector + .selector_expr + .get(&step_uuid) + .expect("Step not found") + .clone() + .rotate(step_height); + + assert_eq!( + format!("{:#?}", selector.next_expr(step_uuid, step_height as u32)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_unselect_step_selector() { + let mut unit = mock_compilation_unit::(); + let step_type = Rc::new(StepType::new(Uuid::nil().as_u128(), "StepType".to_string())); + unit.step_types.insert(step_type.uuid(), step_type.clone()); + + let builder = SimpleStepSelectorBuilder {}; + builder.build(&mut unit); + + let selector = &unit.selector; + let step_uuid = step_type.uuid(); + let expected_expr = selector + .selector_expr_not + .get(&step_uuid) + .expect("Step not found") + .clone(); + + assert_eq!( + format!("{:#?}", selector.unselect(step_uuid)), + format!("{:#?}", expected_expr) + ); + } + + #[test] + fn test_simple_step_selector_builder() { + let builder = SimpleStepSelectorBuilder {}; + let mut unit = mock_compilation_unit::(); + + add_step_types_to_unit(&mut unit, 2); + builder.build(&mut unit); + assert_common_tests(&unit, 2); + } + #[test] fn test_log_n_selector_builder_3_step_types() { let builder = LogNSelectorBuilder {}; diff --git a/src/plonkish/ir/sc.rs b/src/plonkish/ir/sc.rs index 05cf4663..9da33c11 100644 --- a/src/plonkish/ir/sc.rs +++ b/src/plonkish/ir/sc.rs @@ -1,4 +1,4 @@ -use std::{collections::HashMap, hash::Hash, rc::Rc}; +use std::{collections::HashMap, fmt::Debug, hash::Hash, rc::Rc}; use crate::{field::Field, sbpir::SBPIR, util::UUID, wit_gen::TraceWitness}; @@ -7,6 +7,7 @@ use super::{ Circuit, }; +#[derive(Debug)] pub struct SuperCircuit { sub_circuits: Vec>, mapping: MappingGenerator, @@ -75,6 +76,7 @@ impl SuperCircuit { pub type SuperAssignments = HashMap>; pub type SuperTraceWitness = HashMap>; +#[derive(Clone)] pub struct MappingContext { assignments: SuperAssignments, trace_witnesses: SuperTraceWitness, @@ -130,6 +132,12 @@ impl Clone for MappingGenerator { } } +impl std::fmt::Debug for MappingGenerator { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "MappingGenerator") + } +} + impl Default for MappingGenerator { fn default() -> Self { Self { @@ -160,3 +168,169 @@ impl MappingGenerator { ctx.get_trace_witnesses() } } + +#[cfg(test)] +mod test { + use halo2curves::bn256::Fr; + + use crate::{ + plonkish::{ + compiler::{cell_manager::Placement, step_selector::StepSelector}, + ir::Column, + }, + util::uuid, + wit_gen::{AutoTraceGenerator, TraceGenerator}, + }; + + use super::*; + + #[test] + fn test_default() { + let super_circuit: SuperCircuit = Default::default(); + + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits), + format!("{:#?}", Vec::>::default()) + ); + assert_eq!( + format!("{:#?}", super_circuit.mapping), + format!("{:#?}", MappingGenerator::::default()) + ); + } + + #[test] + fn test_add_sub_circuit() { + let mut super_circuit: SuperCircuit = Default::default(); + + fn simple_circuit() -> Circuit { + let columns = vec![Column::advice('a', 0)]; + let exposed = vec![(Column::advice('a', 0), 2)]; + let polys = vec![]; + let lookups = vec![]; + let fixed_assignments = Default::default(); + + Circuit { + columns, + exposed, + polys, + lookups, + fixed_assignments, + id: uuid(), + ast_id: uuid(), + } + } + + let sub_circuit = simple_circuit(); + + super_circuit.add_sub_circuit(sub_circuit.clone()); + + assert_eq!(super_circuit.sub_circuits.len(), 1); + assert_eq!(super_circuit.sub_circuits[0].id, sub_circuit.id); + assert_eq!(super_circuit.sub_circuits[0].ast_id, sub_circuit.ast_id); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].columns), + format!("{:#?}", sub_circuit.columns) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].exposed), + format!("{:#?}", sub_circuit.exposed) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].polys), + format!("{:#?}", sub_circuit.polys) + ); + assert_eq!( + format!("{:#?}", super_circuit.sub_circuits[0].lookups), + format!("{:#?}", sub_circuit.lookups) + ); + } + + #[test] + fn test_get_sub_circuits() { + fn simple_circuit() -> Circuit { + let columns = vec![Column::advice('a', 0)]; + let exposed = vec![(Column::advice('a', 0), 2)]; + let polys = vec![]; + let lookups = vec![]; + let fixed_assignments = Default::default(); + + Circuit { + columns, + exposed, + polys, + lookups, + fixed_assignments, + id: uuid(), + ast_id: uuid(), + } + } + + let super_circuit: SuperCircuit = SuperCircuit { + sub_circuits: vec![simple_circuit()], + mapping: Default::default(), + sub_circuit_asts: Default::default(), + }; + + let sub_circuits = super_circuit.get_sub_circuits(); + + assert_eq!(sub_circuits.len(), 1); + assert_eq!(sub_circuits[0].id, super_circuit.sub_circuits[0].id); + } + + #[test] + fn test_mapping_context_default() { + let ctx = MappingContext::::default(); + + assert_eq!( + format!("{:#?}", ctx.assignments), + format!("{:#?}", SuperAssignments::::default()) + ); + } + + fn simple_assignment_generator() -> AssignmentGenerator { + AssignmentGenerator::new( + vec![Column::advice('a', 0)], + Placement { + forward: HashMap::new(), + shared: HashMap::new(), + fixed: HashMap::new(), + steps: HashMap::new(), + columns: vec![], + base_height: 0, + }, + StepSelector::default(), + TraceGenerator::default(), + AutoTraceGenerator::default(), + 1, + uuid(), + ) + } + + #[test] + fn test_mapping_context_map() { + let mut ctx = MappingContext::::default(); + + assert_eq!(ctx.assignments.len(), 0); + + let gen = simple_assignment_generator(); + + ctx.map(&gen, ()); + + assert_eq!(ctx.assignments.len(), 1); + } + + #[test] + fn test_mapping_context_map_with_witness() { + let mut ctx = MappingContext::::default(); + + let gen = simple_assignment_generator(); + + let witness = TraceWitness:: { + step_instances: vec![], + }; + + ctx.map_with_witness(&gen, witness); + + assert_eq!(ctx.assignments.len(), 1); + } +} diff --git a/src/poly/mod.rs b/src/poly/mod.rs index fbc61bd0..01c12582 100644 --- a/src/poly/mod.rs +++ b/src/poly/mod.rs @@ -322,4 +322,69 @@ mod test { assert_eq!(experiment.eval(&assignments), None) } + + #[test] + fn test_degree_expr() { + use super::Expr::*; + + let expr: Expr = + (Query("a") * Query("a")) + (Query("c") * Query("d")) - Const(Fr::ONE); + + assert_eq!(expr.degree(), 2); + + let expr: Expr = + (Query("a") * Query("a")) + (Query("c") * Query("d")) * Query("e"); + + assert_eq!(expr.degree(), 3); + } + + #[test] + fn test_expr_sum() { + use super::Expr::*; + + let lhs: Expr = Query("a") + Query("b"); + + let rhs: Expr = Query("c") + Query("d"); + + assert_eq!( + format!("({:?} + {:?})", lhs, rhs), + format!("{:?}", Sum(vec![lhs, rhs])) + ); + } + + #[test] + fn test_expr_mul() { + use super::Expr::*; + + let lhs: Expr = Query("a") * Query("b"); + + let rhs: Expr = Query("c") * Query("d"); + + assert_eq!( + format!("({:?} * {:?})", lhs, rhs), + format!("{:?}", Mul(vec![lhs, rhs])) + ); + } + + #[test] + fn test_expr_neg() { + use super::Expr::*; + + let expr: Expr = Query("a") + Query("b"); + + assert_eq!( + format!("(-{:?})", expr), + format!("{:?}", Neg(Box::new(expr))) + ); + + let lhs: Expr = Query("a") * Query("b"); + let rhs: Expr = Query("c") + Query("d"); + + let expr: Expr = lhs.clone() - rhs.clone(); + + assert_eq!( + format!("{:?}", Sum(vec![lhs, Neg(Box::new(rhs))])), + format!("{:?}", expr) + ); + } } diff --git a/src/sbpir/query.rs b/src/sbpir/query.rs index 82cf8f45..5701b0d1 100644 --- a/src/sbpir/query.rs +++ b/src/sbpir/query.rs @@ -211,4 +211,161 @@ mod tests { let expr5: Expr> = Expr::Pow(Box::new(Expr::Const(a)), 2); assert_eq!(format!("{:?}", expr5), "(0xa)^2"); } + + #[test] + fn test_next_for_forward_signal() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, false); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Forward(forward_signal, true)); + } + + #[test] + #[should_panic(expected = "jarrl: cannot rotate next(forward)")] + fn test_next_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.next(); // This should panic + } + + #[test] + fn test_next_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 0); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Shared(shared_signal, 1)); + } + + #[test] + fn test_next_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 0); + let next_queriable = queriable.next(); + + assert_eq!(next_queriable, Queriable::Fixed(fixed_signal, 1)); + } + + #[test] + #[should_panic(expected = "can only next a forward, shared, fixed, or halo2 column")] + fn test_next_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.next(); // This should panic + } + + #[test] + fn test_prev_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 1); + let prev_queriable = queriable.prev(); + + assert_eq!(prev_queriable, Queriable::Shared(shared_signal, 0)); + } + + #[test] + fn test_prev_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 1); + let prev_queriable = queriable.prev(); + + assert_eq!(prev_queriable, Queriable::Fixed(fixed_signal, 0)); + } + + #[test] + #[should_panic(expected = "can only prev a shared or fixed column")] + fn test_prev_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.prev(); // This should panic + } + + #[test] + #[should_panic(expected = "can only prev a shared or fixed column")] + fn test_prev_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.prev(); // This should panic + } + + #[test] + fn test_rot_for_shared_signal() { + let shared_signal = SharedSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Shared(shared_signal, 1); + let rot_queriable = queriable.rot(2); + + assert_eq!(rot_queriable, Queriable::Shared(shared_signal, 3)); + } + + #[test] + fn test_rot_for_fixed_signal() { + let fixed_signal = FixedSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Fixed(fixed_signal, 1); + let rot_queriable = queriable.rot(2); + + assert_eq!(rot_queriable, Queriable::Fixed(fixed_signal, 3)); + } + + #[test] + #[should_panic(expected = "can only rot a shared or fixed column")] + fn test_rot_for_forward_signal_panic() { + let forward_signal = ForwardSignal { + id: 0, + phase: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Forward(forward_signal, true); + let _ = queriable.rot(2); // This should panic + } + + #[test] + #[should_panic(expected = "can only rot a shared or fixed column")] + fn test_rot_for_internal_signal_panic() { + let internal_signal = InternalSignal { + id: 0, + annotation: "", + }; + let queriable: Queriable = Queriable::Internal(internal_signal); + let _ = queriable.rot(2); // This should panic + } }