Skip to content

Commit

Permalink
Refactor Brillig solver struct to capture the running state
Browse files Browse the repository at this point in the history
In preparation to support step-by-step execution of Brillig opcodes when run
inside an ACIR block.
  • Loading branch information
ggiraldez committed Oct 6, 2023
1 parent 2c87b27 commit 89bde9f
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 47 deletions.
130 changes: 87 additions & 43 deletions acvm-repo/acvm/src/pwg/brillig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,79 @@ use crate::{pwg::OpcodeNotSolvable, OpcodeResolutionError};

use super::{get_value, insert_value};

pub(super) struct BrilligSolver;
pub(super) enum BrilligSolverStatus {
Finished,
InProgress,
ForeignCallWait(ForeignCallWaitInfo),
}

impl BrilligSolver {
pub(super) fn solve<B: BlackBoxFunctionSolver>(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
bb_solver: &B,
pub(super) struct BrilligSolver<'b, B: BlackBoxFunctionSolver> {
witness: &'b mut WitnessMap,
brillig: &'b Brillig,
acir_index: usize,
vm: VM<'b, B>,
}

impl<'b, B: BlackBoxFunctionSolver> BrilligSolver<'b, B> {
pub(super) fn build_or_skip(
initial_witness: &'b mut WitnessMap,
brillig: &'b Brillig,
bb_solver: &'b B,
acir_index: usize,
) -> Result<Option<ForeignCallWaitInfo>, OpcodeResolutionError> {
) -> Result<Option<Self>, OpcodeResolutionError> {
if Self::should_skip(initial_witness, brillig)? {
Self::zero_out_brillig_outputs(initial_witness, brillig)?;
return Ok(None);
}

let vm = Self::setup_vm(initial_witness, brillig, bb_solver)?;
Ok(Some(
Self {
witness: initial_witness,
brillig,
acir_index,
vm,
}
))
}

fn should_skip(witness: &mut WitnessMap, brillig: &Brillig) -> Result<bool, OpcodeResolutionError> {
// If the predicate is `None`, then we simply return the value 1
// If the predicate is `Some` but we cannot find a value, then we return stalled
let pred_value = match &brillig.predicate {
Some(pred) => get_value(pred, initial_witness),
Some(pred) => get_value(pred, witness),
None => Ok(FieldElement::one()),
}?;

// A zero predicate indicates the oracle should be skipped, and its outputs zeroed.
if pred_value.is_zero() {
Self::zero_out_brillig_outputs(initial_witness, brillig)?;
return Ok(None);
Ok(pred_value.is_zero())
}

/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr {
insert_value(witness, FieldElement::zero(), initial_witness)?;
}
}
}
}
Ok(())
}

fn setup_vm(
witness: &mut WitnessMap,
brillig: &Brillig,
bb_solver: &'b B,
) -> Result<VM<'b, B>, OpcodeResolutionError> {
// Set input values
let mut input_register_values: Vec<Value> = Vec::new();
let mut input_memory: Vec<Value> = Vec::new();
Expand All @@ -45,7 +96,7 @@ impl BrilligSolver {
// If a certain expression is not solvable, we stall the ACVM and do not proceed with Brillig VM execution.
for input in &brillig.inputs {
match input {
BrilligInputs::Single(expr) => match get_value(expr, initial_witness) {
BrilligInputs::Single(expr) => match get_value(expr, witness) {
Ok(value) => input_register_values.push(value.into()),
Err(_) => {
return Err(OpcodeResolutionError::OpcodeNotSolvable(
Expand All @@ -57,7 +108,7 @@ impl BrilligSolver {
// Attempt to fetch all array input values
let memory_pointer = input_memory.len();
for expr in expr_arr.iter() {
match get_value(expr, initial_witness) {
match get_value(expr, witness) {
Ok(value) => input_memory.push(value.into()),
Err(_) => {
return Err(OpcodeResolutionError::OpcodeNotSolvable(
Expand All @@ -76,39 +127,32 @@ impl BrilligSolver {
// Instantiate a Brillig VM given the solved input registers and memory
// along with the Brillig bytecode, and any present foreign call results.
let input_registers = Registers::load(input_register_values);
let mut vm = VM::new(
Ok(VM::new(
input_registers,
input_memory,
brillig.bytecode.clone(),
brillig.foreign_call_results.clone(),
bb_solver,
);
))
}

pub(super) fn solve(&mut self) -> Result<BrilligSolverStatus, OpcodeResolutionError> {
// Run the Brillig VM on these inputs, bytecode, etc!
let vm_status = vm.process_opcodes();
while matches!(self.vm.process_opcode(), VMStatus::InProgress) {}

self.finish_execution()
}

pub(super) fn finish_execution(&mut self) -> Result<BrilligSolverStatus, OpcodeResolutionError> {
// Check the status of the Brillig VM.
// It may be finished, in-progress, failed, or may be waiting for results of a foreign call.
// Return the "resolution" to the caller who may choose to make subsequent calls
// (when it gets foreign call results for example).
let vm_status = self.vm.get_status();
match vm_status {
VMStatus::Finished => {
for (i, output) in brillig.outputs.iter().enumerate() {
let register_value = vm.get_registers().get(RegisterIndex::from(i));
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, register_value.to_field(), initial_witness)?;
}
BrilligOutputs::Array(witness_arr) => {
// Treat the register value as a pointer to memory
for (i, witness) in witness_arr.iter().enumerate() {
let value = &vm.get_memory()[register_value.to_usize() + i];
insert_value(witness, value.to_field(), initial_witness)?;
}
}
}
}
Ok(None)
self.write_brillig_outputs()?;
Ok(BrilligSolverStatus::Finished)
}
VMStatus::InProgress => unreachable!("Brillig VM has not completed execution"),
VMStatus::Failure { message, call_stack } => {
Expand All @@ -117,31 +161,31 @@ impl BrilligSolver {
call_stack: call_stack
.iter()
.map(|brillig_index| OpcodeLocation::Brillig {
acir_index,
acir_index: self.acir_index,
brillig_index: *brillig_index,
})
.collect(),
})
}
VMStatus::ForeignCallWait { function, inputs } => {
Ok(Some(ForeignCallWaitInfo { function, inputs }))
Ok(BrilligSolverStatus::ForeignCallWait(ForeignCallWaitInfo { function, inputs }))
}
}
}

/// Assigns the zero value to all outputs of the given [`Brillig`] bytecode.
fn zero_out_brillig_outputs(
initial_witness: &mut WitnessMap,
brillig: &Brillig,
) -> Result<(), OpcodeResolutionError> {
for output in &brillig.outputs {
fn write_brillig_outputs(&mut self) -> Result<(), OpcodeResolutionError> {
// Write VM execution results into the witness map
for (i, output) in self.brillig.outputs.iter().enumerate() {
let register_value = self.vm.get_registers().get(RegisterIndex::from(i));
match output {
BrilligOutputs::Simple(witness) => {
insert_value(witness, FieldElement::zero(), initial_witness)?;
insert_value(witness, register_value.to_field(), self.witness)?;
}
BrilligOutputs::Array(witness_arr) => {
for witness in witness_arr {
insert_value(witness, FieldElement::zero(), initial_witness)?;
// Treat the register value as a pointer to memory
for (i, witness) in witness_arr.iter().enumerate() {
let value = &self.vm.get_memory()[register_value.to_usize() + i];
insert_value(witness, value.to_field(), self.witness)?;
}
}
}
Expand Down
17 changes: 13 additions & 4 deletions acvm-repo/acvm/src/pwg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use acir::{
use acvm_blackbox_solver::BlackBoxResolutionError;

use self::{
arithmetic::ArithmeticSolver, brillig::BrilligSolver, directives::solve_directives,
arithmetic::ArithmeticSolver, brillig::{BrilligSolver, BrilligSolverStatus}, directives::solve_directives,
memory_op::MemoryOpSolver,
};
use crate::{BlackBoxFunctionSolver, Language};
Expand Down Expand Up @@ -258,13 +258,22 @@ impl<'backend, B: BlackBoxFunctionSolver> ACVM<'backend, B> {
solver.solve_memory_op(op, &mut self.witness_map, predicate)
}
Opcode::Brillig(brillig) => {
match BrilligSolver::solve(
let result = BrilligSolver::build_or_skip(
&mut self.witness_map,
brillig,
self.backend,
self.instruction_pointer,
) {
Ok(Some(foreign_call)) => return self.wait_for_foreign_call(foreign_call),
);
match result {
Ok(Some(mut solver)) => {
match solver.solve() {
Ok(BrilligSolverStatus::ForeignCallWait(foreign_call)) =>
return self.wait_for_foreign_call(foreign_call),
Ok(BrilligSolverStatus::InProgress) =>
unreachable!("Brillig solver still in progress"),
res => res.map(|_| ()),
}
}
res => res.map(|_| ()),
}
}
Expand Down
4 changes: 4 additions & 0 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ impl<'bb_solver, B: BlackBoxFunctionSolver> VM<'bb_solver, B> {
status
}

pub fn get_status(&self) -> VMStatus {
self.status.clone()
}

/// Sets the current status of the VM to Finished (completed execution).
fn finish(&mut self) -> VMStatus {
self.status(VMStatus::Finished)
Expand Down

0 comments on commit 89bde9f

Please sign in to comment.