Skip to content

Commit

Permalink
fix: Runtime brillig bigint id assignment (noir-lang#5369)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

Resolves noir-lang#5368

## Summary\*

Switches BigInd id assignment to be done in runtime, to fix assignment
within control flow.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[For Experimental Features]** Documentation to be submitted in a
separate PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.
  • Loading branch information
sirasistant authored Jul 1, 2024
1 parent 7ffccf7 commit a8928dd
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 61 deletions.
81 changes: 66 additions & 15 deletions acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
op: &BlackBoxOp,
solver: &Solver,
memory: &mut Memory<F>,
bigint_solver: &mut BigIntSolver,
bigint_solver: &mut BrilligBigintSolver,
) -> Result<(), BlackBoxResolutionError> {
match op {
BlackBoxOp::AES128Encrypt { inputs, iv, key, outputs } => {
Expand Down Expand Up @@ -270,38 +270,44 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
BlackBoxOp::BigIntAdd { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntAdd)?;

let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntAdd)?;
memory.write(*output, new_id.into());
Ok(())
}
BlackBoxOp::BigIntSub { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntSub)?;

let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntSub)?;
memory.write(*output, new_id.into());
Ok(())
}
BlackBoxOp::BigIntMul { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntMul)?;

let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntMul)?;
memory.write(*output, new_id.into());
Ok(())
}
BlackBoxOp::BigIntDiv { lhs, rhs, output } => {
let lhs = memory.read(*lhs).try_into().unwrap();
let rhs = memory.read(*rhs).try_into().unwrap();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_op(lhs, rhs, output, BlackBoxFunc::BigIntDiv)?;

let new_id = bigint_solver.bigint_op(lhs, rhs, BlackBoxFunc::BigIntDiv)?;
memory.write(*output, new_id.into());
Ok(())
}
BlackBoxOp::BigIntFromLeBytes { inputs, modulus, output } => {
let input = read_heap_vector(memory, inputs);
let input: Vec<u8> = input.iter().map(|x| x.try_into().unwrap()).collect();
let modulus = read_heap_vector(memory, modulus);
let modulus: Vec<u8> = modulus.iter().map(|x| x.try_into().unwrap()).collect();
let output = memory.read(*output).try_into().unwrap();
bigint_solver.bigint_from_bytes(&input, &modulus, output)?;

let new_id = bigint_solver.bigint_from_bytes(&input, &modulus)?;
memory.write(*output, new_id.into());

Ok(())
}
BlackBoxOp::BigIntToLeBytes { input, output } => {
Expand Down Expand Up @@ -381,6 +387,46 @@ pub(crate) fn evaluate_black_box<F: AcirField, Solver: BlackBoxFunctionSolver<F>
}
}

/// Wrapper over the generic bigint solver to automatically assign bigint ids in brillig
#[derive(Default, Debug, Clone, PartialEq, Eq)]
pub(crate) struct BrilligBigintSolver {
bigint_solver: BigIntSolver,
last_id: u32,
}

impl BrilligBigintSolver {
pub(crate) fn create_bigint_id(&mut self) -> u32 {
let output = self.last_id;
self.last_id += 1;
output
}

pub(crate) fn bigint_from_bytes(
&mut self,
inputs: &[u8],
modulus: &[u8],
) -> Result<u32, BlackBoxResolutionError> {
let id = self.create_bigint_id();
self.bigint_solver.bigint_from_bytes(inputs, modulus, id)?;
Ok(id)
}

pub(crate) fn bigint_to_bytes(&self, input: u32) -> Result<Vec<u8>, BlackBoxResolutionError> {
self.bigint_solver.bigint_to_bytes(input)
}

pub(crate) fn bigint_op(
&mut self,
lhs: u32,
rhs: u32,
func: BlackBoxFunc,
) -> Result<u32, BlackBoxResolutionError> {
let id = self.create_bigint_id();
self.bigint_solver.bigint_op(lhs, rhs, id, func)?;
Ok(id)
}
}

fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc {
match op {
BlackBoxOp::AES128Encrypt { .. } => BlackBoxFunc::AES128Encrypt,
Expand Down Expand Up @@ -414,10 +460,10 @@ mod test {
brillig::{BlackBoxOp, MemoryAddress},
FieldElement,
};
use acvm_blackbox_solver::{BigIntSolver, StubbedBlackBoxSolver};
use acvm_blackbox_solver::StubbedBlackBoxSolver;

use crate::{
black_box::{evaluate_black_box, to_u8_vec, to_value_vec},
black_box::{evaluate_black_box, to_u8_vec, to_value_vec, BrilligBigintSolver},
HeapArray, HeapVector, Memory,
};

Expand All @@ -439,8 +485,13 @@ mod test {
output: HeapArray { pointer: 2.into(), size: 32 },
};

evaluate_black_box(&op, &StubbedBlackBoxSolver, &mut memory, &mut BigIntSolver::default())
.unwrap();
evaluate_black_box(
&op,
&StubbedBlackBoxSolver,
&mut memory,
&mut BrilligBigintSolver::default(),
)
.unwrap();

let result = memory.read_slice(MemoryAddress(result_pointer), 32);

Expand Down
6 changes: 3 additions & 3 deletions acvm-repo/brillig_vm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ use acir::brillig::{
HeapVector, MemoryAddress, Opcode, ValueOrArray,
};
use acir::AcirField;
use acvm_blackbox_solver::{BigIntSolver, BlackBoxFunctionSolver};
use acvm_blackbox_solver::BlackBoxFunctionSolver;
use arithmetic::{evaluate_binary_field_op, evaluate_binary_int_op, BrilligArithmeticError};
use black_box::evaluate_black_box;
use black_box::{evaluate_black_box, BrilligBigintSolver};
use num_bigint::BigUint;

// Re-export `brillig`.
Expand Down Expand Up @@ -88,7 +88,7 @@ pub struct VM<'a, F, B: BlackBoxFunctionSolver<F>> {
/// The solver for blackbox functions
black_box_solver: &'a B,
// The solver for big integers
bigint_solver: BigIntSolver,
bigint_solver: BrilligBigintSolver,
}

impl<'a, F: AcirField, B: BlackBoxFunctionSolver<F>> VM<'a, F, B> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -243,13 +243,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString>(
[BrilligVariable::SingleAddr(output), BrilligVariable::SingleAddr(modulus_id)],
) = (function_arguments, function_results)
{
prepare_bigint_output(
brillig_context,
lhs_modulus,
rhs_modulus,
output,
modulus_id,
);
prepare_bigint_output(brillig_context, lhs_modulus, rhs_modulus, modulus_id);
brillig_context.black_box_op_instruction(BlackBoxOp::BigIntAdd {
lhs: lhs.address,
rhs: rhs.address,
Expand All @@ -267,13 +261,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString>(
[BrilligVariable::SingleAddr(output), BrilligVariable::SingleAddr(modulus_id)],
) = (function_arguments, function_results)
{
prepare_bigint_output(
brillig_context,
lhs_modulus,
rhs_modulus,
output,
modulus_id,
);
prepare_bigint_output(brillig_context, lhs_modulus, rhs_modulus, modulus_id);
brillig_context.black_box_op_instruction(BlackBoxOp::BigIntSub {
lhs: lhs.address,
rhs: rhs.address,
Expand All @@ -291,13 +279,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString>(
[BrilligVariable::SingleAddr(output), BrilligVariable::SingleAddr(modulus_id)],
) = (function_arguments, function_results)
{
prepare_bigint_output(
brillig_context,
lhs_modulus,
rhs_modulus,
output,
modulus_id,
);
prepare_bigint_output(brillig_context, lhs_modulus, rhs_modulus, modulus_id);
brillig_context.black_box_op_instruction(BlackBoxOp::BigIntMul {
lhs: lhs.address,
rhs: rhs.address,
Expand All @@ -315,13 +297,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString>(
[BrilligVariable::SingleAddr(output), BrilligVariable::SingleAddr(modulus_id)],
) = (function_arguments, function_results)
{
prepare_bigint_output(
brillig_context,
lhs_modulus,
rhs_modulus,
output,
modulus_id,
);
prepare_bigint_output(brillig_context, lhs_modulus, rhs_modulus, modulus_id);
brillig_context.black_box_op_instruction(BlackBoxOp::BigIntDiv {
lhs: lhs.address,
rhs: rhs.address,
Expand All @@ -341,8 +317,6 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString>(
{
let inputs_vector = convert_array_or_vector(brillig_context, inputs, bb_func);
let modulus_vector = convert_array_or_vector(brillig_context, modulus, bb_func);
let output_id = brillig_context.get_new_bigint_id();
brillig_context.const_instruction(*output, F::from(output_id as u128));
brillig_context.black_box_op_instruction(BlackBoxOp::BigIntFromLeBytes {
inputs: inputs_vector.to_heap_vector(),
modulus: modulus_vector.to_heap_vector(),
Expand Down Expand Up @@ -447,7 +421,6 @@ fn prepare_bigint_output<F: AcirField + DebugToString>(
brillig_context: &mut BrilligContext<F>,
lhs_modulus: &SingleAddrVariable,
rhs_modulus: &SingleAddrVariable,
output: &SingleAddrVariable,
modulus_id: &SingleAddrVariable,
) {
// Check moduli
Expand All @@ -464,8 +437,6 @@ fn prepare_bigint_output<F: AcirField + DebugToString>(
Some("moduli should be identical in BigInt operation".to_string()),
);
brillig_context.deallocate_register(condition);
// Set output id
let output_id = brillig_context.get_new_bigint_id();
brillig_context.const_instruction(*output, F::from(output_id as u128));

brillig_context.mov_instruction(modulus_id.address, lhs_modulus.address);
}
8 changes: 0 additions & 8 deletions compiler/noirc_evaluator/src/brillig/brillig_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,6 @@ pub(crate) struct BrilligContext<F> {
next_section: usize,
/// IR printer
debug_show: DebugShow,
/// Counter for generating bigint ids in unconstrained functions
bigint_new_id: u32,
}

impl<F: AcirField + DebugToString> BrilligContext<F> {
Expand All @@ -105,15 +103,9 @@ impl<F: AcirField + DebugToString> BrilligContext<F> {
section_label: 0,
next_section: 1,
debug_show: DebugShow::new(enable_debug_trace),
bigint_new_id: 0,
}
}

pub(crate) fn get_new_bigint_id(&mut self) -> u32 {
let result = self.bigint_new_id;
self.bigint_new_id += 1;
result
}
/// Adds a brillig instruction to the brillig byte code
fn push_opcode(&mut self, opcode: BrilligOpcode<F>) {
self.obj.push_opcode(opcode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ impl<F: AcirField + DebugToString> BrilligContext<F> {
section_label: 0,
next_section: 1,
debug_show: DebugShow::new(false),
bigint_new_id: 0,
};

context.codegen_entry_point(&arguments, &return_parameters);
Expand Down

0 comments on commit a8928dd

Please sign in to comment.