Skip to content

Commit

Permalink
Merge pull request #238 from HigherOrderCO/bug/sc-514/unnecessary-lam…
Browse files Browse the repository at this point in the history
…bdas-sometimes-added-in-match

[sc-514] Unnecessary lambdas sometimes added in match var lifting
  • Loading branch information
imaqtkatt authored Mar 15, 2024
2 parents 140e0b9 + 87c5cd0 commit 81d6486
Show file tree
Hide file tree
Showing 24 changed files with 106 additions and 92 deletions.
85 changes: 52 additions & 33 deletions src/term/transform/simplify_matches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ pub enum SimplifyMatchErr {
impl Ctx<'_> {
pub fn simplify_matches(&mut self) -> Result<(), Diagnostics> {
self.info.start_pass();
let name_gen = &mut 0;

for (def_name, def) in self.book.defs.iter_mut() {
let res = def.simplify_matches(&self.book.ctrs, &self.book.adts);
let res = def.simplify_matches(&self.book.ctrs, &self.book.adts, name_gen);
self.info.take_rule_err(res, def_name.clone());
}

Expand All @@ -29,9 +30,14 @@ impl Ctx<'_> {
}

impl Definition {
pub fn simplify_matches(&mut self, ctrs: &Constructors, adts: &Adts) -> Result<(), SimplifyMatchErr> {
pub fn simplify_matches(
&mut self,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<(), SimplifyMatchErr> {
for rule in self.rules.iter_mut() {
rule.body.simplify_matches(ctrs, adts)?;
rule.body.simplify_matches(ctrs, adts, name_gen)?;
}
Ok(())
}
Expand All @@ -42,21 +48,29 @@ impl Term {
/// arbitrary patterns into matches on a single value, with only
/// simple (non-nested) patterns, and one rule for each constructor.
///
/// The `name_gen` is used to generate fresh variable names for
/// substitution to avoid name clashes.
///
/// See `[simplify_match_expression]` for more information.
pub fn simplify_matches(&mut self, ctrs: &Constructors, adts: &Adts) -> Result<(), SimplifyMatchErr> {
pub fn simplify_matches(
&mut self,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<(), SimplifyMatchErr> {
Term::recursive_call(move || {
match self {
Term::Mat { args, rules } => {
let extracted = extract_args(args);
let args = std::mem::take(args);
let rules = std::mem::take(rules);
let term = simplify_match_expression(args, rules, ctrs, adts)?;
let term = simplify_match_expression(args, rules, ctrs, adts, name_gen)?;
*self = bind_extracted_args(extracted, term);
}

_ => {
for child in self.children_mut() {
child.simplify_matches(ctrs, adts)?;
child.simplify_matches(ctrs, adts, name_gen)?;
}
}
}
Expand Down Expand Up @@ -92,16 +106,17 @@ fn simplify_match_expression(
rules: Vec<Rule>,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<Term, SimplifyMatchErr> {
let fst_row_irrefutable = rules[0].pats.iter().all(|p| p.is_wildcard());
let fst_col_type = infer_match_arg_type(&rules, 0, ctrs)?;

if fst_row_irrefutable {
irrefutable_fst_row_rule(args, rules, ctrs, adts)
irrefutable_fst_row_rule(args, rules, ctrs, adts, name_gen)
} else if fst_col_type == Type::Any {
var_rule(args, rules, ctrs, adts)
var_rule(args, rules, ctrs, adts, name_gen)
} else {
switch_rule(args, rules, fst_col_type, ctrs, adts)
switch_rule(args, rules, fst_col_type, ctrs, adts, name_gen)
}
}

Expand All @@ -113,17 +128,19 @@ fn irrefutable_fst_row_rule(
mut rules: Vec<Rule>,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<Term, SimplifyMatchErr> {
rules.truncate(1);

let Rule { pats, body: mut term } = rules.pop().unwrap();
term.simplify_matches(ctrs, adts)?;
term.simplify_matches(ctrs, adts, name_gen)?;

for (pat, arg) in pats.iter().zip(args.iter()) {
for bind in pat.binds().flatten() {
term.subst(bind, arg);
}
}

let term = pats.into_iter().zip(args).fold(term, |term, (pat, arg)| Term::Let {
pat,
val: Box::new(arg),
nxt: Box::new(term),
});
Ok(term)
}

Expand All @@ -137,21 +154,25 @@ fn var_rule(
rules: Vec<Rule>,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<Term, SimplifyMatchErr> {
let mut new_rules = vec![];
for mut rule in rules {
let rest = rule.pats.split_off(1);

let body =
Term::Let { pat: rule.pats.pop().unwrap(), val: Box::new(args[0].clone()), nxt: Box::new(rule.body) };
let pat = rule.pats.pop().unwrap();
let mut body = rule.body;
if let Pattern::Var(Some(nam)) = &pat {
body.subst(nam, &args[0]);
}

let new_rule = Rule { pats: rest, body };
new_rules.push(new_rule);
}

let rest = args.split_off(1);
let mut term = Term::Mat { args: rest, rules: new_rules };
term.simplify_matches(ctrs, adts)?;
term.simplify_matches(ctrs, adts, name_gen)?;
Ok(term)
}

Expand Down Expand Up @@ -206,6 +227,7 @@ fn switch_rule(
typ: Type,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<Term, SimplifyMatchErr> {
let mut new_rules = vec![];

Expand Down Expand Up @@ -253,10 +275,10 @@ fn switch_rule(
for ctr in adt_ctrs {
// Create the matched constructor and the name of the bound variables.
let Term::Var { nam: arg_nam } = &args[0] else { unreachable!() };
let nested_fields = switch_rule_nested_fields(arg_nam, &ctr);
let nested_fields = switch_rule_nested_fields(arg_nam, &ctr, name_gen);
let matched_ctr = switch_rule_matched_ctr(ctr.clone(), &nested_fields);
let mut body = switch_rule_submatch(&args, &rules, &matched_ctr, &nested_fields)?;
body.simplify_matches(ctrs, adts)?;
body.simplify_matches(ctrs, adts, name_gen)?;
let pats = vec![matched_ctr];
new_rules.push(Rule { pats, body });
}
Expand All @@ -265,13 +287,14 @@ fn switch_rule(
Ok(term)
}

fn switch_rule_nested_fields(arg_nam: &Name, ctr: &Pattern) -> Vec<Option<Name>> {
fn switch_rule_nested_fields(arg_nam: &Name, ctr: &Pattern, name_gen: &mut usize) -> Vec<Option<Name>> {
let mut nested_fields = vec![];
let old_vars = ctr.binds();
for old_var in old_vars {
*name_gen += 1;
let new_nam = if let Some(field) = old_var {
// Name of constructor field
Name::new(format!("{arg_nam}%{field}"))
Name::new(format!("{arg_nam}%{field}%{name_gen}"))
} else {
// Name of var pattern
arg_nam.clone()
Expand Down Expand Up @@ -339,21 +362,17 @@ fn switch_rule_submatch_arm(rule: &Rule, ctr: &Pattern, nested_fields: &[Option<
let body = rule.body.clone();
Some(Rule { pats, body })
} else if rule.pats[0].is_wildcard() {
// Var, reconstruct the value matched in the expression above.
// match x ... {var ...: Body; ...}
// becomes
// match x {
// (Ctr x%field0 ...): match x1 ... {
// x%field0 ...: let var = (Ctr x%field0 ...); Body;
// ... };
// ... }
// Use `subst` to replace the pattern variable in the body
// of the rule with the term that represents the matched constructor.
let mut body = rule.body.clone();
if let Pattern::Var(Some(nam)) = &rule.pats[0] {
body.subst(nam, &ctr.to_term());
}

let nested_var_pats = nested_fields.iter().cloned().map(Pattern::Var);
let old_pats = rule.pats[1 ..].iter().cloned();
let pats = nested_var_pats.chain(old_pats).collect_vec();

let body =
Term::Let { pat: rule.pats[0].clone(), val: Box::new(ctr.to_term()), nxt: Box::new(rule.body.clone()) };

Some(Rule { pats, body })
} else {
// Non-matching constructor. Don't include in submatch expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ input_file: tests/golden_tests/compile_file_o_all/list_merge_sort.hvm
@If$C1 = (* (a a))
@Map = ({4 @Map$C0 {4 @Unpack$C3_$_MergePair$C4_$_Map$C1 a}} a)
@Map$C0 = {4 a {4 {4 @Map$C0 {4 @Unpack$C3_$_MergePair$C4_$_Map$C1 (b c)}} ({5 (a d) b} {4 {4 d {4 c e}} {4 * e}})}}
@Merge$C0 = {4 {9 a {9 b c}} {4 {7 d {4 @Merge$C0 {4 @Merge$C1 (e (f (g h)))}}} ({15 (i (a {2 @If$C0 {2 @If$C1 ({4 {4 j {4 k l}} {4 * l}} ({4 {4 c {4 h m}} {4 * m}} n))}})) {15 o e}} ({13 i {13 j f}} ({11 {4 @Merge$C2 {4 @Merge$C3 (o ({4 {4 b {4 d p}} {4 * p}} k))}} g} n)))}}
@Merge$C0 = {4 {7 a {7 b c}} {4 {9 d {4 @Merge$C0 {4 @Merge$C1 (e (f (g h)))}}} ({11 (i (a {2 @If$C0 {2 @If$C1 ({4 {4 j {4 k l}} {4 * l}} ({4 {4 c {4 h m}} {4 * m}} n))}})) {11 o e}} ({13 i {13 j f}} ({15 {4 @Merge$C2 {4 @Merge$C3 (o ({4 {4 b {4 d p}} {4 * p}} k))}} g} n)))}}
@Merge$C1 = (* @Cons)
@Merge$C2 = {4 a {4 b (c ({4 @Merge$C0 {4 @Merge$C1 (c (a (b d)))}} d))}}
@Merge$C3 = (* (a a))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/compile_file_o_all/num_pattern_with_var.hvm
---
@Foo = ({2 (* #0) {2 @Foo$C3 a}} a)
@Foo$C2 = (a (* a))
@Foo$C3 = (?<((* #0) @Foo$C2) (* a)> a)
@Foo = ({2 (* #0) {2 @Foo$C1 a}} a)
@Foo$C1 = (?<(#0 (a a)) b> b)
@main = a
& @Foo ~ (@true (#3 a))
@true = {2 * {2 a a}}
2 changes: 1 addition & 1 deletion tests/snapshots/desugar_file__used_once_names.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/used_once_names.hvm
---
(foo) = λa λb λc let {d d_2} = c; (a b (d d_2))
(foo) = λa λb λc let {c c_2} = c; (a b (c c_2))

(main) = (foo 2 3 λa a)
4 changes: 2 additions & 2 deletions tests/snapshots/encode_pattern_match__adt_tup_era.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/adt_tup_era.hvm
---
TaggedScott:
(Foo) = λa #Tuple (a #Tuple λb #Tuple λc (#Tuple (b #Tuple λd #Tuple λ* λ* d) c))
(Foo) = λa #Tuple (a #Tuple λb #Tuple λ* #Tuple (b #Tuple λd #Tuple λ* d))

(Main) = (Foo (Pair 1 5))

(Pair) = λa λb #Tuple λc #Tuple (c a b)

Scott:
(Foo) = λa (a λb λc (b λd λ* λ* d c))
(Foo) = λa (a λb λ* (b λd λ* d))

(Main) = (Foo (Pair 1 5))

Expand Down
4 changes: 2 additions & 2 deletions tests/snapshots/encode_pattern_match__concat.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/concat.hvm
---
TaggedScott:
(String.concat) = λa λb (#String (a #String λc #String λd λe (String.cons c (String.concat d e)) λh h) b)
(String.concat) = λa λb (#String (a #String λc #String λd λe (String.cons c (String.concat d e)) λf f) b)

(main) = (String.concat (String.cons 97 (String.cons 98 String.nil)) (String.cons 99 (String.cons 100 String.nil)))

Expand All @@ -12,7 +12,7 @@ TaggedScott:
(String.nil) = #String λ* #String λb b

Scott:
(String.concat) = λa λb (a λc λd λe (String.cons c (String.concat d e)) λh h b)
(String.concat) = λa λb (a λc λd λe (String.cons c (String.concat d e)) λf f b)

(main) = (String.concat (String.cons 97 (String.cons 98 String.nil)) (String.cons 99 (String.cons 100 String.nil)))

Expand Down
4 changes: 2 additions & 2 deletions tests/snapshots/encode_pattern_match__concat_def.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/concat_def.hvm
---
TaggedScott:
(concat) = λa λb (#String (a #String λc #String λd λe (String.cons c (concat d e)) λi i) b)
(concat) = λa λb (#String (a #String λc #String λd λe (String.cons c (concat d e)) λf f) b)

(main) = (concat (String.cons 97 (String.cons 98 String.nil)) (String.cons 99 (String.cons 100 String.nil)))

Expand All @@ -12,7 +12,7 @@ TaggedScott:
(String.nil) = #String λ* #String λb b

Scott:
(concat) = λa λb (a λc λd λe (String.cons c (concat d e)) λi i b)
(concat) = λa λb (a λc λd λe (String.cons c (concat d e)) λf f b)

(main) = (concat (String.cons 97 (String.cons 98 String.nil)) (String.cons 99 (String.cons 100 String.nil)))

Expand Down
12 changes: 6 additions & 6 deletions tests/snapshots/encode_pattern_match__flatten_era_pat.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/flatten_era_pat.hvm
---
TaggedScott:
(Fn1) = λa λb (let (c, d) = a; λe (let (f, *) = d; λ* λ* f c e) b)
(Fn1) = λa λ* let (*, d) = a; let (e, *) = d; e

(Fn2) = λa let (b, c) = a; (let (d, e) = c; λf (let (g, *) = e; λ* λ* g f d) b)
(Fn2) = λa let (*, c) = a; let (*, e) = c; let (f, *) = e; f

(Fn3) = λa λb (let (c, d) = a; λe (match c { 0: λ* λ* 0; 1+: λi λ* λ* i } d e) b)
(Fn3) = λa λ* let (c, *) = a; match c { 0: 0; 1+: λf f }

(main) = (Fn2 ((1, 2), (3, (4, (5, 6)))) 0)

Scott:
(Fn1) = λa λb (let (c, d) = a; λe (let (f, *) = d; λ* λ* f c e) b)
(Fn1) = λa λ* let (*, d) = a; let (e, *) = d; e

(Fn2) = λa let (b, c) = a; (let (d, e) = c; λf (let (g, *) = e; λ* λ* g f d) b)
(Fn2) = λa let (*, c) = a; let (*, e) = c; let (f, *) = e; f

(Fn3) = λa λb (let (c, d) = a; λe (match c { 0: λ* λ* 0; 1+: λi λ* λ* i } d e) b)
(Fn3) = λa λ* let (c, *) = a; match c { 0: 0; 1+: λf f }

(main) = (Fn2 ((1, 2), (3, (4, (5, 6)))) 0)
20 changes: 10 additions & 10 deletions tests/snapshots/encode_pattern_match__list_merge_sort.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/list_merge_sort.hvm
---
TaggedScott:
(If) = λa λb λc (#Bool (a λd λ* d λ* λj j) b c)
(If) = λa λb λc (#Bool (a λd λ* d λ* λh h) b c)

(Pure) = λa (Cons a Nil)

(Map) = λa λb (#List_ (a #List_ λc #List_ λd λe let {f f_2} = e; (Cons (f c) (Map d f_2)) λ* Nil) b)
(Map) = λa λb (#List_ (a #List_ λc #List_ λd λe let {e e_2} = e; (Cons (e c) (Map d e_2)) λ* Nil) b)

(MergeSort) = λa λb (Unpack a (Map b Pure))

(Unpack) = λa λb (#List_ (b #List_ λc #List_ λd λe (#List_ (d #List_ λf #List_ λg λh λi let {o o_2} = h; (Unpack o (MergePair o_2 (Cons i (Cons f g)))) λ* λq q) e c) λ* Nil) a)
(Unpack) = λa λb (#List_ (b #List_ λc #List_ λd λe (#List_ (d #List_ λf #List_ λg λh let {h h_2} = h; λi (Unpack h (MergePair h_2 (Cons i (Cons f g)))) λ* λk k) e c) λ* Nil) a)

(MergePair) = λa λb (#List_ (b #List_ λc #List_ λd λe (#List_ (d #List_ λf #List_ λg λh λi let {m m_2} = h; (Cons (Merge m i f) (MergePair m_2 g)) λ* λo (Cons o Nil)) e c) λ* Nil) a)
(MergePair) = λa λb (#List_ (b #List_ λc #List_ λd λe (#List_ (d #List_ λf #List_ λg λh let {h h_2} = h; λi (Cons (Merge h i f) (MergePair h_2 g)) λ* λk (Cons k Nil)) e c) λ* Nil) a)

(Merge) = λa λb λc (#List_ (b #List_ λd #List_ λe λf λg (#List_ (g #List_ λh #List_ λi λj λk λl let {m m_2} = i; let {n n_2 n_3} = h; let {o o_2} = l; let {p p_2 p_3} = k; let {q q_2 q_3} = j; (If (q p n) (Cons p_2 (Merge q_2 o (Cons n_2 m))) (Cons n_3 (Merge q_3 (Cons p_3 o_2) m_2))) λ* λu λv (Cons u v)) f d e) λ* λcb cb) a c)
(Merge) = λa λb λc (#List_ (b #List_ λd #List_ λe λf λg (#List_ (g #List_ λh let {h h_2 h_3} = h; #List_ λi let {i i_2} = i; λj let {j j_2 j_3} = j; λk let {k k_2 k_3} = k; λl let {l l_2} = l; (If (j k h) (Cons k_2 (Merge j_2 l (Cons h_2 i))) (Cons h_3 (Merge j_3 (Cons k_3 l_2) i_2))) λ* λp λq (Cons p q)) f d e) λ* λt t) a c)

(Nil) = #List_ λ* #List_ λb b

Expand All @@ -26,19 +26,19 @@ TaggedScott:
(True) = #Bool λa #Bool λ* a

Scott:
(If) = λa λb λc (a λd λ* d λ* λj j b c)
(If) = λa λb λc (a λd λ* d λ* λh h b c)

(Pure) = λa (Cons a Nil)

(Map) = λa λb (a λc λd λe let {f f_2} = e; (Cons (f c) (Map d f_2)) λ* Nil b)
(Map) = λa λb (a λc λd λe let {e e_2} = e; (Cons (e c) (Map d e_2)) λ* Nil b)

(MergeSort) = λa λb (Unpack a (Map b Pure))

(Unpack) = λa λb (b λc λd λe (d λf λg λh λi let {o o_2} = h; (Unpack o (MergePair o_2 (Cons i (Cons f g)))) λ* λq q e c) λ* Nil a)
(Unpack) = λa λb (b λc λd λe (d λf λg λh let {h h_2} = h; λi (Unpack h (MergePair h_2 (Cons i (Cons f g)))) λ* λk k e c) λ* Nil a)

(MergePair) = λa λb (b λc λd λe (d λf λg λh λi let {m m_2} = h; (Cons (Merge m i f) (MergePair m_2 g)) λ* λo (Cons o Nil) e c) λ* Nil a)
(MergePair) = λa λb (b λc λd λe (d λf λg λh let {h h_2} = h; λi (Cons (Merge h i f) (MergePair h_2 g)) λ* λk (Cons k Nil) e c) λ* Nil a)

(Merge) = λa λb λc (b λd λe λf λg (g λh λi λj λk λl let {m m_2} = i; let {n n_2 n_3} = h; let {o o_2} = l; let {p p_2 p_3} = k; let {q q_2 q_3} = j; (If (q p n) (Cons p_2 (Merge q_2 o (Cons n_2 m))) (Cons n_3 (Merge q_3 (Cons p_3 o_2) m_2))) λ* λu λv (Cons u v) f d e) λ* λcb cb a c)
(Merge) = λa λb λc (b λd λe λf λg (g λh let {h h_2 h_3} = h; λi let {i i_2} = i; λj let {j j_2 j_3} = j; λk let {k k_2 k_3} = k; λl let {l l_2} = l; (If (j k h) (Cons k_2 (Merge j_2 l (Cons h_2 i))) (Cons h_3 (Merge j_3 (Cons k_3 l_2) i_2))) λ* λp λq (Cons p q) f d e) λ* λt t a c)

(Nil) = λ* λb b

Expand Down
12 changes: 6 additions & 6 deletions tests/snapshots/encode_pattern_match__match_many_args.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/match_many_args.hvm
---
TaggedScott:
(tail_tail) = λa #L (a #L λb #L λc (#L (c #L λ* #L λe λ* e λ* N) b) N)
(tail_tail) = λa #L (a #L λ* #L λc #L (c #L λ* #L λe e N) N)

(or) = λa λb (#Option (b #Option λc λ* c λf f) a)
(or) = λa λb (#Option (b #Option λc λ* c λe e) a)

(or2) = λa λb (#Option (a #Option λc λ* c λf f) b)
(or2) = λa λb (#Option (a #Option λc λ* c λe e) b)

(map) = λa λb (#Option (b #Option λc λd (d c) λ* None) a)

Expand All @@ -28,11 +28,11 @@ TaggedScott:
(C) = λa λb #L λc #L λ* #L (c a b)

Scott:
(tail_tail) = λa (a λb λc (c λ* λe λ* e λ* N b) N)
(tail_tail) = λa (a λ* λc (c λ* λe e N) N)

(or) = λa λb (b λc λ* c λf f a)
(or) = λa λb (b λc λ* c λe e a)

(or2) = λa λb (a λc λ* c λf f b)
(or2) = λa λb (a λc λ* c λe e b)

(map) = λa λb (b λc λd (d c) λ* None a)

Expand Down
Loading

0 comments on commit 81d6486

Please sign in to comment.