Skip to content

Commit

Permalink
Merge pull request #35 from iMplode-nZ/main
Browse files Browse the repository at this point in the history
Add rem_euclid, bool -> int convs, and fix clamp.
  • Loading branch information
shiinamiyuki authored Mar 20, 2024
2 parents 473e4bf + 4c63286 commit ed25171
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 30 deletions.
73 changes: 73 additions & 0 deletions luisa_compute/src/lang/ops/cast_impls.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

#[rustfmt::skip]mod impl_{
use crate::prelude::*;
use super::super::*;
Expand Down Expand Up @@ -970,4 +971,76 @@ impl Expr<Ubyte4> {
pub fn as_byte4(self) -> Expr<Byte4> { self.as_::<Byte4>() }
pub fn cast_i8(self) -> Expr<Byte4> { self.as_::<Byte4>() }
}
impl Expr<bool> {
pub fn as_i32(self) -> Expr<i32> { self.as_::<i32>() }
pub fn cast_i32(self) -> Expr<i32> { self.as_::<i32>() }
pub fn as_u32(self) -> Expr<u32> { self.as_::<u32>() }
pub fn cast_u32(self) -> Expr<u32> { self.as_::<u32>() }
pub fn as_i64(self) -> Expr<i64> { self.as_::<i64>() }
pub fn cast_i64(self) -> Expr<i64> { self.as_::<i64>() }
pub fn as_u64(self) -> Expr<u64> { self.as_::<u64>() }
pub fn cast_u64(self) -> Expr<u64> { self.as_::<u64>() }
pub fn as_i16(self) -> Expr<i16> { self.as_::<i16>() }
pub fn cast_i16(self) -> Expr<i16> { self.as_::<i16>() }
pub fn as_u16(self) -> Expr<u16> { self.as_::<u16>() }
pub fn cast_u16(self) -> Expr<u16> { self.as_::<u16>() }
pub fn as_i8(self) -> Expr<i8> { self.as_::<i8>() }
pub fn cast_i8(self) -> Expr<i8> { self.as_::<i8>() }
pub fn as_u8(self) -> Expr<u8> { self.as_::<u8>() }
pub fn cast_u8(self) -> Expr<u8> { self.as_::<u8>() }
}
impl Expr<Bool2> {
pub fn as_int2(self) -> Expr<Int2> { self.as_::<Int2>() }
pub fn cast_i32(self) -> Expr<Int2> { self.as_::<Int2>() }
pub fn as_uint2(self) -> Expr<Uint2> { self.as_::<Uint2>() }
pub fn cast_u32(self) -> Expr<Uint2> { self.as_::<Uint2>() }
pub fn as_long2(self) -> Expr<Long2> { self.as_::<Long2>() }
pub fn cast_i64(self) -> Expr<Long2> { self.as_::<Long2>() }
pub fn as_ulong2(self) -> Expr<Ulong2> { self.as_::<Ulong2>() }
pub fn cast_u64(self) -> Expr<Ulong2> { self.as_::<Ulong2>() }
pub fn as_short2(self) -> Expr<Short2> { self.as_::<Short2>() }
pub fn cast_i16(self) -> Expr<Short2> { self.as_::<Short2>() }
pub fn as_ushort2(self) -> Expr<Ushort2> { self.as_::<Ushort2>() }
pub fn cast_u16(self) -> Expr<Ushort2> { self.as_::<Ushort2>() }
pub fn as_byte2(self) -> Expr<Byte2> { self.as_::<Byte2>() }
pub fn cast_i8(self) -> Expr<Byte2> { self.as_::<Byte2>() }
pub fn as_ubyte2(self) -> Expr<Ubyte2> { self.as_::<Ubyte2>() }
pub fn cast_u8(self) -> Expr<Ubyte2> { self.as_::<Ubyte2>() }
}
impl Expr<Bool3> {
pub fn as_int3(self) -> Expr<Int3> { self.as_::<Int3>() }
pub fn cast_i32(self) -> Expr<Int3> { self.as_::<Int3>() }
pub fn as_uint3(self) -> Expr<Uint3> { self.as_::<Uint3>() }
pub fn cast_u32(self) -> Expr<Uint3> { self.as_::<Uint3>() }
pub fn as_long3(self) -> Expr<Long3> { self.as_::<Long3>() }
pub fn cast_i64(self) -> Expr<Long3> { self.as_::<Long3>() }
pub fn as_ulong3(self) -> Expr<Ulong3> { self.as_::<Ulong3>() }
pub fn cast_u64(self) -> Expr<Ulong3> { self.as_::<Ulong3>() }
pub fn as_short3(self) -> Expr<Short3> { self.as_::<Short3>() }
pub fn cast_i16(self) -> Expr<Short3> { self.as_::<Short3>() }
pub fn as_ushort3(self) -> Expr<Ushort3> { self.as_::<Ushort3>() }
pub fn cast_u16(self) -> Expr<Ushort3> { self.as_::<Ushort3>() }
pub fn as_byte3(self) -> Expr<Byte3> { self.as_::<Byte3>() }
pub fn cast_i8(self) -> Expr<Byte3> { self.as_::<Byte3>() }
pub fn as_ubyte3(self) -> Expr<Ubyte3> { self.as_::<Ubyte3>() }
pub fn cast_u8(self) -> Expr<Ubyte3> { self.as_::<Ubyte3>() }
}
impl Expr<Bool4> {
pub fn as_int4(self) -> Expr<Int4> { self.as_::<Int4>() }
pub fn cast_i32(self) -> Expr<Int4> { self.as_::<Int4>() }
pub fn as_uint4(self) -> Expr<Uint4> { self.as_::<Uint4>() }
pub fn cast_u32(self) -> Expr<Uint4> { self.as_::<Uint4>() }
pub fn as_long4(self) -> Expr<Long4> { self.as_::<Long4>() }
pub fn cast_i64(self) -> Expr<Long4> { self.as_::<Long4>() }
pub fn as_ulong4(self) -> Expr<Ulong4> { self.as_::<Ulong4>() }
pub fn cast_u64(self) -> Expr<Ulong4> { self.as_::<Ulong4>() }
pub fn as_short4(self) -> Expr<Short4> { self.as_::<Short4>() }
pub fn cast_i16(self) -> Expr<Short4> { self.as_::<Short4>() }
pub fn as_ushort4(self) -> Expr<Ushort4> { self.as_::<Ushort4>() }
pub fn cast_u16(self) -> Expr<Ushort4> { self.as_::<Ushort4>() }
pub fn as_byte4(self) -> Expr<Byte4> { self.as_::<Byte4>() }
pub fn cast_i8(self) -> Expr<Byte4> { self.as_::<Byte4>() }
pub fn as_ubyte4(self) -> Expr<Ubyte4> { self.as_::<Ubyte4>() }
pub fn cast_u8(self) -> Expr<Ubyte4> { self.as_::<Ubyte4>() }
}
}
109 changes: 83 additions & 26 deletions luisa_compute/src/lang/ops/gen_cast.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,95 @@
from typing import List
from itertools import permutations, product

prims = ['f32', 'i32', 'u32', 'f64', 'i64', 'u64', 'f16', 'i16', 'u16', 'i8', 'u8']
file = open('cast_impls.rs', 'w')
print('\n#[rustfmt::skip]mod impl_{\nuse crate::prelude::*;\nuse super::super::*;\n', file=file)
prims = ["f32", "i32", "u32", "f64", "i64", "u64", "f16", "i16", "u16", "i8", "u8"]
bool_convs = ["i32", "u32", "i64", "u64", "i16", "u16", "i8", "u8"]
file = open("cast_impls.rs", "w")
print(
"\n#[rustfmt::skip]mod impl_{\nuse crate::prelude::*;\nuse super::super::*;\n",
file=file,
)
v_name = {
'f32':'float',
'i32':'int',
'u32':'uint',
'f64':'double',
'i64':'long',
'u64':'ulong',
'f16':'half',
'i16':'short',
'u16':'ushort',
'i8':'byte',
'u8':'ubyte'
"f32": "float",
"i32": "int",
"u32": "uint",
"f64": "double",
"i64": "long",
"u64": "ulong",
"f16": "half",
"i16": "short",
"u16": "ushort",
"i8": "byte",
"u8": "ubyte",
}


def make_typename(t):
t = list(t)
t[0] = t[0].upper()
return ''.join(t)
t[0] = t[0].upper()
return "".join(t)


for p in prims:
print('impl Expr<{}> {{'.format(p), file=file)
print("impl Expr<{}> {{".format(p), file=file)
for q in prims:
if p != q:
print(' pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}'.format(q), file=file)
print(' pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}'.format(q), file=file)
print('}', file=file)
for n in [2,3,4]:
print('impl Expr<{}{}> {{'.format(make_typename(v_name[p]),n), file=file)
print(
" pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(
q
),
file=file,
)
print(
" pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(
q
),
file=file,
)
print("}", file=file)
for n in [2, 3, 4]:
print("impl Expr<{}{}> {{".format(make_typename(v_name[p]), n), file=file)
for q in prims:
if p != q:
print(' pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}'.format(make_typename(v_name[q]),n, v_name[q]), file=file)
print(' pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}'.format(make_typename(v_name[q]),n, v_name[q], q), file=file)
print('}', file=file)
print('}', file=file)
print(
" pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format(
make_typename(v_name[q]), n, v_name[q]
),
file=file,
)
print(
" pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format(
make_typename(v_name[q]), n, v_name[q], q
),
file=file,
)
print("}", file=file)

print("impl Expr<bool> {", file=file)
for q in bool_convs:
print(
" pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(q),
file=file,
)
print(
" pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(q),
file=file,
)
print("}", file=file)
for n in [2, 3, 4]:
print("impl Expr<{}{}> {{".format(make_typename("bool"), n), file=file)
for q in bool_convs:
print(
" pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format(
make_typename(v_name[q]), n, v_name[q]
),
file=file,
)
print(
" pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format(
make_typename(v_name[q]), n, v_name[q], q
),
file=file,
)
print("}", file=file)

print("}", file=file)
12 changes: 11 additions & 1 deletion luisa_compute/src/lang/ops/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,22 @@ macro_rules! impl_simple_binop {
}

impl_ops_trait!([X: Linear] MinMaxExpr[MinMaxThis] for Expr<X> where [X::Scalar: Numeric] {
type Output = Expr<X::WithScalar<X::Scalar>>;
type Output = Expr<X>;

fn max_[_max_](self, other) { Func::Max.call2(self, other) }
fn min_[_min_](self, other) { Func::Min.call2(self, other) }
});

impl_ops_trait!([X: Linear] RemEuclidExpr[RemEuclidThis] for Expr<X> where [X::Scalar: Numeric] {
type Output = Expr<X>;

fn rem_euclid[_rem_euclid](self, other) {
track! {
((self % other) + other) % other
}
}
});

impl_ops_trait!([X: Linear] ClampExpr[ClampThis] for Expr<X> where [X::Scalar: Numeric] {
fn clamp[_clamp](self, min, max) { Func::Clamp.call3(self, min, max) }
});
Expand Down
16 changes: 14 additions & 2 deletions luisa_compute/src/lang/ops/spread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ where
}
}

impl<T, S> RemEuclidExpr<S> for T
where
T: SpreadOps<S>,
Expr<T::Join>: RemEuclidThis,
{
type Output = <Expr<T::Join> as RemEuclidThis>::Output;
fn rem_euclid(self, other: S) -> Self::Output {
Expr::<T::Join>::_rem_euclid(Self::lift_self(self), Self::lift_other(other))
}
}

pub fn min<T, S>(x: T, y: S) -> <T as MinMaxExpr<S>>::Output
where
T: MinMaxExpr<S>,
Expand All @@ -236,12 +247,13 @@ where

impl<T: Value, S, U> ClampExpr<S, U> for Expr<T>
where
S: SpreadOps<U, Join = T>,
S: SpreadOps<Expr<T>, Join = T>,
U: SpreadOps<Expr<T>, Join = T>,
Expr<T>: ClampThis,
{
type Output = Expr<T>;
fn clamp(self, min: S, max: U) -> Self::Output {
Expr::<T>::_clamp(self, S::lift_self(min), S::lift_other(max))
Expr::<T>::_clamp(self, S::lift_self(min), U::lift_self(max))
}
}
impl<T, S> EqExpr<S> for T
Expand Down
6 changes: 6 additions & 0 deletions luisa_compute/src/lang/ops/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ ops_trait!(MinMaxExpr<T>[MinMaxThis] {
fn min_[_min_](self, other: T);
});

ops_trait!(RemEuclidExpr<T>[RemEuclidThis] {
type Output;

fn rem_euclid[_rem_euclid](self, other: T);
});

ops_trait!(ClampExpr<A, B>[ClampThis] {
fn clamp[_clamp](self, min: A, max: B);
});
Expand Down

0 comments on commit ed25171

Please sign in to comment.