diff --git a/luisa_compute/examples/vecadd.rs b/luisa_compute/examples/vecadd.rs index 7eb0d80..603abf7 100644 --- a/luisa_compute/examples/vecadd.rs +++ b/luisa_compute/examples/vecadd.rs @@ -1,8 +1,8 @@ use std::env::current_exe; +use luisa::lang::types::vector::alias::*; use luisa::prelude::*; use luisa_compute as luisa; - fn main() { luisa::init_logger(); let args: Vec = std::env::args().collect(); @@ -23,20 +23,17 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let kernel = device.create_kernel::)>(&|buf_z| { + let kernel = device.create_kernel::)>(track!(&|buf_z| { // z is pass by arg let buf_x = x.var(); // x and y are captured let buf_y = y.var(); let tid = dispatch_id().x; let x = buf_x.read(tid); let y = buf_y.read(tid); - let v = Float3::expr(1.0, 1.0, 1.0); - let iv = v.as_::(); let vx = 2.0_f32.var(); // create a local mutable variable - // *vx.get_mut() += *vx + x; - vx.store(vx.load() + x); // store to vx + *vx += x; // store to vx buf_z.write(tid, vx.load() + y); - }); + })); kernel.dispatch([1024, 1, 1], &z); let z_data = z.view(..).copy_to_vec(); println!("{:?}", &z_data[0..16]); diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index a7f231c..a207fdb 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -49,13 +49,16 @@ pub(crate) trait CallFuncTrait { } impl CallFuncTrait for Func { fn call(self, x: Expr) -> Expr { + let x = x.node(); Expr::::from_node(__current_scope(|b| { - b.call(self, &[x.node()], ::type_()) + b.call(self, &[x], ::type_()) })) } fn call2(self, x: Expr, y: Expr) -> Expr { + let x = x.node(); + let y = y.node(); Expr::::from_node(__current_scope(|b| { - b.call(self, &[x.node(), y.node()], ::type_()) + b.call(self, &[x, y], ::type_()) })) } fn call3( @@ -64,27 +67,32 @@ impl CallFuncTrait for Func { y: Expr, z: Expr, ) -> Expr { + let x = x.node(); + let y = y.node(); + let z = z.node(); Expr::::from_node(__current_scope(|b| { - b.call( - self, - &[x.node(), y.node(), z.node()], - ::type_(), - ) + b.call(self, &[x, y, z], ::type_()) })) } fn call_void(self, x: Expr) { + let x = x.node(); __current_scope(|b| { - b.call(self, &[x.node()], Type::void()); + b.call(self, &[x], Type::void()); }); } fn call2_void(self, x: Expr, y: Expr) { + let x = x.node(); + let y = y.node(); __current_scope(|b| { - b.call(self, &[x.node(), y.node()], Type::void()); + b.call(self, &[x, y], Type::void()); }); } fn call3_void(self, x: Expr, y: Expr, z: Expr) { + let x = x.node(); + let y = y.node(); + let z = z.node(); __current_scope(|b| { - b.call(self, &[x.node(), y.node(), z.node()], Type::void()); + b.call(self, &[x, y, z], Type::void()); }); } } diff --git a/luisa_compute/src/lang/ops.rs b/luisa_compute/src/lang/ops.rs index b325637..59fcd64 100644 --- a/luisa_compute/src/lang/ops.rs +++ b/luisa_compute/src/lang/ops.rs @@ -4,10 +4,11 @@ use std::ops::*; use super::types::core::{Floating, Integral, Numeric, Primitive, Signed}; use super::types::vector::{VectorAlign, VectorElement}; -pub mod impls; -pub mod spread; -pub mod traits; +mod impls; +mod spread; +mod traits; +pub use spread::*; pub use traits::*; pub unsafe trait CastFrom: Primitive {} diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 46ba44b..f2631b1 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -1,3 +1,5 @@ +use crate::lang::types::{ExprType, ValueType}; + use super::*; impl Expr { @@ -5,7 +7,11 @@ impl Expr { where Y::Scalar: CastFrom, { - assert_eq!(X::N, Y::N, "Cannot cast between scalars/vectors of different dimensions."); + assert_eq!( + X::N, + Y::N, + "Cannot cast between scalars/vectors of different dimensions." + ); Func::Cast.call(self) } pub fn cast(self) -> Expr> @@ -31,15 +37,6 @@ macro_rules! impl_ops_trait { } )* } - impl<$($bounds)*> $TraitExpr for $T where $($where)* { - type Output = Self; - - $( - fn $fn($sl, $($arg: Self),*) -> Self { - <$T as $TraitThis>::$fn_this($sl, $($arg),*) - } - )* - } }; ( [$($bounds:tt)*] $TraitExpr:ident [$TraitThis:ident] for $T:ty where [$($where:tt)*] { @@ -57,20 +54,23 @@ macro_rules! impl_ops_trait { } )* } - impl<$($bounds)*> $TraitExpr for $T where $($where)* { - type Output = $Output; - - $( - fn $fn($sl, $($arg: Self),*) -> Self::Output { - <$T as $TraitThis>::$fn_this($sl, $($arg),*) - } - )* - } } } +macro_rules! impl_simple_binop { + ( + [$($bounds:tt)*] $TraitExpr:ident [$TraitThis:ident] for $T:ty where [$($where:tt)*]: $fn:ident [$fn_this:ident] ($func:ident) + ) => { + impl_ops_trait!([$($bounds)*] $TraitExpr [$TraitThis] for $T where [$($where)*] { + fn $fn[$fn_this](self, other) { Func::$func.call2(self, other) } + }); + } +} + impl_ops_trait!([X: Linear] MinMaxExpr[MinMaxThis] for Expr where [X::Scalar: Numeric] { - fn max[_max](self, other) { Func::Max.call2(self, other) } - fn min[_min](self, other) { Func::Min.call2(self, other) } + type Output = Expr>; + + fn max_expr[_max_expr](self, other) { Func::Max.call2(self, other) } + fn min_expr[_min_expr](self, other) { Func::Min.call2(self, other) } }); impl_ops_trait!([X: Linear] ClampExpr[ClampThis] for Expr where [X::Scalar: Numeric] { @@ -102,97 +102,16 @@ impl_ops_trait!([X: Linear] CmpExpr[CmpThis] for Expr where [X::Scalar: Numer fn ge[_ge](self, other) { Func::Ge.call2(self, other) } }); -impl Add for Expr -where - X::Scalar: Numeric, -{ - type Output = Self; - fn add(self, other: Self) -> Self { - Func::Add.call2(self, other) - } -} -impl Sub for Expr -where - X::Scalar: Numeric, -{ - type Output = Self; - fn sub(self, other: Self) -> Self { - Func::Sub.call2(self, other) - } -} -impl Mul for Expr -where - X::Scalar: Numeric, -{ - type Output = Self; - fn mul(self, other: Self) -> Self { - Func::Mul.call2(self, other) - } -} -impl Div for Expr -where - X::Scalar: Numeric, -{ - type Output = Self; - fn div(self, other: Self) -> Self { - Func::Div.call2(self, other) - } -} -impl Rem for Expr -where - X::Scalar: Numeric, -{ - type Output = Self; - fn rem(self, other: Self) -> Self { - Func::Rem.call2(self, other) - } -} - -impl BitAnd for Expr -where - X::Scalar: Integral, -{ - type Output = Self; - fn bitand(self, other: Self) -> Self { - Func::BitAnd.call2(self, other) - } -} -impl BitOr for Expr -where - X::Scalar: Integral, -{ - type Output = Self; - fn bitor(self, other: Self) -> Self { - Func::BitOr.call2(self, other) - } -} -impl BitXor for Expr -where - X::Scalar: Integral, -{ - type Output = Self; - fn bitxor(self, other: Self) -> Self { - Func::BitXor.call2(self, other) - } -} -impl Shl for Expr -where - X::Scalar: Integral, -{ - type Output = Self; - fn shl(self, other: Self) -> Self { - Func::Shl.call2(self, other) - } -} -impl Shr for Expr -where - X::Scalar: Integral, -{ - type Output = Self; - fn shr(self, other: Self) -> Self { - Func::Shr.call2(self, other) - } -} +impl_simple_binop!([X: Linear] AddExpr[AddThis] for Expr where [X::Scalar: Numeric]: add[_add](Add)); +impl_simple_binop!([X: Linear] SubExpr[SubThis] for Expr where [X::Scalar: Numeric]: sub[_sub](Sub)); +impl_simple_binop!([X: Linear] MulExpr[MulThis] for Expr where [X::Scalar: Numeric]: mul[_mul](Mul)); +impl_simple_binop!([X: Linear] DivExpr[DivThis] for Expr where [X::Scalar: Numeric]: div[_div](Div)); +impl_simple_binop!([X: Linear] RemExpr[RemThis] for Expr where [X::Scalar: Numeric]: rem[_rem](Rem)); +impl_simple_binop!([X: Linear] BitAndExpr[BitAndThis] for Expr where [X::Scalar: Integral]: bitand[_bitand](BitAnd)); +impl_simple_binop!([X: Linear] BitOrExpr[BitOrThis] for Expr where [X::Scalar: Integral]: bitor[_bitor](BitOr)); +impl_simple_binop!([X: Linear] BitXorExpr[BitXorThis] for Expr where [X::Scalar: Integral]: bitxor[_bitxor](BitXor)); +impl_simple_binop!([X: Linear] ShlExpr[ShlThis] for Expr where [X::Scalar: Integral]: shl[_shl](Shl)); +impl_simple_binop!([X: Linear] ShrExpr[ShrThis] for Expr where [X::Scalar: Integral]: shr[_shr](Shr)); impl Neg for Expr where @@ -266,7 +185,7 @@ where log10 => Log10 } fn is_finite(&self) -> Self::Bool { - !self.is_infinite() & !self.is_nan() + !self.is_infinite().bitand(!self.is_nan()) } fn is_infinite(&self) -> Self::Bool { Func::IsInf.call(self.clone()) @@ -275,10 +194,10 @@ where Func::IsNan.call(self.clone()) } fn sqr(&self) -> Self { - self.clone() * self.clone() + self.clone().mul(self.clone()) } fn cube(&self) -> Self { - self.clone() * self.clone() * self.clone() + self.clone().mul(self.clone()).mul(self.clone()) } fn recip(&self) -> Self { todo!() @@ -309,7 +228,7 @@ impl_ops_trait!([X: Linear] FloatArcTan2Expr[FloatArcTan2This] for Expr where }); impl_ops_trait!([X: Linear] FloatLogExpr[FloatLogThis] for Expr where [X::Scalar: Floating] { - fn log[_log](self, base) { self.ln() / base.ln()} + fn log[_log](self, base) { self.ln().div(base.ln()) } }); impl_ops_trait!([X: Linear] FloatPowfExpr[FloatPowfThis] for Expr where [X::Scalar: Floating] { @@ -401,7 +320,7 @@ impl LoopMaybeExpr for Expr { } } -impl LazyBoolMaybeExpr for bool { +impl LazyBoolMaybeExpr for bool { type Bool = bool; fn and(self, other: impl FnOnce() -> bool) -> bool { self && other() @@ -410,7 +329,7 @@ impl LazyBoolMaybeExpr for bool { self || other() } } -impl LazyBoolMaybeExpr> for bool { +impl LazyBoolMaybeExpr, ExprType> for bool { type Bool = Expr; fn and(self, other: impl FnOnce() -> Expr) -> Self::Bool { if self { @@ -427,7 +346,7 @@ impl LazyBoolMaybeExpr> for bool { } } } -impl LazyBoolMaybeExpr for Expr { +impl LazyBoolMaybeExpr for Expr { type Bool = Expr; fn and(self, other: impl FnOnce() -> bool) -> Self::Bool { if other() { @@ -444,7 +363,7 @@ impl LazyBoolMaybeExpr for Expr { } } } -impl LazyBoolMaybeExpr for Expr { +impl LazyBoolMaybeExpr, ExprType> for Expr { type Bool = Expr; fn and(self, other: impl FnOnce() -> Expr) -> Self::Bool { crate::lang::control_flow::if_then_else(self, other, || false.expr()) @@ -453,65 +372,3 @@ impl LazyBoolMaybeExpr for Expr { crate::lang::control_flow::if_then_else(self, || true.expr(), other) } } - -impl EqMaybeExpr for T -where - T: EqExpr, -{ - type Bool = >::Output; - fn __eq(self, other: S) -> Self::Bool { - self.eq(other) - } - fn __ne(self, other: S) -> Self::Bool { - self.ne(other) - } -} -impl EqMaybeExpr for T -where - T: PartialEq, -{ - type Bool = bool; - fn __eq(self, other: S) -> Self::Bool { - self == other - } - fn __ne(self, other: S) -> Self::Bool { - self != other - } -} - -impl CmpMaybeExpr for T -where - T: CmpExpr, -{ - type Bool = >::Output; - fn __lt(self, other: S) -> Self::Bool { - self.lt(other) - } - fn __le(self, other: S) -> Self::Bool { - self.le(other) - } - fn __gt(self, other: S) -> Self::Bool { - self.gt(other) - } - fn __ge(self, other: S) -> Self::Bool { - self.ge(other) - } -} -impl CmpMaybeExpr for T -where - T: PartialOrd, -{ - type Bool = bool; - fn __lt(self, other: S) -> Self::Bool { - self < other - } - fn __le(self, other: S) -> Self::Bool { - self <= other - } - fn __gt(self, other: S) -> Self::Bool { - self > other - } - fn __ge(self, other: S) -> Self::Bool { - self >= other - } -} diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs index e42e18d..c85500f 100644 --- a/luisa_compute/src/lang/ops/spread.rs +++ b/luisa_compute/src/lang/ops/spread.rs @@ -1,3 +1,6 @@ +use crate::lang::types::core::PrimitiveVar; +use crate::lang::types::VarProxy; + use super::*; use traits::*; @@ -43,10 +46,11 @@ macro_rules! impl_spread { macro_rules! call_linear_fn_spread { ($f:ident [$($bounds:tt)*]($T:ty)) => { - // The T to other value impls mess up the `Ord` impls. - $f!(@sym [$($bounds)*] Expr<$T>: |x| x, $T: |x| x.expr() => Expr<$T>); + $f!(@sym [$($bounds)*] Expr<$T>: |x| x, Expr<$T>: |x| x => Expr<$T>); + + $f!([$($bounds)*] $T: |x| x.expr(), Expr<$T>: |x| x => Expr<$T>); $f!(['a, $($bounds)*] &'a $T: |x| x.expr(), Expr<$T>: |x| x => Expr<$T>); - $f!(@sym ['b, $($bounds)*] &'b Expr<$T>: |x| x.clone(), $T: |x| x.expr() => Expr<$T>); + $f!(['b, $($bounds)*] $T: |x| x.expr(), &'b Expr<$T>: |x| x.clone() => Expr<$T>); $f!(['a, 'b, $($bounds)*] &'a $T: |x| x.expr(), &'b Expr<$T>: |x| x.clone() => Expr<$T>); $f!(['b, $($bounds)*] Expr<$T>: |x| x, &'b Expr<$T>: |x| x.clone() => Expr<$T>); @@ -55,10 +59,10 @@ macro_rules! call_linear_fn_spread { $f!(@sym [$($bounds)*] Var<$T>: |x| x.load(), Var<$T>: |x| x.load() => Expr<$T>); $f!(@sym ['a, 'b, $($bounds)*] &'a Var<$T>: |x| x.load(), &'b Var<$T>: |x| x.load() => Expr<$T>); - $f!(@sym [$($bounds)*] Var<$T>: |x| x.load(), $T: |x| x.expr() => Expr<$T>); + $f!([$($bounds)*] $T: |x| x.expr(), Var<$T>: |x| x.load() => Expr<$T>); $f!(['a, $($bounds)*] &'a $T: |x| x.expr(), Var<$T>: |x| x.load() => Expr<$T>); - $f!(@sym ['b, $($bounds)*] &'b Var<$T>: |x| x.load(), $T: |x| x.expr() => Expr<$T>); - $f!(['a, 'b, $($bounds)*] &'a $T: |x| x.expr(), &'b Var<$T>: |x| x.load() => Expr<$T>); + $f!(['b, $($bounds)*] $T: |x| x.expr(), &'b Var<$T>: |x| x.load() => Expr<$T>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| x.expr(), &'b Var<$T>: |x| x.load() => Expr<$T>); $f!(['a, $($bounds)*] &'a Expr<$T>: |x| x.clone(), Var<$T>: |x| x.load() => Expr<$T>); $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| x.clone(), &'b Var<$T>: |x| x.load() => Expr<$T>); @@ -74,11 +78,11 @@ call_linear_fn_spread!(impl_spread[T]); macro_rules! call_vector_fn_spread { ($f:ident [$($bounds:tt)*]($N:tt, $T:ty) $Vt:ty, $Vsplat:path) => { - $f!(@sym [$($bounds)*] Expr<$Vt>: |x| x, $T: |x| $Vsplat(x) => Expr<$Vt>); + $f!([$($bounds)*] $T: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); $f!(['a, $($bounds)*] &'a $T: |x| $Vsplat(*x), Expr<$Vt>: |x| x => Expr<$Vt>); $f!([$($bounds)*] Expr<$T>: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); $f!(['a, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); - $f!(@sym ['b, $($bounds)*] &'b Expr<$Vt>: |x| x.clone(), $T: |x| $Vsplat(x) => Expr<$Vt>); + $f!(['b, $($bounds)*] $T: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); $f!(['a, 'b, $($bounds)*] &'a $T: |x| $Vsplat(*x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); @@ -88,11 +92,11 @@ macro_rules! call_vector_fn_spread { $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b $Vt: |x| x.expr() => Expr<$Vt>); $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b $Vt: |x| x.expr() => Expr<$Vt>); - $f!(@sym [$($bounds)*] Var<$Vt>: |x| x.load(), $T: |x| $Vsplat(x) => Expr<$Vt>); + $f!([$($bounds)*] $T: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); $f!(['a, $($bounds)*] &'a $T: |x| $Vsplat(*x), Var<$Vt>: |x| x.load() => Expr<$Vt>); $f!([$($bounds)*] Expr<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); $f!(['a, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); - $f!(@sym ['b, $($bounds)*] &'b Var<$Vt>: |x| x.load(), $T: |x| $Vsplat(x) => Expr<$Vt>); + $f!(['b, $($bounds)*] $T: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); $f!(['a, 'b, $($bounds)*] &'a $T: |x| $Vsplat(*x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); @@ -112,218 +116,239 @@ macro_rules! call_vector_fn_spread { call_vector_fn_spread!(impl_spread[N, T]); -mod trait_impls { - use super::*; - impl MinMaxExpr for T - where - T: SpreadOps, - Expr: MinMaxThis, - { - type Output = Expr; - fn max(self, other: S) -> Self::Output { - Expr::::_max(Self::lift_self(self), Self::lift_other(other)) +macro_rules! impl_simple_binop_spread { + ($TraitExpr:ident [$TraitThis:ident]: $fn:ident[$fn_this:ident]) => { + impl $TraitExpr for T + where + T: SpreadOps, + Expr: $TraitThis, + { + type Output = Expr; + fn $fn(self, other: S) -> Self::Output { + Expr::::$fn_this(Self::lift_self(self), Self::lift_other(other)) + } } - fn min(self, other: S) -> Self::Output { - Expr::::_min(Self::lift_self(self), Self::lift_other(other)) + }; +} + +macro_rules! impl_var_assign { + ($TraitExpr:ident: $fn:ident) => { + impl $TraitExpr for Var + where + T::Var: $TraitExpr, + { + fn $fn(self, other: S) { + >::$fn(self.deref().clone(), other); + } } - } - impl ClampExpr for Expr - where - S: SpreadOps, - Expr: ClampThis, - { - type Output = Expr; - fn clamp(self, min: S, max: U) -> Self::Output { - Expr::::_clamp(self, S::lift_self(min), S::lift_other(max)) + }; +} + +macro_rules! impl_assignop_spread { + ([$($bounds:tt)*] $TraitExpr:ident [$TraitOrigExpr:ident] for $X:ty[$V:ty]: $assign_fn:ident [$fn:ident]) => { + impl<_Other, $($bounds)*> $TraitExpr<_Other> for $X + where + Expr<$V>: $TraitOrigExpr, + Expr<$V>: SpreadOps<_Other, Join = $V> + Sized, + { + fn $assign_fn(self, other: _Other) { + self.as_var_from_proxy().store( + as $TraitOrigExpr>::$fn( + self.deref().clone(), + as SpreadOps<_Other>>::lift_other(other) + ) + ); + } } } - impl EqExpr for T - where - T: SpreadOps, - Expr: EqThis, - { - type Output = as EqThis>::Output; - fn eq(self, other: S) -> Self::Output { - Expr::::_eq(Self::lift_self(self), Self::lift_other(other)) - } - fn ne(self, other: S) -> Self::Output { - Expr::::_ne(Self::lift_self(self), Self::lift_other(other)) - } +} +impl_var_assign!(AddAssignExpr: add_assign); +impl_var_assign!(SubAssignExpr: sub_assign); +impl_var_assign!(MulAssignExpr: mul_assign); +impl_var_assign!(DivAssignExpr: div_assign); +impl_var_assign!(RemAssignExpr: rem_assign); +impl_var_assign!(BitAndAssignExpr: bitand_assign); +impl_var_assign!(BitOrAssignExpr: bitor_assign); +impl_var_assign!(BitXorAssignExpr: bitxor_assign); +impl_var_assign!(ShlAssignExpr: shl_assign); +impl_var_assign!(ShrAssignExpr: shr_assign); + +macro_rules! impl_assignops { + ([$($bounds:tt)*] $X:ty[$V:ty]) => { + impl_assignop_spread!([$($bounds)*] AddAssignExpr[AddThis] for $X[$V]: add_assign[_add]); + impl_assignop_spread!([$($bounds)*] SubAssignExpr[SubThis] for $X[$V]: sub_assign[_sub]); + impl_assignop_spread!([$($bounds)*] MulAssignExpr[MulThis] for $X[$V]: mul_assign[_mul]); + impl_assignop_spread!([$($bounds)*] DivAssignExpr[DivThis] for $X[$V]: div_assign[_div]); + impl_assignop_spread!([$($bounds)*] RemAssignExpr[RemThis] for $X[$V]: rem_assign[_rem]); + impl_assignop_spread!([$($bounds)*] BitAndAssignExpr[BitAndThis] for $X[$V]: bitand_assign[_bitand]); + impl_assignop_spread!([$($bounds)*] BitOrAssignExpr[BitOrThis] for $X[$V]: bitor_assign[_bitor]); + impl_assignop_spread!([$($bounds)*] BitXorAssignExpr[BitXorThis] for $X[$V]: bitxor_assign[_bitxor]); + impl_assignop_spread!([$($bounds)*] ShlAssignExpr[ShlThis] for $X[$V]: shl_assign[_shl]); + impl_assignop_spread!([$($bounds)*] ShrAssignExpr[ShrThis] for $X[$V]: shr_assign[_shr]); } - impl CmpExpr for T - where - T: SpreadOps, - Expr: CmpThis, - { - type Output = as CmpThis>::Output; - fn lt(self, other: S) -> Self::Output { - Expr::::_lt(Self::lift_self(self), Self::lift_other(other)) - } - fn le(self, other: S) -> Self::Output { - Expr::::_le(Self::lift_self(self), Self::lift_other(other)) - } - fn gt(self, other: S) -> Self::Output { - Expr::::_gt(Self::lift_self(self), Self::lift_other(other)) - } - fn ge(self, other: S) -> Self::Output { - Expr::::_ge(Self::lift_self(self), Self::lift_other(other)) - } +} +impl_assignops!([T: Primitive] PrimitiveVar[T]); +impl_assignops!([T: VectorAlign<2, VectorVar = VectorVarProxy2>] VectorVarProxy2[Vector]); +impl_assignops!([T: VectorAlign<3, VectorVar = VectorVarProxy3>] VectorVarProxy3[Vector]); +impl_assignops!([T: VectorAlign<4, VectorVar = VectorVarProxy4>] VectorVarProxy4[Vector]); + +impl_simple_binop_spread!(AddExpr[AddThis]: add[_add]); +impl_simple_binop_spread!(SubExpr[SubThis]: sub[_sub]); +impl_simple_binop_spread!(MulExpr[MulThis]: mul[_mul]); +impl_simple_binop_spread!(DivExpr[DivThis]: div[_div]); +impl_simple_binop_spread!(RemExpr[RemThis]: rem[_rem]); +impl_simple_binop_spread!(BitAndExpr[BitAndThis]: bitand[_bitand]); +impl_simple_binop_spread!(BitOrExpr[BitOrThis]: bitor[_bitor]); +impl_simple_binop_spread!(BitXorExpr[BitXorThis]: bitxor[_bitxor]); +impl_simple_binop_spread!(ShlExpr[ShlThis]: shl[_shl]); +impl_simple_binop_spread!(ShrExpr[ShrThis]: shr[_shr]); + +impl MinMaxExpr for T +where + T: SpreadOps, + Expr: MinMaxThis, +{ + type Output = as MinMaxThis>::Output; + fn min_expr(self, other: S) -> Self::Output { + Expr::::_min_expr(Self::lift_self(self), Self::lift_other(other)) } - impl FloatMulAddExpr for Expr - where - S: SpreadOps, - Expr: FloatMulAddThis, - { - type Output = Expr; - fn mul_add(self, mul: S, add: U) -> Self::Output { - Expr::::_mul_add(self, S::lift_self(mul), S::lift_other(add)) - } + fn max_expr(self, other: S) -> Self::Output { + Expr::::_max_expr(Self::lift_self(self), Self::lift_other(other)) } - impl FloatCopySignExpr for T - where - T: SpreadOps, - Expr: FloatCopySignThis, - { - type Output = Expr; - fn copy_sign(self, sign: S) -> Self::Output { - Expr::::_copy_sign(Self::lift_self(self), Self::lift_other(sign)) - } +} + +pub fn min(x: T, y: S) -> >::Output +where + T: MinMaxExpr, +{ + x.min_expr(y) +} +pub fn max(x: T, y: S) -> >::Output +where + T: MinMaxExpr, +{ + x.max_expr(y) +} + +impl ClampExpr for Expr +where + S: SpreadOps, + Expr: ClampThis, +{ + type Output = Expr; + fn clamp(self, min: S, max: U) -> Self::Output { + Expr::::_clamp(self, S::lift_self(min), S::lift_other(max)) } - impl FloatStepExpr for T - where - T: SpreadOps, - Expr: FloatStepThis, - { - type Output = Expr; - fn step(self, edge: S) -> Self::Output { - Expr::::_step(Self::lift_self(self), Self::lift_other(edge)) - } +} +impl EqExpr for T +where + T: SpreadOps, + Expr: EqThis, +{ + type Output = as EqThis>::Output; + fn eq(self, other: S) -> Self::Output { + Expr::::_eq(Self::lift_self(self), Self::lift_other(other)) } - impl FloatSmoothStepExpr for Expr - where - S: SpreadOps, - Expr: FloatSmoothStepThis, - { - type Output = Expr; - fn smooth_step(self, edge0: S, edge1: U) -> Self::Output { - Expr::::_smooth_step(self, S::lift_self(edge0), S::lift_other(edge1)) - } + fn ne(self, other: S) -> Self::Output { + Expr::::_ne(Self::lift_self(self), Self::lift_other(other)) } - impl FloatArcTan2Expr for T - where - T: SpreadOps, - Expr: FloatArcTan2This, - { - type Output = Expr; - fn atan2(self, other: S) -> Self::Output { - Expr::::_atan2(Self::lift_self(self), Self::lift_other(other)) - } +} +impl CmpExpr for T +where + T: SpreadOps, + Expr: CmpThis, +{ + type Output = as CmpThis>::Output; + fn lt(self, other: S) -> Self::Output { + Expr::::_lt(Self::lift_self(self), Self::lift_other(other)) } - impl FloatLogExpr for T - where - T: SpreadOps, - Expr: FloatLogThis, - { - type Output = Expr; - fn log(self, base: S) -> Self::Output { - Expr::::_log(Self::lift_self(self), Self::lift_other(base)) - } + fn le(self, other: S) -> Self::Output { + Expr::::_le(Self::lift_self(self), Self::lift_other(other)) } - impl FloatPowfExpr for T - where - T: SpreadOps, - Expr: FloatPowfThis, - { - type Output = Expr; - fn powf(self, exponent: S) -> Self::Output { - Expr::::_powf(Self::lift_self(self), Self::lift_other(exponent)) - } + fn gt(self, other: S) -> Self::Output { + Expr::::_gt(Self::lift_self(self), Self::lift_other(other)) } - impl FloatLerpExpr for Expr - where - S: SpreadOps, - Expr: FloatLerpThis, - { - type Output = Expr; - fn lerp(self, other: S, frac: U) -> Self::Output { - Expr::::_lerp(self, S::lift_self(other), S::lift_other(frac)) - } + fn ge(self, other: S) -> Self::Output { + Expr::::_ge(Self::lift_self(self), Self::lift_other(other)) } } -macro_rules! impl_spread_op { - ([ $($bounds:tt)* ]: $Op:ident::$op_fn:ident for $T:ty, $S:ty) => { - impl<$($bounds)*> $Op <$S> for $T where $T: SpreadOps<$S>, Expr<<$T as SpreadOps<$S>>::Join>: $Op { - type Output = >::Join> as $Op>::Output; - fn $op_fn (self, other: $S) -> Self::Output { - >::Join> as $Op>::$op_fn (<$T as SpreadOps<$S>>::lift_self(self), <$T as SpreadOps<$S>>::lift_other(other)) - } - } +impl FloatMulAddExpr for Expr +where + S: SpreadOps, + Expr: FloatMulAddThis, +{ + type Output = Expr; + fn mul_add(self, mul: S, add: U) -> Self::Output { + Expr::::_mul_add(self, S::lift_self(mul), S::lift_other(add)) } } - -macro_rules! impl_num_spread_single { - ([ $($bounds:tt)* ] $T:ty, $S:ty) => { - impl_spread_op!( [ $($bounds)* ]: Add::add for $T, $S); - impl_spread_op!( [ $($bounds)* ]: Sub::sub for $T, $S); - impl_spread_op!( [ $($bounds)* ]: Mul::mul for $T, $S); - impl_spread_op!( [ $($bounds)* ]: Div::div for $T, $S); - impl_spread_op!( [ $($bounds)* ]: Rem::rem for $T, $S); +impl FloatCopySignExpr for T +where + T: SpreadOps, + Expr: FloatCopySignThis, +{ + type Output = Expr; + fn copy_sign(self, sign: S) -> Self::Output { + Expr::::_copy_sign(Self::lift_self(self), Self::lift_other(sign)) } } -macro_rules! impl_int_spread_single { - ([ $($bounds:tt)* ] $T:ty, $S:ty) => { - impl_spread_op!([ $($bounds)* ]: BitAnd::bitand for $T, $S); - impl_spread_op!([ $($bounds)* ]: BitOr::bitor for $T, $S); - impl_spread_op!([ $($bounds)* ]: BitXor::bitxor for $T, $S); - impl_spread_op!([ $($bounds)* ]: Shl::shl for $T, $S); - impl_spread_op!([ $($bounds)* ]: Shr::shr for $T, $S); +impl FloatStepExpr for T +where + T: SpreadOps, + Expr: FloatStepThis, +{ + type Output = Expr; + fn step(self, edge: S) -> Self::Output { + Expr::::_step(Self::lift_self(self), Self::lift_other(edge)) } } - -macro_rules! impl_num_spread { - (@sym [$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { - impl_num_spread_single!([$($bounds)*] $T, $S); - }; - ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { - impl_num_spread_single!([$($bounds)*] $T, $S); - impl_num_spread_single!([$($bounds)*] $S, $T); +impl FloatSmoothStepExpr for Expr +where + S: SpreadOps, + Expr: FloatSmoothStepThis, +{ + type Output = Expr; + fn smooth_step(self, edge0: S, edge1: U) -> Self::Output { + Expr::::_smooth_step(self, S::lift_self(edge0), S::lift_other(edge1)) } } -macro_rules! impl_int_spread { - (@sym [$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { - impl_int_spread_single!([$($bounds)*] $T, $S); - }; - ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { - impl_int_spread_single!([$($bounds)*] $T, $S); - impl_int_spread_single!([$($bounds)*] $S, $T); +impl FloatArcTan2Expr for T +where + T: SpreadOps, + Expr: FloatArcTan2This, +{ + type Output = Expr; + fn atan2(self, other: S) -> Self::Output { + Expr::::_atan2(Self::lift_self(self), Self::lift_other(other)) } } -macro_rules! call_spreads { - ($f:ident: $($T:ty),+) => { - $( - call_linear_fn_spread!($f []($T)); - call_vector_fn_spread!($f [](2, $T)); - call_vector_fn_spread!($f [](3, $T)); - call_vector_fn_spread!($f [](4, $T)); - )+ - }; +impl FloatLogExpr for T +where + T: SpreadOps, + Expr: FloatLogThis, +{ + type Output = Expr; + fn log(self, base: S) -> Self::Output { + Expr::::_log(Self::lift_self(self), Self::lift_other(base)) + } } -call_spreads!(impl_num_spread: f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64); -call_spreads!(impl_int_spread: bool, i8, i16, i32, i64, u8, u16, u32, u64); - -#[allow(dead_code)] -mod tests { - use super::*; - fn test() { - let x = 10.0f32; - let y = Vector::<_, 2>::splat(20.0f32); - let x = x.expr(); - - let w = (&x.var()).min(&0.0_f32.expr()); - let z = 10_f32.max(5_f32); - let i = 15_u32; - let j = z.log(10.0); - let _ = 1_u32.max(2_u32); - println!("{:?}", w); +impl FloatPowfExpr for T +where + T: SpreadOps, + Expr: FloatPowfThis, +{ + type Output = Expr; + fn powf(self, exponent: S) -> Self::Output { + Expr::::_powf(Self::lift_self(self), Self::lift_other(exponent)) + } +} +impl FloatLerpExpr for Expr +where + S: SpreadOps, + Expr: FloatLerpThis, +{ + type Output = Expr; + fn lerp(self, other: S, frac: U) -> Self::Output { + Expr::::_lerp(self, S::lift_self(other), S::lift_other(frac)) } } diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index a4561e1..2746600 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -1,10 +1,89 @@ +use crate::lang::types::{ExprType, TrackingType, ValueType}; + use super::*; // The double trait implementation is necessary as the compiler infinite loops // when trying to resolve the Expr: SpreadOps>> bound. macro_rules! ops_trait { ( - $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident] { + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident, $TraitOrig:ident$(($OrigOutput:path))? => $TraitMaybe:ident ] { + $(type $o:ident;)? + $( + fn $fn:ident [$fn_this:ident, $orig_fn:expr => $fn_maybe:ident] ($self:ident, $($arg:ident: $S:ident),*); + )+ + } + ) => { + ops_trait!( + @XVARS(X, $TraitExpr<$($T),*>, $TraitOrig<$($T),*>) + $TraitExpr<$($T),*> [ $TraitThis, $TraitOrig$(($OrigOutput))? => $TraitMaybe ] { + $(type $o;)? + $( + fn $fn [$fn_this, $orig_fn => $fn_maybe] ($self, $($arg: $S),*); + )+ + } + ); + }; + ( + @XVARS($X:ident, $EXPANDED_EXPR:path, $EXPANDED_ORIG:path) + + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident, $TraitOrig:ident => $TraitMaybe:ident ] { + $(type $o:ident;)? + $( + fn $fn:ident [$fn_this:ident, $orig_fn:expr => $fn_maybe:ident] ($self:ident, $($arg:ident: $S:ident),*); + )+ + } + ) => { + ops_trait!( + @XVARS($X, $EXPANDED_EXPR, $EXPANDED_ORIG) + $TraitExpr<$($T),*> [ $TraitThis, $TraitOrig(<$X as $EXPANDED_ORIG>::Output) => $TraitMaybe ] { + $(type $o;)? + $( + fn $fn [$fn_this, $orig_fn => $fn_maybe] ($self, $($arg: $S),*); + )+ + } + ); + }; + ( + @XVARS($X:ident, $EXPANDED_EXPR:path, $EXPANDED_ORIG:path) + + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident, $TraitOrig:ident($($OrigOutput:tt)*) => $TraitMaybe:ident ] { + $(type $o:ident;)? + $( + fn $fn:ident [$fn_this:ident, $orig_fn:expr => $fn_maybe:ident] ($self:ident, $($arg:ident: $S:ident),*); + )+ + } + ) => { + ops_trait!($TraitExpr <$($T),*> [ $TraitThis ] { + $(type $o;)? + $( + fn $fn [$fn_this] ($self, $($arg: $S),*); + )+ + }); + pub trait $TraitMaybe<$($T,)* Ty: TrackingType> { + type Output; + $( + fn $fn_maybe($self, $($arg: $S),*) -> Self::Output; + )* + } + impl<$X $(,$T)*> $TraitMaybe<$($T,)* ExprType> for $X where $X: $EXPANDED_EXPR { + type Output = <$X as $EXPANDED_EXPR>::Output; + $( + fn $fn_maybe($self, $($arg: $S),*) -> Self::Output { + <$X as $EXPANDED_EXPR>::$fn($self, $($arg),*) + } + )* + } + impl<$X $(,$T)*> $TraitMaybe<$($T,)* ValueType> for $X where $X: $EXPANDED_ORIG { + type Output = $($OrigOutput)*; + $( + fn $fn_maybe($self, $($arg: $S),*) -> Self::Output { + $orig_fn + } + )* + } + }; + ( + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident ] { $( fn $fn:ident [$fn_this:ident] (self, $($arg:ident: $S:ident),*); )+ @@ -24,7 +103,7 @@ macro_rules! ops_trait { } }; ( - $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident] { + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident ] { type Output; $( fn $fn:ident [$fn_this:ident] (self, $($arg:ident: $S:ident),*); @@ -47,9 +126,48 @@ macro_rules! ops_trait { } } +macro_rules! simple_binop_trait { + ($TraitExpr:ident [$TraitThis:ident, $TraitOrig:ident => $TraitMaybe:ident]: $fn:ident [$fn_this: ident, $fn_orig:ident => $fn_maybe:ident]) => { + ops_trait!( + $TraitExpr[$TraitThis, $TraitOrig => $TraitMaybe] { + fn $fn[$fn_this, >::$fn_orig(self, rhs) => $fn_maybe](self, rhs: T); + } + ); + } +} + +macro_rules! assignop_trait { + ($TraitExpr:ident [$TraitOrig:ident => $TraitMaybe:ident]: $fn:ident [$fn_orig:ident => $fn_maybe:ident]) => { + pub trait $TraitExpr { + fn $fn(self, other: T); + } + pub trait $TraitMaybe { + fn $fn_maybe(self, other: T); + } + impl $TraitMaybe for X + where + X: $TraitExpr, + { + fn $fn_maybe(self, other: T) { + >::$fn(self, other) + } + } + impl $TraitMaybe for &mut X + where + X: $TraitOrig, + { + fn $fn_maybe(self, other: T) { + >::$fn_orig(self, other) + } + } + }; +} + ops_trait!(MinMaxExpr[MinMaxThis] { - fn max[_max](self, other: T); - fn min[_min](self, other: T); + type Output; + + fn max_expr[_max_expr](self, other: T); + fn min_expr[_min_expr](self, other: T); }); ops_trait!(ClampExpr[ClampThis] { @@ -60,22 +178,32 @@ pub trait AbsExpr { fn abs(&self) -> Self; } -ops_trait!(EqExpr[EqThis] { +ops_trait!(EqExpr[EqThis, PartialEq(bool) => EqMaybeExpr] { type Output; - fn eq[_eq](self, other: T); - fn ne[_ne](self, other: T); + fn eq[_eq, self == other => __eq](self, other: T); + fn ne[_ne, self != other => __ne](self, other: T); }); -ops_trait!(CmpExpr[CmpThis] { +ops_trait!(CmpExpr[CmpThis, PartialOrd(bool) => CmpMaybeExpr] { type Output; - fn lt[_lt](self, other: T); - fn le[_le](self, other: T); - fn gt[_gt](self, other: T); - fn ge[_ge](self, other: T); + fn lt[_lt, self < other => __lt](self, other: T); + fn le[_le, self <= other => __le](self, other: T); + fn gt[_gt, self > other => __gt](self, other: T); + fn ge[_ge, self >= other => __ge](self, other: T); }); +simple_binop_trait!(AddExpr[AddThis, Add => AddMaybeExpr]: add[_add, add => __add]); +simple_binop_trait!(SubExpr[SubThis, Sub => SubMaybeExpr]: sub[_sub, sub => __sub]); +simple_binop_trait!(MulExpr[MulThis, Mul => MulMaybeExpr]: mul[_mul, mul => __mul]); +simple_binop_trait!(DivExpr[DivThis, Div => DivMaybeExpr]: div[_div, div => __div]); +simple_binop_trait!(RemExpr[RemThis, Rem => RemMaybeExpr]: rem[_rem, rem => __rem]); +simple_binop_trait!(BitAndExpr[BitAndThis, BitAnd => BitAndMaybeExpr]: bitand[_bitand, bitand => __bitand]); +simple_binop_trait!(BitOrExpr[BitOrThis, BitOr => BitOrMaybeExpr]: bitor[_bitor, bitor => __bitor]); +simple_binop_trait!(BitXorExpr[BitXorThis, BitXor => BitXorMaybeExpr]: bitxor[_bitxor, bitxor => __bitxor]); +simple_binop_trait!(ShlExpr[ShlThis, Shl => ShlMaybeExpr]: shl[_shl, shl => __shl]); +simple_binop_trait!(ShrExpr[ShrThis, Shr => ShrMaybeExpr]: shr[_shr, shr => __shr]); pub trait IntExpr { fn rotate_right(&self, n: Expr) -> Self; fn rotate_left(&self, n: Expr) -> Self; @@ -156,6 +284,19 @@ ops_trait!(FloatLerpExpr[FloatLerpThis] { fn lerp[_lerp](self, other: A, frac: B); }); +assignop_trait!(AddAssignExpr[AddAssign => AddAssignMaybeExpr]: add_assign[add_assign => __add_assign]); +assignop_trait!(SubAssignExpr[SubAssign => SubAssignMaybeExpr]: sub_assign[sub_assign => __sub_assign]); +assignop_trait!(MulAssignExpr[MulAssign => MulAssignMaybeExpr]: mul_assign[mul_assign => __mul_assign]); +assignop_trait!(DivAssignExpr[DivAssign => DivAssignMaybeExpr]: div_assign[div_assign => __div_assign]); +assignop_trait!(RemAssignExpr[RemAssign => RemAssignMaybeExpr]: rem_assign[rem_assign => __rem_assign]); +assignop_trait!(BitAndAssignExpr[BitAndAssign => BitAndAssignMaybeExpr]: bitand_assign[bitand_assign => __bitand_assign]); +assignop_trait!(BitOrAssignExpr[BitOrAssign => BitOrAssignMaybeExpr]: bitor_assign[bitor_assign => __bitor_assign]); +assignop_trait!(BitXorAssignExpr[BitXorAssign => BitXorAssignMaybeExpr]: bitxor_assign[bitxor_assign => __bitxor_assign]); +assignop_trait!(ShlAssignExpr[ShlAssign => ShlAssignMaybeExpr]: shl_assign[shl_assign => __shl_assign]); +assignop_trait!(ShrAssignExpr[ShrAssign => ShrAssignMaybeExpr]: shr_assign[shr_assign => __shr_assign]); + +// Traits for track!. + pub trait StoreMaybeExpr { fn store(self, value: V); } @@ -173,22 +314,8 @@ pub trait LoopMaybeExpr { fn while_loop(cond: impl FnMut() -> Self, body: impl FnMut()); } -pub trait LazyBoolMaybeExpr { +pub trait LazyBoolMaybeExpr { type Bool; fn and(self, other: impl FnOnce() -> T) -> Self::Bool; fn or(self, other: impl FnOnce() -> T) -> Self::Bool; } - -pub trait EqMaybeExpr { - type Bool; - fn __eq(self, other: T) -> Self::Bool; - fn __ne(self, other: T) -> Self::Bool; -} - -pub trait CmpMaybeExpr { - type Bool; - fn __lt(self, other: T) -> Self::Bool; - fn __le(self, other: T) -> Self::Bool; - fn __gt(self, other: T) -> Self::Bool; - fn __ge(self, other: T) -> Self::Bool; -} diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index 3b663d2..e40d2d8 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -53,6 +53,7 @@ pub trait ExprProxy: Copy + 'static { type Value: Value; fn from_expr(expr: Expr) -> Self; + fn as_expr_from_proxy(&self) -> &Expr; } /// A trait for implementing remote impls on top of an [`Var`] using [`Deref`]. @@ -62,6 +63,7 @@ pub trait ExprProxy: Copy + 'static { /// impls. pub trait VarProxy: Copy + 'static { type Value: Value; + fn as_var_from_proxy(&self) -> &Var; fn from_var(expr: Var) -> Self; } @@ -240,16 +242,32 @@ impl Var { b.local_zero_init(::type_()) })) } - pub fn load(&self) -> Expr { - __current_scope(|b| { - let nodes = self.to_vec_nodes(); - let mut ret = vec![]; - for node in nodes { - ret.push(b.call(Func::Load, &[node], node.type_().clone())); + pub fn _ref<'a>(self) -> &'a Self { + RECORDER.with(|r| { + let r = r.borrow(); + let v: &Var = r.arena.alloc(self); + unsafe { + let v: &'a Var = std::mem::transmute(v); + v } - Expr::::from_nodes(&mut ret.into_iter()) }) } + pub fn load(&self) -> Expr { + Expr::::from_nodes( + &mut __current_scope(|b| { + let nodes = self.to_vec_nodes(); + let mut ret = vec![]; + for node in nodes { + ret.push(b.call(Func::Load, &[node], node.type_().clone())); + } + ret + }) + .into_iter(), + ) + } + pub fn store(&self, value: impl AsExpr) { + crate::lang::_store(self, &value.as_expr()); + } } pub fn _deref_proxy(proxy: &P) -> &Expr { @@ -269,6 +287,9 @@ macro_rules! impl_simple_expr_proxy { fn from_expr(expr: $crate::lang::types::Expr<$t>) -> Self { Self(expr) } + fn as_expr_from_proxy(&self) -> &$crate::lang::types::Expr<$t> { + &self.0 + } } } } @@ -284,6 +305,9 @@ macro_rules! impl_simple_var_proxy { fn from_var(var: $crate::lang::types::Var<$t>) -> Self { Self(var) } + fn as_var_from_proxy(&self) -> &$crate::lang::types::Var<$t> { + &self.0 + } } impl $(< $($bounds)* >)? std::ops::Deref for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? { type Target = $crate::lang::types::Expr<$t>; @@ -355,7 +379,7 @@ pub trait AsExpr: Tracked { } impl AsExpr for T { - fn as_expr(&self) -> Expr { + fn as_expr(&self) -> Expr { self.expr() } } diff --git a/luisa_compute/src/lang/types/alignment.rs b/luisa_compute/src/lang/types/alignment.rs index 466caeb..04f4e6f 100644 --- a/luisa_compute/src/lang/types/alignment.rs +++ b/luisa_compute/src/lang/types/alignment.rs @@ -1,4 +1,3 @@ -use super::*; use std::hash::Hash; pub trait Alignment: Default + Copy + Hash + Eq + 'static { diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index aed9554..f282759 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -24,7 +24,9 @@ impl Index for ArrayExpr { let i = i.to_u64(); // TODO: Add need_runtime_check()? - lc_assert!(i.lt((N as u64).expr())); + if need_runtime_check() { + lc_assert!(i.lt((N as u64).expr())); + } Expr::::from_node(__current_scope(|b| { b.call(Func::ExtractElement, &[self.0.node, i.node()], T::type_()) diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index aabdbd7..69164ca 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -14,7 +14,8 @@ mod element; mod impls; pub mod swizzle; -use swizzle::*; +pub use impls::*; +pub use swizzle::*; pub trait VectorElement: VectorAlign<2> + VectorAlign<3> + VectorAlign<4> {} impl + VectorAlign<3> + VectorAlign<4>> VectorElement for T {} @@ -69,78 +70,18 @@ impl FromNode for DoubledProxyData { } } -pub trait VectorExprProxy { - const N: usize; - type T: Primitive; - fn node(&self) -> NodeRef; - fn _permute2(&self, x: u32, y: u32) -> Expr> - where - Self::T: VectorAlign<2>, - { - assert!(x < Self::N as u32); - assert!(y < Self::N as u32); - let x = x.expr(); - let y = y.expr(); - Expr::>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[self.node(), x.node(), y.node()], - Vec2::::type_(), - ) - })) - } - fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> - where - Self::T: VectorAlign<3>, - { - assert!(x < Self::N as u32); - assert!(y < Self::N as u32); - assert!(z < Self::N as u32); - let x = x.expr(); - let y = y.expr(); - let z = z.expr(); - Expr::>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[self.node(), x.node(), y.node(), z.node()], - Vec3::::type_(), - ) - })) - } - fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> - where - Self::T: VectorAlign<4>, - { - assert!(x < Self::N as u32); - assert!(y < Self::N as u32); - assert!(z < Self::N as u32); - assert!(w < Self::N as u32); - let x = x.expr(); - let y = y.expr(); - let z = z.expr(); - let w = w.expr(); - Expr::>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[self.node(), x.node(), y.node(), z.node(), w.node()], - Vec4::::type_(), - ) - })) - } -} - macro_rules! vector_proxies { ($N:literal [ $($c:ident),* ]: $ExprName:ident, $VarName:ident) => { #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct $ExprName> { - _node: NodeRef, + self_: Expr>, $(pub $c: Expr),* } #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct $VarName> { - _node: NodeRef, + self_: Var>, $(pub $c: Var),* } @@ -156,16 +97,19 @@ macro_rules! vector_proxies { if i >= $N { i = 0; } )* Self{ - _node: e.node(), + self_: e, $($c),* } } + fn as_expr_from_proxy(&self)->&Expr { + &self.self_ + } } impl> VectorExprProxy for $ExprName { const N: usize = $N; type T = T; fn node(&self) -> NodeRef { - self._node + self.self_.node() } } impl>> VarProxy for $VarName { @@ -180,10 +124,13 @@ macro_rules! vector_proxies { if i >= $N { i = 0; } )* Self{ - _node: e.node(), + self_: e, $($c),* } } + fn as_var_from_proxy(&self)->&Var { + &self.self_ + } } impl>> Deref for $VarName { type Target = Expr>; diff --git a/luisa_compute/src/lang/types/vector/impls.rs b/luisa_compute/src/lang/types/vector/impls.rs index 40c3db1..1cad7fd 100644 --- a/luisa_compute/src/lang/types/vector/impls.rs +++ b/luisa_compute/src/lang/types/vector/impls.rs @@ -1,4 +1,6 @@ use super::*; +use crate::lang::index::IntoIndex; +use std::ops::Index; impl, const N: usize> From<[T; N]> for Vector { fn from(elements: [T; N]) -> Self { @@ -56,7 +58,7 @@ where } macro_rules! impl_sized { - ($Vn:ident($N: literal): $($xs:ident),+) => { + ($Vn:ident($N: literal), $Vexpr:ident, $Vvar:ident : $($xs:ident),+) => { impl> $Vn { pub fn new($($xs: T),+) -> Self { Self { @@ -68,8 +70,125 @@ macro_rules! impl_sized { Self::expr_from_elements([$($xs.as_expr()),+]) } } + impl> $Vexpr { + pub fn dot(&self, other: impl AsExpr>) -> Expr { + Expr::::from_node(__current_scope(|s| { + s.call( + Func::Dot, + &[self.node(), other.as_expr().node()], + T::type_(), + ) + })) + } + } + impl, X: IntoIndex> Index for $Vexpr { + type Output = Expr; + fn index(&self, i: X) -> &Self::Output { + let i = i.to_u64(); + + if need_runtime_check() { + lc_assert!(i.lt(($N as u64).expr())); + } + + Expr::::from_node(__current_scope(|s| { + s.call( + Func::ExtractElement, + &[self.node(), i.node()], + T::type_(), + ) + }))._ref() + } + } + impl, X: IntoIndex> Index for $Vvar { + type Output = Var; + fn index(&self, i: X) -> &Self::Output { + let i = i.to_u64(); + + if need_runtime_check() { + lc_assert!(i.lt(($N as u64).expr())); + } + + Var::::from_node(__current_scope(|s| { + s.call( + Func::GetElementPtr, + &[self.self_.node(), i.node()], + T::type_(), + ) + }))._ref() + } + } + } +} +impl_sized!(Vec2(2), VectorExprProxy2, VectorVarProxy2: x, y); +impl_sized!(Vec3(3), VectorExprProxy3, VectorVarProxy3: x, y, z); +impl_sized!(Vec4(4), VectorExprProxy4, VectorVarProxy4: x, y, z, w); + +pub trait VectorExprProxy { + const N: usize; + type T: Primitive; + fn node(&self) -> NodeRef; + fn _permute2(&self, x: u32, y: u32) -> Expr> + where + Self::T: VectorAlign<2>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node()], + Vec2::::type_(), + ) + })) + } + fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> + where + Self::T: VectorAlign<3>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + assert!(z < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + let z = z.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node(), z.node()], + Vec3::::type_(), + ) + })) + } + fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> + where + Self::T: VectorAlign<4>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + assert!(z < Self::N as u32); + assert!(w < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + let z = z.expr(); + let w = w.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node(), z.node(), w.node()], + Vec4::::type_(), + ) + })) + } + fn length(&self) -> Expr { + Expr::::from_node(__current_scope(|s| { + s.call(Func::Length, &[self.node()], Self::T::type_()) + })) + } + fn length_squared(&self) -> Expr { + Expr::::from_node(__current_scope(|s| { + s.call(Func::LengthSquared, &[self.node()], Self::T::type_()) + })) } } -impl_sized!(Vec2(2): x, y); -impl_sized!(Vec3(3): x, y, z); -impl_sized!(Vec4(4): x, y, z, w); diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 8f5bec4..0edd4d1 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -13,6 +13,8 @@ pub mod resource; pub mod rtx; pub mod runtime; +pub use crate::lang::ops::{max, min}; + pub mod prelude { pub use half::f16; @@ -21,12 +23,17 @@ pub mod prelude { }; pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, set_block_size}; pub use crate::lang::index::{IndexRead, IndexWrite}; - pub use crate::lang::ops::*; - pub use crate::lang::types::vector::alias::*; - pub use crate::lang::types::vector::swizzle::*; - pub use crate::lang::types::vector::{ - Mat2, Mat3, Mat4, SquareMatrix, Vec2, Vec3, Vec4, Vector, + pub use crate::lang::ops::{ + AbsExpr, ActivateMaybeExpr, AddAssignExpr, AddExpr, BitAndAssignExpr, BitAndExpr, + BitOrAssignExpr, BitOrExpr, BitXorAssignExpr, BitXorExpr, ClampExpr, CmpExpr, + DivAssignExpr, DivExpr, EqExpr, FloatArcTan2Expr, FloatCopySignExpr, FloatExpr, + FloatLerpExpr, FloatLogExpr, FloatMulAddExpr, FloatPowfExpr, FloatPowiExpr, + FloatSmoothStepExpr, FloatStepExpr, IntExpr, LazyBoolMaybeExpr, LoopMaybeExpr, MinMaxExpr, + MulAssignExpr, MulExpr, RemAssignExpr, RemExpr, SelectMaybeExpr, ShlAssignExpr, ShlExpr, + ShrAssignExpr, ShrExpr, SubAssignExpr, SubExpr, }; + pub use crate::lang::types::vector::swizzle::*; + pub use crate::lang::types::vector::VectorExprProxy; pub use crate::lang::types::{AsExpr, Expr, Value, Var}; pub use crate::lang::Aggregate; pub use crate::resource::{IoTexel, StorageTexel, *}; @@ -47,6 +54,8 @@ mod internal_prelude { new_node, register_type, BasicBlock, Const, Func, Instruction, IrBuilder, Node, PhiIncoming, Pooled, Type, TypeOf, INVALID_REF, }; + pub(crate) use crate::lang::ops::Linear; + pub(crate) use crate::lang::types::vector::alias::*; pub(crate) use crate::lang::types::vector::*; pub(crate) use crate::lang::{ ir, CallFuncTrait, Recorder, __compose, __extract, __insert, __module_pools, diff --git a/luisa_compute/src/printer.rs b/luisa_compute/src/printer.rs index a41b988..f0520f6 100644 --- a/luisa_compute/src/printer.rs +++ b/luisa_compute/src/printer.rs @@ -128,14 +128,15 @@ impl Printer { let item_id = items.len() as u32; if_!( - offset.lt(data.len().cast::()) - & (offset + 1 + args.count as u32).le(data.len().cast::()), + offset + .lt(data.len().cast::()) + .bitand((offset.add(1 + args.count as u32)).le(data.len().cast::())), { data.atomic_fetch_add(0, 1); data.write(offset, item_id); let mut cnt = 0; for (i, pack_fn) in args.pack_fn.iter().enumerate() { - pack_fn(offset + 1 + cnt, &data); + pack_fn(offset.add(1 + cnt), &data); cnt += args.count_per_arg[i] as u32; } } diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index ffbd899..3aba148 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -1274,6 +1274,12 @@ impl Tex2d { pub fn format(&self) -> PixelFormat { self.handle.format } + pub fn read(&self, uv: impl Into>) -> Expr { + self.var().read(uv) + } + pub fn write(&self, uv: impl Into>, v: impl Into>) { + self.var().write(uv, v) + } } impl Tex3d { pub fn view(&self, level: u32) -> Tex3dView { @@ -1294,6 +1300,12 @@ impl Tex3d { pub fn format(&self) -> PixelFormat { self.handle.format } + pub fn read(&self, uv: impl Into>) -> Expr { + self.var().read(uv) + } + pub fn write(&self, uv: impl Into>, v: impl Into>) { + self.var().write(uv, v) + } } #[derive(Clone)] pub struct BufferVar { @@ -1680,13 +1692,9 @@ impl IndexRead for BufferVar { if need_runtime_check() { lc_assert!(i.lt(self.len())); } - __current_scope(|b| { - FromNode::from_node(b.call( - Func::BufferRead, - &[self.node, ToNode::node(&i)], - T::type_(), - )) - }) + Expr::::from_node(__current_scope(|b| { + b.call(Func::BufferRead, &[self.node, ToNode::node(&i)], T::type_()) + })) } } impl IndexWrite for BufferVar { diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index 7e6589d..776508c 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -278,6 +278,9 @@ impl Compiler { } } + fn as_expr_from_proxy(&self) -> &#lang_path::types::Expr<#name> { + &self.self_ + } } // #[allow(unused_parens)] // impl #impl_generics #lang_path::FromNode for #var_proxy_name #ty_generics #where_clause { @@ -303,6 +306,9 @@ impl Compiler { #(#field_names),* } } + fn as_var_from_proxy(&self) -> &#lang_path::types::Var<#name> { + &self.self_ + } } #[allow(unused_parens)] impl #impl_generics std::ops::Deref for #var_proxy_name #ty_generics #where_clause { diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index f36fbbc..c7a2b8d 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -54,7 +54,7 @@ impl VisitMut for TraceVisitor { }) = &**left { *node = parse_quote_spanned! {span=> - <_ as #trait_path::StoreMaybeExpr>::store(#expr, #right) + <_ as #trait_path::StoreMaybeExpr<_>>::store(#expr, #right) } } } @@ -97,36 +97,113 @@ impl VisitMut for TraceVisitor { } } Expr::Binary(expr) => { + let left = &expr.left; + let right = &expr.right; + + if let Expr::Unary(ExprUnary { + op: UnOp::Deref(_), + expr: left, + .. + }) = &**left + { + let op_fn_str = match &expr.op { + BinOp::AddAssign(_) => "__add_assign", + BinOp::SubAssign(_) => "__sub_assign", + BinOp::MulAssign(_) => "__mul_assign", + BinOp::DivAssign(_) => "__div_assign", + BinOp::RemAssign(_) => "__rem_assign", + BinOp::BitAndAssign(_) => "__bitand_assign", + BinOp::BitOrAssign(_) => "__bitor_assign", + BinOp::BitXorAssign(_) => "__bitxor_assign", + BinOp::ShlAssign(_) => "__shl_assign", + BinOp::ShrAssign(_) => "__shr_assign", + _ => "", + }; + let op_trait_str = match &expr.op { + BinOp::AddAssign(_) => "AddAssignMaybeExpr", + BinOp::SubAssign(_) => "SubAssignMaybeExpr", + BinOp::MulAssign(_) => "MulAssignMaybeExpr", + BinOp::DivAssign(_) => "DivAssignMaybeExpr", + BinOp::RemAssign(_) => "RemAssignMaybeExpr", + BinOp::BitAndAssign(_) => "BitAndAssignMaybeExpr", + BinOp::BitOrAssign(_) => "BitOrAssignMaybeExpr", + BinOp::BitXorAssign(_) => "BitXorAssignMaybeExpr", + BinOp::ShlAssign(_) => "ShlAssignMaybeExpr", + BinOp::ShrAssign(_) => "ShrAssignMaybeExpr", + _ => "", + }; + if !op_fn_str.is_empty() { + let op_fn = Ident::new(op_fn_str, expr.op.span()); + let op_trait = Ident::new(op_trait_str, expr.op.span()); + *node = parse_quote_spanned! {span=> + <_ as #trait_path::#op_trait<_, _>>::#op_fn(#left, #right) + }; + visit_expr_mut(self, node); + return; + } + } + let left = if let Expr::Paren(ExprParen { expr, .. }) = &**left { + expr + } else { + left + }; + let right = if let Expr::Paren(ExprParen { expr, .. }) = &**right { + expr + } else { + right + }; let op_fn_str = match &expr.op { - BinOp::Eq(_) => "eq", - BinOp::Ne(_) => "ne", + BinOp::Add(_) => "__add", + BinOp::Sub(_) => "__sub", + BinOp::Mul(_) => "__mul", + BinOp::Div(_) => "__div", + BinOp::Rem(_) => "__rem", + BinOp::BitAnd(_) => "__bitand", + BinOp::BitOr(_) => "__bitor", + BinOp::BitXor(_) => "__bitxor", + BinOp::Shl(_) => "__shl", + BinOp::Shr(_) => "__shr", BinOp::And(_) => "and", BinOp::Or(_) => "or", - BinOp::Lt(_) => "lt", - BinOp::Le(_) => "le", - BinOp::Ge(_) => "ge", - BinOp::Gt(_) => "gt", + BinOp::Eq(_) => "__eq", + BinOp::Ne(_) => "__ne", + BinOp::Lt(_) => "__lt", + BinOp::Le(_) => "__le", + BinOp::Ge(_) => "__ge", + BinOp::Gt(_) => "__gt", + + _ => "", + }; + let op_trait_str = match &expr.op { + BinOp::Add(_) => "AddMaybeExpr", + BinOp::Sub(_) => "SubMaybeExpr", + BinOp::Mul(_) => "MulMaybeExpr", + BinOp::Div(_) => "DivMaybeExpr", + BinOp::Rem(_) => "RemMaybeExpr", + BinOp::BitAnd(_) => "BitAndMaybeExpr", + BinOp::BitOr(_) => "BitOrMaybeExpr", + BinOp::BitXor(_) => "BitXorMaybeExpr", + BinOp::Shl(_) => "ShlMaybeExpr", + BinOp::Shr(_) => "ShrMaybeExpr", + BinOp::And(_) | BinOp::Or(_) => "LazyBoolMaybeExpr", + BinOp::Eq(_) | BinOp::Ne(_) => "EqMaybeExpr", + BinOp::Lt(_) | BinOp::Le(_) | BinOp::Ge(_) | BinOp::Gt(_) => "CmpMaybeExpr", _ => "", }; if !op_fn_str.is_empty() { - let left = &expr.left; - let right = &expr.right; let op_fn = Ident::new(op_fn_str, expr.op.span()); - if op_fn_str == "eq" || op_fn_str == "ne" { + let op_trait = Ident::new(op_trait_str, expr.op.span()); + if let BinOp::And(_) | BinOp::Or(_) = &expr.op { *node = parse_quote_spanned! {span=> - <_ as #trait_path::EqMaybeExpr<_, _>>::#op_fn(#left, #right) - } - } else if op_fn_str == "and" || op_fn_str == "or" { - *node = parse_quote_spanned! {span=> - <_ as #trait_path::LazyBoolMaybeExpr<_>>::#op_fn(#left, || #right) - } + <_ as #trait_path::#op_trait<_, _>>::#op_fn(#left, || #right) + }; } else { *node = parse_quote_spanned! {span=> - <_ as #trait_path::CmpMaybeExpr<_, _>>::#op_fn(#left, #right) - } + <_ as #trait_path::#op_trait<_, _>>::#op_fn(#left, #right) + }; } } }