forked from LuisaGroup/luisa-compute-rs
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcallable_advanced.rs
63 lines (60 loc) · 2.02 KB
/
callable_advanced.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
use luisa::lang::types::dynamic::*;
use luisa::prelude::*;
use luisa_compute as luisa;
use std::env::current_exe;
fn main() {
luisa::init_logger();
let args: Vec<String> = std::env::args().collect();
assert!(
args.len() <= 2,
"Usage: {} <backend>. <backend>: cpu, cuda, dx, metal, remote",
args[0]
);
let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device(if args.len() == 2 {
args[1].as_str()
} else {
"cpu"
});
let add = DynCallable::<fn(DynExpr, DynExpr) -> DynExpr>::new(
&device,
Box::new(|a: DynExpr, b: DynExpr| -> DynExpr {
if let Some(a) = a.downcast::<f32>() {
let b = b.downcast::<f32>().unwrap();
return DynExpr::new(track!(a + b));
} else if let Some(a) = a.downcast::<i32>() {
let b = b.downcast::<i32>().unwrap();
return DynExpr::new(track!(a + b));
} else {
unreachable!()
}
}),
);
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
let w = device.create_buffer::<i32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = Kernel::<fn(Buffer<f32>)>::new(
&device,
&track!(|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
buf_z.write(tid, add.call(x.into(), y.into()).get::<f32>());
w.var().write(
tid,
add.call(x.as_::<i32>().into(), y.as_::<i32>().into())
.get::<i32>(),
);
}),
);
kernel.dispatch([1024, 1, 1], &z);
let z_data = z.view(..).copy_to_vec();
println!("{:?}", &z_data[0..16]);
let w_data = w.view(..).copy_to_vec();
println!("{:?}", &w_data[0..16]);
}