From 16caadf851fe874a124d00f50235e6fb01743669 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sun, 24 Sep 2023 04:24:48 -0400 Subject: [PATCH] add for_unrolled --- README.md | 12 ++++++ luisa_compute/examples/mpm.rs | 53 +++++++++++--------------- luisa_compute/src/lang/control_flow.rs | 6 +++ luisa_compute/src/lib.rs | 2 +- 4 files changed, 42 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 5787d587..545a6ef6 100644 --- a/README.md +++ b/README.md @@ -184,6 +184,18 @@ fn pow_unrolled(x:Expr, i:u32)->Expr { **p } ``` +Of course this can be tedius if you just want to unroll a loop. Thus we provide a `for_unrolled` function that unrolls a loop for you. +```rust +#[tracked] +fn pow_unrolled(x:Expr, i:u32)->Expr { + let p = 1.0f32.var(); + for_unrolled(0..i, |_|{ + p *= x; + }); + **p +} +``` + ### Variables and Expressions diff --git a/luisa_compute/examples/mpm.rs b/luisa_compute/examples/mpm.rs index 66f621d3..008f4139 100644 --- a/luisa_compute/examples/mpm.rs +++ b/luisa_compute/examples/mpm.rs @@ -118,18 +118,16 @@ fn main() { let affine = Mat2::diag_expr(Float2::expr(stress, stress)) + P_MASS as f32 * C.var().read(p); let vp = v.var().read(p); - escape!(for ii in 0..9 { - let (i, j) = (ii % 3, ii / 3); - track!({ - let offset = Int2::expr(i as i32, j as i32); - let dpos = (offset.cast_f32() - fx) * DX; - let weight = w[i].x * w[j].y; - let vadd = weight * (P_MASS * vp + affine * dpos); - let idx = index((base + offset).cast_u32()); - grid_v.var().atomic_fetch_add(idx * 2, vadd.x); - grid_v.var().atomic_fetch_add(idx * 2 + 1, vadd.y); - grid_m.var().atomic_fetch_add(idx, weight * P_MASS); - }); + for_unrolled(0..9, |ii| { + let (i, j) = escape!((ii % 3, ii / 3)); + let offset = Int2::expr(i as i32, j as i32); + let dpos = (offset.cast_f32() - fx) * DX; + let weight = w[i].x * w[j].y; + let vadd = weight * (P_MASS * vp + affine * dpos); + let idx = index((base + offset).cast_u32()); + grid_v.var().atomic_fetch_add(idx * 2, vadd.x); + grid_v.var().atomic_fetch_add(idx * 2 + 1, vadd.y); + grid_m.var().atomic_fetch_add(idx, weight * P_MASS); }); }), ); @@ -183,24 +181,19 @@ fn main() { let new_C = Var::::zeroed(); new_v.store(Float2::expr(0.0f32, 0.0f32)); new_C.store(Mat2::expr(Float2::expr(0., 0.), Float2::expr(0., 0.))); - escape!({ - for ii in 0..9 { - let (i, j) = (ii % 3, ii / 3); - track!({ - let offset = Int2::expr(i as i32, j as i32); - let dpos = (offset.cast_f32() - fx) * DX.expr(); - let weight = w[i].x * w[j].y; - let idx = index((base + offset).cast_u32()); - let g_v = Float2::expr( - grid_v.var().read(idx * 2u32), - grid_v.var().read(idx * 2u32 + 1u32), - ); - new_v.store(new_v.load() + weight * g_v); - new_C.store( - new_C.load() + 4.0f32 * weight * g_v.outer_product(dpos) / (DX * DX), - ); - }); - } + for_unrolled(0..9, |ii| { + let (i, j) = escape!((ii % 3, ii / 3)); + + let offset = Int2::expr(i as i32, j as i32); + let dpos = (offset.cast_f32() - fx) * DX.expr(); + let weight = w[i].x * w[j].y; + let idx = index((base + offset).cast_u32()); + let g_v = Float2::expr( + grid_v.var().read(idx * 2u32), + grid_v.var().read(idx * 2u32 + 1u32), + ); + new_v.store(new_v.load() + weight * g_v); + new_C.store(new_C.load() + 4.0f32 * weight * g_v.outer_product(dpos) / (DX * DX)); }); v.var().write(p, new_v); diff --git a/luisa_compute/src/lang/control_flow.rs b/luisa_compute/src/lang/control_flow.rs index 24ba1318..bef30901 100644 --- a/luisa_compute/src/lang/control_flow.rs +++ b/luisa_compute/src/lang/control_flow.rs @@ -286,6 +286,12 @@ pub fn loop_(body: impl Fn()) { }); } +pub fn for_unrolled(iter: I, body: impl Fn(I::Item)) { + for i in iter { + body(i); + } +} + pub fn for_range(r: R, body: impl Fn(Expr)) { let start = r.start(); let end = r.end(); diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 2476f2d8..549ece83 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -19,7 +19,7 @@ pub mod prelude { pub use half::f16; pub use crate::lang::control_flow::{ - break_, continue_, for_range, return_, return_v, select, switch, + break_, continue_, for_range, for_unrolled, return_, return_v, select, switch, }; pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, set_block_size}; pub use crate::lang::index::{IndexRead, IndexWrite};