Skip to content

Commit

Permalink
Implement else delaying for allowing better borrow checker.
Browse files Browse the repository at this point in the history
  • Loading branch information
entropylost committed Mar 16, 2024
1 parent dadc37f commit 8b4807d
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
95 changes: 95 additions & 0 deletions luisa_compute/src/lang/ops/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -402,40 +402,130 @@ impl<V: Value, E: AsExpr<Value = V>> StoreMaybeExpr<E> for Var<V> {
}
}

pub struct NormalDelayedElse<R>(Option<R>);
impl<R> DelayedElse<R> for NormalDelayedElse<R> {
fn finish_else(self, else_expr: impl FnOnce() -> R) -> R {
self.0.unwrap_or_else(else_expr)
}
}

impl<R> SelectMaybeExpr<R> for bool {
type DelayedElse = NormalDelayedElse<R>;
#[inline]
fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R {
if self {
on()
} else {
off()
}
}
#[inline]
fn select(self, on: R, off: R) -> R {
if self {
on
} else {
off
}
}
#[inline]
fn if_then_delayed(self, then: impl FnOnce() -> R) -> Self::DelayedElse {
NormalDelayedElse(if self { Some(then()) } else { None })
}
}

pub struct ExprDelayedElse<R: Aggregate> {
cond: NodeRef,
then: R,
}
impl<R: Aggregate> DelayedElse<R> for ExprDelayedElse<R> {
fn finish_else(self, else_: impl FnOnce() -> R) -> R {
let Self { cond, then } = self;
let then_nodes = then
.to_vec_nodes()
.into_iter()
.map(|x| x.get())
.collect::<Vec<_>>();
let then_block = with_recorder(|r| {
let pools = r.pools.clone();
let s = &mut r.scopes;
let then_block = s.pop().unwrap().finish();
s.push(IrBuilder::new(pools));
r.add_block_to_inaccessible(&then_block);
then_block
});
let else_ = else_();
let else_nodes = else_
.to_vec_nodes()
.into_iter()
.map(|x| x.get())
.collect::<Vec<_>>();
let else_block = with_recorder(|r| {
let s = &mut r.scopes;
let else_block = s.pop().unwrap().finish();
r.add_block_to_inaccessible(&else_block);
else_block
});
__current_scope(|b| {
b.if_(cond, then_block, else_block);
});
assert_eq!(then_nodes.len(), else_nodes.len());
let phis = then_nodes
.iter()
.zip(else_nodes.iter())
.map(|(then, else_)| {
let incomings = vec![
PhiIncoming {
value: *then,
block: then_block,
},
PhiIncoming {
value: *else_,
block: else_block,
},
];
assert_eq!(then.type_(), else_.type_());
let phi = __current_scope(|b| b.phi(&incomings, then.type_().clone()));
phi.into()
})
.collect::<Vec<_>>();
R::from_vec_nodes(phis)
}
}

impl<R: Aggregate> SelectMaybeExpr<R> for Expr<bool> {
type DelayedElse = ExprDelayedElse<R>;
fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R {
crate::lang::control_flow::if_then_else(self, on, off)
}
fn select(self, on: R, off: R) -> R {
crate::lang::control_flow::select(self, on, off)
}
fn if_then_delayed(self, then: impl FnOnce() -> R) -> Self::DelayedElse {
let cond = self.node().get();
with_recorder(|r| {
let pools = r.pools.clone();
let s = &mut r.scopes;
s.push(IrBuilder::new(pools));
});
let then = then();
ExprDelayedElse { cond, then }
}
}

impl<R: Aggregate> SelectMaybeExpr<R> for Var<bool> {
type DelayedElse = ExprDelayedElse<R>;
fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R {
crate::lang::control_flow::if_then_else(**self, on, off)
}
fn select(self, on: R, off: R) -> R {
crate::lang::control_flow::select(**self, on, off)
}
fn if_then_delayed(self, then: impl FnOnce() -> R) -> Self::DelayedElse {
(**self).if_then_delayed(then)
}
}
impl ActivateMaybeExpr for bool {
#[inline]
fn activate(self, then: impl FnOnce()) {
if self {
then()
Expand All @@ -455,6 +545,7 @@ impl ActivateMaybeExpr for Var<bool> {
}

impl LoopMaybeExpr for bool {
#[inline]
fn while_loop(mut cond: impl FnMut() -> Self, mut body: impl FnMut()) {
while cond() {
body()
Expand All @@ -470,22 +561,26 @@ impl LoopMaybeExpr for Expr<bool> {

impl LazyBoolMaybeExpr<bool, ValueType> for bool {
type Bool = bool;
#[inline]
fn and(self, other: impl FnOnce() -> bool) -> bool {
self && other()
}
#[inline]
fn or(self, other: impl FnOnce() -> bool) -> bool {
self || other()
}
}
impl LazyBoolMaybeExpr<Expr<bool>, ExprType> for bool {
type Bool = Expr<bool>;
#[inline]
fn and(self, other: impl FnOnce() -> Expr<bool>) -> Self::Bool {
if self {
other()
} else {
false.expr()
}
}
#[inline]
fn or(self, other: impl FnOnce() -> Expr<bool>) -> Self::Bool {
if self {
true.expr()
Expand Down
7 changes: 7 additions & 0 deletions luisa_compute/src/lang/ops/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,16 @@ pub trait StoreMaybeExpr<V> {
fn __store(self, value: V);
}

pub trait DelayedElse<R> {
fn finish_else(self, else_expr: impl FnOnce() -> R) -> R;
}

pub trait SelectMaybeExpr<R> {
type DelayedElse: DelayedElse<R>;

fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R;
fn select(self, on: R, off: R) -> R;
fn if_then_delayed(self, then: impl FnOnce() -> R) -> Self::DelayedElse;
}

pub trait ActivateMaybeExpr {
Expand Down
6 changes: 5 additions & 1 deletion luisa_compute_track/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,11 @@ impl VisitMut for TraceVisitor {
file!(),
line!(),
column!(),
|| <_ as #trait_path::SelectMaybeExpr<_>>::if_then_else(#cond, || #then_branch, || #else_branch))
|| {
let delayed_else = <_ as #trait_path::SelectMaybeExpr<_>>::if_then_delayed(#cond, || #then_branch);
<_ as #trait_path::DelayedElse<_>>::finish_else(delayed_else, || #else_branch)
}
)
}
} else {
*node = parse_quote_spanned! {span=>
Expand Down

0 comments on commit 8b4807d

Please sign in to comment.