From df05a22b9bf3eff8a69834602d689f1fe138f05b Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Mar 2024 20:05:12 +0000 Subject: [PATCH 1/2] Added rem_euclid. Also fixed clamp, and implemented bool -> int convs. --- luisa_compute/src/lang/ops/cast_impls.rs | 73 +++++++++++++++ luisa_compute/src/lang/ops/gen_cast.py | 109 +++++++++++++++++------ luisa_compute/src/lang/ops/impls.rs | 12 ++- luisa_compute/src/lang/ops/spread.rs | 16 +++- luisa_compute/src/lang/ops/traits.rs | 6 ++ luisa_compute_sys/LuisaCompute | 2 +- 6 files changed, 188 insertions(+), 30 deletions(-) diff --git a/luisa_compute/src/lang/ops/cast_impls.rs b/luisa_compute/src/lang/ops/cast_impls.rs index b6c34864..97fe7ef8 100644 --- a/luisa_compute/src/lang/ops/cast_impls.rs +++ b/luisa_compute/src/lang/ops/cast_impls.rs @@ -1,3 +1,4 @@ + #[rustfmt::skip]mod impl_{ use crate::prelude::*; use super::super::*; @@ -970,4 +971,76 @@ impl Expr { pub fn as_byte4(self) -> Expr { self.as_::() } pub fn cast_i8(self) -> Expr { self.as_::() } } +impl Expr { + pub fn as_i32(self) -> Expr { self.as_::() } + pub fn cast_i32(self) -> Expr { self.as_::() } + pub fn as_u32(self) -> Expr { self.as_::() } + pub fn cast_u32(self) -> Expr { self.as_::() } + pub fn as_i64(self) -> Expr { self.as_::() } + pub fn cast_i64(self) -> Expr { self.as_::() } + pub fn as_u64(self) -> Expr { self.as_::() } + pub fn cast_u64(self) -> Expr { self.as_::() } + pub fn as_i16(self) -> Expr { self.as_::() } + pub fn cast_i16(self) -> Expr { self.as_::() } + pub fn as_u16(self) -> Expr { self.as_::() } + pub fn cast_u16(self) -> Expr { self.as_::() } + pub fn as_i8(self) -> Expr { self.as_::() } + pub fn cast_i8(self) -> Expr { self.as_::() } + pub fn as_u8(self) -> Expr { self.as_::() } + pub fn cast_u8(self) -> Expr { self.as_::() } +} +impl Expr { + pub fn as_int2(self) -> Expr { self.as_::() } + pub fn cast_i32(self) -> Expr { self.as_::() } + pub fn as_uint2(self) -> Expr { self.as_::() } + pub fn cast_u32(self) -> Expr { self.as_::() } + pub fn as_long2(self) -> Expr { self.as_::() } + pub fn cast_i64(self) -> Expr { self.as_::() } + pub fn as_ulong2(self) -> Expr { self.as_::() } + pub fn cast_u64(self) -> Expr { self.as_::() } + pub fn as_short2(self) -> Expr { self.as_::() } + pub fn cast_i16(self) -> Expr { self.as_::() } + pub fn as_ushort2(self) -> Expr { self.as_::() } + pub fn cast_u16(self) -> Expr { self.as_::() } + pub fn as_byte2(self) -> Expr { self.as_::() } + pub fn cast_i8(self) -> Expr { self.as_::() } + pub fn as_ubyte2(self) -> Expr { self.as_::() } + pub fn cast_u8(self) -> Expr { self.as_::() } +} +impl Expr { + pub fn as_int3(self) -> Expr { self.as_::() } + pub fn cast_i32(self) -> Expr { self.as_::() } + pub fn as_uint3(self) -> Expr { self.as_::() } + pub fn cast_u32(self) -> Expr { self.as_::() } + pub fn as_long3(self) -> Expr { self.as_::() } + pub fn cast_i64(self) -> Expr { self.as_::() } + pub fn as_ulong3(self) -> Expr { self.as_::() } + pub fn cast_u64(self) -> Expr { self.as_::() } + pub fn as_short3(self) -> Expr { self.as_::() } + pub fn cast_i16(self) -> Expr { self.as_::() } + pub fn as_ushort3(self) -> Expr { self.as_::() } + pub fn cast_u16(self) -> Expr { self.as_::() } + pub fn as_byte3(self) -> Expr { self.as_::() } + pub fn cast_i8(self) -> Expr { self.as_::() } + pub fn as_ubyte3(self) -> Expr { self.as_::() } + pub fn cast_u8(self) -> Expr { self.as_::() } +} +impl Expr { + pub fn as_int4(self) -> Expr { self.as_::() } + pub fn cast_i32(self) -> Expr { self.as_::() } + pub fn as_uint4(self) -> Expr { self.as_::() } + pub fn cast_u32(self) -> Expr { self.as_::() } + pub fn as_long4(self) -> Expr { self.as_::() } + pub fn cast_i64(self) -> Expr { self.as_::() } + pub fn as_ulong4(self) -> Expr { self.as_::() } + pub fn cast_u64(self) -> Expr { self.as_::() } + pub fn as_short4(self) -> Expr { self.as_::() } + pub fn cast_i16(self) -> Expr { self.as_::() } + pub fn as_ushort4(self) -> Expr { self.as_::() } + pub fn cast_u16(self) -> Expr { self.as_::() } + pub fn as_byte4(self) -> Expr { self.as_::() } + pub fn cast_i8(self) -> Expr { self.as_::() } + pub fn as_ubyte4(self) -> Expr { self.as_::() } + pub fn cast_u8(self) -> Expr { self.as_::() } +} } diff --git a/luisa_compute/src/lang/ops/gen_cast.py b/luisa_compute/src/lang/ops/gen_cast.py index a6810dbf..5e4e6814 100644 --- a/luisa_compute/src/lang/ops/gen_cast.py +++ b/luisa_compute/src/lang/ops/gen_cast.py @@ -1,38 +1,95 @@ from typing import List from itertools import permutations, product -prims = ['f32', 'i32', 'u32', 'f64', 'i64', 'u64', 'f16', 'i16', 'u16', 'i8', 'u8'] -file = open('cast_impls.rs', 'w') -print('\n#[rustfmt::skip]mod impl_{\nuse crate::prelude::*;\nuse super::super::*;\n', file=file) +prims = ["f32", "i32", "u32", "f64", "i64", "u64", "f16", "i16", "u16", "i8", "u8"] +bool_convs = ["i32", "u32", "i64", "u64", "i16", "u16", "i8", "u8"] +file = open("cast_impls.rs", "w") +print( + "\n#[rustfmt::skip]mod impl_{\nuse crate::prelude::*;\nuse super::super::*;\n", + file=file, +) v_name = { - 'f32':'float', - 'i32':'int', - 'u32':'uint', - 'f64':'double', - 'i64':'long', - 'u64':'ulong', - 'f16':'half', - 'i16':'short', - 'u16':'ushort', - 'i8':'byte', - 'u8':'ubyte' + "f32": "float", + "i32": "int", + "u32": "uint", + "f64": "double", + "i64": "long", + "u64": "ulong", + "f16": "half", + "i16": "short", + "u16": "ushort", + "i8": "byte", + "u8": "ubyte", } + + def make_typename(t): t = list(t) - t[0] = t[0].upper() - return ''.join(t) + t[0] = t[0].upper() + return "".join(t) + + for p in prims: - print('impl Expr<{}> {{'.format(p), file=file) + print("impl Expr<{}> {{".format(p), file=file) for q in prims: if p != q: - print(' pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}'.format(q), file=file) - print(' pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}'.format(q), file=file) - print('}', file=file) - for n in [2,3,4]: - print('impl Expr<{}{}> {{'.format(make_typename(v_name[p]),n), file=file) + print( + " pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format( + q + ), + file=file, + ) + print( + " pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format( + q + ), + file=file, + ) + print("}", file=file) + for n in [2, 3, 4]: + print("impl Expr<{}{}> {{".format(make_typename(v_name[p]), n), file=file) for q in prims: if p != q: - print(' pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}'.format(make_typename(v_name[q]),n, v_name[q]), file=file) - print(' pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}'.format(make_typename(v_name[q]),n, v_name[q], q), file=file) - print('}', file=file) -print('}', file=file) \ No newline at end of file + print( + " pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format( + make_typename(v_name[q]), n, v_name[q] + ), + file=file, + ) + print( + " pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format( + make_typename(v_name[q]), n, v_name[q], q + ), + file=file, + ) + print("}", file=file) + +print("impl Expr {", file=file) +for q in bool_convs: + print( + " pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(q), + file=file, + ) + print( + " pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(q), + file=file, + ) +print("}", file=file) +for n in [2, 3, 4]: + print("impl Expr<{}{}> {{".format(make_typename("bool"), n), file=file) + for q in bool_convs: + print( + " pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format( + make_typename(v_name[q]), n, v_name[q] + ), + file=file, + ) + print( + " pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format( + make_typename(v_name[q]), n, v_name[q], q + ), + file=file, + ) + print("}", file=file) + +print("}", file=file) diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 0d8fd89a..fdb3f194 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -67,12 +67,22 @@ macro_rules! impl_simple_binop { } impl_ops_trait!([X: Linear] MinMaxExpr[MinMaxThis] for Expr where [X::Scalar: Numeric] { - type Output = Expr>; + type Output = Expr; fn max_[_max_](self, other) { Func::Max.call2(self, other) } fn min_[_min_](self, other) { Func::Min.call2(self, other) } }); +impl_ops_trait!([X: Linear] RemEuclidExpr[RemEuclidThis] for Expr where [X::Scalar: Numeric] { + type Output = Expr; + + fn rem_euclid[_rem_euclid](self, other) { + track! { + ((self % other) + other) % other + } + } +}); + impl_ops_trait!([X: Linear] ClampExpr[ClampThis] for Expr where [X::Scalar: Numeric] { fn clamp[_clamp](self, min, max) { Func::Clamp.call3(self, min, max) } }); diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs index 9b55ba2c..3e4208b0 100644 --- a/luisa_compute/src/lang/ops/spread.rs +++ b/luisa_compute/src/lang/ops/spread.rs @@ -221,6 +221,17 @@ where } } +impl RemEuclidExpr for T +where + T: SpreadOps, + Expr: RemEuclidThis, +{ + type Output = as RemEuclidThis>::Output; + fn rem_euclid(self, other: S) -> Self::Output { + Expr::::_rem_euclid(Self::lift_self(self), Self::lift_other(other)) + } +} + pub fn min(x: T, y: S) -> >::Output where T: MinMaxExpr, @@ -236,12 +247,13 @@ where impl ClampExpr for Expr where - S: SpreadOps, + S: SpreadOps, + U: SpreadOps, Expr: ClampThis, { type Output = Expr; fn clamp(self, min: S, max: U) -> Self::Output { - Expr::::_clamp(self, S::lift_self(min), S::lift_other(max)) + Expr::::_clamp(self, S::lift_self(min), U::lift_self(max)) } } impl EqExpr for T diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index 5bb53238..d281eb51 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -170,6 +170,12 @@ ops_trait!(MinMaxExpr[MinMaxThis] { fn min_[_min_](self, other: T); }); +ops_trait!(RemEuclidExpr[RemEuclidThis] { + type Output; + + fn rem_euclid[_rem_euclid](self, other: T); +}); + ops_trait!(ClampExpr[ClampThis] { fn clamp[_clamp](self, min: A, max: B); }); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 37a02808..33ce816b 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 37a0280899fe80a1979431e3852cdc3952c03638 +Subproject commit 33ce816b586808a202bc5f4e5f0728ddd803934a From 4c63286d104040ce644825c38d44998fdac3ce18 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Mar 2024 20:35:37 +0000 Subject: [PATCH 2/2] Bugfix. --- luisa_compute/src/lang/ops/spread.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs index 3e4208b0..78e9b674 100644 --- a/luisa_compute/src/lang/ops/spread.rs +++ b/luisa_compute/src/lang/ops/spread.rs @@ -247,8 +247,8 @@ where impl ClampExpr for Expr where - S: SpreadOps, - U: SpreadOps, + S: SpreadOps, Join = T>, + U: SpreadOps, Join = T>, Expr: ClampThis, { type Output = Expr;