From 1252b5fcc7ed56bb55e95745b83be6e556805397 Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Fri, 31 May 2024 12:38:01 +0100 Subject: [PATCH] feat: place return value witnesses directly after function arguments (#5142) # Description ## Problem\* Resolves #5104 ## Summary\* This PR preallocates some witnesses to hold the return values at the beginning of ACIR gen and then adds assertions to fill these witnesses with the return values. This ensures that the return values will be placed in the witness map directly after any function inputs (reasons for this being desirable are laid out in #5104) ## Additional Context ## Documentation\* Check one: - [x] No documentation needed. - [ ] Documentation included in this PR. - [ ] **[For Experimental Features]** Documentation to be submitted in a separate PR. # PR Checklist\* - [x] I have tested the changes locally. - [x] I have formatted the changes with [Prettier](https://prettier.io/) and/or `cargo fmt` on default settings. --------- Co-authored-by: jfecher --- .../src/ssa/acir_gen/acir_ir/acir_variable.rs | 13 +- .../ssa/acir_gen/acir_ir/generated_acir.rs | 8 -- .../noirc_evaluator/src/ssa/acir_gen/mod.rs | 122 +++++++++++------- 3 files changed, 75 insertions(+), 68 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs index 93200d4f841..93dd47afe68 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/acir_variable.rs @@ -255,7 +255,7 @@ impl AcirContext { } /// Converts an [`AcirVar`] to a [`Witness`] - fn var_to_witness(&mut self, var: AcirVar) -> Result { + pub(crate) fn var_to_witness(&mut self, var: AcirVar) -> Result { let expression = self.var_to_expression(var)?; let witness = if let Some(constant) = expression.to_const() { // Check if a witness has been assigned this value already, if so reuse it. @@ -1027,15 +1027,6 @@ impl AcirContext { Ok(remainder) } - /// Converts the `AcirVar` to a `Witness` if it hasn't been already, and appends it to the - /// `GeneratedAcir`'s return witnesses. - pub(crate) fn return_var(&mut self, acir_var: AcirVar) -> Result<(), InternalError> { - let return_var = self.get_or_create_witness_var(acir_var)?; - let witness = self.var_to_witness(return_var)?; - self.acir_ir.push_return_witness(witness); - Ok(()) - } - /// Constrains the `AcirVar` variable to be of type `NumericType`. pub(crate) fn range_constrain_var( &mut self, @@ -1538,9 +1529,11 @@ impl AcirContext { pub(crate) fn finish( mut self, inputs: Vec, + return_values: Vec, warnings: Vec, ) -> GeneratedAcir { self.acir_ir.input_witnesses = inputs; + self.acir_ir.return_witnesses = return_values; self.acir_ir.warnings = warnings; self.acir_ir } diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs index 6c79c0a228d..9a09e7c06ee 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/acir_ir/generated_acir.rs @@ -45,9 +45,6 @@ pub(crate) struct GeneratedAcir { opcodes: Vec>, /// All witness indices that comprise the final return value of the program - /// - /// Note: This may contain repeated indices, which is necessary for later mapping into the - /// abi's return type. pub(crate) return_witnesses: Vec, /// All witness indices which are inputs to the main function @@ -164,11 +161,6 @@ impl GeneratedAcir { fresh_witness } - - /// Adds a witness index to the program's return witnesses. - pub(crate) fn push_return_witness(&mut self, witness: Witness) { - self.return_witnesses.push(witness); - } } impl GeneratedAcir { diff --git a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs index 13677506d0b..6d7c5e570c1 100644 --- a/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs @@ -36,11 +36,7 @@ use acvm::acir::circuit::brillig::BrilligBytecode; use acvm::acir::circuit::{AssertionPayload, ErrorSelector, OpcodeLocation}; use acvm::acir::native_types::Witness; use acvm::acir::BlackBoxFunc; -use acvm::{ - acir::AcirField, - acir::{circuit::opcodes::BlockId, native_types::Expression}, - FieldElement, -}; +use acvm::{acir::circuit::opcodes::BlockId, acir::AcirField, FieldElement}; use fxhash::FxHashMap as HashMap; use im::Vector; use iter_extended::{try_vecmap, vecmap}; @@ -330,38 +326,10 @@ impl Ssa { bytecode: brillig.byte_code, }); - let runtime_types = self.functions.values().map(|function| function.runtime()); - for (acir, runtime_type) in acirs.iter_mut().zip(runtime_types) { - if matches!(runtime_type, RuntimeType::Acir(_)) { - generate_distinct_return_witnesses(acir); - } - } - Ok((acirs, brillig, self.error_selector_to_type)) } } -fn generate_distinct_return_witnesses(acir: &mut GeneratedAcir) { - // Create a witness for each return witness we have to guarantee that the return witnesses match the standard - // layout for serializing those types as if they were being passed as inputs. - // - // This is required for recursion as otherwise in situations where we cannot make use of the program's ABI - // (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're - // working with rather than following the standard ABI encoding rules. - // - // TODO: We're being conservative here by generating a new witness for every expression. - // This means that we're likely to get a number of constraints which are just renumbering witnesses. - // This can be tackled by: - // - Tracking the last assigned public input witness and only renumbering a witness if it is below this value. - // - Modifying existing constraints to rearrange their outputs so they are suitable - // - See: https://github.com/noir-lang/noir/pull/4467 - let distinct_return_witness = vecmap(acir.return_witnesses.clone(), |return_witness| { - acir.create_witness_for_expression(&Expression::from(return_witness)) - }); - - acir.return_witnesses = distinct_return_witness; -} - impl<'a> Context<'a> { fn new(shared_context: &'a mut SharedContext) -> Context<'a> { let mut acir_context = AcirContext::default(); @@ -422,6 +390,25 @@ impl<'a> Context<'a> { let dfg = &main_func.dfg; let entry_block = &dfg[main_func.entry_block()]; let input_witness = self.convert_ssa_block_params(entry_block.parameters(), dfg)?; + let num_return_witnesses = + self.get_num_return_witnesses(entry_block.unwrap_terminator(), dfg); + + // Create a witness for each return witness we have to guarantee that the return witnesses match the standard + // layout for serializing those types as if they were being passed as inputs. + // + // This is required for recursion as otherwise in situations where we cannot make use of the program's ABI + // (e.g. for `std::verify_proof` or the solidity verifier), we need extra knowledge about the program we're + // working with rather than following the standard ABI encoding rules. + // + // We allocate these witnesses now before performing ACIR gen for the rest of the program as the location of + // the function's return values can then be determined through knowledge of its ABI alone. + let return_witness_vars = + vecmap(0..num_return_witnesses, |_| self.acir_context.add_variable()); + + let return_witnesses = vecmap(&return_witness_vars, |return_var| { + let expr = self.acir_context.var_to_expression(*return_var).unwrap(); + expr.to_witness().expect("return vars should be witnesses") + }); self.data_bus = dfg.data_bus.to_owned(); let mut warnings = Vec::new(); @@ -429,8 +416,19 @@ impl<'a> Context<'a> { warnings.extend(self.convert_ssa_instruction(*instruction_id, dfg, ssa, brillig)?); } - warnings.extend(self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?); - Ok(self.acir_context.finish(input_witness, warnings)) + let (return_vars, return_warnings) = + self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?; + + // TODO: This is a naive method of assigning the return values to their witnesses as + // we're likely to get a number of constraints which are asserting one witness to be equal to another. + // + // We should search through the program and relabel these witnesses so we can remove this constraint. + for (witness_var, return_var) in return_witness_vars.iter().zip(return_vars) { + self.acir_context.assert_eq_var(*witness_var, return_var, None)?; + } + + warnings.extend(return_warnings); + Ok(self.acir_context.finish(input_witness, return_witnesses, warnings)) } fn convert_brillig_main( @@ -468,17 +466,13 @@ impl<'a> Context<'a> { )?; self.shared_context.insert_generated_brillig(main_func.id(), arguments, 0, code); - let output_vars: Vec<_> = output_values + let return_witnesses: Vec = output_values .iter() .flat_map(|value| value.clone().flatten()) - .map(|value| value.0) - .collect(); + .map(|(value, _)| self.acir_context.var_to_witness(value)) + .collect::>()?; - for acir_var in output_vars { - self.acir_context.return_var(acir_var)?; - } - - let generated_acir = self.acir_context.finish(witness_inputs, Vec::new()); + let generated_acir = self.acir_context.finish(witness_inputs, return_witnesses, Vec::new()); assert_eq!( generated_acir.opcodes().len(), @@ -1724,12 +1718,39 @@ impl<'a> Context<'a> { self.define_result(dfg, instruction, AcirValue::Var(result, typ)); } + /// Converts an SSA terminator's return values into their ACIR representations + fn get_num_return_witnesses( + &mut self, + terminator: &TerminatorInstruction, + dfg: &DataFlowGraph, + ) -> usize { + let return_values = match terminator { + TerminatorInstruction::Return { return_values, .. } => return_values, + // TODO(https://github.com/noir-lang/noir/issues/4616): Enable recursion on foldable/non-inlined ACIR functions + _ => unreachable!("ICE: Program must have a singular return"), + }; + + return_values.iter().fold(0, |acc, value_id| { + let is_databus = self + .data_bus + .return_data + .map_or(false, |return_databus| dfg[*value_id] == dfg[return_databus]); + + if is_databus { + // We do not return value for the data bus. + acc + } else { + acc + dfg.type_of_value(*value_id).flattened_size() + } + }) + } + /// Converts an SSA terminator's return values into their ACIR representations fn convert_ssa_return( &mut self, terminator: &TerminatorInstruction, dfg: &DataFlowGraph, - ) -> Result, RuntimeError> { + ) -> Result<(Vec, Vec), RuntimeError> { let (return_values, call_stack) = match terminator { TerminatorInstruction::Return { return_values, call_stack } => { (return_values, call_stack.clone()) @@ -1739,6 +1760,7 @@ impl<'a> Context<'a> { }; let mut has_constant_return = false; + let mut return_vars: Vec = Vec::new(); for value_id in return_values { let is_databus = self .data_bus @@ -1759,7 +1781,7 @@ impl<'a> Context<'a> { dfg, )?; } else { - self.acir_context.return_var(acir_var)?; + return_vars.push(acir_var); } } } @@ -1770,7 +1792,7 @@ impl<'a> Context<'a> { Vec::new() }; - Ok(warnings) + Ok((return_vars, warnings)) } /// Gets the cached `AcirVar` that was converted from the corresponding `ValueId`. If it does @@ -3079,8 +3101,8 @@ mod test { check_call_opcode( &func_with_nested_call_opcodes[1], 2, - vec![Witness(2), Witness(1)], - vec![Witness(3)], + vec![Witness(3), Witness(1)], + vec![Witness(4)], ); } @@ -3100,13 +3122,13 @@ mod test { for (expected_input, input) in expected_inputs.iter().zip(inputs) { assert_eq!( expected_input, input, - "Expected witness {expected_input:?} but got {input:?}" + "Expected input witness {expected_input:?} but got {input:?}" ); } for (expected_output, output) in expected_outputs.iter().zip(outputs) { assert_eq!( expected_output, output, - "Expected witness {expected_output:?} but got {output:?}" + "Expected output witness {expected_output:?} but got {output:?}" ); } }