Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[sc-514] Unnecessary lambdas sometimes added in match var lifting #238

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 52 additions & 33 deletions src/term/transform/simplify_matches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ pub enum SimplifyMatchErr {
impl Ctx<'_> {
pub fn simplify_matches(&mut self) -> Result<(), Diagnostics> {
self.info.start_pass();
let name_gen = &mut 0;

for (def_name, def) in self.book.defs.iter_mut() {
let res = def.simplify_matches(&self.book.ctrs, &self.book.adts);
let res = def.simplify_matches(&self.book.ctrs, &self.book.adts, name_gen);
self.info.take_rule_err(res, def_name.clone());
}

Expand All @@ -29,9 +30,14 @@ impl Ctx<'_> {
}

impl Definition {
pub fn simplify_matches(&mut self, ctrs: &Constructors, adts: &Adts) -> Result<(), SimplifyMatchErr> {
pub fn simplify_matches(
&mut self,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<(), SimplifyMatchErr> {
for rule in self.rules.iter_mut() {
rule.body.simplify_matches(ctrs, adts)?;
rule.body.simplify_matches(ctrs, adts, name_gen)?;
}
Ok(())
}
Expand All @@ -42,21 +48,29 @@ impl Term {
/// arbitrary patterns into matches on a single value, with only
/// simple (non-nested) patterns, and one rule for each constructor.
///
/// The `name_gen` is used to generate fresh variable names for
/// substitution to avoid name clashes.
///
/// See `[simplify_match_expression]` for more information.
pub fn simplify_matches(&mut self, ctrs: &Constructors, adts: &Adts) -> Result<(), SimplifyMatchErr> {
pub fn simplify_matches(
&mut self,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
imaqtkatt marked this conversation as resolved.
Show resolved Hide resolved
) -> Result<(), SimplifyMatchErr> {
Term::recursive_call(move || {
match self {
Term::Mat { args, rules } => {
let extracted = extract_args(args);
let args = std::mem::take(args);
let rules = std::mem::take(rules);
let term = simplify_match_expression(args, rules, ctrs, adts)?;
let term = simplify_match_expression(args, rules, ctrs, adts, name_gen)?;
*self = bind_extracted_args(extracted, term);
}

_ => {
for child in self.children_mut() {
child.simplify_matches(ctrs, adts)?;
child.simplify_matches(ctrs, adts, name_gen)?;
}
}
}
Expand Down Expand Up @@ -92,16 +106,17 @@ fn simplify_match_expression(
rules: Vec<Rule>,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<Term, SimplifyMatchErr> {
let fst_row_irrefutable = rules[0].pats.iter().all(|p| p.is_wildcard());
let fst_col_type = infer_match_arg_type(&rules, 0, ctrs)?;

if fst_row_irrefutable {
irrefutable_fst_row_rule(args, rules, ctrs, adts)
irrefutable_fst_row_rule(args, rules, ctrs, adts, name_gen)
} else if fst_col_type == Type::Any {
var_rule(args, rules, ctrs, adts)
var_rule(args, rules, ctrs, adts, name_gen)
} else {
switch_rule(args, rules, fst_col_type, ctrs, adts)
switch_rule(args, rules, fst_col_type, ctrs, adts, name_gen)
}
}

Expand All @@ -113,17 +128,19 @@ fn irrefutable_fst_row_rule(
mut rules: Vec<Rule>,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<Term, SimplifyMatchErr> {
rules.truncate(1);

let Rule { pats, body: mut term } = rules.pop().unwrap();
term.simplify_matches(ctrs, adts)?;
term.simplify_matches(ctrs, adts, name_gen)?;

for (pat, arg) in pats.iter().zip(args.iter()) {
for bind in pat.binds().flatten() {
term.subst(bind, arg);
}
}

let term = pats.into_iter().zip(args).fold(term, |term, (pat, arg)| Term::Let {
pat,
val: Box::new(arg),
nxt: Box::new(term),
});
Ok(term)
}

Expand All @@ -137,21 +154,25 @@ fn var_rule(
rules: Vec<Rule>,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<Term, SimplifyMatchErr> {
let mut new_rules = vec![];
for mut rule in rules {
let rest = rule.pats.split_off(1);

let body =
Term::Let { pat: rule.pats.pop().unwrap(), val: Box::new(args[0].clone()), nxt: Box::new(rule.body) };
let pat = rule.pats.pop().unwrap();
let mut body = rule.body;
if let Pattern::Var(Some(nam)) = &pat {
body.subst(nam, &args[0]);
}

let new_rule = Rule { pats: rest, body };
new_rules.push(new_rule);
}

let rest = args.split_off(1);
let mut term = Term::Mat { args: rest, rules: new_rules };
term.simplify_matches(ctrs, adts)?;
term.simplify_matches(ctrs, adts, name_gen)?;
Ok(term)
}

Expand Down Expand Up @@ -206,6 +227,7 @@ fn switch_rule(
typ: Type,
ctrs: &Constructors,
adts: &Adts,
name_gen: &mut usize,
) -> Result<Term, SimplifyMatchErr> {
let mut new_rules = vec![];

Expand Down Expand Up @@ -253,10 +275,10 @@ fn switch_rule(
for ctr in adt_ctrs {
// Create the matched constructor and the name of the bound variables.
let Term::Var { nam: arg_nam } = &args[0] else { unreachable!() };
let nested_fields = switch_rule_nested_fields(arg_nam, &ctr);
let nested_fields = switch_rule_nested_fields(arg_nam, &ctr, name_gen);
let matched_ctr = switch_rule_matched_ctr(ctr.clone(), &nested_fields);
let mut body = switch_rule_submatch(&args, &rules, &matched_ctr, &nested_fields)?;
body.simplify_matches(ctrs, adts)?;
body.simplify_matches(ctrs, adts, name_gen)?;
let pats = vec![matched_ctr];
new_rules.push(Rule { pats, body });
}
Expand All @@ -265,13 +287,14 @@ fn switch_rule(
Ok(term)
}

fn switch_rule_nested_fields(arg_nam: &Name, ctr: &Pattern) -> Vec<Option<Name>> {
fn switch_rule_nested_fields(arg_nam: &Name, ctr: &Pattern, name_gen: &mut usize) -> Vec<Option<Name>> {
let mut nested_fields = vec![];
let old_vars = ctr.binds();
for old_var in old_vars {
*name_gen += 1;
let new_nam = if let Some(field) = old_var {
// Name of constructor field
Name::new(format!("{arg_nam}%{field}"))
Name::new(format!("{arg_nam}%{field}%{name_gen}"))
} else {
// Name of var pattern
arg_nam.clone()
Expand Down Expand Up @@ -339,21 +362,17 @@ fn switch_rule_submatch_arm(rule: &Rule, ctr: &Pattern, nested_fields: &[Option<
let body = rule.body.clone();
Some(Rule { pats, body })
} else if rule.pats[0].is_wildcard() {
// Var, reconstruct the value matched in the expression above.
// match x ... {var ...: Body; ...}
// becomes
// match x {
// (Ctr x%field0 ...): match x1 ... {
// x%field0 ...: let var = (Ctr x%field0 ...); Body;
// ... };
// ... }
// Use `subst` to replace the pattern variable in the body
// of the rule with the term that represents the matched constructor.
let mut body = rule.body.clone();
if let Pattern::Var(Some(nam)) = &rule.pats[0] {
body.subst(nam, &ctr.to_term());
imaqtkatt marked this conversation as resolved.
Show resolved Hide resolved
}

let nested_var_pats = nested_fields.iter().cloned().map(Pattern::Var);
let old_pats = rule.pats[1 ..].iter().cloned();
let pats = nested_var_pats.chain(old_pats).collect_vec();

let body =
Term::Let { pat: rule.pats[0].clone(), val: Box::new(ctr.to_term()), nxt: Box::new(rule.body.clone()) };

Some(Rule { pats, body })
} else {
// Non-matching constructor. Don't include in submatch expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ input_file: tests/golden_tests/compile_file_o_all/list_merge_sort.hvm
@If$C1 = (* (a a))
@Map = ({4 @Map$C0 {4 @Unpack$C3_$_MergePair$C4_$_Map$C1 a}} a)
@Map$C0 = {4 a {4 {4 @Map$C0 {4 @Unpack$C3_$_MergePair$C4_$_Map$C1 (b c)}} ({5 (a d) b} {4 {4 d {4 c e}} {4 * e}})}}
@Merge$C0 = {4 {9 a {9 b c}} {4 {7 d {4 @Merge$C0 {4 @Merge$C1 (e (f (g h)))}}} ({15 (i (a {2 @If$C0 {2 @If$C1 ({4 {4 j {4 k l}} {4 * l}} ({4 {4 c {4 h m}} {4 * m}} n))}})) {15 o e}} ({13 i {13 j f}} ({11 {4 @Merge$C2 {4 @Merge$C3 (o ({4 {4 b {4 d p}} {4 * p}} k))}} g} n)))}}
@Merge$C0 = {4 {7 a {7 b c}} {4 {9 d {4 @Merge$C0 {4 @Merge$C1 (e (f (g h)))}}} ({11 (i (a {2 @If$C0 {2 @If$C1 ({4 {4 j {4 k l}} {4 * l}} ({4 {4 c {4 h m}} {4 * m}} n))}})) {11 o e}} ({13 i {13 j f}} ({15 {4 @Merge$C2 {4 @Merge$C3 (o ({4 {4 b {4 d p}} {4 * p}} k))}} g} n)))}}
@Merge$C1 = (* @Cons)
@Merge$C2 = {4 a {4 b (c ({4 @Merge$C0 {4 @Merge$C1 (c (a (b d)))}} d))}}
@Merge$C3 = (* (a a))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/compile_file_o_all/num_pattern_with_var.hvm
---
@Foo = ({2 (* #0) {2 @Foo$C3 a}} a)
@Foo$C2 = (a (* a))
@Foo$C3 = (?<((* #0) @Foo$C2) (* a)> a)
@Foo = ({2 (* #0) {2 @Foo$C1 a}} a)
@Foo$C1 = (?<(#0 (a a)) b> b)
@main = a
& @Foo ~ (@true (#3 a))
@true = {2 * {2 a a}}
2 changes: 1 addition & 1 deletion tests/snapshots/desugar_file__used_once_names.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
source: tests/golden_tests.rs
input_file: tests/golden_tests/desugar_file/used_once_names.hvm
---
(foo) = λa λb λc let {d d_2} = c; (a b (d d_2))
(foo) = λa λb λc let {c c_2} = c; (a b (c c_2))

(main) = (foo 2 3 λa a)
4 changes: 2 additions & 2 deletions tests/snapshots/encode_pattern_match__adt_tup_era.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/adt_tup_era.hvm
---
TaggedScott:
(Foo) = λa #Tuple (a #Tuple λb #Tuple λc (#Tuple (b #Tuple λd #Tuple λ* λ* d) c))
(Foo) = λa #Tuple (a #Tuple λb #Tuple λ* #Tuple (b #Tuple λd #Tuple λ* d))

(Main) = (Foo (Pair 1 5))

(Pair) = λa λb #Tuple λc #Tuple (c a b)

Scott:
(Foo) = λa (a λb λc (b λd λ* λ* d c))
(Foo) = λa (a λb λ* (b λd λ* d))

(Main) = (Foo (Pair 1 5))

Expand Down
4 changes: 2 additions & 2 deletions tests/snapshots/encode_pattern_match__concat.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/concat.hvm
---
TaggedScott:
(String.concat) = λa λb (#String (a #String λc #String λd λe (String.cons c (String.concat d e)) λh h) b)
(String.concat) = λa λb (#String (a #String λc #String λd λe (String.cons c (String.concat d e)) λf f) b)

(main) = (String.concat (String.cons 97 (String.cons 98 String.nil)) (String.cons 99 (String.cons 100 String.nil)))

Expand All @@ -12,7 +12,7 @@ TaggedScott:
(String.nil) = #String λ* #String λb b

Scott:
(String.concat) = λa λb (a λc λd λe (String.cons c (String.concat d e)) λh h b)
(String.concat) = λa λb (a λc λd λe (String.cons c (String.concat d e)) λf f b)

(main) = (String.concat (String.cons 97 (String.cons 98 String.nil)) (String.cons 99 (String.cons 100 String.nil)))

Expand Down
4 changes: 2 additions & 2 deletions tests/snapshots/encode_pattern_match__concat_def.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/concat_def.hvm
---
TaggedScott:
(concat) = λa λb (#String (a #String λc #String λd λe (String.cons c (concat d e)) λi i) b)
(concat) = λa λb (#String (a #String λc #String λd λe (String.cons c (concat d e)) λf f) b)

(main) = (concat (String.cons 97 (String.cons 98 String.nil)) (String.cons 99 (String.cons 100 String.nil)))

Expand All @@ -12,7 +12,7 @@ TaggedScott:
(String.nil) = #String λ* #String λb b

Scott:
(concat) = λa λb (a λc λd λe (String.cons c (concat d e)) λi i b)
(concat) = λa λb (a λc λd λe (String.cons c (concat d e)) λf f b)

(main) = (concat (String.cons 97 (String.cons 98 String.nil)) (String.cons 99 (String.cons 100 String.nil)))

Expand Down
12 changes: 6 additions & 6 deletions tests/snapshots/encode_pattern_match__flatten_era_pat.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/flatten_era_pat.hvm
---
TaggedScott:
(Fn1) = λa λb (let (c, d) = a; λe (let (f, *) = d; λ* λ* f c e) b)
(Fn1) = λa λ* let (*, d) = a; let (e, *) = d; e

(Fn2) = λa let (b, c) = a; (let (d, e) = c; λf (let (g, *) = e; λ* λ* g f d) b)
(Fn2) = λa let (*, c) = a; let (*, e) = c; let (f, *) = e; f

(Fn3) = λa λb (let (c, d) = a; λe (match c { 0: λ* λ* 0; 1+: λi λ* λ* i } d e) b)
(Fn3) = λa λ* let (c, *) = a; match c { 0: 0; 1+: λf f }

(main) = (Fn2 ((1, 2), (3, (4, (5, 6)))) 0)

Scott:
(Fn1) = λa λb (let (c, d) = a; λe (let (f, *) = d; λ* λ* f c e) b)
(Fn1) = λa λ* let (*, d) = a; let (e, *) = d; e

(Fn2) = λa let (b, c) = a; (let (d, e) = c; λf (let (g, *) = e; λ* λ* g f d) b)
(Fn2) = λa let (*, c) = a; let (*, e) = c; let (f, *) = e; f

(Fn3) = λa λb (let (c, d) = a; λe (match c { 0: λ* λ* 0; 1+: λi λ* λ* i } d e) b)
(Fn3) = λa λ* let (c, *) = a; match c { 0: 0; 1+: λf f }

(main) = (Fn2 ((1, 2), (3, (4, (5, 6)))) 0)
20 changes: 10 additions & 10 deletions tests/snapshots/encode_pattern_match__list_merge_sort.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/list_merge_sort.hvm
---
TaggedScott:
(If) = λa λb λc (#Bool (a λd λ* d λ* λj j) b c)
(If) = λa λb λc (#Bool (a λd λ* d λ* λh h) b c)

(Pure) = λa (Cons a Nil)

(Map) = λa λb (#List_ (a #List_ λc #List_ λd λe let {f f_2} = e; (Cons (f c) (Map d f_2)) λ* Nil) b)
(Map) = λa λb (#List_ (a #List_ λc #List_ λd λe let {e e_2} = e; (Cons (e c) (Map d e_2)) λ* Nil) b)

(MergeSort) = λa λb (Unpack a (Map b Pure))

(Unpack) = λa λb (#List_ (b #List_ λc #List_ λd λe (#List_ (d #List_ λf #List_ λg λh λi let {o o_2} = h; (Unpack o (MergePair o_2 (Cons i (Cons f g)))) λ* λq q) e c) λ* Nil) a)
(Unpack) = λa λb (#List_ (b #List_ λc #List_ λd λe (#List_ (d #List_ λf #List_ λg λh let {h h_2} = h; λi (Unpack h (MergePair h_2 (Cons i (Cons f g)))) λ* λk k) e c) λ* Nil) a)

(MergePair) = λa λb (#List_ (b #List_ λc #List_ λd λe (#List_ (d #List_ λf #List_ λg λh λi let {m m_2} = h; (Cons (Merge m i f) (MergePair m_2 g)) λ* λo (Cons o Nil)) e c) λ* Nil) a)
(MergePair) = λa λb (#List_ (b #List_ λc #List_ λd λe (#List_ (d #List_ λf #List_ λg λh let {h h_2} = h; λi (Cons (Merge h i f) (MergePair h_2 g)) λ* λk (Cons k Nil)) e c) λ* Nil) a)

(Merge) = λa λb λc (#List_ (b #List_ λd #List_ λe λf λg (#List_ (g #List_ λh #List_ λi λj λk λl let {m m_2} = i; let {n n_2 n_3} = h; let {o o_2} = l; let {p p_2 p_3} = k; let {q q_2 q_3} = j; (If (q p n) (Cons p_2 (Merge q_2 o (Cons n_2 m))) (Cons n_3 (Merge q_3 (Cons p_3 o_2) m_2))) λ* λu λv (Cons u v)) f d e) λ* λcb cb) a c)
(Merge) = λa λb λc (#List_ (b #List_ λd #List_ λe λf λg (#List_ (g #List_ λh let {h h_2 h_3} = h; #List_ λi let {i i_2} = i; λj let {j j_2 j_3} = j; λk let {k k_2 k_3} = k; λl let {l l_2} = l; (If (j k h) (Cons k_2 (Merge j_2 l (Cons h_2 i))) (Cons h_3 (Merge j_3 (Cons k_3 l_2) i_2))) λ* λp λq (Cons p q)) f d e) λ* λt t) a c)

(Nil) = #List_ λ* #List_ λb b

Expand All @@ -26,19 +26,19 @@ TaggedScott:
(True) = #Bool λa #Bool λ* a

Scott:
(If) = λa λb λc (a λd λ* d λ* λj j b c)
(If) = λa λb λc (a λd λ* d λ* λh h b c)

(Pure) = λa (Cons a Nil)

(Map) = λa λb (a λc λd λe let {f f_2} = e; (Cons (f c) (Map d f_2)) λ* Nil b)
(Map) = λa λb (a λc λd λe let {e e_2} = e; (Cons (e c) (Map d e_2)) λ* Nil b)

(MergeSort) = λa λb (Unpack a (Map b Pure))

(Unpack) = λa λb (b λc λd λe (d λf λg λh λi let {o o_2} = h; (Unpack o (MergePair o_2 (Cons i (Cons f g)))) λ* λq q e c) λ* Nil a)
(Unpack) = λa λb (b λc λd λe (d λf λg λh let {h h_2} = h; λi (Unpack h (MergePair h_2 (Cons i (Cons f g)))) λ* λk k e c) λ* Nil a)

(MergePair) = λa λb (b λc λd λe (d λf λg λh λi let {m m_2} = h; (Cons (Merge m i f) (MergePair m_2 g)) λ* λo (Cons o Nil) e c) λ* Nil a)
(MergePair) = λa λb (b λc λd λe (d λf λg λh let {h h_2} = h; λi (Cons (Merge h i f) (MergePair h_2 g)) λ* λk (Cons k Nil) e c) λ* Nil a)

(Merge) = λa λb λc (b λd λe λf λg (g λh λi λj λk λl let {m m_2} = i; let {n n_2 n_3} = h; let {o o_2} = l; let {p p_2 p_3} = k; let {q q_2 q_3} = j; (If (q p n) (Cons p_2 (Merge q_2 o (Cons n_2 m))) (Cons n_3 (Merge q_3 (Cons p_3 o_2) m_2))) λ* λu λv (Cons u v) f d e) λ* λcb cb a c)
(Merge) = λa λb λc (b λd λe λf λg (g λh let {h h_2 h_3} = h; λi let {i i_2} = i; λj let {j j_2 j_3} = j; λk let {k k_2 k_3} = k; λl let {l l_2} = l; (If (j k h) (Cons k_2 (Merge j_2 l (Cons h_2 i))) (Cons h_3 (Merge j_3 (Cons k_3 l_2) i_2))) λ* λp λq (Cons p q) f d e) λ* λt t a c)

(Nil) = λ* λb b

Expand Down
12 changes: 6 additions & 6 deletions tests/snapshots/encode_pattern_match__match_many_args.hvm.snap
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ source: tests/golden_tests.rs
input_file: tests/golden_tests/encode_pattern_match/match_many_args.hvm
---
TaggedScott:
(tail_tail) = λa #L (a #L λb #L λc (#L (c #L λ* #L λe λ* e λ* N) b) N)
(tail_tail) = λa #L (a #L λ* #L λc #L (c #L λ* #L λe e N) N)

(or) = λa λb (#Option (b #Option λc λ* c λf f) a)
(or) = λa λb (#Option (b #Option λc λ* c λe e) a)

(or2) = λa λb (#Option (a #Option λc λ* c λf f) b)
(or2) = λa λb (#Option (a #Option λc λ* c λe e) b)

(map) = λa λb (#Option (b #Option λc λd (d c) λ* None) a)

Expand All @@ -28,11 +28,11 @@ TaggedScott:
(C) = λa λb #L λc #L λ* #L (c a b)

Scott:
(tail_tail) = λa (a λb λc (c λ* λe λ* e λ* N b) N)
(tail_tail) = λa (a λ* λc (c λ* λe e N) N)

(or) = λa λb (b λc λ* c λf f a)
(or) = λa λb (b λc λ* c λe e a)

(or2) = λa λb (a λc λ* c λf f b)
(or2) = λa λb (a λc λ* c λe e b)

(map) = λa λb (b λc λd (d c) λ* None a)

Expand Down
Loading
Loading