Skip to content

Commit

Permalink
Relaxed Fn restrictions on control flow.
Browse files Browse the repository at this point in the history
  • Loading branch information
entropylost committed Mar 16, 2024
1 parent 92cfa27 commit d36d6f2
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 25 deletions.
14 changes: 5 additions & 9 deletions luisa_compute/src/lang/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ pub fn return_() {

pub fn if_then_else<R: Aggregate>(
cond: Expr<bool>,
then: impl Fn() -> R,
else_: impl Fn() -> R,
then: impl FnOnce() -> R,
else_: impl FnOnce() -> R,
) -> R {
let cond = cond.node().get();
with_recorder(|r| {
Expand Down Expand Up @@ -199,11 +199,7 @@ pub fn select<A: Aggregate>(mask: Expr<bool>, a: A, b: A) -> A {
A::from_vec_nodes(ret)
}

pub fn generic_loop(
mut cond: impl FnMut() -> Expr<bool>,
mut body: impl FnMut(),
mut update: impl FnMut(),
) {
pub fn generic_loop(cond: impl FnOnce() -> Expr<bool>, body: impl FnOnce(), update: impl FnOnce()) {
with_recorder(|r| {
let pools = r.pools.clone();
let s = &mut r.scopes;
Expand Down Expand Up @@ -308,13 +304,13 @@ pub fn loop_(body: impl Fn()) {
});
}

pub fn for_unrolled<I: IntoIterator>(iter: I, body: impl Fn(I::Item)) {
pub fn for_unrolled<I: IntoIterator>(iter: I, mut body: impl FnMut(I::Item)) {
for i in iter {
body(i);
}
}

pub fn for_range<R: ForLoopRange>(r: R, body: impl Fn(Expr<R::Element>)) {
pub fn for_range<R: ForLoopRange>(r: R, body: impl FnMut(Expr<R::Element>)) {
let start = r.start().get();
let end = r.end().get();
let inc = |v: NodeRef| {
Expand Down
28 changes: 14 additions & 14 deletions luisa_compute/src/lang/ops/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ impl<V: Value, E: AsExpr<Value = V>> StoreMaybeExpr<E> for Var<V> {
}

impl<R> SelectMaybeExpr<R> for bool {
fn if_then_else(self, on: impl Fn() -> R, off: impl Fn() -> R) -> R {
fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R {
if self {
on()
} else {
Expand All @@ -419,7 +419,7 @@ impl<R> SelectMaybeExpr<R> for bool {
}
}
impl<R: Aggregate> SelectMaybeExpr<R> for Expr<bool> {
fn if_then_else(self, on: impl Fn() -> R, off: impl Fn() -> R) -> R {
fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R {
crate::lang::control_flow::if_then_else(self, on, off)
}
fn select(self, on: R, off: R) -> R {
Expand All @@ -428,28 +428,28 @@ impl<R: Aggregate> SelectMaybeExpr<R> for Expr<bool> {
}

impl<R: Aggregate> SelectMaybeExpr<R> for Var<bool> {
fn if_then_else(self, on: impl Fn() -> R, off: impl Fn() -> R) -> R {
fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R {
crate::lang::control_flow::if_then_else(**self, on, off)
}
fn select(self, on: R, off: R) -> R {
crate::lang::control_flow::select(**self, on, off)
}
}
impl ActivateMaybeExpr for bool {
fn activate(self, then: impl Fn()) {
fn activate(self, then: impl FnOnce()) {
if self {
then()
}
}
}
impl ActivateMaybeExpr for Expr<bool> {
fn activate(self, then: impl Fn()) {
fn activate(self, then: impl FnOnce()) {
crate::lang::control_flow::if_then_else(self, then, || {})
}
}

impl ActivateMaybeExpr for Var<bool> {
fn activate(self, then: impl Fn()) {
fn activate(self, then: impl FnOnce()) {
crate::lang::control_flow::if_then_else(self.load(), then, || {})
}
}
Expand All @@ -470,23 +470,23 @@ impl LoopMaybeExpr for Expr<bool> {

impl LazyBoolMaybeExpr<bool, ValueType> for bool {
type Bool = bool;
fn and(self, other: impl Fn() -> bool) -> bool {
fn and(self, other: impl FnOnce() -> bool) -> bool {
self && other()
}
fn or(self, other: impl Fn() -> bool) -> bool {
fn or(self, other: impl FnOnce() -> bool) -> bool {
self || other()
}
}
impl LazyBoolMaybeExpr<Expr<bool>, ExprType> for bool {
type Bool = Expr<bool>;
fn and(self, other: impl Fn() -> Expr<bool>) -> Self::Bool {
fn and(self, other: impl FnOnce() -> Expr<bool>) -> Self::Bool {
if self {
other()
} else {
false.expr()
}
}
fn or(self, other: impl Fn() -> Expr<bool>) -> Self::Bool {
fn or(self, other: impl FnOnce() -> Expr<bool>) -> Self::Bool {
if self {
true.expr()
} else {
Expand All @@ -496,21 +496,21 @@ impl LazyBoolMaybeExpr<Expr<bool>, ExprType> for bool {
}
impl LazyBoolMaybeExpr<bool, ExprType> for Expr<bool> {
type Bool = Expr<bool>;
fn and(self, other: impl Fn() -> bool) -> Self::Bool {
fn and(self, other: impl FnOnce() -> bool) -> Self::Bool {
let other = other().expr();
select(self, other, false.expr())
}
fn or(self, other: impl Fn() -> bool) -> Self::Bool {
fn or(self, other: impl FnOnce() -> bool) -> Self::Bool {
let other = other().expr();
select(self, true.expr(), other)
}
}
impl LazyBoolMaybeExpr<Expr<bool>, ExprType> for Expr<bool> {
type Bool = Expr<bool>;
fn and(self, other: impl Fn() -> Expr<bool>) -> Self::Bool {
fn and(self, other: impl FnOnce() -> Expr<bool>) -> Self::Bool {
crate::lang::control_flow::if_then_else(self, other, || false.expr())
}
fn or(self, other: impl Fn() -> Expr<bool>) -> Self::Bool {
fn or(self, other: impl FnOnce() -> Expr<bool>) -> Self::Bool {
crate::lang::control_flow::if_then_else(self, || true.expr(), other)
}
}
4 changes: 2 additions & 2 deletions luisa_compute/src/lang/ops/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -353,12 +353,12 @@ pub trait StoreMaybeExpr<V> {
}

pub trait SelectMaybeExpr<R> {
fn if_then_else(self, on: impl Fn() -> R, off: impl Fn() -> R) -> R;
fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R;
fn select(self, on: R, off: R) -> R;
}

pub trait ActivateMaybeExpr {
fn activate(self, then: impl Fn());
fn activate(self, then: impl FnOnce());
}

pub trait LoopMaybeExpr {
Expand Down

0 comments on commit d36d6f2

Please sign in to comment.