Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
Feature/improve unit tests coverage (#189)
Browse files Browse the repository at this point in the history
Added more unit tests to the following modules:
- poly
- ast/query
- dsl
- dsl/cb
- compiler/step_selector
- compiler
- super_circuit

Covers some of these issues:
#157 #102 #105
  • Loading branch information
rutefig authored Mar 21, 2024
1 parent ad07077 commit d139e1b
Show file tree
Hide file tree
Showing 8 changed files with 924 additions and 68 deletions.
47 changes: 47 additions & 0 deletions src/frontend/dsl/cb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -716,4 +716,51 @@ mod tests {
matches!(v[1], Expr::Const(c) if c == 40u64.field())) &&
matches!(v[1], Expr::Const(c) if c == 10u64.field())));
}

#[test]
fn test_constraint_from_queriable() {
// Create a Queriable instance and convert it to a Constraint
let queriable = Queriable::StepTypeNext(StepTypeHandler::new("test_step".to_owned()));
let constraint: Constraint<Fr> = Constraint::from(queriable);

assert_eq!(constraint.annotation, "test_step");
assert!(
matches!(constraint.expr, Expr::Query(Queriable::StepTypeNext(s)) if
matches!(s, StepTypeHandler {id: _id, annotation: "test_step"}))
);
assert!(matches!(constraint.typing, Typing::Boolean));
}

#[test]
fn test_constraint_from_expr() {
// Create an expression and convert it to a Constraint
let expr = <u64 as ToExpr<Fr, Queriable<Fr>>>::expr(&10) * 20u64.expr();
let constraint: Constraint<Fr> = Constraint::from(expr);

// returns "10 * 20"
assert!(matches!(constraint.expr, Expr::Mul(v) if v.len() == 2 &&
matches!(v[0], Expr::Const(c) if c == 10u64.field()) &&
matches!(v[1], Expr::Const(c) if c == 20u64.field())));
assert!(matches!(constraint.typing, Typing::Unknown));
}

#[test]
fn test_constraint_from_int() {
// Create an integer and convert it to a Constraint
let constraint: Constraint<Fr> = Constraint::from(10);

// returns "10"
assert!(matches!(constraint.expr, Expr::Const(c) if c == 10u64.field()));
assert!(matches!(constraint.typing, Typing::Unknown));
}

#[test]
fn test_constraint_from_bool() {
// Create a boolean and convert it to a Constraint
let constraint: Constraint<Fr> = Constraint::from(true);

assert_eq!(constraint.annotation, "0x1");
assert!(matches!(constraint.expr, Expr::Const(c) if c == 1u64.field()));
assert!(matches!(constraint.typing, Typing::Unknown));
}
}
173 changes: 108 additions & 65 deletions src/frontend/dsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ impl<F, TraceArgs> CircuitContext<F, TraceArgs> {
self.circuit.last_step = Some(step_type.into().uuid());
}

/// Enforce the number of step instances by adding a constraint to the circuit. Takes a `usize`
/// parameter that represents the total number of steps.
pub fn pragma_num_steps(&mut self, num_steps: usize) {
self.circuit.num_steps = num_steps;
}
Expand Down Expand Up @@ -231,6 +233,7 @@ impl<F> StepTypeContext<F> {
}

/// DEPRECATED
// #[deprecated(note = "use step types setup for constraints instead")]
pub fn constr<C: Into<Constraint<F>>>(&mut self, constraint: C) {
println!("DEPRECATED constr: use setup for constraints in step types");

Expand All @@ -241,6 +244,7 @@ impl<F> StepTypeContext<F> {
}

/// DEPRECATED
#[deprecated(note = "use step types setup for constraints instead")]
pub fn transition<C: Into<Constraint<F>>>(&mut self, constraint: C) {
println!("DEPRECATED transition: use setup for constraints in step types");

Expand Down Expand Up @@ -430,28 +434,49 @@ pub mod sc;

#[cfg(test)]
mod tests {
use crate::sbpir::ForwardSignal;

use super::*;

fn setup_circuit_context<F, TraceArgs>() -> CircuitContext<F, TraceArgs>
where
F: Default,
TraceArgs: Default,
{
CircuitContext {
circuit: SBPIR::default(),
tables: Default::default(),
}
}

#[test]
fn test_disable_q_enable() {
fn test_circuit_default_initialization() {
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};

context.pragma_disable_q_enable();
// Assert default values
assert!(circuit.step_types.is_empty());
assert!(circuit.forward_signals.is_empty());
assert!(circuit.shared_signals.is_empty());
assert!(circuit.fixed_signals.is_empty());
assert!(circuit.exposed.is_empty());
assert!(circuit.annotations.is_empty());
assert!(circuit.trace.is_none());
assert!(circuit.first_step.is_none());
assert!(circuit.last_step.is_none());
assert!(circuit.num_steps == 0);
assert!(circuit.q_enable);
}

#[test]
fn test_disable_q_enable() {
let mut context = setup_circuit_context::<i32, i32>();
context.pragma_disable_q_enable();
assert!(!context.circuit.q_enable);
}

#[test]
fn test_set_num_steps() {
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

context.pragma_num_steps(3);
assert_eq!(context.circuit.num_steps, 3);
Expand All @@ -460,14 +485,29 @@ mod tests {
assert_eq!(context.circuit.num_steps, 0);
}

#[test]
fn test_set_first_step() {
let mut context = setup_circuit_context::<i32, i32>();

let step_type: StepTypeHandler = context.step_type("step_type");

context.pragma_first_step(step_type);
assert_eq!(context.circuit.first_step, Some(step_type.uuid()));
}

#[test]
fn test_set_last_step() {
let mut context = setup_circuit_context::<i32, i32>();

let step_type: StepTypeHandler = context.step_type("step_type");

context.pragma_last_step(step_type);
assert_eq!(context.circuit.last_step, Some(step_type.uuid()));
}

#[test]
fn test_forward() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// set forward signals
let forward_a: Queriable<i32> = context.forward("forward_a");
Expand All @@ -479,14 +519,21 @@ mod tests {
assert_eq!(context.circuit.forward_signals[1].uuid(), forward_b.uuid());
}

#[test]
fn test_adding_duplicate_signal_names() {
let mut context = setup_circuit_context::<i32, i32>();
context.forward("duplicate_name");
context.forward("duplicate_name");
// Assert how the system should behave. Does it override the previous signal, throw an
// error, or something else?
// TODO: Should we let the user know that they are adding a duplicate signal name? And let
// the circuit have two signals with the same name?
assert_eq!(context.circuit.forward_signals.len(), 2);
}

#[test]
fn test_forward_with_phase() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// set forward signals with specified phase
context.forward_with_phase("forward_a", 1);
Expand All @@ -500,12 +547,7 @@ mod tests {

#[test]
fn test_shared() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// set shared signal
let shared_a: Queriable<i32> = context.shared("shared_a");
Expand All @@ -517,12 +559,7 @@ mod tests {

#[test]
fn test_shared_with_phase() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// set shared signal with specified phase
context.shared_with_phase("shared_a", 2);
Expand All @@ -534,12 +571,7 @@ mod tests {

#[test]
fn test_fixed() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// set fixed signal
context.fixed("fixed_a");
Expand All @@ -550,12 +582,7 @@ mod tests {

#[test]
fn test_expose() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// set forward signal and step to expose
let forward_a: Queriable<i32> = context.forward("forward_a");
Expand All @@ -572,14 +599,21 @@ mod tests {
);
}

#[test]
#[ignore]
#[should_panic(expected = "Signal not found")]
fn test_expose_non_existing_signal() {
let mut context = setup_circuit_context::<i32, i32>();
let non_existing_signal =
Queriable::Forward(ForwardSignal::new_with_phase(0, "".to_owned()), false); // Create a signal not added to the circuit
context.expose(non_existing_signal, ExposeOffset::First);

todo!("remove the ignore after fixing the check for non existing signals")
}

#[test]
fn test_step_type() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// create a step type
let handler: StepTypeHandler = context.step_type("fibo_first_step");
Expand All @@ -593,12 +627,7 @@ mod tests {

#[test]
fn test_step_type_def() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// create a step type including its definition
let simple_step = context.step_type_def("simple_step", |context| {
Expand All @@ -619,12 +648,7 @@ mod tests {

#[test]
fn test_step_type_def_pass_handler() {
// create circuit context
let circuit: SBPIR<i32, i32> = SBPIR::default();
let mut context = CircuitContext {
circuit,
tables: Default::default(),
};
let mut context = setup_circuit_context::<i32, i32>();

// create a step type handler
let handler: StepTypeHandler = context.step_type("simple_step");
Expand All @@ -645,4 +669,23 @@ mod tests {
context.circuit.step_types[&simple_step.uuid()].uuid()
);
}

#[test]
fn test_trace() {
let mut context = setup_circuit_context::<i32, i32>();

// set trace function
context.trace(|_, _: i32| {});

// assert trace function was set
assert!(context.circuit.trace.is_some());
}

#[test]
#[should_panic(expected = "circuit cannot have more than one trace generator")]
fn test_setting_trace_multiple_times() {
let mut context = setup_circuit_context::<i32, i32>();
context.trace(|_, _| {});
context.trace(|_, _| {});
}
}
Loading

0 comments on commit d139e1b

Please sign in to comment.