From 89bde9f194f9ae1535a0efd64b84dd6143de6238 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gustavo=20Gir=C3=A1ldez?= Date: Wed, 4 Oct 2023 17:59:28 -0400 Subject: [PATCH] Refactor Brillig solver struct to capture the running state In preparation to support step-by-step execution of Brillig opcodes when run inside an ACIR block. --- acvm-repo/acvm/src/pwg/brillig.rs | 130 ++++++++++++++++++++---------- acvm-repo/acvm/src/pwg/mod.rs | 17 +++- acvm-repo/brillig_vm/src/lib.rs | 4 + 3 files changed, 104 insertions(+), 47 deletions(-) diff --git a/acvm-repo/acvm/src/pwg/brillig.rs b/acvm-repo/acvm/src/pwg/brillig.rs index 9b0ecd87492..54181921426 100644 --- a/acvm-repo/acvm/src/pwg/brillig.rs +++ b/acvm-repo/acvm/src/pwg/brillig.rs @@ -14,28 +14,79 @@ use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError}; use super::{get_value, insert_value}; -pub(super) struct BrilligSolver; +pub(super) enum BrilligSolverStatus { + Finished, + InProgress, + ForeignCallWait(ForeignCallWaitInfo), +} -impl BrilligSolver { - pub(super) fn solve( - initial_witness: &mut WitnessMap, - brillig: &Brillig, - bb_solver: &B, +pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> { + witness: &'b mut WitnessMap, + brillig: &'b Brillig, + acir_index: usize, + vm: VM<'b, B>, +} + +impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> { + pub(super) fn build_or_skip( + initial_witness: &'b mut WitnessMap, + brillig: &'b Brillig, + bb_solver: &'b B, acir_index: usize, - ) -> Result, OpcodeResolutionError> { + ) -> Result, OpcodeResolutionError> { + if Self::should_skip(initial_witness, brillig)? { + Self::zero_out_brillig_outputs(initial_witness, brillig)?; + return Ok(None); + } + + let vm = Self::setup_vm(initial_witness, brillig, bb_solver)?; + Ok(Some( + Self { + witness: initial_witness, + brillig, + acir_index, + vm, + } + )) + } + + fn should_skip(witness: &mut WitnessMap, brillig: &Brillig) -> Result { // If the predicate is `None`, then we simply return the value 1 // If the predicate is `Some` but we cannot find a value, then we return stalled let pred_value = match &brillig.predicate { - Some(pred) => get_value(pred, initial_witness), + Some(pred) => get_value(pred, witness), None => Ok(FieldElement::one()), }?; // A zero predicate indicates the oracle should be skipped, and its outputs zeroed. - if pred_value.is_zero() { - Self::zero_out_brillig_outputs(initial_witness, brillig)?; - return Ok(None); + Ok(pred_value.is_zero()) + } + + /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. + fn zero_out_brillig_outputs( + initial_witness: &mut WitnessMap, + brillig: &Brillig, + ) -> Result<(), OpcodeResolutionError> { + for output in &brillig.outputs { + match output { + BrilligOutputs::Simple(witness) => { + insert_value(witness, FieldElement::zero(), initial_witness)?; + } + BrilligOutputs::Array(witness_arr) => { + for witness in witness_arr { + insert_value(witness, FieldElement::zero(), initial_witness)?; + } + } + } } + Ok(()) + } + fn setup_vm( + witness: &mut WitnessMap, + brillig: &Brillig, + bb_solver: &'b B, + ) -> Result, OpcodeResolutionError> { // Set input values let mut input_register_values: Vec = Vec::new(); let mut input_memory: Vec = Vec::new(); @@ -45,7 +96,7 @@ impl BrilligSolver { // If a certain expression is not solvable, we stall the ACVM and do not proceed with Brillig VM execution. for input in &brillig.inputs { match input { - BrilligInputs::Single(expr) => match get_value(expr, initial_witness) { + BrilligInputs::Single(expr) => match get_value(expr, witness) { Ok(value) => input_register_values.push(value.into()), Err(_) => { return Err(OpcodeResolutionError::OpcodeNotSolvable( @@ -57,7 +108,7 @@ impl BrilligSolver { // Attempt to fetch all array input values let memory_pointer = input_memory.len(); for expr in expr_arr.iter() { - match get_value(expr, initial_witness) { + match get_value(expr, witness) { Ok(value) => input_memory.push(value.into()), Err(_) => { return Err(OpcodeResolutionError::OpcodeNotSolvable( @@ -76,39 +127,32 @@ impl BrilligSolver { // Instantiate a Brillig VM given the solved input registers and memory // along with the Brillig bytecode, and any present foreign call results. let input_registers = Registers::load(input_register_values); - let mut vm = VM::new( + Ok(VM::new( input_registers, input_memory, brillig.bytecode.clone(), brillig.foreign_call_results.clone(), bb_solver, - ); + )) + } + pub(super) fn solve(&mut self) -> Result { // Run the Brillig VM on these inputs, bytecode, etc! - let vm_status = vm.process_opcodes(); + while matches!(self.vm.process_opcode(), VMStatus::InProgress) {} + self.finish_execution() + } + + pub(super) fn finish_execution(&mut self) -> Result { // Check the status of the Brillig VM. // It may be finished, in-progress, failed, or may be waiting for results of a foreign call. // Return the "resolution" to the caller who may choose to make subsequent calls // (when it gets foreign call results for example). + let vm_status = self.vm.get_status(); match vm_status { VMStatus::Finished => { - for (i, output) in brillig.outputs.iter().enumerate() { - let register_value = vm.get_registers().get(RegisterIndex::from(i)); - match output { - BrilligOutputs::Simple(witness) => { - insert_value(witness, register_value.to_field(), initial_witness)?; - } - BrilligOutputs::Array(witness_arr) => { - // Treat the register value as a pointer to memory - for (i, witness) in witness_arr.iter().enumerate() { - let value = &vm.get_memory()[register_value.to_usize() + i]; - insert_value(witness, value.to_field(), initial_witness)?; - } - } - } - } - Ok(None) + self.write_brillig_outputs()?; + Ok(BrilligSolverStatus::Finished) } VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"), VMStatus::Failure { message, call_stack } => { @@ -117,31 +161,31 @@ impl BrilligSolver { call_stack: call_stack .iter() .map(|brillig_index| OpcodeLocation::Brillig { - acir_index, + acir_index: self.acir_index, brillig_index: *brillig_index, }) .collect(), }) } VMStatus::ForeignCallWait { function, inputs } => { - Ok(Some(ForeignCallWaitInfo { function, inputs })) + Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs })) } } } - /// Assigns the zero value to all outputs of the given [`Brillig`] bytecode. - fn zero_out_brillig_outputs( - initial_witness: &mut WitnessMap, - brillig: &Brillig, - ) -> Result<(), OpcodeResolutionError> { - for output in &brillig.outputs { + fn write_brillig_outputs(&mut self) -> Result<(), OpcodeResolutionError> { + // Write VM execution results into the witness map + for (i, output) in self.brillig.outputs.iter().enumerate() { + let register_value = self.vm.get_registers().get(RegisterIndex::from(i)); match output { BrilligOutputs::Simple(witness) => { - insert_value(witness, FieldElement::zero(), initial_witness)?; + insert_value(witness, register_value.to_field(), self.witness)?; } BrilligOutputs::Array(witness_arr) => { - for witness in witness_arr { - insert_value(witness, FieldElement::zero(), initial_witness)?; + // Treat the register value as a pointer to memory + for (i, witness) in witness_arr.iter().enumerate() { + let value = &self.vm.get_memory()[register_value.to_usize() + i]; + insert_value(witness, value.to_field(), self.witness)?; } } } diff --git a/acvm-repo/acvm/src/pwg/mod.rs b/acvm-repo/acvm/src/pwg/mod.rs index 3fcf1088225..7fc94433da8 100644 --- a/acvm-repo/acvm/src/pwg/mod.rs +++ b/acvm-repo/acvm/src/pwg/mod.rs @@ -11,7 +11,7 @@ use acir::{ use acvm_blackbox_solver::BlackBoxResolutionError; use self::{ - arithmetic::ArithmeticSolver, brillig::BrilligSolver, directives::solve_directives, + arithmetic::ArithmeticSolver, brillig::{BrilligSolver, BrilligSolverStatus}, directives::solve_directives, memory_op::MemoryOpSolver, }; use crate::{BlackBoxFunctionSolver, Language}; @@ -258,13 +258,22 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> { solver.solve_memory_op(op, &mut self.witness_map, predicate) } Opcode::Brillig(brillig) => { - match BrilligSolver::solve( + let result = BrilligSolver::build_or_skip( &mut self.witness_map, brillig, self.backend, self.instruction_pointer, - ) { - Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call), + ); + match result { + Ok(Some(mut solver)) => { + match solver.solve() { + Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) => + return self.wait_for_foreign_call(foreign_call), + Ok(BrilligSolverStatus::InProgress) => + unreachable!("Brillig solver still in progress"), + res => res.map(|_| ()), + } + } res => res.map(|_| ()), } } diff --git a/acvm-repo/brillig_vm/src/lib.rs b/acvm-repo/brillig_vm/src/lib.rs index e2c8ae6521a..eb31000203a 100644 --- a/acvm-repo/brillig_vm/src/lib.rs +++ b/acvm-repo/brillig_vm/src/lib.rs @@ -112,6 +112,10 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> { status } + pub fn get_status(&self) -> VMStatus { + self.status.clone() + } + /// Sets the current status of the VM to Finished (completed execution). fn finish(&mut self) -> VMStatus { self.status(VMStatus::Finished)