Skip to content

Commit

Permalink
misc tweaks to adt encoding
Browse files Browse the repository at this point in the history
  • Loading branch information
tjjfvi committed Mar 14, 2024
1 parent 0c54a08 commit d8eba0c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 38 deletions.
61 changes: 23 additions & 38 deletions src/transform/encode_adts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,61 +4,46 @@ impl Tree {
/// Encode scott-encoded ADTs into optimized compact ADT nodes
pub fn encode_scott_adts(&mut self) {
maybe_grow(|| match self {
ctr @ Tree::Ctr { .. } => {
let Tree::Ctr { lab, mut ports } = std::mem::take(ctr) else { unreachable!() };
fn get_field_and_ret(lab: &u16, ports: &[Tree]) -> Option<(usize, String)> {
let mut fields_idx = None;
let ret_var = match ports.last() {
Some(Tree::Var { nam: ret_var }) => ret_var.clone(),
_ => {
return None;
}
};
&mut Tree::Ctr { lab, ref mut ports } => {
fn get_variant_index(lab: u16, ports: &[Tree]) -> Option<usize> {
let Some(Tree::Var { nam: ret_var }) = ports.last() else { return None };

for (idx, i) in ports.iter().take(ports.len() - 1).enumerate() {
let mut variant_index = None;
for (idx, i) in ports[0 .. ports.len() - 1].iter().enumerate() {
match i {
Tree::Era => (),
Tree::Ctr { lab: inner_lab, ports } if lab == inner_lab && fields_idx.is_none() => {
// Ensure that the last port
// is the return variable
if match ports.last() {
Some(Tree::Var { nam }) => *nam != ret_var,
_ => true,
} {
Tree::Era => {}
Tree::Ctr { lab: inner_lab, ports } if *inner_lab == lab && variant_index.is_none() => {
// Ensure that the last port is the return variable
let Some(Tree::Var { nam }) = ports.last() else { return None };
if nam != ret_var {
return None;
}
fields_idx = Some(idx);
}
// Nilary field.
Tree::Var { nam } if ret_var == *nam && fields_idx.is_none() => {
fields_idx = Some(idx);
variant_index = Some(idx);
}
_ => {
// Does not encode an ADT.
return None;
// Nilary variant
Tree::Var { nam } if nam == ret_var && variant_index.is_none() => {
variant_index = Some(idx);
}
// Does not encode an ADT.
_ => return None,
}
}
fields_idx.map(move |x| (x, ret_var))
variant_index
}

if let Some((fields_idx, ret_var)) = get_field_and_ret(&lab, &ports) {
let fields = match ports.swap_remove(fields_idx) {
if let Some(variant_index) = get_variant_index(lab, ports) {
let fields = match ports.swap_remove(variant_index) {
Tree::Ctr { ports: mut fields, .. } => {
fields.pop();
fields
}
Tree::Var { nam } if nam == ret_var => {
vec![]
}
Tree::Var { .. } => vec![],
_ => unreachable!(),
};
let new = Tree::Adt { lab, variant_index: fields_idx, variant_count: ports.len(), fields };
*ctr = new;
} else {
*ctr = Tree::Ctr { lab, ports };
ctr.children_mut().for_each(Tree::encode_scott_adts);
*self = Tree::Adt { lab, variant_index, variant_count: ports.len(), fields };
}

self.children_mut().for_each(Tree::encode_scott_adts);
}
other => other.children_mut().for_each(Tree::encode_scott_adts),
})
Expand Down
1 change: 1 addition & 0 deletions tests/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,5 @@ pub fn test_adt_encoding() {
assert_display_snapshot!(parse_and_encode("(* ((a (b (c R))) R))"), @"(:1:2 a b c)");
assert_display_snapshot!(parse_and_encode("{4 * {4 {4 a {4 b {4 c R}}} R}}"), @"{4:1:2 a b c}");
assert_display_snapshot!(parse_and_encode("(* x x)"), @"(:1:2)");
assert_display_snapshot!(parse_and_encode("(((((* x x) x) * x) x) * x)"), @"(:0:2 (:0:2 (:1:2)))");
}

0 comments on commit d8eba0c

Please sign in to comment.