diff --git a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs index 019bace33a..fac6e52104 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/constant_folding.rs @@ -169,12 +169,17 @@ struct Context<'a> { /// Contains sets of values which are constrained to be equivalent to each other. /// - /// The mapping's structure is `side_effects_enabled_var => (constrained_value => simplified_value)`. + /// The mapping's structure is `side_effects_enabled_var => (constrained_value => [(block, simplified_value)])`. /// /// We partition the maps of constrained values according to the side-effects flag at the point /// at which the values are constrained. This prevents constraints which are only sometimes enforced /// being used to modify the rest of the program. - constraint_simplification_mappings: HashMap>, + /// + /// We also keep track of how a value was simplified to other values per block. That is, + /// a same ValueId could have been simplified to one value in one block and to another value + /// in another block. + constraint_simplification_mappings: + HashMap>>, // Cache of instructions without any side-effects along with their outputs. cached_instruction_results: InstructionResultCache, @@ -244,8 +249,15 @@ impl<'brillig> Context<'brillig> { id: InstructionId, side_effects_enabled_var: &mut ValueId, ) { - let constraint_simplification_mapping = self.get_constraint_map(*side_effects_enabled_var); - let instruction = Self::resolve_instruction(id, dfg, constraint_simplification_mapping); + let constraint_simplification_mapping = + self.constraint_simplification_mappings.get(side_effects_enabled_var); + let instruction = Self::resolve_instruction( + id, + block, + dfg, + &mut self.dom, + constraint_simplification_mapping, + ); let old_results = dfg.instruction_results(id).to_vec(); // If a copy of this instruction exists earlier in the block, then reuse the previous results. @@ -258,11 +270,13 @@ impl<'brillig> Context<'brillig> { return; } CacheResult::NeedToHoistToCommonBlock(dominator, _cached) => { - // Just change the block to insert in the common dominator instead. - // This will only move the current instance of the instruction right now. - // When constant folding is run a second time later on, it'll catch - // that the previous instance can be deduplicated to this instance. - block = dominator; + if instruction_can_be_hoisted(&instruction, dfg, self.use_constraint_info) { + // Just change the block to insert in the common dominator instead. + // This will only move the current instance of the instruction right now. + // When constant folding is run a second time later on, it'll catch + // that the previous instance can be deduplicated to this instance. + block = dominator; + } } } } @@ -307,8 +321,10 @@ impl<'brillig> Context<'brillig> { /// Fetches an [`Instruction`] by its [`InstructionId`] and fully resolves its inputs. fn resolve_instruction( instruction_id: InstructionId, + block: BasicBlockId, dfg: &DataFlowGraph, - constraint_simplification_mapping: &HashMap, + dom: &mut DominatorTree, + constraint_simplification_mapping: Option<&HashMap>>, ) -> Instruction { let instruction = dfg[instruction_id].clone(); @@ -319,19 +335,30 @@ impl<'brillig> Context<'brillig> { // constraints to the cache. fn resolve_cache( dfg: &DataFlowGraph, - cache: &HashMap, + dom: &mut DominatorTree, + cache: Option<&HashMap>>, value_id: ValueId, + block: BasicBlockId, ) -> ValueId { let resolved_id = dfg.resolve(value_id); - match cache.get(&resolved_id) { - Some(cached_value) => resolve_cache(dfg, cache, *cached_value), - None => resolved_id, + let Some(cached_values) = cache.and_then(|cache| cache.get(&resolved_id)) else { + return resolved_id; + }; + + for (cached_block, cached_value) in cached_values { + // We can only use the simplified value if it was simplified in a block that dominates the current one + if dom.dominates(*cached_block, block) { + return resolve_cache(dfg, dom, cache, *cached_value, block); + } } + + resolved_id } // Resolve any inputs to ensure that we're comparing like-for-like instructions. - instruction - .map_values(|value_id| resolve_cache(dfg, constraint_simplification_mapping, value_id)) + instruction.map_values(|value_id| { + resolve_cache(dfg, dom, constraint_simplification_mapping, value_id, block) + }) } /// Pushes a new [`Instruction`] into the [`DataFlowGraph`] which applies any optimizations @@ -377,26 +404,11 @@ impl<'brillig> Context<'brillig> { // to map from the more complex to the simpler value. if let Instruction::Constrain(lhs, rhs, _) = instruction { // These `ValueId`s should be fully resolved now. - match (&dfg[lhs], &dfg[rhs]) { - // Ignore trivial constraints - (Value::NumericConstant { .. }, Value::NumericConstant { .. }) => (), - - // Prefer replacing with constants where possible. - (Value::NumericConstant { .. }, _) => { - self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs); - } - (_, Value::NumericConstant { .. }) => { - self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs); - } - // Otherwise prefer block parameters over instruction results. - // This is as block parameters are more likely to be a single witness rather than a full expression. - (Value::Param { .. }, Value::Instruction { .. }) => { - self.get_constraint_map(side_effects_enabled_var).insert(rhs, lhs); - } - (Value::Instruction { .. }, Value::Param { .. }) => { - self.get_constraint_map(side_effects_enabled_var).insert(lhs, rhs); - } - (_, _) => (), + if let Some((complex, simple)) = simplify(dfg, lhs, rhs) { + self.get_constraint_map(side_effects_enabled_var) + .entry(complex) + .or_default() + .push((block, simple)); } } } @@ -420,7 +432,7 @@ impl<'brillig> Context<'brillig> { fn get_constraint_map( &mut self, side_effects_enabled_var: ValueId, - ) -> &mut HashMap { + ) -> &mut HashMap> { self.constraint_simplification_mappings.entry(side_effects_enabled_var).or_default() } @@ -611,6 +623,20 @@ impl<'brillig> Context<'brillig> { } } +fn instruction_can_be_hoisted( + instruction: &Instruction, + dfg: &mut DataFlowGraph, + deduplicate_with_predicate: bool, +) -> bool { + // These two can never be hoisted as they have a side-effect + // (though it's fine to de-duplicate them, just not fine to hoist them) + if matches!(instruction, Instruction::Constrain(..) | Instruction::RangeCheck { .. }) { + return false; + } + + instruction.can_be_deduplicated(dfg, deduplicate_with_predicate) +} + impl ResultCache { /// Records that an `Instruction` in block `block` produced the result values `results`. fn cache(&mut self, block: BasicBlockId, results: Vec) { @@ -687,6 +713,25 @@ fn value_id_to_calldata(value_id: ValueId, dfg: &DataFlowGraph, calldata: &mut V panic!("Expected ValueId to be numeric constant or array constant"); } +/// Check if one expression is simpler than the other. +/// Returns `Some((complex, simple))` if a simplification was found, otherwise `None`. +/// Expects the `ValueId`s to be fully resolved. +fn simplify(dfg: &DataFlowGraph, lhs: ValueId, rhs: ValueId) -> Option<(ValueId, ValueId)> { + match (&dfg[lhs], &dfg[rhs]) { + // Ignore trivial constraints + (Value::NumericConstant { .. }, Value::NumericConstant { .. }) => None, + + // Prefer replacing with constants where possible. + (Value::NumericConstant { .. }, _) => Some((rhs, lhs)), + (_, Value::NumericConstant { .. }) => Some((lhs, rhs)), + // Otherwise prefer block parameters over instruction results. + // This is as block parameters are more likely to be a single witness rather than a full expression. + (Value::Param { .. }, Value::Instruction { .. }) => Some((rhs, lhs)), + (Value::Instruction { .. }, Value::Param { .. }) => Some((lhs, rhs)), + (_, _) => None, + } +} + #[cfg(test)] mod test { use std::sync::Arc; @@ -1341,4 +1386,50 @@ mod test { let ssa = ssa.fold_constants_with_brillig(&brillig); assert_normalized_ssa_equals(ssa, expected); } + + #[test] + fn does_not_use_cached_constrain_in_block_that_is_not_dominated() { + let src = " + brillig(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = eq v0, Field 0 + jmpif v3 then: b1, else: b2 + b1(): + v5 = eq v1, Field 1 + constrain v1 == Field 1 + jmp b2() + b2(): + v6 = eq v1, Field 0 + constrain v1 == Field 0 // This `v1` here was incorrectly replaced with `Field 1` + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, src); + } + + #[test] + fn does_not_hoist_constrain_to_common_ancestor() { + let src = " + brillig(inline) fn main f0 { + b0(v0: Field, v1: Field): + v3 = eq v0, Field 0 + jmpif v3 then: b1, else: b2 + b1(): + constrain v1 == Field 1 + jmp b2() + b2(): + jmpif v0 then: b3, else: b4 + b3(): + constrain v1 == Field 1 // This was incorrectly hoisted to b0 but this condition is not valid when going b0 -> b2 -> b4 + jmp b4() + b4(): + return + } + "; + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.fold_constants_using_constraints(); + assert_normalized_ssa_equals(ssa, src); + } }