diff --git a/src/cir.rs b/src/cir.rs index da1fecbb..c38a387d 100644 --- a/src/cir.rs +++ b/src/cir.rs @@ -5,23 +5,9 @@ use crate::interpreter::Execute; use crate::semantics::FieldCounter; use std::collections::HashMap; -#[derive(Debug)] -pub struct CirProgram { - pub(crate) instructions: Vec, -} - -impl CirProgram { - pub fn new() -> Self { - Self { - instructions: Vec::new(), - } - } -} - -impl Default for CirProgram { - fn default() -> Self { - Self::new() - } +pub enum CirProgram { + Instructions(Vec), + Predicate(Predicate), } #[derive(Debug)] @@ -63,10 +49,27 @@ pub trait Translate { impl Translate for Expression { type Output = CirProgram; fn translate(&self) -> Self::Output { - let mut cir = CirProgram::new(); - cir_translate_helper(self, &mut cir); - cir.instructions.shrink_to_fit(); // shrink the memory - cir + let mut cir_instructions = Vec::new(); + cir_translate_helper(self, &mut cir_instructions); + cir_instructions.shrink_to_fit(); + + if cir_instructions.len() == 1 { + // Avoid unnecessary cloning + let is_predicate = matches!(&cir_instructions[0], &CirInstruction::Predicate(_)); + + if is_predicate { + // Unwrap is safe here because we know that there is one and only one instruction + if let CirInstruction::Predicate(p) = cir_instructions.pop().unwrap() { + CirProgram::Predicate(p) + } else { + unreachable!() + } + } else { + CirProgram::Instructions(cir_instructions) + } + } else { + CirProgram::Instructions(cir_instructions) + } } } @@ -76,59 +79,59 @@ impl Translate for Expression { /// * reference to translated CIR /// This function returns: /// * index of translated IR -fn cir_translate_helper(exp: &Expression, cir: &mut CirProgram) -> usize { +fn cir_translate_helper(exp: &Expression, cir_instructions: &mut Vec) -> usize { match exp { Expression::Logical(logic_exp) => match logic_exp.as_ref() { LogicalExpression::And(l, r) => { let left = match l { Expression::Logical(_logic_exp) => { - CirOperand::Index(cir_translate_helper(l, cir)) + CirOperand::Index(cir_translate_helper(l, cir_instructions)) } Expression::Predicate(p) => CirOperand::Predicate(p.clone()), }; let right = match r { Expression::Logical(_logic_exp) => { - CirOperand::Index(cir_translate_helper(r, cir)) + CirOperand::Index(cir_translate_helper(r, cir_instructions)) } Expression::Predicate(p) => CirOperand::Predicate(p.clone()), }; let and_ins = AndIns { left, right }; - cir.instructions.push(CirInstruction::AndIns(and_ins)); + cir_instructions.push(CirInstruction::AndIns(and_ins)); } LogicalExpression::Or(l, r) => { let left = match l { Expression::Logical(_logic_exp) => { - CirOperand::Index(cir_translate_helper(l, cir)) + CirOperand::Index(cir_translate_helper(l, cir_instructions)) } Expression::Predicate(p) => CirOperand::Predicate(p.clone()), }; let right = match r { Expression::Logical(_logic_exp) => { - CirOperand::Index(cir_translate_helper(r, cir)) + CirOperand::Index(cir_translate_helper(r, cir_instructions)) } Expression::Predicate(p) => CirOperand::Predicate(p.clone()), }; let or_ins = OrIns { left, right }; - cir.instructions.push(CirInstruction::OrIns(or_ins)); + cir_instructions.push(CirInstruction::OrIns(or_ins)); } LogicalExpression::Not(r) => { let right: CirOperand = match r { Expression::Logical(_logic_exp) => { - CirOperand::Index(cir_translate_helper(r, cir)) + CirOperand::Index(cir_translate_helper(r, cir_instructions)) } Expression::Predicate(p) => CirOperand::Predicate(p.clone()), }; let not_ins = NotIns { right }; - cir.instructions.push(CirInstruction::NotIns(not_ins)); + cir_instructions.push(CirInstruction::NotIns(not_ins)); } }, Expression::Predicate(p) => { - cir.instructions.push(CirInstruction::Predicate(p.clone())); + cir_instructions.push(CirInstruction::Predicate(p.clone())); } } - cir.instructions.len() - 1 + cir_instructions.len() - 1 } fn execute_helper( @@ -173,7 +176,12 @@ fn execute_helper( impl Execute for CirProgram { fn execute(&self, ctx: &mut Context, m: &mut Match) -> bool { - execute_helper(&self.instructions, self.instructions.len() - 1, ctx, m) + match self { + CirProgram::Instructions(instructions) => { + execute_helper(instructions, instructions.len() - 1, ctx, m) + } + CirProgram::Predicate(p) => p.execute(ctx, m), + } } } @@ -241,15 +249,40 @@ impl FieldCounter for CirInstruction { impl FieldCounter for CirProgram { fn add_to_counter(&self, map: &mut HashMap) { - self.instructions + match self { + CirProgram::Instructions(instructions) => { + instructions .iter() .for_each(|instruction: &CirInstruction| instruction.add_to_counter(map)); + } + CirProgram::Predicate(p) => p.add_to_counter(map), + } } fn remove_from_counter(&self, map: &mut HashMap) { - self.instructions + match self { + CirProgram::Instructions(instructions) => { + instructions .iter() .for_each(|instruction: &CirInstruction| instruction.remove_from_counter(map)); + } + CirProgram::Predicate(p) => p.remove_from_counter(map), + } + } +} + +impl FieldCounter for Predicate { + fn add_to_counter(&self, map: &mut HashMap) { + *map.entry(self.lhs.var_name.clone()).or_default() += 1; + } + + fn remove_from_counter(&self, map: &mut HashMap) { + let val = map.get_mut(&self.lhs.var_name).unwrap(); + *val -= 1; + + if *val == 0 { + assert!(map.remove(&self.lhs.var_name).is_some()); + } } }