Skip to content

Commit

Permalink
add for_unrolled
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 24, 2023
1 parent 21cb900 commit 16caadf
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 31 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,18 @@ fn pow_unrolled(x:Expr<f32>, i:u32)->Expr<f32> {
**p
}
```
Of course this can be tedius if you just want to unroll a loop. Thus we provide a `for_unrolled` function that unrolls a loop for you.
```rust
#[tracked]
fn pow_unrolled(x:Expr<f32>, i:u32)->Expr<f32> {
let p = 1.0f32.var();
for_unrolled(0..i, |_|{
p *= x;
});
**p
}
```



### Variables and Expressions
Expand Down
53 changes: 23 additions & 30 deletions luisa_compute/examples/mpm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,18 +118,16 @@ fn main() {
let affine =
Mat2::diag_expr(Float2::expr(stress, stress)) + P_MASS as f32 * C.var().read(p);
let vp = v.var().read(p);
escape!(for ii in 0..9 {
let (i, j) = (ii % 3, ii / 3);
track!({
let offset = Int2::expr(i as i32, j as i32);
let dpos = (offset.cast_f32() - fx) * DX;
let weight = w[i].x * w[j].y;
let vadd = weight * (P_MASS * vp + affine * dpos);
let idx = index((base + offset).cast_u32());
grid_v.var().atomic_fetch_add(idx * 2, vadd.x);
grid_v.var().atomic_fetch_add(idx * 2 + 1, vadd.y);
grid_m.var().atomic_fetch_add(idx, weight * P_MASS);
});
for_unrolled(0..9, |ii| {
let (i, j) = escape!((ii % 3, ii / 3));
let offset = Int2::expr(i as i32, j as i32);
let dpos = (offset.cast_f32() - fx) * DX;
let weight = w[i].x * w[j].y;
let vadd = weight * (P_MASS * vp + affine * dpos);
let idx = index((base + offset).cast_u32());
grid_v.var().atomic_fetch_add(idx * 2, vadd.x);
grid_v.var().atomic_fetch_add(idx * 2 + 1, vadd.y);
grid_m.var().atomic_fetch_add(idx, weight * P_MASS);
});
}),
);
Expand Down Expand Up @@ -183,24 +181,19 @@ fn main() {
let new_C = Var::<Mat2>::zeroed();
new_v.store(Float2::expr(0.0f32, 0.0f32));
new_C.store(Mat2::expr(Float2::expr(0., 0.), Float2::expr(0., 0.)));
escape!({
for ii in 0..9 {
let (i, j) = (ii % 3, ii / 3);
track!({
let offset = Int2::expr(i as i32, j as i32);
let dpos = (offset.cast_f32() - fx) * DX.expr();
let weight = w[i].x * w[j].y;
let idx = index((base + offset).cast_u32());
let g_v = Float2::expr(
grid_v.var().read(idx * 2u32),
grid_v.var().read(idx * 2u32 + 1u32),
);
new_v.store(new_v.load() + weight * g_v);
new_C.store(
new_C.load() + 4.0f32 * weight * g_v.outer_product(dpos) / (DX * DX),
);
});
}
for_unrolled(0..9, |ii| {
let (i, j) = escape!((ii % 3, ii / 3));

let offset = Int2::expr(i as i32, j as i32);
let dpos = (offset.cast_f32() - fx) * DX.expr();
let weight = w[i].x * w[j].y;
let idx = index((base + offset).cast_u32());
let g_v = Float2::expr(
grid_v.var().read(idx * 2u32),
grid_v.var().read(idx * 2u32 + 1u32),
);
new_v.store(new_v.load() + weight * g_v);
new_C.store(new_C.load() + 4.0f32 * weight * g_v.outer_product(dpos) / (DX * DX));
});

v.var().write(p, new_v);
Expand Down
6 changes: 6 additions & 0 deletions luisa_compute/src/lang/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,12 @@ pub fn loop_(body: impl Fn()) {
});
}

pub fn for_unrolled<I: IntoIterator>(iter: I, body: impl Fn(I::Item)) {
for i in iter {
body(i);
}
}

pub fn for_range<R: ForLoopRange>(r: R, body: impl Fn(Expr<R::Element>)) {
let start = r.start();
let end = r.end();
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub mod prelude {
pub use half::f16;

pub use crate::lang::control_flow::{
break_, continue_, for_range, return_, return_v, select, switch,
break_, continue_, for_range, for_unrolled, return_, return_v, select, switch,
};
pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, set_block_size};
pub use crate::lang::index::{IndexRead, IndexWrite};
Expand Down

0 comments on commit 16caadf

Please sign in to comment.