diff --git a/src/term/check/ctrs_arities.rs b/src/term/check/ctrs_arities.rs index 8d426e6dd..d40e787a9 100644 --- a/src/term/check/ctrs_arities.rs +++ b/src/term/check/ctrs_arities.rs @@ -49,20 +49,15 @@ impl Pattern { let mut to_check = vec![self]; while let Some(pat) = to_check.pop() { - match pat { - Pattern::Ctr(name, args) => { - let expected = arities.get(name).unwrap(); - let found = args.len(); - if *expected != found { - return Err(MatchErr::CtrArityMismatch(name.clone(), found, *expected)); - } + if let Pattern::Ctr(name, args) = pat { + let expected = arities.get(name).unwrap(); + let found = args.len(); + if *expected != found { + return Err(MatchErr::CtrArityMismatch(name.clone(), found, *expected)); } - Pattern::Lst(els) | Pattern::Tup(els) => { - for el in els { - to_check.push(el); - } - } - Pattern::Var(..) | Pattern::Num(..) | Pattern::Str(_) => {} + } + for child in pat.children() { + to_check.push(child); } } Ok(()) diff --git a/src/term/check/unbound_pats.rs b/src/term/check/unbound_pats.rs index d25fd152a..5a605e099 100644 --- a/src/term/check/unbound_pats.rs +++ b/src/term/check/unbound_pats.rs @@ -46,16 +46,12 @@ impl Pattern { let mut unbounds = HashSet::new(); let mut check = vec![self]; while let Some(pat) = check.pop() { - match pat { - Pattern::Ctr(nam, args) => { - if !is_ctr(nam) { - unbounds.insert(nam.clone()); - } - check.extend(args.iter()); + if let Pattern::Ctr(nam, _) = pat { + if !is_ctr(nam) { + unbounds.insert(nam.clone()); } - Pattern::Tup(args) | Pattern::Lst(args) => args.iter().for_each(|arg| check.push(arg)), - Pattern::Var(_) | Pattern::Num(_) | Pattern::Str(_) => {} } + check.extend(pat.children()); } unbounds } diff --git a/src/term/mod.rs b/src/term/mod.rs index f6552bdf9..ac7231b03 100644 --- a/src/term/mod.rs +++ b/src/term/mod.rs @@ -846,6 +846,15 @@ impl Pattern { } } + pub fn children_mut(&mut self) -> ChildrenIter<&mut Pattern> { + match self { + Pattern::Ctr(_, els) | Pattern::Tup(els) | Pattern::Lst(els) => { + ChildrenIter::Many(Box::new(els.iter_mut())) + } + Pattern::Var(_) | Pattern::Num(_) | Pattern::Str(_) => ChildrenIter::zero(), + } + } + /// Returns an iterator over each subpattern in depth-first, left to right order. // TODO: Not lazy. pub fn iter(&self) -> impl DoubleEndedIterator + Clone { diff --git a/src/term/transform/definition_pruning.rs b/src/term/transform/definition_pruning.rs index 45a4ab713..8189315ab 100644 --- a/src/term/transform/definition_pruning.rs +++ b/src/term/transform/definition_pruning.rs @@ -97,20 +97,17 @@ impl Book { Some(name) => self.insert_ctrs_used(name, uses, adt_encoding), None => self.insert_used(def_name, used, uses, adt_encoding), }, - Term::Lst { els } => { + Term::Lst { .. } => { self.insert_ctrs_used(&Name::from(LIST), uses, adt_encoding); - for term in els { - to_find.push(term); - } } Term::Str { .. } => { self.insert_ctrs_used(&Name::from(STRING), uses, adt_encoding); } - _ => { - for child in term.children() { - to_find.push(child); - } - } + _ => {} + } + + for child in term.children() { + to_find.push(child); } } } diff --git a/src/term/transform/desugar_implicit_match_binds.rs b/src/term/transform/desugar_implicit_match_binds.rs index 81bdbaa81..5aa9a70dc 100644 --- a/src/term/transform/desugar_implicit_match_binds.rs +++ b/src/term/transform/desugar_implicit_match_binds.rs @@ -12,77 +12,62 @@ impl Book { impl Term { pub fn desugar_implicit_match_binds(&mut self, ctrs: &Constructors, adts: &Adts) { - let mut to_desugar = vec![self]; - - while let Some(term) = to_desugar.pop() { - match term { - Term::Mat { args, rules } => { - // Make all the matched terms variables - let mut match_args = vec![]; - for arg in args.iter_mut() { - if let Term::Var { nam } = arg { - match_args.push((nam.clone(), None)) - } else { - let nam = Name::new(format!("%matched_{}", match_args.len())); - let arg = std::mem::replace(arg, Term::Var { nam: nam.clone() }); - match_args.push((nam, Some(arg))); - } + Term::recursive_call(move || { + for child in self.children_mut() { + child.desugar_implicit_match_binds(ctrs, adts); + } + if let Term::Mat { args, rules } = self { + // Make all the matched terms variables + let mut match_args = vec![]; + for arg in args.iter_mut() { + if let Term::Var { nam } = arg { + match_args.push((nam.clone(), None)) + } else { + let nam = Name::new(format!("%matched_{}", match_args.len())); + let arg = std::mem::replace(arg, Term::Var { nam: nam.clone() }); + match_args.push((nam, Some(arg))); } + } - // Make implicit match binds explicit - for rule in rules.iter_mut() { - for ((nam, _), pat) in match_args.iter().zip(rule.pats.iter_mut()) { - match pat { - Pattern::Var(_) => (), - Pattern::Ctr(ctr_nam, pat_args) => { - let adt = &adts[ctrs.get(ctr_nam).unwrap()]; - let ctr_args = adt.ctrs.get(ctr_nam).unwrap(); - if pat_args.is_empty() && !ctr_args.is_empty() { - // Implicit ctr args - *pat_args = ctr_args - .iter() - .map(|field| Pattern::Var(Some(Name::new(format!("{nam}.{field}"))))) - .collect(); - } - } - Pattern::Num(NumCtr::Num(_)) => (), - Pattern::Num(NumCtr::Succ(_, Some(_))) => (), - Pattern::Num(NumCtr::Succ(n, p @ None)) => { - // Implicit num arg - *p = Some(Some(Name::new(format!("{nam}-{n}")))); + // Make implicit match binds explicit + for rule in rules.iter_mut() { + for ((nam, _), pat) in match_args.iter().zip(rule.pats.iter_mut()) { + match pat { + Pattern::Var(_) => (), + Pattern::Ctr(ctr_nam, pat_args) => { + let adt = &adts[ctrs.get(ctr_nam).unwrap()]; + let ctr_args = adt.ctrs.get(ctr_nam).unwrap(); + if pat_args.is_empty() && !ctr_args.is_empty() { + // Implicit ctr args + *pat_args = ctr_args + .iter() + .map(|field| Pattern::Var(Some(Name::new(format!("{nam}.{field}"))))) + .collect(); } - Pattern::Tup(..) => (), - Pattern::Lst(..) => (), - Pattern::Str(..) => (), } + Pattern::Num(NumCtr::Num(_)) => (), + Pattern::Num(NumCtr::Succ(_, Some(_))) => (), + Pattern::Num(NumCtr::Succ(n, p @ None)) => { + // Implicit num arg + *p = Some(Some(Name::new(format!("{nam}-{n}")))); + } + Pattern::Tup(..) => (), + Pattern::Lst(..) => (), + Pattern::Str(..) => (), } } - - // Add the binds to the extracted term vars. - *term = match_args.into_iter().rev().fold(std::mem::take(term), |nxt, (nam, val)| { - if let Some(val) = val { - // Non-Var term that was extracted. - Term::Let { pat: Pattern::Var(Some(nam)), val: Box::new(val), nxt: Box::new(nxt) } - } else { - nxt - } - }); - - // Add the next values to check - let mut term = term; - while let Term::Let { nxt, .. } = term { - term = nxt; - } - let Term::Mat { args: _, rules } = term else { unreachable!() }; - to_desugar.extend(rules.iter_mut().map(|r| &mut r.body)); } - _ => { - for child in term.children_mut() { - to_desugar.push(child); + // Add the binds to the extracted term vars. + *self = match_args.into_iter().rev().fold(std::mem::take(self), |nxt, (nam, val)| { + if let Some(val) = val { + // Non-Var term that was extracted. + Term::Let { pat: Pattern::Var(Some(nam)), val: Box::new(val), nxt: Box::new(nxt) } + } else { + nxt } - } + }); } - } + }) } } diff --git a/src/term/transform/desugar_let_destructors.rs b/src/term/transform/desugar_let_destructors.rs index 9dc441fb7..6d840f39f 100644 --- a/src/term/transform/desugar_let_destructors.rs +++ b/src/term/transform/desugar_let_destructors.rs @@ -13,20 +13,17 @@ impl Book { impl Term { pub fn desugar_let_destructors(&mut self) { - Term::recursive_call(move || match self { - // Only transform `let`s that are not on variables. - Term::Let { pat: Pattern::Var(_), .. } => { - for child in self.children_mut() { - child.desugar_let_destructors(); - } + Term::recursive_call(move || { + for child in self.children_mut() { + child.desugar_let_destructors(); } - Term::Let { pat, val, nxt } => { - let pat = std::mem::replace(pat, Pattern::Var(None)); - let mut val = std::mem::take(val); - let mut nxt = std::mem::take(nxt); - val.desugar_let_destructors(); - nxt.desugar_let_destructors(); + if let Term::Let { pat, val, nxt } = self + && !pat.is_wildcard() + { + let pat = std::mem::replace(pat, Pattern::Var(None)); + let val = std::mem::take(val); + let nxt = std::mem::take(nxt); let rules = vec![Rule { pats: vec![pat], body: *nxt }]; @@ -39,12 +36,6 @@ impl Term { Term::Let { pat, val, nxt: Box::new(Term::Mat { args, rules }) } }; } - - _ => { - for child in self.children_mut() { - child.desugar_let_destructors(); - } - } }) } } diff --git a/src/term/transform/encode_pattern_matching.rs b/src/term/transform/encode_pattern_matching.rs index e3ab94f03..978e31168 100644 --- a/src/term/transform/encode_pattern_matching.rs +++ b/src/term/transform/encode_pattern_matching.rs @@ -23,23 +23,18 @@ impl Book { impl Term { pub fn encode_simple_matches(&mut self, ctrs: &Constructors, adts: &Adts, adt_encoding: AdtEncoding) { - Term::recursive_call(move || match self { - Term::Mat { .. } => { + Term::recursive_call(move || { + for child in self.children_mut() { + child.encode_simple_matches(ctrs, adts, adt_encoding) + } + + if let Term::Mat { .. } = self { debug_assert!(self.is_simple_match(ctrs, adts), "{self}"); let Term::Mat { args, rules } = self else { unreachable!() }; - for rule in rules.iter_mut() { - rule.body.encode_simple_matches(ctrs, adts, adt_encoding); - } let arg = std::mem::take(&mut args[0]); let rules = std::mem::take(rules); *self = encode_match(arg, rules, ctrs, adt_encoding); } - - _ => { - for child in self.children_mut() { - child.encode_simple_matches(ctrs, adts, adt_encoding) - } - } }) } } diff --git a/src/term/transform/eta_reduction.rs b/src/term/transform/eta_reduction.rs index 83b66a06d..bb7b32fc5 100644 --- a/src/term/transform/eta_reduction.rs +++ b/src/term/transform/eta_reduction.rs @@ -14,34 +14,27 @@ impl Term { /// Eta-reduces a term and any subterms. /// Expects variables to be linear. pub fn eta_reduction(&mut self) { - Term::recursive_call(move || match self { - Term::Lam { tag: lam_tag, nam: Some(lam_var), bod } => { - bod.eta_reduction(); - match bod.as_mut() { - Term::App { tag: arg_tag, fun, arg: box Term::Var { nam: var_nam } } - if lam_var == var_nam && lam_tag == arg_tag => - { - *self = std::mem::take(fun.as_mut()); - } - _ => {} - } + Term::recursive_call(move || { + for child in self.children_mut() { + child.eta_reduction() } - Term::Chn { tag: chn_tag, nam: chn_var, bod } => { - bod.eta_reduction(); - match bod.as_mut() { - Term::App { tag: arg_tag, fun, arg: box Term::Lnk { nam: var_nam } } - if chn_var == var_nam && chn_tag == arg_tag => - { - *self = std::mem::take(fun.as_mut()); - } - _ => {} + match self { + Term::Lam { + tag: lam_tag, + nam: lam_var, + bod: box Term::App { tag: arg_tag, fun, arg: box Term::Var { nam: var_nam } }, + } if lam_var == var_nam && lam_tag == arg_tag => { + *self = std::mem::take(fun.as_mut()); } - } - - _ => { - for child in self.children_mut() { - child.eta_reduction() + Term::Chn { + tag: chn_tag, + nam: chn_var, + bod: box Term::App { tag: arg_tag, fun, arg: box Term::Lnk { nam: var_nam } }, + } if chn_var == var_nam && chn_tag == arg_tag => { + *self = std::mem::take(fun.as_mut()); } + + _ => {} } }) } diff --git a/src/term/transform/inline.rs b/src/term/transform/inline.rs index f924c5e80..a0deb26e1 100644 --- a/src/term/transform/inline.rs +++ b/src/term/transform/inline.rs @@ -25,18 +25,14 @@ impl Term { let mut to_inline = vec![self]; while let Some(term) = to_inline.pop() { - match term { - Term::Ref { nam: def_name } => { - if inlineables.contains(def_name) { - *term = defs.get(def_name).unwrap().rule().body.clone(); - } + if let Term::Ref { nam: def_name } = term { + if inlineables.contains(def_name) { + *term = defs.get(def_name).unwrap().rule().body.clone(); } + } - _ => { - for child in term.children_mut() { - to_inline.push(child); - } - } + for child in term.children_mut() { + to_inline.push(child); } } } diff --git a/src/term/transform/match_defs_to_term.rs b/src/term/transform/match_defs_to_term.rs index 65e33aa83..97c603ad9 100644 --- a/src/term/transform/match_defs_to_term.rs +++ b/src/term/transform/match_defs_to_term.rs @@ -10,10 +10,28 @@ impl Book { } impl Definition { - /// Converts a pattern matching function with multiple rules and args, into a single rule without pattern matching. + /// Converts a pattern matching function with multiple rules and args, into a + /// single rule without pattern matching. + /// /// Moves the pattern matching of the rules into a complex match expression. /// - /// Preconditions: Rule arities must be correct + /// Example: + /// + /// ```hvm + /// if True then else = then + /// if False then else = else + /// ``` + /// + /// becomes + /// + /// ```hvm + /// if = @%x0 @%x1 @%x2 match %x0, %x1, %x2 { + /// True then else: then + /// False then else: else + /// } + /// ``` + /// + /// Preconditions: Rule arities must be correct. pub fn convert_match_def_to_term(&mut self) { let rule = def_rules_to_match(std::mem::take(&mut self.rules)); self.rules = vec![rule]; diff --git a/src/term/transform/resolve_ctrs_in_pats.rs b/src/term/transform/resolve_ctrs_in_pats.rs index bafa526f6..39a8a155f 100644 --- a/src/term/transform/resolve_ctrs_in_pats.rs +++ b/src/term/transform/resolve_ctrs_in_pats.rs @@ -1,9 +1,12 @@ use crate::term::{Book, Name, Pattern, Term}; impl Book { - /// Resolve Constructor names inside rule patterns and match patterns. - /// When parsing a rule we don't have all the constructors yet, - /// so no way to know if a particular name belongs to a constructor or is a matched variable. + /// Resolve Constructor names inside rule patterns and match patterns, + /// converting `Pattern::Var(Some(nam))` into `Pattern::Ctr(nam, vec![])` + /// when the name is that of a constructor. + /// + /// When parsing a rule we don't have all the constructors yet, so no way to + /// know if a particular name belongs to a constructor or is a matched variable. /// Therefore we must do it later, here. pub fn resolve_ctrs_in_pats(&mut self) { let is_ctr = |nam: &Name| self.ctrs.contains_key(nam); @@ -23,21 +26,13 @@ impl Pattern { let mut to_resolve = vec![self]; while let Some(pat) = to_resolve.pop() { - match pat { - Pattern::Var(Some(nam)) => { - if is_ctr(nam) { - *pat = Pattern::Ctr(nam.clone(), vec![]); - } + if let Pattern::Var(Some(nam)) = pat { + if is_ctr(nam) { + *pat = Pattern::Ctr(nam.clone(), vec![]); } - Pattern::Ctr(_, args) | Pattern::Lst(args) | Pattern::Tup(args) => { - for arg in args { - to_resolve.push(arg); - } - } - Pattern::Var(None) => (), - Pattern::Num(_) => (), - Pattern::Str(_) => (), } + + to_resolve.extend(pat.children_mut()); } } } diff --git a/src/term/transform/resolve_refs.rs b/src/term/transform/resolve_refs.rs index 3128da306..311cde6d6 100644 --- a/src/term/transform/resolve_refs.rs +++ b/src/term/transform/resolve_refs.rs @@ -19,7 +19,12 @@ impl Display for ReferencedMainErr { impl Ctx<'_> { /// Decides if names inside a term belong to a Var or to a Ref. + /// Converts `Term::Var(nam)` into `Term::Ref(nam)` when the name + /// refers to a function definition and there is no variable in + /// scope shadowing that definition. + /// /// Precondition: Refs are encoded as vars, Constructors are resolved. + /// /// Postcondition: Refs are encoded as refs, with the correct def id. pub fn resolve_refs(&mut self) -> Result<(), Info> { self.info.start_pass(); @@ -50,32 +55,30 @@ impl Term { scope: &mut HashMap<&'a Name, usize>, ) -> Result<(), ReferencedMainErr> { Term::recursive_call(move || { - match self { - // If variable not defined, we check if it's a ref and swap if it is. - Term::Var { nam } => { - if is_var_in_scope(nam, scope) { - if let Some(main) = main { - if nam == main { - return Err(ReferencedMainErr); - } - } + if let Term::Var { nam } = self + && is_var_in_scope(nam, scope) + { + // If the variable is actually a reference to main, don't swap and return an error. + if let Some(main) = main + && nam == main + { + return Err(ReferencedMainErr); + } + + // If the variable is actually a reference to a function, swap the term. + if def_names.contains(nam) || CORE_BUILTINS.contains(&nam.0.as_ref()) { + *self = Term::r#ref(nam); + } + } - if def_names.contains(nam) || CORE_BUILTINS.contains(&nam.0.as_ref()) { - *self = Term::r#ref(nam); - } - } + for (child, binds) in self.children_mut_with_binds() { + let binds: Vec<_> = binds.collect(); + for bind in binds.iter() { + push_scope(bind.as_ref(), scope); } - _ => { - for (child, binds) in self.children_mut_with_binds() { - let binds: Vec<_> = binds.collect(); - for bind in binds.iter() { - push_scope(bind.as_ref(), scope); - } - child.resolve_refs(def_names, main, scope)?; - for bind in binds.iter() { - pop_scope(bind.as_ref(), scope); - } - } + child.resolve_refs(def_names, main, scope)?; + for bind in binds.iter() { + pop_scope(bind.as_ref(), scope); } } Ok(()) diff --git a/tests/golden_tests/compile_file/implicit_match_in_match_arg.hvm b/tests/golden_tests/compile_file/implicit_match_in_match_arg.hvm new file mode 100644 index 000000000..641753747 --- /dev/null +++ b/tests/golden_tests/compile_file/implicit_match_in_match_arg.hvm @@ -0,0 +1,4 @@ +main x = match (match x {0: 0; 1+: x-1}) { + 0: 0 + 1+x-2: x-2 +} \ No newline at end of file diff --git a/tests/golden_tests/compile_file/nested_ctr_wrong_arity.hvm b/tests/golden_tests/compile_file/nested_ctr_wrong_arity.hvm new file mode 100644 index 000000000..34c762d1b --- /dev/null +++ b/tests/golden_tests/compile_file/nested_ctr_wrong_arity.hvm @@ -0,0 +1,5 @@ +data Pair = (Pair fst snd) + +fst_fst (Pair (Pair fst) *) = fst + +main = (fst_fst (Pair (Pair 1 2) 3)) \ No newline at end of file diff --git a/tests/snapshots/compile_file__implicit_match_in_match_arg.hvm.snap b/tests/snapshots/compile_file__implicit_match_in_match_arg.hvm.snap new file mode 100644 index 000000000..c9c278a0c --- /dev/null +++ b/tests/snapshots/compile_file__implicit_match_in_match_arg.hvm.snap @@ -0,0 +1,5 @@ +--- +source: tests/golden_tests.rs +input_file: tests/golden_tests/compile_file/implicit_match_in_match_arg.hvm +--- +@main = (?<(#0 (a a)) ?<(#0 (b b)) c>> c) diff --git a/tests/snapshots/compile_file__nested_ctr_wrong_arity.hvm.snap b/tests/snapshots/compile_file__nested_ctr_wrong_arity.hvm.snap new file mode 100644 index 000000000..b1bb3e2e0 --- /dev/null +++ b/tests/snapshots/compile_file__nested_ctr_wrong_arity.hvm.snap @@ -0,0 +1,6 @@ +--- +source: tests/golden_tests.rs +input_file: tests/golden_tests/compile_file/nested_ctr_wrong_arity.hvm +--- +In definition 'fst_fst': + Constructor arity mismatch in pattern matching. Constructor 'Pair' expects 2 fields, found 1.