Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf(cir): special treatment for single-predicate expression #286

Open
wants to merge 1 commit into
base: reapply-#246
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 72 additions & 38 deletions src/cir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,9 @@ use crate::semantics::FieldCounter;
use std::collections::HashMap;

#[derive(Debug)]
pub struct CirProgram {
pub(crate) instructions: Vec<CirInstruction>,
}

impl CirProgram {
pub fn new() -> Self {
Self {
instructions: Vec::new(),
}
}
}

impl Default for CirProgram {
fn default() -> Self {
Self::new()
}
pub enum CirProgram {
Instructions(Vec<CirInstruction>),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use boxed slice instead Vec to emphasize that instructions won't grow over time?

Predicate(Predicate),
}

#[derive(Debug)]
Expand Down Expand Up @@ -63,10 +50,27 @@ pub trait Translate {
impl Translate for Expression {
type Output = CirProgram;
fn translate(&self) -> Self::Output {
let mut cir = CirProgram::new();
cir_translate_helper(self, &mut cir);
cir.instructions.shrink_to_fit(); // shrink the memory
cir
let mut cir_instructions = Vec::new();
cir_translate_helper(self, &mut cir_instructions);
cir_instructions.shrink_to_fit();

if cir_instructions.len() == 1 {
// Avoid unnecessary cloning
let is_predicate = matches!(&cir_instructions[0], &CirInstruction::Predicate(_));

if is_predicate {
// Unwrap is safe here because we know that there is one and only one instruction
if let CirInstruction::Predicate(p) = cir_instructions.pop().unwrap() {
CirProgram::Predicate(p)
} else {
unreachable!()
}
} else {
CirProgram::Instructions(cir_instructions)
}
} else {
CirProgram::Instructions(cir_instructions)
}
}
}

Expand All @@ -76,59 +80,59 @@ impl Translate for Expression {
/// * reference to translated CIR
/// This function returns:
/// * index of translated IR
fn cir_translate_helper(exp: &Expression, cir: &mut CirProgram) -> usize {
fn cir_translate_helper(exp: &Expression, cir_instructions: &mut Vec<CirInstruction>) -> usize {
match exp {
Expression::Logical(logic_exp) => match logic_exp.as_ref() {
LogicalExpression::And(l, r) => {
let left = match l {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(l, cir))
CirOperand::Index(cir_translate_helper(l, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};

let right = match r {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(r, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
let and_ins = AndIns { left, right };
cir.instructions.push(CirInstruction::AndIns(and_ins));
cir_instructions.push(CirInstruction::AndIns(and_ins));
}
LogicalExpression::Or(l, r) => {
let left = match l {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(l, cir))
CirOperand::Index(cir_translate_helper(l, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};

let right = match r {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(r, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
let or_ins = OrIns { left, right };
cir.instructions.push(CirInstruction::OrIns(or_ins));
cir_instructions.push(CirInstruction::OrIns(or_ins));
}
LogicalExpression::Not(r) => {
let right: CirOperand = match r {
Expression::Logical(_logic_exp) => {
CirOperand::Index(cir_translate_helper(r, cir))
CirOperand::Index(cir_translate_helper(r, cir_instructions))
}
Expression::Predicate(p) => CirOperand::Predicate(p.clone()),
};
let not_ins = NotIns { right };
cir.instructions.push(CirInstruction::NotIns(not_ins));
cir_instructions.push(CirInstruction::NotIns(not_ins));
}
},
Expression::Predicate(p) => {
cir.instructions.push(CirInstruction::Predicate(p.clone()));
cir_instructions.push(CirInstruction::Predicate(p.clone()));
}
}
cir.instructions.len() - 1
cir_instructions.len() - 1
}

fn execute_helper(
Expand Down Expand Up @@ -173,7 +177,12 @@ fn execute_helper(

impl Execute for CirProgram {
fn execute(&self, ctx: &mut Context, m: &mut Match) -> bool {
execute_helper(&self.instructions, self.instructions.len() - 1, ctx, m)
match self {
CirProgram::Instructions(instructions) => {
execute_helper(instructions, instructions.len() - 1, ctx, m)
}
CirProgram::Predicate(p) => p.execute(ctx, m),
}
}
}

Expand Down Expand Up @@ -241,15 +250,40 @@ impl FieldCounter for CirInstruction {

impl FieldCounter for CirProgram {
fn add_to_counter(&self, map: &mut HashMap<String, usize>) {
self.instructions
.iter()
.for_each(|instruction: &CirInstruction| instruction.add_to_counter(map));
match self {
CirProgram::Instructions(instructions) => {
instructions
.iter()
.for_each(|instruction: &CirInstruction| instruction.add_to_counter(map));
}
CirProgram::Predicate(p) => p.add_to_counter(map),
}
}

fn remove_from_counter(&self, map: &mut HashMap<String, usize>) {
self.instructions
.iter()
.for_each(|instruction: &CirInstruction| instruction.remove_from_counter(map));
match self {
CirProgram::Instructions(instructions) => {
instructions
.iter()
.for_each(|instruction: &CirInstruction| instruction.remove_from_counter(map));
}
CirProgram::Predicate(p) => p.remove_from_counter(map),
}
}
}

impl FieldCounter for Predicate {
fn add_to_counter(&self, map: &mut HashMap<String, usize>) {
*map.entry(self.lhs.var_name.clone()).or_default() += 1;
}

fn remove_from_counter(&self, map: &mut HashMap<String, usize>) {
let val = map.get_mut(&self.lhs.var_name).unwrap();
*val -= 1;

if *val == 0 {
assert!(map.remove(&self.lhs.var_name).is_some());
}
}
}

Expand Down
Loading