diff --git a/src/run/net.rs b/src/run/net.rs index 36a2e690..c4bb3a9b 100644 --- a/src/run/net.rs +++ b/src/run/net.rs @@ -23,11 +23,9 @@ impl<'h, M: Mode> Net<'h, M> { Net { linker: Linker::new(heap), tid: 0, tids: 1, trgs: vec![Trg::port(Port(0)); 1 << 16], root } } - /// Boots a net from a Ref. + /// Boots a net from a Def. pub fn boot(&mut self, def: &Def) { - let def = Port::new_ref(def); - trace!(self, def); - self.root.set_target(def); + self.call(Port::new_ref(def), self.root.as_var()); } } @@ -47,36 +45,6 @@ impl<'a, M: Mode> Net<'a, M> { count } - /// Expands [`Ref`] nodes in the tree connected to `root`. - #[inline(always)] - pub fn expand(&mut self) { - assert!(!M::LAZY); - fn go(net: &mut Net, wire: Wire, len: usize, key: usize) { - trace!(net.tracer, wire); - let port = wire.load_target(); - trace!(net.tracer, port); - if port == Port::LOCK { - return; - } - if port.tag() == Ctr { - let node = port.traverse_node(); - if len >= net.tids || key % 2 == 0 { - go(net, node.p1, len.saturating_mul(2), key / 2); - } - if len >= net.tids || key % 2 == 1 { - go(net, node.p2, len.saturating_mul(2), key / 2); - } - } else if port.tag() == Ref && port != Port::ERA { - let got = wire.swap_target(Port::LOCK); - if got != Port::LOCK { - trace!(net.tracer, port, wire); - net.call(port, Port::new_var(wire.addr())); - } - } - } - go(self, self.root.clone(), 1, self.tid); - } - // Lazy mode weak head normalizer #[inline(always)] fn weak_normal(&mut self, mut prev: Port, root: Wire) -> Port { @@ -146,7 +114,48 @@ impl<'a, M: Mode> Net<'a, M> { self.expand(); while !self.redexes.is_empty() { self.reduce(usize::MAX); - self.expand(); + } + } + } +} + +impl<'h, M: Mode> Net<'h, M> { + /// Expands [`Tag::Ref`] nodes in the tree connected to `root`. + pub fn expand(&mut self) { + assert!(!M::LAZY); + let (new_root, out_port) = self.create_wire(); + let old_root = std::mem::replace(&mut self.root, new_root); + self.link_wire_port(old_root, ExpandDef::new(out_port)); + } +} + +struct ExpandDef { + out: Port, +} + +impl ExpandDef { + fn new(out: Port) -> Port { + Port::new_ref(Box::leak(Box::new(Def::new(LabSet::ALL, ExpandDef { out })))) + } +} + +impl AsDef for ExpandDef { + unsafe fn call(def: *const Def, net: &mut Net, port: Port) { + if port.tag() == Tag::Ref && port != Port::ERA { + return net.call(port, Port::new_ref(Def::upcast(unsafe { &*def }))); + } + let def = *Box::from_raw(def as *mut Def); + match port.tag() { + Tag::Red => { + unreachable!() + } + Tag::Ref | Tag::Num | Tag::Var => net.link_port_port(def.data.out, port), + tag @ (Tag::Op2 | Tag::Op1 | Tag::Mat | Tag::Ctr) => { + let old = port.consume_node(); + let new = net.create_node(tag, old.lab); + net.link_port_port(def.data.out, new.p0); + net.link_wire_port(old.p1, ExpandDef::new(new.p1)); + net.link_wire_port(old.p2, ExpandDef::new(new.p2)); } } } diff --git a/src/run/parallel.rs b/src/run/parallel.rs index 9207a0eb..1fc97fde 100644 --- a/src/run/parallel.rs +++ b/src/run/parallel.rs @@ -24,6 +24,8 @@ impl<'h, M: Mode> Net<'h, M> { pub fn parallel_normal(&mut self) { assert!(!M::LAZY); + self.expand(); + const SHARE_LIMIT: usize = 1 << 12; // max share redexes per split const LOCAL_LIMIT: usize = 1 << 18; // max local rewrites per epoch @@ -77,7 +79,6 @@ impl<'h, M: Mode> Net<'h, M> { fn main(ctx: &mut ThreadContext) { loop { reduce(ctx); - ctx.net.expand(); if count(ctx) == 0 { break; }