Skip to content

Commit

Permalink
perf(cir): special treatment for single-predicate expression
Browse files Browse the repository at this point in the history
  • Loading branch information
Oyami-Srk committed Dec 30, 2024
1 parent 76638f2 commit 2c89705
Showing 1 changed file with 72 additions and 39 deletions.
111 changes: 72 additions & 39 deletions src/cir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,9 @@ use crate::interpreter::Execute;
use crate::semantics::FieldCounter;
use std::collections::HashMap;

#[derive(Debug)]
pub struct CirProgram {
pub(crate) instructions: Vec<CirInstruction>,
}

impl CirProgram {
pub fn new() -> Self {
Self {
instructions: Vec::new(),
}
}
}

impl Default for CirProgram {
fn default() -> Self {
Self::new()
}
pub enum CirProgram {
Instructions(Vec<CirInstruction>),
Predicate(Predicate),
}

#[derive(Debug)]
Expand Down Expand Up @@ -63,10 +49,27 @@ pub trait Translate {
impl Translate for Expression {
type Output = CirProgram;
fn translate(&self) -> Self::Output {
let mut cir = CirProgram::new();
cir_translate_helper(self, &mut cir);
cir.instructions.shrink_to_fit(); // shrink the memory
cir
let mut cir_instructions = Vec::new();
cir_translate_helper(self, &mut cir_instructions);
cir_instructions.shrink_to_fit();

if cir_instructions.len() == 1 {
// Avoid unnecessary cloning
let is_predicate = matches!(&cir_instructions[0], &CirInstruction::Predicate(_));

if is_predicate {
// Unwrap is safe here because we know that there is one and only one instruction
if let CirInstruction::Predicate(p) = cir_instructions.pop().unwrap() {
CirProgram::Predicate(p)
} else {
unreachable!()
}
} else {
CirProgram::Instructions(cir_instructions)
}
} else {
CirProgram::Instructions(cir_instructions)
}
}
}

Expand All @@ -76,59 +79,59 @@ impl Translate for Expression {
/// * reference to translated CIR
/// This function returns:
/// * index of translated IR
fn cir_translate_helper(exp: &Expression, cir: &mut CirProgram) -> usize {
fn cir_translate_helper(exp: &Expression, cir_instructions: &mut Vec<CirInstruction>) -> usize {
match exp {
Expression::Logical(logic_exp) => match logic_exp.as_ref() {
LogicalExpression::And(l, r) => {
let left = match l {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(l, cir))
CirOperand::Index(cir_translate_helper(l, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};

let right = match r {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(r, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
let and_ins = AndIns { left, right };
cir.instructions.push(CirInstruction::AndIns(and_ins));
cir_instructions.push(CirInstruction::AndIns(and_ins));
}
LogicalExpression::Or(l, r) => {
let left = match l {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(l, cir))
CirOperand::Index(cir_translate_helper(l, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};

let right = match r {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(r, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
let or_ins = OrIns { left, right };
cir.instructions.push(CirInstruction::OrIns(or_ins));
cir_instructions.push(CirInstruction::OrIns(or_ins));
}
LogicalExpression::Not(r) => {
let right: CirOperand = match r {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(r, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
let not_ins = NotIns { right };
cir.instructions.push(CirInstruction::NotIns(not_ins));
cir_instructions.push(CirInstruction::NotIns(not_ins));
}
},
Expression::Predicate(p) => {
cir.instructions.push(CirInstruction::Predicate(p.clone()));
cir_instructions.push(CirInstruction::Predicate(p.clone()));
}
}
cir.instructions.len() - 1
cir_instructions.len() - 1
}

fn execute_helper(
Expand Down Expand Up @@ -173,7 +176,12 @@ fn execute_helper(

impl Execute for CirProgram {
fn execute(&self, ctx: &mut Context, m: &mut Match) -> bool {
execute_helper(&self.instructions, self.instructions.len() - 1, ctx, m)
match self {
CirProgram::Instructions(instructions) => {
execute_helper(instructions, instructions.len() - 1, ctx, m)
}
CirProgram::Predicate(p) => p.execute(ctx, m),
}
}
}

Expand Down Expand Up @@ -241,15 +249,40 @@ impl FieldCounter for CirInstruction {

impl FieldCounter for CirProgram {
fn add_to_counter(&self, map: &mut HashMap<String, usize>) {
self.instructions
.iter()
.for_each(|instruction: &CirInstruction| instruction.add_to_counter(map));
match self {
CirProgram::Instructions(instructions) => {
instructions
.iter()
.for_each(|instruction: &CirInstruction| instruction.add_to_counter(map));
}
CirProgram::Predicate(p) => p.add_to_counter(map),
}
}

fn remove_from_counter(&self, map: &mut HashMap<String, usize>) {
match self {
CirProgram::Instructions(instructions) => {
instructions
.iter()
.for_each(|instruction: &CirInstruction| instruction.remove_from_counter(map));
}
CirProgram::Predicate(p) => p.remove_from_counter(map),
}
}
}

impl FieldCounter for Predicate {
fn add_to_counter(&self, map: &mut HashMap<String, usize>) {
*map.entry(self.lhs.var_name.clone()).or_default() += 1;
}

fn remove_from_counter(&self, map: &mut HashMap<String, usize>) {
self.instructions
.iter()
.for_each(|instruction: &CirInstruction| instruction.remove_from_counter(map));
let val = map.get_mut(&self.lhs.var_name).unwrap();
*val -= 1;

if *val == 0 {
assert!(map.remove(&self.lhs.var_name).is_some());
}
}
}

Expand Down

0 comments on commit 2c89705

Please sign in to comment.