diff --git a/src/ast.rs b/src/ast.rs index 0e9adb7f..17560ce7 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -146,10 +146,10 @@ pub const MAX_ADT_FIELDS: usize = MAX_ARITY - 1; impl Net { pub fn trees(&self) -> impl Iterator { - std::iter::once(&self.root).chain(self.redexes.iter().map(|(x, y)| [x, y]).flatten()) + std::iter::once(&self.root).chain(self.redexes.iter().flat_map(|(x, y)| [x, y])) } pub fn trees_mut(&mut self) -> impl Iterator { - std::iter::once(&mut self.root).chain(self.redexes.iter_mut().map(|(x, y)| [x, y]).flatten()) + std::iter::once(&mut self.root).chain(self.redexes.iter_mut().flat_map(|(x, y)| [x, y])) } } diff --git a/src/transform/eta_reduce.rs b/src/transform/eta_reduce.rs index 2a7f091d..67e2b4ad 100644 --- a/src/transform/eta_reduce.rs +++ b/src/transform/eta_reduce.rs @@ -1,19 +1,22 @@ -use std::collections::HashMap; +use std::{collections::HashMap, ops::RangeFrom}; use crate::ast::{Net, Tree}; -/// Converts (x y), (x y) into x, x +/// Converts `(x y), (x y)` into `x, x`. impl Net { pub fn eta_reduce(&mut self) { let mut phase1 = Phase1::default(); - phase1.walk_and_sort_net(self); - println!("{:?}", phase1.gaps); - let mut phase2 = Phase2 { gaps: phase1.gaps }; - phase2.reduce_net(self); + for tree in self.trees() { + phase1.walk_tree(tree); + } + let mut phase2 = Phase2 { nodes: phase1.nodes, index: 0 .. }; + for tree in self.trees_mut() { + phase2.reduce_tree(tree); + } } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] enum NodeType { Var(isize), Ctr(u16), @@ -24,43 +27,36 @@ enum NodeType { #[derive(Default)] struct Phase1<'a> { vars: HashMap<&'a str, usize>, - gaps: Vec, + nodes: Vec, } impl<'a> Phase1<'a> { - fn walk_and_sort_net(&mut self, net: &'a mut Net) { - self.walk_and_sort_tree(&mut net.root); - for (a, b) in net.redexes.iter_mut() { - self.walk_and_sort_tree(a); - self.walk_and_sort_tree(b); - } - } - fn walk_and_sort_tree(&mut self, tree: &'a mut Tree) { + fn walk_tree(&mut self, tree: &'a Tree) { match tree { Tree::Ctr { lab, ports } => { let last_port = ports.len() - 1; - for (idx, i) in ports.iter_mut().enumerate() { + for (idx, i) in ports.iter().enumerate() { if idx != last_port { - self.gaps.push(NodeType::Ctr(*lab)); + self.nodes.push(NodeType::Ctr(*lab)); } - self.walk_and_sort_tree(i); + self.walk_tree(i); } } Tree::Var { nam } => { let nam: &str = &*nam; if let Some(i) = self.vars.get(nam) { - let j = self.gaps.len() as isize; - self.gaps.push(NodeType::Var(*i as isize - j)); - self.gaps[*i] = NodeType::Var(j - *i as isize); + let j = self.nodes.len() as isize; + self.nodes.push(NodeType::Var(*i as isize - j)); + self.nodes[*i] = NodeType::Var(j - *i as isize); } else { - self.vars.insert(nam, self.gaps.len()); - self.gaps.push(NodeType::Hole); + self.vars.insert(nam, self.nodes.len()); + self.nodes.push(NodeType::Hole); } } _ => { - self.gaps.push(NodeType::Other); - for i in tree.children_mut() { - self.walk_and_sort_tree(i); + self.nodes.push(NodeType::Other); + for i in tree.children() { + self.walk_tree(i); } } } @@ -68,59 +64,44 @@ impl<'a> Phase1<'a> { } struct Phase2 { - gaps: Vec, + nodes: Vec, + index: RangeFrom, } impl Phase2 { - fn reduce_net(&mut self, net: &mut Net) { - let mut index = 0; - self.reduce_tree(&mut net.root, &mut index); - for (a, b) in net.redexes.iter_mut() { - self.reduce_tree(a, &mut index); - self.reduce_tree(b, &mut index); + fn reduce_ctr(&mut self, lab: u16, ports: &mut Vec, skip: usize) -> NodeType { + if skip == ports.len() { + return NodeType::Other; } - } - fn reduce_tree(&mut self, tree: &mut Tree, index: &mut usize) { - match tree { - ctr @ Tree::Ctr { .. } => { - let Tree::Ctr { lab, ports } = ctr else { unreachable!() }; - // reduce from the inside of the n-ary node to the outside - let mut indices = vec![]; - let last_port = ports.len() - 1; - for (idx, i) in ports.iter_mut().enumerate() { - indices.push(*index); - if idx != last_port { - *index += 1; - } - self.reduce_tree(i, index); - } - while indices.len() > 1 { - let tail_var = indices.pop().unwrap(); - let head_ctr = indices.pop().unwrap(); - let head_var = head_ctr + 1; - if let (NodeType::Var(off1), NodeType::Var(off2)) = (&self.gaps[head_var], &self.gaps[tail_var]) { - if off1 == off2 { - if let NodeType::Ctr(other_lab) = &self.gaps[head_ctr.strict_add_signed(*off1)] { - if other_lab == lab { - indices.push(head_var); - ports.pop(); - continue; - } - } - } - } - break; - } - if ports.len() == 1 { - *ctr = ports.pop().unwrap(); + if skip == ports.len() - 1 { + return self.reduce_tree(&mut ports[skip]); + } + let head_index = self.index.next().unwrap(); + let a = self.reduce_tree(&mut ports[skip]); + let b = self.reduce_ctr(lab, ports, skip + 1); + if a == b { + if let NodeType::Var(delta) = a { + if self.nodes[head_index.wrapping_add_signed(delta)] == NodeType::Ctr(lab) { + ports.pop(); + return NodeType::Var(delta); } } - tree => { - *index += 1; - for i in tree.children_mut() { - self.reduce_tree(i, index); - } + } + NodeType::Ctr(lab) + } + fn reduce_tree(&mut self, tree: &mut Tree) -> NodeType { + if let Tree::Ctr { lab, ports } = tree { + let ty = self.reduce_ctr(*lab, ports, 0); + if ports.len() == 1 { + *tree = ports.pop().unwrap(); + } + ty + } else { + let index = self.index.next().unwrap(); + for i in tree.children_mut() { + self.reduce_tree(i); } + self.nodes[index] } } } diff --git a/tests/transform.rs b/tests/transform.rs index b81aeeec..ef4b16eb 100644 --- a/tests/transform.rs +++ b/tests/transform.rs @@ -78,4 +78,7 @@ pub fn test_eta() { a & (a c) ~ c "###); + assert_display_snapshot!(parse_and_reduce("((a b) [a b])"), @"((a b) [a b])"); + assert_display_snapshot!(parse_and_reduce("((a b c) b c)"), @"((a b) b)"); + assert_display_snapshot!(parse_and_reduce("([(a b) (c d)] [(a b) (c d)])"), @"(a a)"); }