Skip to content

Commit

Permalink
added as_expr_proxy, as_var_proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 21, 2023
2 parents 4886531 + 35465f0 commit 27335ce
Show file tree
Hide file tree
Showing 16 changed files with 746 additions and 539 deletions.
11 changes: 4 additions & 7 deletions luisa_compute/examples/vecadd.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use std::env::current_exe;

use luisa::lang::types::vector::alias::*;
use luisa::prelude::*;
use luisa_compute as luisa;

fn main() {
luisa::init_logger();
let args: Vec<String> = std::env::args().collect();
Expand All @@ -23,20 +23,17 @@ fn main() {
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&|buf_z| {
let kernel = device.create_kernel::<fn(Buffer<f32>)>(track!(&|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let v = Float3::expr(1.0, 1.0, 1.0);
let iv = v.as_::<Int3>();
let vx = 2.0_f32.var(); // create a local mutable variable
// *vx.get_mut() += *vx + x;
vx.store(vx.load() + x); // store to vx
*vx += x; // store to vx
buf_z.write(tid, vx.load() + y);
});
}));
kernel.dispatch([1024, 1, 1], &z);
let z_data = z.view(..).copy_to_vec();
println!("{:?}", &z_data[0..16]);
Expand Down
28 changes: 18 additions & 10 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,16 @@ pub(crate) trait CallFuncTrait {
}
impl CallFuncTrait for Func {
fn call<T: Value, S: Value>(self, x: Expr<T>) -> Expr<S> {
let x = x.node();
Expr::<S>::from_node(__current_scope(|b| {
b.call(self, &[x.node()], <S as TypeOf>::type_())
b.call(self, &[x], <S as TypeOf>::type_())
}))
}
fn call2<T: Value, S: Value, U: Value>(self, x: Expr<T>, y: Expr<S>) -> Expr<U> {
let x = x.node();
let y = y.node();
Expr::<U>::from_node(__current_scope(|b| {
b.call(self, &[x.node(), y.node()], <U as TypeOf>::type_())
b.call(self, &[x, y], <U as TypeOf>::type_())
}))
}
fn call3<T: Value, S: Value, U: Value, V: Value>(
Expand All @@ -64,27 +67,32 @@ impl CallFuncTrait for Func {
y: Expr<S>,
z: Expr<U>,
) -> Expr<V> {
let x = x.node();
let y = y.node();
let z = z.node();
Expr::<V>::from_node(__current_scope(|b| {
b.call(
self,
&[x.node(), y.node(), z.node()],
<V as TypeOf>::type_(),
)
b.call(self, &[x, y, z], <V as TypeOf>::type_())
}))
}
fn call_void<T: Value>(self, x: Expr<T>) {
let x = x.node();
__current_scope(|b| {
b.call(self, &[x.node()], Type::void());
b.call(self, &[x], Type::void());
});
}
fn call2_void<T: Value, S: Value>(self, x: Expr<T>, y: Expr<S>) {
let x = x.node();
let y = y.node();
__current_scope(|b| {
b.call(self, &[x.node(), y.node()], Type::void());
b.call(self, &[x, y], Type::void());
});
}
fn call3_void<T: Value, S: Value, U: Value>(self, x: Expr<T>, y: Expr<S>, z: Expr<U>) {
let x = x.node();
let y = y.node();
let z = z.node();
__current_scope(|b| {
b.call(self, &[x.node(), y.node(), z.node()], Type::void());
b.call(self, &[x, y, z], Type::void());
});
}
}
Expand Down
7 changes: 4 additions & 3 deletions luisa_compute/src/lang/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::ops::*;
use super::types::core::{Floating, Integral, Numeric, Primitive, Signed};
use super::types::vector::{VectorAlign, VectorElement};

pub mod impls;
pub mod spread;
pub mod traits;
mod impls;
mod spread;
mod traits;

pub use spread::*;
pub use traits::*;

pub unsafe trait CastFrom<T: Primitive>: Primitive {}
Expand Down
Loading

0 comments on commit 27335ce

Please sign in to comment.