From de3d1fc9292e16defdb250a0ad7ac159281d8db6 Mon Sep 17 00:00:00 2001 From: FranchuFranchu Date: Fri, 22 Mar 2024 15:48:17 -0300 Subject: [PATCH] Implement a separate `fast` redex stack for redexes that do not allocate extra memory to prevent some OOMs. --- src/host/readback.rs | 3 +- src/run/linker.rs | 65 +++++++++++++++++++++++++++++++++---- src/run/net.rs | 3 +- src/run/parallel.rs | 24 +++++++------- src/transform/pre_reduce.rs | 2 +- 5 files changed, 75 insertions(+), 22 deletions(-) diff --git a/src/host/readback.rs b/src/host/readback.rs index 72126391..d3665330 100644 --- a/src/host/readback.rs +++ b/src/host/readback.rs @@ -18,8 +18,7 @@ impl Host { let mut net = Net::default(); net.root = state.read_wire(rt_net.root.clone()); - - for (a, b) in &rt_net.redexes { + for (a, b) in rt_net.redexes.iter() { net.redexes.push((state.read_port(a.clone(), None), state.read_port(b.clone(), None))) } diff --git a/src/run/linker.rs b/src/run/linker.rs index b943129d..31d67b26 100644 --- a/src/run/linker.rs +++ b/src/run/linker.rs @@ -1,5 +1,3 @@ -use std::collections::VecDeque; - use super::*; /// Stores extra data needed about the nodes when in lazy mode. (In strict mode, @@ -20,8 +18,8 @@ pub(super) struct Header { /// non-atomically (because they must be locked). pub struct Linker<'h, M: Mode> { pub(super) allocator: Allocator<'h>, - pub redexes: VecDeque<(Port, Port)>, pub rwts: Rewrites, + pub redexes: RedexQueue, headers: IntMap, _mode: PhantomData, } @@ -32,7 +30,7 @@ impl<'h, M: Mode> Linker<'h, M> { pub fn new(heap: &'h Heap) -> Self { Linker { allocator: Allocator::new(heap), - redexes: VecDeque::new(), + redexes: RedexQueue::default(), rwts: Default::default(), headers: Default::default(), _mode: PhantomData, @@ -86,12 +84,20 @@ impl<'h, M: Mode> Linker<'h, M> { /// Pushes an active pair to the redex queue; `a` and `b` must both be /// principal ports. #[inline(always)] - fn redux(&mut self, a: Port, b: Port) { + pub fn redux(&mut self, a: Port, b: Port) { trace!(self, a, b); + debug_assert!(!(a.is(Tag::Var) || a.is(Tag::Red) || b.is(Tag::Var) || b.is(Tag::Red))); if a.is_skippable() && b.is_skippable() { self.rwts.eras += 1; } else if !M::LAZY { - self.redexes.push_back((a, b)); + // Prioritize redexes that do not allocate memory, + // to prevent OOM errors that can be avoided + // by reducing redexes in a different order (see #91) + if redex_would_shrink(&a, &b) { + self.redexes.fast.push((a, b)); + } else { + self.redexes.slow.push((a, b)); + } } else { self.set_header(a.clone(), b.clone()); self.set_header(b.clone(), a.clone()); @@ -343,3 +349,50 @@ impl<'h, M: Mode> Linker<'h, M> { self.headers[&port.addr()].targ.clone() } } + +#[derive(Debug, Default)] +pub struct RedexQueue { + pub(super) fast: Vec<(Port, Port)>, + pub(super) slow: Vec<(Port, Port)>, +} + +impl RedexQueue { + /// Returns the highest-priority redex in the queue, if any + #[inline(always)] + pub fn pop(&mut self) -> Option<(Port, Port)> { + self.fast.pop().or_else(|| self.slow.pop()) + } + #[inline(always)] + pub fn len(&self) -> usize { + self.fast.len() + self.slow.len() + } + #[inline(always)] + pub fn is_empty(&self) -> bool { + self.fast.is_empty() && self.slow.is_empty() + } + #[inline(always)] + pub fn take(&mut self) -> impl Iterator { + std::mem::take(&mut self.fast).into_iter().chain(std::mem::take(&mut self.slow)) + } + #[inline(always)] + pub fn iter(&self) -> impl Iterator { + self.fast.iter().chain(self.slow.iter()) + } + #[inline(always)] + pub fn iter_mut(&mut self) -> impl Iterator { + self.fast.iter_mut().chain(self.slow.iter_mut()) + } + #[inline(always)] + pub fn clear(&mut self) { + self.fast.clear(); + self.slow.clear(); + } +} + +// Returns whether a redex does not allocate memory +fn redex_would_shrink(a: &Port, b: &Port) -> bool { + (*a == Port::ERA || *b == Port::ERA) + || (!(a.tag() == Tag::Ref || b.tag() == Tag::Ref) + && (((a.tag() == Tag::Ctr && b.tag() == Tag::Ctr) || a.lab() == b.lab()) + || (a.tag() == Tag::Num || b.tag() == Tag::Num))) +} diff --git a/src/run/net.rs b/src/run/net.rs index 7422be98..edd73b35 100644 --- a/src/run/net.rs +++ b/src/run/net.rs @@ -44,7 +44,8 @@ impl<'a, M: Mode> Net<'a, M> { pub fn reduce(&mut self, limit: usize) -> usize { assert!(!M::LAZY); let mut count = 0; - while let Some((a, b)) = self.redexes.pop_front() { + + while let Some((a, b)) = self.redexes.pop() { self.interact(a, b); count += 1; if count >= limit { diff --git a/src/run/parallel.rs b/src/run/parallel.rs index de74b682..aec41c3d 100644 --- a/src/run/parallel.rs +++ b/src/run/parallel.rs @@ -3,7 +3,8 @@ use super::*; impl<'h, M: Mode> Net<'h, M> { /// Forks the net into `tids` child nets, for parallel operation. pub fn fork(&mut self, tids: usize) -> impl Iterator + '_ { - let mut redexes = std::mem::take(&mut self.redexes).into_iter(); + let redexes_len = self.redexes.len(); + let mut redexes = self.redexes.take(); (0 .. tids).map(move |tid| { let heap_size = (self.heap.0.len() / tids) & !63; // round down to needed alignment let heap_start = heap_size * tid; @@ -14,8 +15,8 @@ impl<'h, M: Mode> Net<'h, M> { net.tid = tid; net.tids = tids; net.tracer.set_tid(tid); - let count = redexes.len() / (tids - tid); - net.redexes.extend((&mut redexes).take(count)); + let count = redexes_len / (tids - tid); + (&mut redexes).take(count).for_each(|i| net.redux(i.0, i.1)); net }) } @@ -37,7 +38,7 @@ impl<'h, M: Mode> Net<'h, M> { net: Net<'a, M>, // thread's own net object delta: &'a AtomicRewrites, // global delta rewrites share: &'a Vec<(AtomicU64, AtomicU64)>, // global share buffer - rlens: &'a Vec, // global redex lengths + rlens: &'a Vec, // global redex lengths (only counting shareable ones) total: &'a AtomicUsize, // total redex length barry: Arc, // synchronization barrier } @@ -71,7 +72,6 @@ impl<'h, M: Mode> Net<'h, M> { }); // Clear redexes and sum stats - self.redexes.clear(); delta.add_to(&mut self.rwts); // Main reduction loop @@ -106,7 +106,7 @@ impl<'h, M: Mode> Net<'h, M> { ctx.barry.wait(); ctx.total.store(0, Relaxed); ctx.barry.wait(); - ctx.rlens[ctx.tid].store(ctx.net.redexes.len(), Relaxed); + ctx.rlens[ctx.tid].store(ctx.net.redexes.slow.len(), Relaxed); ctx.total.fetch_add(ctx.net.redexes.len(), Relaxed); ctx.barry.wait(); ctx.total.load(Relaxed) @@ -120,7 +120,7 @@ impl<'h, M: Mode> Net<'h, M> { let shift = (1 << (plog2 - 1)) >> (ctx.tick % plog2); let a_tid = ctx.tid; let b_tid = if side == 1 { a_tid - shift } else { a_tid + shift }; - let a_len = ctx.net.redexes.len(); + let a_len = ctx.net.redexes.slow.len(); let b_len = ctx.rlens[b_tid].load(Relaxed); let send = if a_len > b_len { (a_len - b_len) / 2 } else { 0 }; let recv = if b_len > a_len { (b_len - a_len) / 2 } else { 0 }; @@ -128,8 +128,8 @@ impl<'h, M: Mode> Net<'h, M> { let recv = std::cmp::min(recv, SHARE_LIMIT); for i in 0 .. send { let init = a_len - send * 2; - let rdx0 = ctx.net.redexes[init + i * 2 + 0].clone(); - let rdx1 = ctx.net.redexes[init + i * 2 + 1].clone(); + let rdx0 = ctx.net.redexes.slow[init + i * 2 + 0].clone(); + let rdx1 = ctx.net.redexes.slow[init + i * 2 + 1].clone(); //let init = 0; //let ref0 = ctx.net.redexes.get_unchecked_mut(init + i * 2 + 0); //let rdx0 = *ref0; @@ -138,15 +138,15 @@ impl<'h, M: Mode> Net<'h, M> { //let rdx1 = *ref1; //*ref1 = (Ptr(0), Ptr(0)); let targ = ctx.share.get_unchecked(b_tid * SHARE_LIMIT + i); - ctx.net.redexes[init + i] = rdx0; + ctx.net.redexes.slow[init + i] = rdx0; targ.0.store(rdx1.0.0, Relaxed); targ.1.store(rdx1.1.0, Relaxed); } - ctx.net.redexes.truncate(a_len - send); + ctx.net.redexes.slow.truncate(a_len - send); ctx.barry.wait(); for i in 0 .. recv { let got = ctx.share.get_unchecked(a_tid * SHARE_LIMIT + i); - ctx.net.redexes.push_back((Port(got.0.load(Relaxed)), Port(got.1.load(Relaxed)))); + ctx.net.redexes.slow.push((Port(got.0.load(Relaxed)), Port(got.1.load(Relaxed)))); } } } diff --git a/src/transform/pre_reduce.rs b/src/transform/pre_reduce.rs index cc968e34..b8cf6916 100644 --- a/src/transform/pre_reduce.rs +++ b/src/transform/pre_reduce.rs @@ -143,7 +143,7 @@ impl<'a> State<'a> { self.rewrites += rt.rwts; // Move interactions with inert defs back into the net redexes array - rt.redexes.extend(self.captured_redexes.lock().unwrap().drain(..)); + self.captured_redexes.lock().unwrap().drain(..).for_each(|r| rt.redux(r.0, r.1)); let net = self.host.readback(&mut rt);