diff --git a/compiler/noirc_driver/src/lib.rs b/compiler/noirc_driver/src/lib.rs index a7e7e2d4e2..2646b13a33 100644 --- a/compiler/noirc_driver/src/lib.rs +++ b/compiler/noirc_driver/src/lib.rs @@ -141,6 +141,10 @@ pub struct CompileOptions { #[arg(long)] pub skip_brillig_constraints_check: bool, + /// Flag to turn off preprocessing functions during SSA passes. + #[arg(long)] + pub skip_preprocess_fns: bool, + /// Setting to decide on an inlining strategy for Brillig functions. /// A more aggressive inliner should generate larger programs but more optimized /// A less aggressive inliner should generate smaller programs @@ -679,6 +683,7 @@ pub fn compile_no_check( emit_ssa: if options.emit_ssa { Some(context.package_build_path.clone()) } else { None }, skip_underconstrained_check: options.skip_underconstrained_check, skip_brillig_constraints_check: options.skip_brillig_constraints_check, + skip_preprocess_fns: options.skip_preprocess_fns, inliner_aggressiveness: options.inliner_aggressiveness, max_bytecode_increase_percent: options.max_bytecode_increase_percent, }; diff --git a/compiler/noirc_evaluator/src/ssa.rs b/compiler/noirc_evaluator/src/ssa.rs index ed515bbe98..dcf401ab58 100644 --- a/compiler/noirc_evaluator/src/ssa.rs +++ b/compiler/noirc_evaluator/src/ssa.rs @@ -71,6 +71,9 @@ pub struct SsaEvaluatorOptions { /// Skip the missing Brillig call constraints check pub skip_brillig_constraints_check: bool, + /// Skip preprocessing functions. + pub skip_preprocess_fns: bool, + /// The higher the value, the more inlined Brillig functions will be. pub inliner_aggressiveness: i64, @@ -150,15 +153,24 @@ pub(crate) fn optimize_into_acir( /// Run all SSA passes. fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result { Ok(builder - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions") + .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (1st)") .run_pass(Ssa::defunctionalize, "Defunctionalization") .run_pass(Ssa::remove_paired_rc, "Removing Paired rc_inc & rc_decs") + .run_pass( + |ssa| { + if options.skip_preprocess_fns { + return ssa; + } + ssa.preprocess_functions(options.inliner_aggressiveness) + }, + "Preprocessing Functions", + ) .run_pass(|ssa| ssa.inline_functions(options.inliner_aggressiveness), "Inlining (1st)") // Run mem2reg with the CFG separated into blocks .run_pass(Ssa::mem2reg, "Mem2Reg (1st)") .run_pass(Ssa::simplify_cfg, "Simplifying (1st)") .run_pass(Ssa::as_slice_optimization, "`as_slice` optimization") - .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions") + .run_pass(Ssa::remove_unreachable_functions, "Removing Unreachable Functions (2nd)") .try_run_pass( Ssa::evaluate_static_assert_and_assert_constant, "`static_assert` and `assert_constant`", @@ -188,7 +200,7 @@ fn optimize_all(builder: SsaBuilder, options: &SsaEvaluatorOptions) -> Result RuntimeType::Acir(InlineType::InlineAlways), RuntimeType::Brillig(_) => RuntimeType::Brillig(InlineType::InlineAlways), diff --git a/compiler/noirc_evaluator/src/ssa/opt/hint.rs b/compiler/noirc_evaluator/src/ssa/opt/hint.rs index 1326c2cc01..3f913614e7 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/hint.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/hint.rs @@ -19,6 +19,7 @@ mod tests { emit_ssa: None, skip_underconstrained_check: true, skip_brillig_constraints_check: true, + skip_preprocess_fns: true, inliner_aggressiveness: 0, max_bytecode_increase_percent: None, }; diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 7554ad64a9..d7066371d5 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -2,9 +2,10 @@ //! The purpose of this pass is to inline the instructions of each function call //! within the function caller. If all function calls are known, there will only //! be a single function remaining when the pass finishes. -use std::collections::{BTreeSet, HashSet, VecDeque}; +use std::collections::{BTreeMap, BTreeSet, HashSet, VecDeque}; use acvm::acir::AcirField; +use im::HashMap; use iter_extended::{btree_map, vecmap}; use crate::ssa::{ @@ -19,7 +20,6 @@ use crate::ssa::{ }, ssa_gen::Ssa, }; -use fxhash::FxHashMap as HashMap; /// An arbitrary limit to the maximum number of recursive call /// frames at any point in time. @@ -46,59 +46,85 @@ impl Ssa { /// This step should run after runtime separation, since it relies on the runtime of the called functions being final. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn inline_functions(self, aggressiveness: i64) -> Ssa { - let inline_sources = get_functions_to_inline_into(&self, false, aggressiveness); - Self::inline_functions_inner(self, &inline_sources, false) + let inline_infos = compute_inline_infos(&self, false, aggressiveness); + Self::inline_functions_inner(self, &inline_infos, false) } - // Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points + /// Run the inlining pass where functions marked with `InlineType::NoPredicates` as not entry points pub(crate) fn inline_functions_with_no_predicates(self, aggressiveness: i64) -> Ssa { - let inline_sources = get_functions_to_inline_into(&self, true, aggressiveness); - Self::inline_functions_inner(self, &inline_sources, true) + let inline_infos = compute_inline_infos(&self, true, aggressiveness); + Self::inline_functions_inner(self, &inline_infos, true) } fn inline_functions_inner( mut self, - inline_sources: &BTreeSet, + inline_infos: &InlineInfos, inline_no_predicates_functions: bool, ) -> Ssa { - // Note that we clear all functions other than those in `inline_sources`. - // If we decide to do partial inlining then we should change this to preserve those functions which still exist. - self.functions = btree_map(inline_sources, |entry_point| { - let should_inline_call = - |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { - let callee = &ssa.functions[&called_func_id]; - let caller_runtime = ssa.functions[entry_point].runtime(); - - match callee.runtime() { - RuntimeType::Acir(inline_type) => { - // If the called function is acir, we inline if it's not an entry point - - // If we have not already finished the flattening pass, functions marked - // to not have predicates should be preserved. - let preserve_function = - !inline_no_predicates_functions && callee.is_no_predicates(); - !inline_type.is_entry_point() && !preserve_function - } - RuntimeType::Brillig(_) => { - if caller_runtime.is_acir() { - // We never inline a brillig function into an ACIR function. - return false; - } - - // Avoid inlining recursive functions. - !inline_sources.contains(&called_func_id) - } - } - }; + let inline_targets = + inline_infos.iter().filter_map(|(id, info)| info.is_inline_target().then_some(*id)); + // NOTE: Functions are processed independently of each other, with the final mapping replacing the original, + // instead of inlining the "leaf" functions, moving up towards the entry point. + self.functions = btree_map(inline_targets, |entry_point| { + let function = &self.functions[&entry_point]; let new_function = - InlineContext::new(&self, *entry_point).inline_all(&self, &should_inline_call); - (*entry_point, new_function) + function.inlined(&self, inline_no_predicates_functions, inline_infos); + (entry_point, new_function) }); self } } +impl Function { + /// Create a new function which has the functions called by this one inlined into its body. + pub(super) fn inlined( + &self, + ssa: &Ssa, + inline_no_predicates_functions: bool, + inline_infos: &InlineInfos, + ) -> Function { + let caller_runtime = self.runtime(); + + let should_inline_call = + |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { + // Do not inline self-recursive functions on the top level. + // Inlining a self-recursive function works when there is something to inline into + // by importing all the recursive blocks, but for the entry function there is no wrapper. + if called_func_id == self.id() { + return false; + } + let callee = &ssa.functions[&called_func_id]; + + match callee.runtime() { + RuntimeType::Acir(inline_type) => { + // If the called function is acir, we inline if it's not an entry point + + // If we have not already finished the flattening pass, functions marked + // to not have predicates should be preserved. + let preserve_function = + !inline_no_predicates_functions && callee.is_no_predicates(); + + !inline_type.is_entry_point() && !preserve_function + } + RuntimeType::Brillig(_) => { + if caller_runtime.is_acir() { + // We never inline a brillig function into an ACIR function. + return false; + } + // We inline inline if the function called wasn't ruled out as too costly or recursive. + inline_infos + .get(&called_func_id) + .map(|info| info.should_inline) + .unwrap_or_default() + } + } + }; + + InlineContext::new(ssa, self.id()).inline_all(ssa, &should_inline_call) + } +} + /// The context for the function inlining pass. /// /// This works using an internal FunctionBuilder to build a new main function from scratch. @@ -148,6 +174,8 @@ struct PerFunctionContext<'function> { } /// Utility function to find out the direct calls of a function. +/// +/// Returns the function IDs from all `Call` instructions without deduplication. fn called_functions_vec(func: &Function) -> Vec { let mut called_function_ids = Vec::new(); for block_id in func.reachable_blocks() { @@ -165,32 +193,61 @@ fn called_functions_vec(func: &Function) -> Vec { called_function_ids } -/// Utility function to find out the deduplicated direct calls of a function. +/// Utility function to find out the deduplicated direct calls made from a function. fn called_functions(func: &Function) -> BTreeSet { called_functions_vec(func).into_iter().collect() } +/// Information about a function to aid the decision about whether to inline it or not. +/// The final decision depends on what we're inlining it into. +#[derive(Default, Debug)] +pub(super) struct InlineInfo { + is_brillig_entry_point: bool, + is_acir_entry_point: bool, + is_recursive: bool, + should_inline: bool, + weight: i64, + cost: i64, +} + +impl InlineInfo { + /// Functions which are to be retained, not inlined. + pub(super) fn is_inline_target(&self) -> bool { + self.is_brillig_entry_point + || self.is_acir_entry_point + || self.is_recursive + || !self.should_inline + } +} + +type InlineInfos = BTreeMap; + /// The functions we should inline into (and that should be left in the final program) are: /// - main /// - Any Brillig function called from Acir /// - Some Brillig functions depending on aggressiveness and some metrics /// - Any Acir functions with a [fold inline type][InlineType::Fold], -fn get_functions_to_inline_into( +/// +/// The returned `InlineInfos` won't have every function in it, only the ones which the algorithm visited. +pub(super) fn compute_inline_infos( ssa: &Ssa, inline_no_predicates_functions: bool, aggressiveness: i64, -) -> BTreeSet { - let mut brillig_entry_points = BTreeSet::default(); - let mut acir_entry_points = BTreeSet::default(); - - if matches!(ssa.main().runtime(), RuntimeType::Brillig(_)) { - brillig_entry_points.insert(ssa.main_id); - } else { - acir_entry_points.insert(ssa.main_id); - } +) -> InlineInfos { + let mut inline_infos = InlineInfos::default(); + + inline_infos.insert( + ssa.main_id, + InlineInfo { + is_acir_entry_point: ssa.main().runtime().is_acir(), + is_brillig_entry_point: ssa.main().runtime().is_brillig(), + ..Default::default() + }, + ); + // Handle ACIR functions. for (func_id, function) in ssa.functions.iter() { - if matches!(function.runtime(), RuntimeType::Brillig(_)) { + if function.runtime().is_brillig() { continue; } @@ -198,83 +255,197 @@ fn get_functions_to_inline_into( // to not have predicates should be preserved. let preserve_function = !inline_no_predicates_functions && function.is_no_predicates(); if function.runtime().is_entry_point() || preserve_function { - acir_entry_points.insert(*func_id); + inline_infos.entry(*func_id).or_default().is_acir_entry_point = true; } - for called_function_id in called_functions(function) { - if matches!(ssa.functions[&called_function_id].runtime(), RuntimeType::Brillig(_)) { - brillig_entry_points.insert(called_function_id); + // Any Brillig function called from ACIR is an entry into the Brillig VM. + for called_func_id in called_functions(function) { + if ssa.functions[&called_func_id].runtime().is_brillig() { + inline_infos.entry(called_func_id).or_default().is_brillig_entry_point = true; } } } - let times_called = compute_times_called(ssa); + let callers = compute_callers(ssa); + let times_called = compute_times_called(&callers); - let brillig_functions_to_retain: BTreeSet<_> = compute_functions_to_retain( + mark_brillig_functions_to_retain( ssa, - &brillig_entry_points, - ×_called, inline_no_predicates_functions, aggressiveness, + ×_called, + &mut inline_infos, ); - acir_entry_points - .into_iter() - .chain(brillig_entry_points) - .chain(brillig_functions_to_retain) + inline_infos +} + +/// Compute the time each function is called from any other function. +fn compute_times_called( + callers: &BTreeMap>, +) -> HashMap { + callers + .iter() + .map(|(callee, callers)| { + let total_calls = callers.values().sum(); + (*callee, total_calls) + }) .collect() } -fn compute_times_called(ssa: &Ssa) -> HashMap { +/// Compute for each function the set of functions that call it, and how many times they do so. +fn compute_callers(ssa: &Ssa) -> BTreeMap> { ssa.functions .iter() - .flat_map(|(_caller_id, function)| { - let called_functions_vec = called_functions_vec(function); - called_functions_vec.into_iter() + .flat_map(|(caller_id, function)| { + let called_functions = called_functions_vec(function); + called_functions.into_iter().map(|callee_id| (*caller_id, callee_id)) }) - .chain(std::iter::once(ssa.main_id)) - .fold(HashMap::default(), |mut map, func_id| { - *map.entry(func_id).or_insert(0) += 1; - map + .fold( + // Make sure an entry exists even for ones that don't get called. + ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(), + |mut acc, (caller_id, callee_id)| { + let callers = acc.entry(callee_id).or_default(); + *callers.entry(caller_id).or_default() += 1; + acc + }, + ) +} + +/// Compute for each function the set of functions called by it, and how many times it does so. +fn compute_callees(ssa: &Ssa) -> BTreeMap> { + ssa.functions + .iter() + .flat_map(|(caller_id, function)| { + let called_functions = called_functions_vec(function); + called_functions.into_iter().map(|callee_id| (*caller_id, callee_id)) }) + .fold( + // Make sure an entry exists even for ones that don't call anything. + ssa.functions.keys().map(|id| (*id, BTreeMap::new())).collect(), + |mut acc, (caller_id, callee_id)| { + let callees = acc.entry(caller_id).or_default(); + *callees.entry(callee_id).or_default() += 1; + acc + }, + ) } -fn should_retain_recursive( +/// Compute something like a topological order of the functions, starting with the ones +/// that do not call any other functions, going towards the entry points. When cycles +/// are detected, take the one which are called by the most to break the ties. +/// +/// This can be used to simplify the most often called functions first. +/// +/// Returns the functions paired with their transitive weight, which accumulates +/// the weight of all the functions they call. +pub(super) fn compute_bottom_up_order(ssa: &Ssa) -> Vec<(FunctionId, usize)> { + let mut order = Vec::new(); + let mut visited = HashSet::new(); + + // Call graph which we'll repeatedly prune to find the "leaves". + let mut callees = compute_callees(ssa); + let callers = compute_callers(ssa); + + // Number of times a function is called, to break cycles in the call graph. + let mut times_called = compute_times_called(&callers).into_iter().collect::>(); + // Sort by number of calls ascending, so popping yields the next most called; break ties by ID. + times_called.sort_by_key(|(id, cnt)| (*cnt, *id)); + + // Start with the weight of the functions in isolation, then accumulate as we pop off the ones they call. + let mut weights = ssa + .functions + .iter() + .map(|(id, f)| (*id, compute_function_own_weight(f))) + .collect::>(); + + // Seed the queue with functions that don't call anything. + let mut queue = callees + .iter() + .filter_map(|(id, callees)| callees.is_empty().then_some(*id)) + .collect::>(); + + loop { + while let Some(id) = queue.pop_front() { + let weight = weights[&id]; + order.push((id, weight)); + visited.insert(id); + // Update the callers of this function. + for (caller, call_count) in &callers[&id] { + // Update the weight of the caller with the weight of this function. + weights[caller] = weights[caller].saturating_add(call_count.saturating_mul(weight)); + // Remove this function from the callees of the caller. + let callees = callees.get_mut(caller).unwrap(); + callees.remove(&id); + // If the caller doesn't call any other function, enqueue it. + if callees.is_empty() && !visited.contains(caller) { + queue.push_back(*caller); + } + } + } + // If we ran out of the queue, maybe there is a cycle; take the next most called function. + while let Some((id, _)) = times_called.pop() { + if !visited.contains(&id) { + queue.push_back(id); + break; + } + } + if times_called.is_empty() && queue.is_empty() { + assert_eq!(order.len(), callers.len()); + return order; + } + } +} + +/// Traverse the call graph starting from a given function, marking function to be retained if they are: +/// * recursive functions, or +/// * the cost of inlining outweighs the cost of not doing so +fn mark_functions_to_retain_recursive( ssa: &Ssa, - func: FunctionId, - times_called: &HashMap, - should_retain_function: &mut HashMap, - mut explored_functions: im::HashSet, inline_no_predicates_functions: bool, aggressiveness: i64, + times_called: &HashMap, + inline_infos: &mut InlineInfos, + mut explored_functions: im::HashSet, + func: FunctionId, ) { - // We have already decided on this function - if should_retain_function.get(&func).is_some() { + // Check if we have set any of the fields this method touches. + let decided = |inline_infos: &InlineInfos| { + inline_infos + .get(&func) + .map(|info| info.is_recursive || info.should_inline || info.weight != 0) + .unwrap_or_default() + }; + + // Check if we have already decided on this function + if decided(inline_infos) { return; } - // Recursive, this function won't be inlined + + // If recursive, this function won't be inlined if explored_functions.contains(&func) { - should_retain_function.insert(func, (true, 0)); + inline_infos.entry(func).or_default().is_recursive = true; return; } explored_functions.insert(func); - // Decide on dependencies first - let called_functions = called_functions(&ssa.functions[&func]); - for function in called_functions.iter() { - should_retain_recursive( + // Decide on dependencies first, so we know their weight. + let called_functions = called_functions_vec(&ssa.functions[&func]); + for callee in &called_functions { + mark_functions_to_retain_recursive( ssa, - *function, - times_called, - should_retain_function, - explored_functions.clone(), inline_no_predicates_functions, aggressiveness, + times_called, + inline_infos, + explored_functions.clone(), + *callee, ); } + // We could have decided on this function while deciding on dependencies - // If the function is recursive - if should_retain_function.get(&func).is_some() { + // if the function is recursive. + if decided(inline_infos) { return; } @@ -282,13 +453,18 @@ fn should_retain_recursive( // We compute the weight (roughly the number of instructions) of the function after inlining // And the interface cost of the function (the inherent cost at the callsite, roughly the number of args and returns) // We then can compute an approximation of the cost of inlining vs the cost of retaining the function - // We do this computation using saturating i64s to avoid overflows - let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, called_function| { - let (should_retain, weight) = should_retain_function[called_function]; - if should_retain { - acc + // We do this computation using saturating i64s to avoid overflows, + // and because we want to calculate a difference which can be negative. + + // Total weight of functions called by this one, unless we decided not to inline them. + // Callees which appear multiple times would be inlined multiple times. + let inlined_function_weights: i64 = called_functions.iter().fold(0, |acc, callee| { + let info = &inline_infos[callee]; + // If the callee is not going to be inlined then we can ignore its cost. + if info.should_inline { + acc.saturating_add(info.weight) } else { - acc.saturating_add(weight) + acc } }); @@ -301,54 +477,50 @@ fn should_retain_recursive( let inline_cost = times_called.saturating_mul(this_function_weight); let retain_cost = times_called.saturating_mul(interface_cost) + this_function_weight; + let net_cost = inline_cost.saturating_sub(retain_cost); let runtime = ssa.functions[&func].runtime(); // We inline if the aggressiveness is higher than inline cost minus the retain cost // If aggressiveness is infinite, we'll always inline // If aggressiveness is 0, we'll inline when the inline cost is lower than the retain cost // If aggressiveness is minus infinity, we'll never inline (other than in the mandatory cases) - let should_inline = ((inline_cost.saturating_sub(retain_cost)) < aggressiveness) + let should_inline = (net_cost < aggressiveness) || runtime.is_inline_always() || (runtime.is_no_predicates() && inline_no_predicates_functions); - should_retain_function.insert(func, (!should_inline, this_function_weight)); + let info = inline_infos.entry(func).or_default(); + info.should_inline = should_inline; + info.weight = this_function_weight; + info.cost = net_cost; } -fn compute_functions_to_retain( +/// Mark Brillig functions that should not be inlined because they are recursive or expensive. +fn mark_brillig_functions_to_retain( ssa: &Ssa, - entry_points: &BTreeSet, - times_called: &HashMap, inline_no_predicates_functions: bool, aggressiveness: i64, -) -> BTreeSet { - let mut should_retain_function = HashMap::default(); + times_called: &HashMap, + inline_infos: &mut BTreeMap, +) { + let brillig_entry_points = inline_infos + .iter() + .filter_map(|(id, info)| info.is_brillig_entry_point.then_some(*id)) + .collect::>(); - for entry_point in entry_points.iter() { - should_retain_recursive( + for entry_point in brillig_entry_points { + mark_functions_to_retain_recursive( ssa, - *entry_point, - times_called, - &mut should_retain_function, - im::HashSet::default(), inline_no_predicates_functions, aggressiveness, + times_called, + inline_infos, + im::HashSet::default(), + entry_point, ); } - - should_retain_function - .into_iter() - .filter_map( - |(func_id, (should_retain, _))| { - if should_retain { - Some(func_id) - } else { - None - } - }, - ) - .collect() } +/// Compute a weight of a function based on the number of instructions in its reachable blocks. fn compute_function_own_weight(func: &Function) -> usize { let mut weight = 0; for block_id in func.reachable_blocks() { @@ -359,6 +531,7 @@ fn compute_function_own_weight(func: &Function) -> usize { weight } +/// Compute interface cost of a function based on the number of inputs and outputs. fn compute_function_interface_cost(func: &Function) -> usize { func.parameters().len() + func.returns().len() } @@ -433,7 +606,7 @@ impl InlineContext { if self.recursion_level > RECURSION_LIMIT { panic!( - "Attempted to recur more than {RECURSION_LIMIT} times during inlining function '{}': {}", source_function.name(), source_function + "Attempted to recur more than {RECURSION_LIMIT} times during inlining function '{}':\n{}", source_function.name(), source_function ); } @@ -899,6 +1072,7 @@ mod test { map::Id, types::{NumericType, Type}, }, + Ssa, }; #[test] @@ -1171,26 +1345,25 @@ mod test { #[test] #[should_panic( - expected = "Attempted to recur more than 1000 times during inlining function 'main': acir(inline) fn main f0 {" + expected = "Attempted to recur more than 1000 times during inlining function 'foo':\nacir(inline) fn foo f1 {" )] fn unconditional_recursion() { - // fn main f1 { - // b0(): - // call f1() - // return - // } - let main_id = Id::test_new(0); - let mut builder = FunctionBuilder::new("main".into(), main_id); - - let main = builder.import_function(main_id); - let results = builder.insert_call(main, Vec::new(), vec![]).to_vec(); - builder.terminate_with_return(results); - - let ssa = builder.finish(); - assert_eq!(ssa.functions.len(), 1); + let src = " + acir(inline) fn main f0 { + b0(): + call f1() + return + } + acir(inline) fn foo f1 { + b0(): + call f1() + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + assert_eq!(ssa.functions.len(), 2); - let inlined = ssa.inline_functions(i64::MAX); - assert_eq!(inlined.functions.len(), 0); + let _ = ssa.inline_functions(i64::MAX); } #[test] diff --git a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs index 125cf3a12c..224916c95e 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/loop_invariant.rs @@ -36,7 +36,7 @@ impl Ssa { } impl Function { - fn loop_invariant_code_motion(&mut self) { + pub(super) fn loop_invariant_code_motion(&mut self) { Loops::find_all(self).hoist_loop_invariants(self); } } diff --git a/compiler/noirc_evaluator/src/ssa/opt/mod.rs b/compiler/noirc_evaluator/src/ssa/opt/mod.rs index f97d36f084..44796e2531 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/mod.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/mod.rs @@ -17,6 +17,7 @@ mod loop_invariant; mod make_constrain_not_equal; mod mem2reg; mod normalize_value_ids; +mod preprocess_fns; mod rc; mod remove_bit_shifts; mod remove_enable_side_effects; diff --git a/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs new file mode 100644 index 0000000000..5671844fcf --- /dev/null +++ b/compiler/noirc_evaluator/src/ssa/opt/preprocess_fns.rs @@ -0,0 +1,51 @@ +//! Pre-process functions before inlining them into others. + +use crate::ssa::Ssa; + +use super::inlining; + +impl Ssa { + /// Run pre-processing steps on functions in isolation. + pub(crate) fn preprocess_functions(mut self, aggressiveness: i64) -> Ssa { + // Bottom-up order, starting with the "leaf" functions, so we inline already optimized code into the ones that call them. + let bottom_up = inlining::compute_bottom_up_order(&self); + + // As a heuristic to avoid optimizing functions near the entry point, find a cutoff weight. + let total_weight = bottom_up.iter().fold(0usize, |acc, (_, w)| acc.saturating_add(*w)); + let mean_weight = total_weight / bottom_up.len(); + let cutoff_weight = mean_weight; + + // Preliminary inlining decisions. + // Functions which are inline targets will be processed in later passes. + // Here we want to treat the functions which will be inlined into them. + let inline_infos = inlining::compute_inline_infos(&self, false, aggressiveness); + + for (id, _) in bottom_up + .into_iter() + .filter(|(id, _)| { + inline_infos.get(id).map(|info| !info.is_inline_target()).unwrap_or(true) + }) + .filter(|(_, weight)| *weight < cutoff_weight) + { + let function = &self.functions[&id]; + let mut function = function.inlined(&self, false, &inline_infos); + // Help unrolling determine bounds. + function.as_slice_optimization(); + // Prepare for unrolling + function.loop_invariant_code_motion(); + // We might not be able to unroll all loops without fully inlining them, so ignore errors. + let _ = function.unroll_loops_iteratively(); + // Reduce the number of redundant stores/loads after unrolling + function.mem2reg(); + // Try to reduce the number of blocks. + function.simplify_function(); + // Remove leftover instructions. + function.dead_instruction_elimination(true, false); + + // Put it back into the SSA, so the next functions can pick it up. + self.functions.insert(id, function); + } + + self + } +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs index cae7735b2c..a6e5c96d63 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/unrolling.rs @@ -57,46 +57,40 @@ impl Ssa { /// fewer SSA instructions, but that can still result in more Brillig opcodes. #[tracing::instrument(level = "trace", skip(self))] pub(crate) fn unroll_loops_iteratively( - mut self: Ssa, + mut self, max_bytecode_increase_percent: Option, ) -> Result { - for (_, function) in self.functions.iter_mut() { - // Take a snapshot of the function to compare byte size increase, - // but only if the setting indicates we have to, otherwise skip it. - let orig_func_and_max_incr_pct = max_bytecode_increase_percent - .filter(|_| function.runtime().is_brillig()) - .map(|max_incr_pct| (function.clone(), max_incr_pct)); - - // Try to unroll loops first: - let (mut has_unrolled, mut unroll_errors) = function.try_unroll_loops(); - - // Keep unrolling until no more errors are found - while !unroll_errors.is_empty() { - let prev_unroll_err_count = unroll_errors.len(); - - // Simplify the SSA before retrying - simplify_between_unrolls(function); - - // Unroll again - let (new_unrolled, new_errors) = function.try_unroll_loops(); - unroll_errors = new_errors; - has_unrolled |= new_unrolled; - - // If we didn't manage to unroll any more loops, exit - if unroll_errors.len() >= prev_unroll_err_count { - return Err(unroll_errors.swap_remove(0)); - } - } - - if has_unrolled { - if let Some((orig_function, max_incr_pct)) = orig_func_and_max_incr_pct { - // DIE is run at the end of our SSA optimizations, so we mark all globals as in use here. - let used_globals = &self.globals.dfg.values_iter().map(|(id, _)| id).collect(); - let (_, brillig_globals) = - convert_ssa_globals(false, &self.globals, used_globals); + let mut global_cache = None; + + for function in self.functions.values_mut() { + let is_brillig = function.runtime().is_brillig(); + + // Take a snapshot in case we have to restore it. + let orig_function = + (max_bytecode_increase_percent.is_some() && is_brillig).then(|| function.clone()); + + // We must be able to unroll ACIR loops at this point, so exit on failure to unroll. + let has_unrolled = function.unroll_loops_iteratively()?; + + // Check if the size increase is acceptable + // This is here now instead of in `Function::unroll_loops_iteratively` because we'd need + // more finessing to convince the borrow checker that it's okay to share a read-only reference + // to the globals and a mutable reference to the function at the same time, both part of the `Ssa`. + if has_unrolled && is_brillig { + if let Some(max_incr_pct) = max_bytecode_increase_percent { + if global_cache.is_none() { + // DIE is run at the end of our SSA optimizations, so we mark all globals as in use here. + let used_globals = + &self.globals.dfg.values_iter().map(|(id, _)| id).collect(); + let (_, brillig_globals) = + convert_ssa_globals(false, &self.globals, used_globals); + global_cache = Some(brillig_globals); + } + let brillig_globals = global_cache.as_ref().unwrap(); - let new_size = brillig_bytecode_size(function, &brillig_globals); - let orig_size = brillig_bytecode_size(&orig_function, &brillig_globals); + let orig_function = orig_function.expect("took snapshot to compare"); + let new_size = brillig_bytecode_size(function, brillig_globals); + let orig_size = brillig_bytecode_size(&orig_function, brillig_globals); if !is_new_size_ok(orig_size, new_size, max_incr_pct) { *function = orig_function; } @@ -108,6 +102,38 @@ impl Ssa { } impl Function { + /// Try to unroll loops in the function. + /// + /// Returns an `Err` if it cannot be done, for example because the loop bounds + /// cannot be determined at compile time. This can happen during pre-processing, + /// but it should still leave the function in a partially unrolled, but valid state. + /// + /// If successful, returns a flag indicating whether any loops have been unrolled. + pub(super) fn unroll_loops_iteratively(&mut self) -> Result { + // Try to unroll loops first: + let (mut has_unrolled, mut unroll_errors) = self.try_unroll_loops(); + + // Keep unrolling until no more errors are found + while !unroll_errors.is_empty() { + let prev_unroll_err_count = unroll_errors.len(); + + // Simplify the SSA before retrying + simplify_between_unrolls(self); + + // Unroll again + let (new_unrolled, new_errors) = self.try_unroll_loops(); + unroll_errors = new_errors; + has_unrolled |= new_unrolled; + + // If we didn't manage to unroll any more loops, exit + if unroll_errors.len() >= prev_unroll_err_count { + return Err(unroll_errors.swap_remove(0)); + } + } + + Ok(has_unrolled) + } + // Loop unrolling in brillig can lead to a code explosion currently. // This can also be true for ACIR, but we have no alternative to unrolling in ACIR. // Brillig also generally prefers smaller code rather than faster code, diff --git a/compiler/noirc_frontend/src/tests.rs b/compiler/noirc_frontend/src/tests.rs index 29ba7b0fab..087e34fcc6 100644 --- a/compiler/noirc_frontend/src/tests.rs +++ b/compiler/noirc_frontend/src/tests.rs @@ -3457,6 +3457,11 @@ fn arithmetic_generics_rounding_fail_on_struct() { #[test] fn unconditional_recursion_fail() { + // These examples are self recursive top level functions, which would actually + // not be inlined in the SSA (there is nothing to inline into but self), so it + // wouldn't panic due to infinite recursion, but the errors asserted here + // come from the compilation checks, which does static analysis to catch the + // problem before it even has a chance to cause a panic. let srcs = vec![ r#" fn main() { diff --git a/cspell.json b/cspell.json index ed9f7427c6..a42b90d2e8 100644 --- a/cspell.json +++ b/cspell.json @@ -35,6 +35,7 @@ "bunx", "bytecount", "cachix", + "callees", "callsite", "callsites", "callstack",