Skip to content

Commit

Permalink
fix multiple nested ctrs
Browse files Browse the repository at this point in the history
  • Loading branch information
tjjfvi committed Mar 14, 2024
1 parent c157e49 commit 2be113d
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 75 deletions.
4 changes: 2 additions & 2 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ pub const MAX_ADT_FIELDS: usize = MAX_ARITY - 1;

impl Net {
pub fn trees(&self) -> impl Iterator<Item = &Tree> {
std::iter::once(&self.root).chain(self.redexes.iter().map(|(x, y)| [x, y]).flatten())
std::iter::once(&self.root).chain(self.redexes.iter().flat_map(|(x, y)| [x, y]))
}
pub fn trees_mut(&mut self) -> impl Iterator<Item = &mut Tree> {
std::iter::once(&mut self.root).chain(self.redexes.iter_mut().map(|(x, y)| [x, y]).flatten())
std::iter::once(&mut self.root).chain(self.redexes.iter_mut().flat_map(|(x, y)| [x, y]))
}
}

Expand Down
127 changes: 54 additions & 73 deletions src/transform/eta_reduce.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
use std::collections::HashMap;
use std::{collections::HashMap, ops::RangeFrom};

use crate::ast::{Net, Tree};

/// Converts (x y), (x y) into x, x
/// Converts `(x y), (x y)` into `x, x`.
impl Net {
pub fn eta_reduce(&mut self) {
let mut phase1 = Phase1::default();
phase1.walk_and_sort_net(self);
println!("{:?}", phase1.gaps);
let mut phase2 = Phase2 { gaps: phase1.gaps };
phase2.reduce_net(self);
for tree in self.trees() {
phase1.walk_tree(tree);
}
let mut phase2 = Phase2 { nodes: phase1.nodes, index: 0 .. };
for tree in self.trees_mut() {
phase2.reduce_tree(tree);
}
}
}

#[derive(Debug)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum NodeType {
Var(isize),
Ctr(u16),
Expand All @@ -24,103 +27,81 @@ enum NodeType {
#[derive(Default)]
struct Phase1<'a> {
vars: HashMap<&'a str, usize>,
gaps: Vec<NodeType>,
nodes: Vec<NodeType>,
}

impl<'a> Phase1<'a> {
fn walk_and_sort_net(&mut self, net: &'a mut Net) {
self.walk_and_sort_tree(&mut net.root);
for (a, b) in net.redexes.iter_mut() {
self.walk_and_sort_tree(a);
self.walk_and_sort_tree(b);
}
}
fn walk_and_sort_tree(&mut self, tree: &'a mut Tree) {
fn walk_tree(&mut self, tree: &'a Tree) {
match tree {
Tree::Ctr { lab, ports } => {
let last_port = ports.len() - 1;
for (idx, i) in ports.iter_mut().enumerate() {
for (idx, i) in ports.iter().enumerate() {
if idx != last_port {
self.gaps.push(NodeType::Ctr(*lab));
self.nodes.push(NodeType::Ctr(*lab));
}
self.walk_and_sort_tree(i);
self.walk_tree(i);
}
}
Tree::Var { nam } => {
let nam: &str = &*nam;
if let Some(i) = self.vars.get(nam) {
let j = self.gaps.len() as isize;
self.gaps.push(NodeType::Var(*i as isize - j));
self.gaps[*i] = NodeType::Var(j - *i as isize);
let j = self.nodes.len() as isize;
self.nodes.push(NodeType::Var(*i as isize - j));
self.nodes[*i] = NodeType::Var(j - *i as isize);
} else {
self.vars.insert(nam, self.gaps.len());
self.gaps.push(NodeType::Hole);
self.vars.insert(nam, self.nodes.len());
self.nodes.push(NodeType::Hole);
}
}
_ => {
self.gaps.push(NodeType::Other);
for i in tree.children_mut() {
self.walk_and_sort_tree(i);
self.nodes.push(NodeType::Other);
for i in tree.children() {
self.walk_tree(i);
}
}
}
}
}

struct Phase2 {
gaps: Vec<NodeType>,
nodes: Vec<NodeType>,
index: RangeFrom<usize>,
}

impl Phase2 {
fn reduce_net(&mut self, net: &mut Net) {
let mut index = 0;
self.reduce_tree(&mut net.root, &mut index);
for (a, b) in net.redexes.iter_mut() {
self.reduce_tree(a, &mut index);
self.reduce_tree(b, &mut index);
fn reduce_ctr(&mut self, lab: u16, ports: &mut Vec<Tree>, skip: usize) -> NodeType {
if skip == ports.len() {
return NodeType::Other;
}
}
fn reduce_tree(&mut self, tree: &mut Tree, index: &mut usize) {
match tree {
ctr @ Tree::Ctr { .. } => {
let Tree::Ctr { lab, ports } = ctr else { unreachable!() };
// reduce from the inside of the n-ary node to the outside
let mut indices = vec![];
let last_port = ports.len() - 1;
for (idx, i) in ports.iter_mut().enumerate() {
indices.push(*index);
if idx != last_port {
*index += 1;
}
self.reduce_tree(i, index);
}
while indices.len() > 1 {
let tail_var = indices.pop().unwrap();
let head_ctr = indices.pop().unwrap();
let head_var = head_ctr + 1;
if let (NodeType::Var(off1), NodeType::Var(off2)) = (&self.gaps[head_var], &self.gaps[tail_var]) {
if off1 == off2 {
if let NodeType::Ctr(other_lab) = &self.gaps[head_ctr.strict_add_signed(*off1)] {
if other_lab == lab {
indices.push(head_var);
ports.pop();
continue;
}
}
}
}
break;
}
if ports.len() == 1 {
*ctr = ports.pop().unwrap();
if skip == ports.len() - 1 {
return self.reduce_tree(&mut ports[skip]);
}
let head_index = self.index.next().unwrap();
let a = self.reduce_tree(&mut ports[skip]);
let b = self.reduce_ctr(lab, ports, skip + 1);
if a == b {
if let NodeType::Var(delta) = a {
if self.nodes[head_index.wrapping_add_signed(delta)] == NodeType::Ctr(lab) {
ports.pop();
return NodeType::Var(delta);
}
}
tree => {
*index += 1;
for i in tree.children_mut() {
self.reduce_tree(i, index);
}
}
NodeType::Ctr(lab)
}
fn reduce_tree(&mut self, tree: &mut Tree) -> NodeType {
if let Tree::Ctr { lab, ports } = tree {
let ty = self.reduce_ctr(*lab, ports, 0);
if ports.len() == 1 {
*tree = ports.pop().unwrap();
}
ty
} else {
let index = self.index.next().unwrap();
for i in tree.children_mut() {
self.reduce_tree(i);
}
self.nodes[index]
}
}
}
3 changes: 3 additions & 0 deletions tests/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,7 @@ pub fn test_eta() {
a
& (a c) ~ c
"###);
assert_display_snapshot!(parse_and_reduce("((a b) [a b])"), @"((a b) [a b])");
assert_display_snapshot!(parse_and_reduce("((a b c) b c)"), @"((a b) b)");
assert_display_snapshot!(parse_and_reduce("([(a b) (c d)] [(a b) (c d)])"), @"(a a)");
}

0 comments on commit 2be113d

Please sign in to comment.