Skip to content

Commit

Permalink
Implement a separate fast redex stack for redexes that do not alloc…
Browse files Browse the repository at this point in the history
…ate extra memory to prevent some OOMs.
  • Loading branch information
FranchuFranchu committed Mar 22, 2024
1 parent 38b4ac0 commit de3d1fc
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 22 deletions.
3 changes: 1 addition & 2 deletions src/host/readback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ impl Host {
let mut net = Net::default();

net.root = state.read_wire(rt_net.root.clone());

for (a, b) in &rt_net.redexes {
for (a, b) in rt_net.redexes.iter() {
net.redexes.push((state.read_port(a.clone(), None), state.read_port(b.clone(), None)))
}

Expand Down
65 changes: 59 additions & 6 deletions src/run/linker.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::VecDeque;

use super::*;

/// Stores extra data needed about the nodes when in lazy mode. (In strict mode,
Expand All @@ -20,8 +18,8 @@ pub(super) struct Header {
/// non-atomically (because they must be locked).
pub struct Linker<'h, M: Mode> {
pub(super) allocator: Allocator<'h>,
pub redexes: VecDeque<(Port, Port)>,
pub rwts: Rewrites,
pub redexes: RedexQueue,
headers: IntMap<Addr, Header>,
_mode: PhantomData<M>,
}
Expand All @@ -32,7 +30,7 @@ impl<'h, M: Mode> Linker<'h, M> {
pub fn new(heap: &'h Heap) -> Self {
Linker {
allocator: Allocator::new(heap),
redexes: VecDeque::new(),
redexes: RedexQueue::default(),
rwts: Default::default(),
headers: Default::default(),
_mode: PhantomData,
Expand Down Expand Up @@ -86,12 +84,20 @@ impl<'h, M: Mode> Linker<'h, M> {
/// Pushes an active pair to the redex queue; `a` and `b` must both be
/// principal ports.
#[inline(always)]
fn redux(&mut self, a: Port, b: Port) {
pub fn redux(&mut self, a: Port, b: Port) {
trace!(self, a, b);
debug_assert!(!(a.is(Tag::Var) || a.is(Tag::Red) || b.is(Tag::Var) || b.is(Tag::Red)));
if a.is_skippable() && b.is_skippable() {
self.rwts.eras += 1;
} else if !M::LAZY {
self.redexes.push_back((a, b));
// Prioritize redexes that do not allocate memory,
// to prevent OOM errors that can be avoided
// by reducing redexes in a different order (see #91)
if redex_would_shrink(&a, &b) {
self.redexes.fast.push((a, b));
} else {
self.redexes.slow.push((a, b));
}
} else {
self.set_header(a.clone(), b.clone());
self.set_header(b.clone(), a.clone());
Expand Down Expand Up @@ -343,3 +349,50 @@ impl<'h, M: Mode> Linker<'h, M> {
self.headers[&port.addr()].targ.clone()
}
}

#[derive(Debug, Default)]
pub struct RedexQueue {
pub(super) fast: Vec<(Port, Port)>,
pub(super) slow: Vec<(Port, Port)>,
}

impl RedexQueue {
/// Returns the highest-priority redex in the queue, if any
#[inline(always)]
pub fn pop(&mut self) -> Option<(Port, Port)> {
self.fast.pop().or_else(|| self.slow.pop())
}
#[inline(always)]
pub fn len(&self) -> usize {
self.fast.len() + self.slow.len()
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.fast.is_empty() && self.slow.is_empty()
}
#[inline(always)]
pub fn take(&mut self) -> impl Iterator<Item = (Port, Port)> {
std::mem::take(&mut self.fast).into_iter().chain(std::mem::take(&mut self.slow))
}
#[inline(always)]
pub fn iter(&self) -> impl Iterator<Item = &(Port, Port)> {
self.fast.iter().chain(self.slow.iter())
}
#[inline(always)]
pub fn iter_mut(&mut self) -> impl Iterator<Item = &mut (Port, Port)> {
self.fast.iter_mut().chain(self.slow.iter_mut())
}
#[inline(always)]
pub fn clear(&mut self) {
self.fast.clear();
self.slow.clear();
}
}

// Returns whether a redex does not allocate memory
fn redex_would_shrink(a: &Port, b: &Port) -> bool {
(*a == Port::ERA || *b == Port::ERA)
|| (!(a.tag() == Tag::Ref || b.tag() == Tag::Ref)
&& (((a.tag() == Tag::Ctr && b.tag() == Tag::Ctr) || a.lab() == b.lab())
|| (a.tag() == Tag::Num || b.tag() == Tag::Num)))
}
3 changes: 2 additions & 1 deletion src/run/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ impl<'a, M: Mode> Net<'a, M> {
pub fn reduce(&mut self, limit: usize) -> usize {
assert!(!M::LAZY);
let mut count = 0;
while let Some((a, b)) = self.redexes.pop_front() {

while let Some((a, b)) = self.redexes.pop() {
self.interact(a, b);
count += 1;
if count >= limit {
Expand Down
24 changes: 12 additions & 12 deletions src/run/parallel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ use super::*;
impl<'h, M: Mode> Net<'h, M> {
/// Forks the net into `tids` child nets, for parallel operation.
pub fn fork(&mut self, tids: usize) -> impl Iterator<Item = Self> + '_ {
let mut redexes = std::mem::take(&mut self.redexes).into_iter();
let redexes_len = self.redexes.len();
let mut redexes = self.redexes.take();
(0 .. tids).map(move |tid| {
let heap_size = (self.heap.0.len() / tids) & !63; // round down to needed alignment
let heap_start = heap_size * tid;
Expand All @@ -14,8 +15,8 @@ impl<'h, M: Mode> Net<'h, M> {
net.tid = tid;
net.tids = tids;
net.tracer.set_tid(tid);
let count = redexes.len() / (tids - tid);
net.redexes.extend((&mut redexes).take(count));
let count = redexes_len / (tids - tid);
(&mut redexes).take(count).for_each(|i| net.redux(i.0, i.1));
net
})
}
Expand All @@ -37,7 +38,7 @@ impl<'h, M: Mode> Net<'h, M> {
net: Net<'a, M>, // thread's own net object
delta: &'a AtomicRewrites, // global delta rewrites
share: &'a Vec<(AtomicU64, AtomicU64)>, // global share buffer
rlens: &'a Vec<AtomicUsize>, // global redex lengths
rlens: &'a Vec<AtomicUsize>, // global redex lengths (only counting shareable ones)
total: &'a AtomicUsize, // total redex length
barry: Arc<Barrier>, // synchronization barrier
}
Expand Down Expand Up @@ -71,7 +72,6 @@ impl<'h, M: Mode> Net<'h, M> {
});

// Clear redexes and sum stats
self.redexes.clear();
delta.add_to(&mut self.rwts);

// Main reduction loop
Expand Down Expand Up @@ -106,7 +106,7 @@ impl<'h, M: Mode> Net<'h, M> {
ctx.barry.wait();
ctx.total.store(0, Relaxed);
ctx.barry.wait();
ctx.rlens[ctx.tid].store(ctx.net.redexes.len(), Relaxed);
ctx.rlens[ctx.tid].store(ctx.net.redexes.slow.len(), Relaxed);
ctx.total.fetch_add(ctx.net.redexes.len(), Relaxed);
ctx.barry.wait();
ctx.total.load(Relaxed)
Expand All @@ -120,16 +120,16 @@ impl<'h, M: Mode> Net<'h, M> {
let shift = (1 << (plog2 - 1)) >> (ctx.tick % plog2);
let a_tid = ctx.tid;
let b_tid = if side == 1 { a_tid - shift } else { a_tid + shift };
let a_len = ctx.net.redexes.len();
let a_len = ctx.net.redexes.slow.len();
let b_len = ctx.rlens[b_tid].load(Relaxed);
let send = if a_len > b_len { (a_len - b_len) / 2 } else { 0 };
let recv = if b_len > a_len { (b_len - a_len) / 2 } else { 0 };
let send = std::cmp::min(send, SHARE_LIMIT);
let recv = std::cmp::min(recv, SHARE_LIMIT);
for i in 0 .. send {
let init = a_len - send * 2;
let rdx0 = ctx.net.redexes[init + i * 2 + 0].clone();
let rdx1 = ctx.net.redexes[init + i * 2 + 1].clone();
let rdx0 = ctx.net.redexes.slow[init + i * 2 + 0].clone();
let rdx1 = ctx.net.redexes.slow[init + i * 2 + 1].clone();
//let init = 0;
//let ref0 = ctx.net.redexes.get_unchecked_mut(init + i * 2 + 0);
//let rdx0 = *ref0;
Expand All @@ -138,15 +138,15 @@ impl<'h, M: Mode> Net<'h, M> {
//let rdx1 = *ref1;
//*ref1 = (Ptr(0), Ptr(0));
let targ = ctx.share.get_unchecked(b_tid * SHARE_LIMIT + i);
ctx.net.redexes[init + i] = rdx0;
ctx.net.redexes.slow[init + i] = rdx0;
targ.0.store(rdx1.0.0, Relaxed);
targ.1.store(rdx1.1.0, Relaxed);
}
ctx.net.redexes.truncate(a_len - send);
ctx.net.redexes.slow.truncate(a_len - send);
ctx.barry.wait();
for i in 0 .. recv {
let got = ctx.share.get_unchecked(a_tid * SHARE_LIMIT + i);
ctx.net.redexes.push_back((Port(got.0.load(Relaxed)), Port(got.1.load(Relaxed))));
ctx.net.redexes.slow.push((Port(got.0.load(Relaxed)), Port(got.1.load(Relaxed))));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/transform/pre_reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ impl<'a> State<'a> {
self.rewrites += rt.rwts;

// Move interactions with inert defs back into the net redexes array
rt.redexes.extend(self.captured_redexes.lock().unwrap().drain(..));
self.captured_redexes.lock().unwrap().drain(..).for_each(|r| rt.redux(r.0, r.1));

let net = self.host.readback(&mut rt);

Expand Down

0 comments on commit de3d1fc

Please sign in to comment.