diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index de4c786..3dacb21 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -402,7 +402,16 @@ impl> StoreMaybeExpr for Var { } } +pub struct NormalDelayedElse(Option); +impl DelayedElse for NormalDelayedElse { + fn finish_else(self, else_expr: impl FnOnce() -> R) -> R { + self.0.unwrap_or_else(else_expr) + } +} + impl SelectMaybeExpr for bool { + type DelayedElse = NormalDelayedElse; + #[inline] fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R { if self { on() @@ -410,6 +419,7 @@ impl SelectMaybeExpr for bool { off() } } + #[inline] fn select(self, on: R, off: R) -> R { if self { on @@ -417,25 +427,105 @@ impl SelectMaybeExpr for bool { off } } + #[inline] + fn if_then_delayed(self, then: impl FnOnce() -> R) -> Self::DelayedElse { + NormalDelayedElse(if self { Some(then()) } else { None }) + } +} + +pub struct ExprDelayedElse { + cond: NodeRef, + then: R, +} +impl DelayedElse for ExprDelayedElse { + fn finish_else(self, else_: impl FnOnce() -> R) -> R { + let Self { cond, then } = self; + let then_nodes = then + .to_vec_nodes() + .into_iter() + .map(|x| x.get()) + .collect::>(); + let then_block = with_recorder(|r| { + let pools = r.pools.clone(); + let s = &mut r.scopes; + let then_block = s.pop().unwrap().finish(); + s.push(IrBuilder::new(pools)); + r.add_block_to_inaccessible(&then_block); + then_block + }); + let else_ = else_(); + let else_nodes = else_ + .to_vec_nodes() + .into_iter() + .map(|x| x.get()) + .collect::>(); + let else_block = with_recorder(|r| { + let s = &mut r.scopes; + let else_block = s.pop().unwrap().finish(); + r.add_block_to_inaccessible(&else_block); + else_block + }); + __current_scope(|b| { + b.if_(cond, then_block, else_block); + }); + assert_eq!(then_nodes.len(), else_nodes.len()); + let phis = then_nodes + .iter() + .zip(else_nodes.iter()) + .map(|(then, else_)| { + let incomings = vec![ + PhiIncoming { + value: *then, + block: then_block, + }, + PhiIncoming { + value: *else_, + block: else_block, + }, + ]; + assert_eq!(then.type_(), else_.type_()); + let phi = __current_scope(|b| b.phi(&incomings, then.type_().clone())); + phi.into() + }) + .collect::>(); + R::from_vec_nodes(phis) + } } + impl SelectMaybeExpr for Expr { + type DelayedElse = ExprDelayedElse; fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R { crate::lang::control_flow::if_then_else(self, on, off) } fn select(self, on: R, off: R) -> R { crate::lang::control_flow::select(self, on, off) } + fn if_then_delayed(self, then: impl FnOnce() -> R) -> Self::DelayedElse { + let cond = self.node().get(); + with_recorder(|r| { + let pools = r.pools.clone(); + let s = &mut r.scopes; + s.push(IrBuilder::new(pools)); + }); + let then = then(); + ExprDelayedElse { cond, then } + } } impl SelectMaybeExpr for Var { + type DelayedElse = ExprDelayedElse; fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R { crate::lang::control_flow::if_then_else(**self, on, off) } fn select(self, on: R, off: R) -> R { crate::lang::control_flow::select(**self, on, off) } + fn if_then_delayed(self, then: impl FnOnce() -> R) -> Self::DelayedElse { + (**self).if_then_delayed(then) + } } impl ActivateMaybeExpr for bool { + #[inline] fn activate(self, then: impl FnOnce()) { if self { then() @@ -455,6 +545,7 @@ impl ActivateMaybeExpr for Var { } impl LoopMaybeExpr for bool { + #[inline] fn while_loop(mut cond: impl FnMut() -> Self, mut body: impl FnMut()) { while cond() { body() @@ -470,15 +561,18 @@ impl LoopMaybeExpr for Expr { impl LazyBoolMaybeExpr for bool { type Bool = bool; + #[inline] fn and(self, other: impl FnOnce() -> bool) -> bool { self && other() } + #[inline] fn or(self, other: impl FnOnce() -> bool) -> bool { self || other() } } impl LazyBoolMaybeExpr, ExprType> for bool { type Bool = Expr; + #[inline] fn and(self, other: impl FnOnce() -> Expr) -> Self::Bool { if self { other() @@ -486,6 +580,7 @@ impl LazyBoolMaybeExpr, ExprType> for bool { false.expr() } } + #[inline] fn or(self, other: impl FnOnce() -> Expr) -> Self::Bool { if self { true.expr() diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index 216ec6c..a62e39f 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -352,9 +352,16 @@ pub trait StoreMaybeExpr { fn __store(self, value: V); } +pub trait DelayedElse { + fn finish_else(self, else_expr: impl FnOnce() -> R) -> R; +} + pub trait SelectMaybeExpr { + type DelayedElse: DelayedElse; + fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R; fn select(self, on: R, off: R) -> R; + fn if_then_delayed(self, then: impl FnOnce() -> R) -> Self::DelayedElse; } pub trait ActivateMaybeExpr { diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index 271664a..5540096 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -102,7 +102,11 @@ impl VisitMut for TraceVisitor { file!(), line!(), column!(), - || <_ as #trait_path::SelectMaybeExpr<_>>::if_then_else(#cond, || #then_branch, || #else_branch)) + || { + let delayed_else = <_ as #trait_path::SelectMaybeExpr<_>>::if_then_delayed(#cond, || #then_branch); + <_ as #trait_path::DelayedElse<_>>::finish_else(delayed_else, || #else_branch) + } + ) } } else { *node = parse_quote_spanned! {span=>