diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index e205750a176..12b408760a7 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -26,38 +26,21 @@ pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> { } impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { - /// Constructs a solver for a Brillig block given the bytecode and initial - /// witness. If the block should be skipped entirely because its predicate - /// evaluates to false, zero out the block outputs and return Ok(None). - pub(super) fn build_or_skip<'w>( - initial_witness: &'w mut WitnessMap, - brillig: &'w Brillig, - bb_solver: &'b B, - acir_index: usize, - ) -> Result, OpcodeResolutionError> { - if Self::should_skip(initial_witness, brillig)? { - Self::zero_out_brillig_outputs(initial_witness, brillig)?; - return Ok(None); - } - - let vm = Self::build_vm(initial_witness, brillig, bb_solver)?; - Ok(Some(Self { vm, acir_index })) - } - - fn should_skip(witness: &WitnessMap, brillig: &Brillig) -> Result { - // If the predicate is `None`, then we simply return the value 1 + /// Evaluates if the Brillig block should be skipped entirely + pub(super) fn should_skip( + witness: &WitnessMap, + brillig: &Brillig, + ) -> Result { + // If the predicate is `None`, the block should never be skipped // If the predicate is `Some` but we cannot find a value, then we return stalled - let pred_value = match &brillig.predicate { - Some(pred) => get_value(pred, witness), - None => Ok(FieldElement::one()), - }?; - - // A zero predicate indicates the oracle should be skipped, and its outputs zeroed. - Ok(pred_value.is_zero()) + match &brillig.predicate { + Some(pred) => Ok(get_value(pred, witness)?.is_zero()), + None => Ok(false), + } } /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. - fn zero_out_brillig_outputs( + pub(super) fn zero_out_brillig_outputs( initial_witness: &mut WitnessMap, brillig: &Brillig, ) -> Result<(), OpcodeResolutionError> { @@ -76,11 +59,14 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { Ok(()) } - fn build_vm( - witness: &WitnessMap, + /// Constructs a solver for a Brillig block given the bytecode and initial + /// witness. + pub(super) fn new( + initial_witness: &mut WitnessMap, brillig: &Brillig, bb_solver: &'b B, - ) -> Result, OpcodeResolutionError> { + acir_index: usize, + ) -> Result { // Set input values let mut input_register_values: Vec = Vec::new(); let mut input_memory: Vec = Vec::new(); @@ -90,7 +76,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { // If a certain expression is not solvable, we stall the ACVM and do not proceed with Brillig VM execution. for input in &brillig.inputs { match input { - BrilligInputs::Single(expr) => match get_value(expr, witness) { + BrilligInputs::Single(expr) => match get_value(expr, initial_witness) { Ok(value) => input_register_values.push(value.into()), Err(_) => { return Err(OpcodeResolutionError::OpcodeNotSolvable( @@ -102,7 +88,7 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { // Attempt to fetch all array input values let memory_pointer = input_memory.len(); for expr in expr_arr.iter() { - match get_value(expr, witness) { + match get_value(expr, initial_witness) { Ok(value) => input_memory.push(value.into()), Err(_) => { return Err(OpcodeResolutionError::OpcodeNotSolvable( @@ -121,13 +107,14 @@ impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { // Instantiate a Brillig VM given the solved input registers and memory // along with the Brillig bytecode, and any present foreign call results. let input_registers = Registers::load(input_register_values); - Ok(VM::new( + let vm = VM::new( input_registers, input_memory, brillig.bytecode.clone(), brillig.foreign_call_results.clone(), bb_solver, - )) + ); + Ok(Self { vm, acir_index }) } pub(super) fn solve(&mut self) -> Result { diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 1ac17b135c3..bd672906369 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -258,39 +258,10 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { let solver = self.block_solvers.entry(*block_id).or_default(); solver.solve_memory_op(op, &mut self.witness_map, predicate) } - Opcode::Brillig(brillig) => { - let witness = &mut self.witness_map; - // get the active Brillig solver, or try to build one if necessary - // (Brillig execution maybe bypassed by constraints) - let maybe_solver = match self.brillig_solver.as_mut() { - Some(solver) => Ok(Some(solver)), - None => BrilligSolver::build_or_skip( - witness, - brillig, - self.backend, - self.instruction_pointer, - ) - .map(|optional_solver| { - optional_solver.map(|solver| self.brillig_solver.insert(solver)) - }), - }; - match maybe_solver { - Ok(Some(solver)) => match solver.solve() { - Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => { - return self.wait_for_foreign_call(foreign_call); - } - Ok(BrilligSolverStatus::InProgress) => { - unreachable!("Brillig solver still in progress") - } - Ok(BrilligSolverStatus::Finished) => { - // clear active Brillig solver and write execution outputs - self.brillig_solver.take().unwrap().finalize(witness, brillig) - } - res => res.map(|_| ()), - }, - res => res.map(|_| ()), - } - } + Opcode::Brillig(_) => match self.solve_brillig_opcode() { + Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call), + res => res.map(|_| ()), + }, }; match resolution { Ok(()) => { @@ -324,6 +295,40 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { } } } + + fn solve_brillig_opcode( + &mut self, + ) -> Result, OpcodeResolutionError> { + let brillig = match &self.opcodes[self.instruction_pointer] { + Opcode::Brillig(brillig) => brillig, + _ => unreachable!("Not executing a Brillig opcode"), + }; + let witness = &mut self.witness_map; + if BrilligSolver::::should_skip(witness, brillig)? { + BrilligSolver::::zero_out_brillig_outputs(witness, brillig).map(|_| None) + } else { + let mut solver = match self.brillig_solver.take() { + None => { + BrilligSolver::new(witness, brillig, self.backend, self.instruction_pointer)? + } + Some(solver) => solver, + }; + match solver.solve()? { + BrilligSolverStatus::ForeignCallWait(foreign_call) => { + _ = self.brillig_solver.insert(solver); + Ok(Some(foreign_call)) + } + BrilligSolverStatus::InProgress => { + unreachable!("Brillig solver still in progress") + } + BrilligSolverStatus::Finished => { + // Write execution outputs + solver.finalize(witness, brillig)?; + Ok(None) + } + } + } + } } // Returns the concrete value for a particular witness