Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rem_euclid, bool -> int convs, and fix clamp. #35

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 73 additions & 0 deletions luisa_compute/src/lang/ops/cast_impls.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

#[rustfmt::skip]mod impl_{
use crate::prelude::*;
use super::super::*;
Expand Down Expand Up @@ -970,4 +971,76 @@ impl Expr<Ubyte4> {
pub fn as_byte4(self) -> Expr<Byte4> { self.as_::<Byte4>() }
pub fn cast_i8(self) -> Expr<Byte4> { self.as_::<Byte4>() }
}
impl Expr<bool> {
pub fn as_i32(self) -> Expr<i32> { self.as_::<i32>() }
pub fn cast_i32(self) -> Expr<i32> { self.as_::<i32>() }
pub fn as_u32(self) -> Expr<u32> { self.as_::<u32>() }
pub fn cast_u32(self) -> Expr<u32> { self.as_::<u32>() }
pub fn as_i64(self) -> Expr<i64> { self.as_::<i64>() }
pub fn cast_i64(self) -> Expr<i64> { self.as_::<i64>() }
pub fn as_u64(self) -> Expr<u64> { self.as_::<u64>() }
pub fn cast_u64(self) -> Expr<u64> { self.as_::<u64>() }
pub fn as_i16(self) -> Expr<i16> { self.as_::<i16>() }
pub fn cast_i16(self) -> Expr<i16> { self.as_::<i16>() }
pub fn as_u16(self) -> Expr<u16> { self.as_::<u16>() }
pub fn cast_u16(self) -> Expr<u16> { self.as_::<u16>() }
pub fn as_i8(self) -> Expr<i8> { self.as_::<i8>() }
pub fn cast_i8(self) -> Expr<i8> { self.as_::<i8>() }
pub fn as_u8(self) -> Expr<u8> { self.as_::<u8>() }
pub fn cast_u8(self) -> Expr<u8> { self.as_::<u8>() }
}
impl Expr<Bool2> {
pub fn as_int2(self) -> Expr<Int2> { self.as_::<Int2>() }
pub fn cast_i32(self) -> Expr<Int2> { self.as_::<Int2>() }
pub fn as_uint2(self) -> Expr<Uint2> { self.as_::<Uint2>() }
pub fn cast_u32(self) -> Expr<Uint2> { self.as_::<Uint2>() }
pub fn as_long2(self) -> Expr<Long2> { self.as_::<Long2>() }
pub fn cast_i64(self) -> Expr<Long2> { self.as_::<Long2>() }
pub fn as_ulong2(self) -> Expr<Ulong2> { self.as_::<Ulong2>() }
pub fn cast_u64(self) -> Expr<Ulong2> { self.as_::<Ulong2>() }
pub fn as_short2(self) -> Expr<Short2> { self.as_::<Short2>() }
pub fn cast_i16(self) -> Expr<Short2> { self.as_::<Short2>() }
pub fn as_ushort2(self) -> Expr<Ushort2> { self.as_::<Ushort2>() }
pub fn cast_u16(self) -> Expr<Ushort2> { self.as_::<Ushort2>() }
pub fn as_byte2(self) -> Expr<Byte2> { self.as_::<Byte2>() }
pub fn cast_i8(self) -> Expr<Byte2> { self.as_::<Byte2>() }
pub fn as_ubyte2(self) -> Expr<Ubyte2> { self.as_::<Ubyte2>() }
pub fn cast_u8(self) -> Expr<Ubyte2> { self.as_::<Ubyte2>() }
}
impl Expr<Bool3> {
pub fn as_int3(self) -> Expr<Int3> { self.as_::<Int3>() }
pub fn cast_i32(self) -> Expr<Int3> { self.as_::<Int3>() }
pub fn as_uint3(self) -> Expr<Uint3> { self.as_::<Uint3>() }
pub fn cast_u32(self) -> Expr<Uint3> { self.as_::<Uint3>() }
pub fn as_long3(self) -> Expr<Long3> { self.as_::<Long3>() }
pub fn cast_i64(self) -> Expr<Long3> { self.as_::<Long3>() }
pub fn as_ulong3(self) -> Expr<Ulong3> { self.as_::<Ulong3>() }
pub fn cast_u64(self) -> Expr<Ulong3> { self.as_::<Ulong3>() }
pub fn as_short3(self) -> Expr<Short3> { self.as_::<Short3>() }
pub fn cast_i16(self) -> Expr<Short3> { self.as_::<Short3>() }
pub fn as_ushort3(self) -> Expr<Ushort3> { self.as_::<Ushort3>() }
pub fn cast_u16(self) -> Expr<Ushort3> { self.as_::<Ushort3>() }
pub fn as_byte3(self) -> Expr<Byte3> { self.as_::<Byte3>() }
pub fn cast_i8(self) -> Expr<Byte3> { self.as_::<Byte3>() }
pub fn as_ubyte3(self) -> Expr<Ubyte3> { self.as_::<Ubyte3>() }
pub fn cast_u8(self) -> Expr<Ubyte3> { self.as_::<Ubyte3>() }
}
impl Expr<Bool4> {
pub fn as_int4(self) -> Expr<Int4> { self.as_::<Int4>() }
pub fn cast_i32(self) -> Expr<Int4> { self.as_::<Int4>() }
pub fn as_uint4(self) -> Expr<Uint4> { self.as_::<Uint4>() }
pub fn cast_u32(self) -> Expr<Uint4> { self.as_::<Uint4>() }
pub fn as_long4(self) -> Expr<Long4> { self.as_::<Long4>() }
pub fn cast_i64(self) -> Expr<Long4> { self.as_::<Long4>() }
pub fn as_ulong4(self) -> Expr<Ulong4> { self.as_::<Ulong4>() }
pub fn cast_u64(self) -> Expr<Ulong4> { self.as_::<Ulong4>() }
pub fn as_short4(self) -> Expr<Short4> { self.as_::<Short4>() }
pub fn cast_i16(self) -> Expr<Short4> { self.as_::<Short4>() }
pub fn as_ushort4(self) -> Expr<Ushort4> { self.as_::<Ushort4>() }
pub fn cast_u16(self) -> Expr<Ushort4> { self.as_::<Ushort4>() }
pub fn as_byte4(self) -> Expr<Byte4> { self.as_::<Byte4>() }
pub fn cast_i8(self) -> Expr<Byte4> { self.as_::<Byte4>() }
pub fn as_ubyte4(self) -> Expr<Ubyte4> { self.as_::<Ubyte4>() }
pub fn cast_u8(self) -> Expr<Ubyte4> { self.as_::<Ubyte4>() }
}
}
109 changes: 83 additions & 26 deletions luisa_compute/src/lang/ops/gen_cast.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,95 @@
from typing import List
from itertools import permutations, product

prims = ['f32', 'i32', 'u32', 'f64', 'i64', 'u64', 'f16', 'i16', 'u16', 'i8', 'u8']
file = open('cast_impls.rs', 'w')
print('\n#[rustfmt::skip]mod impl_{\nuse crate::prelude::*;\nuse super::super::*;\n', file=file)
prims = ["f32", "i32", "u32", "f64", "i64", "u64", "f16", "i16", "u16", "i8", "u8"]
bool_convs = ["i32", "u32", "i64", "u64", "i16", "u16", "i8", "u8"]
file = open("cast_impls.rs", "w")
print(
"\n#[rustfmt::skip]mod impl_{\nuse crate::prelude::*;\nuse super::super::*;\n",
file=file,
)
v_name = {
'f32':'float',
'i32':'int',
'u32':'uint',
'f64':'double',
'i64':'long',
'u64':'ulong',
'f16':'half',
'i16':'short',
'u16':'ushort',
'i8':'byte',
'u8':'ubyte'
"f32": "float",
"i32": "int",
"u32": "uint",
"f64": "double",
"i64": "long",
"u64": "ulong",
"f16": "half",
"i16": "short",
"u16": "ushort",
"i8": "byte",
"u8": "ubyte",
}


def make_typename(t):
t = list(t)
t[0] = t[0].upper()
return ''.join(t)
t[0] = t[0].upper()
return "".join(t)


for p in prims:
print('impl Expr<{}> {{'.format(p), file=file)
print("impl Expr<{}> {{".format(p), file=file)
for q in prims:
if p != q:
print(' pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}'.format(q), file=file)
print(' pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}'.format(q), file=file)
print('}', file=file)
for n in [2,3,4]:
print('impl Expr<{}{}> {{'.format(make_typename(v_name[p]),n), file=file)
print(
" pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(
q
),
file=file,
)
print(
" pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(
q
),
file=file,
)
print("}", file=file)
for n in [2, 3, 4]:
print("impl Expr<{}{}> {{".format(make_typename(v_name[p]), n), file=file)
for q in prims:
if p != q:
print(' pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}'.format(make_typename(v_name[q]),n, v_name[q]), file=file)
print(' pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}'.format(make_typename(v_name[q]),n, v_name[q], q), file=file)
print('}', file=file)
print('}', file=file)
print(
" pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format(
make_typename(v_name[q]), n, v_name[q]
),
file=file,
)
print(
" pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format(
make_typename(v_name[q]), n, v_name[q], q
),
file=file,
)
print("}", file=file)

print("impl Expr<bool> {", file=file)
for q in bool_convs:
print(
" pub fn as_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(q),
file=file,
)
print(
" pub fn cast_{0}(self) -> Expr<{0}> {{ self.as_::<{0}>() }}".format(q),
file=file,
)
print("}", file=file)
for n in [2, 3, 4]:
print("impl Expr<{}{}> {{".format(make_typename("bool"), n), file=file)
for q in bool_convs:
print(
" pub fn as_{2}{1}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format(
make_typename(v_name[q]), n, v_name[q]
),
file=file,
)
print(
" pub fn cast_{3}(self) -> Expr<{0}{1}> {{ self.as_::<{0}{1}>() }}".format(
make_typename(v_name[q]), n, v_name[q], q
),
file=file,
)
print("}", file=file)

print("}", file=file)
12 changes: 11 additions & 1 deletion luisa_compute/src/lang/ops/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,22 @@ macro_rules! impl_simple_binop {
}

impl_ops_trait!([X: Linear] MinMaxExpr[MinMaxThis] for Expr<X> where [X::Scalar: Numeric] {
type Output = Expr<X::WithScalar<X::Scalar>>;
type Output = Expr<X>;

fn max_[_max_](self, other) { Func::Max.call2(self, other) }
fn min_[_min_](self, other) { Func::Min.call2(self, other) }
});

impl_ops_trait!([X: Linear] RemEuclidExpr[RemEuclidThis] for Expr<X> where [X::Scalar: Numeric] {
type Output = Expr<X>;

fn rem_euclid[_rem_euclid](self, other) {
track! {
((self % other) + other) % other
}
}
});

impl_ops_trait!([X: Linear] ClampExpr[ClampThis] for Expr<X> where [X::Scalar: Numeric] {
fn clamp[_clamp](self, min, max) { Func::Clamp.call3(self, min, max) }
});
Expand Down
16 changes: 14 additions & 2 deletions luisa_compute/src/lang/ops/spread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ where
}
}

impl<T, S> RemEuclidExpr<S> for T
where
T: SpreadOps<S>,
Expr<T::Join>: RemEuclidThis,
{
type Output = <Expr<T::Join> as RemEuclidThis>::Output;
fn rem_euclid(self, other: S) -> Self::Output {
Expr::<T::Join>::_rem_euclid(Self::lift_self(self), Self::lift_other(other))
}
}

pub fn min<T, S>(x: T, y: S) -> <T as MinMaxExpr<S>>::Output
where
T: MinMaxExpr<S>,
Expand All @@ -236,12 +247,13 @@ where

impl<T: Value, S, U> ClampExpr<S, U> for Expr<T>
where
S: SpreadOps<U, Join = T>,
S: SpreadOps<Expr<T>, Join = T>,
U: SpreadOps<Expr<T>, Join = T>,
Expr<T>: ClampThis,
{
type Output = Expr<T>;
fn clamp(self, min: S, max: U) -> Self::Output {
Expr::<T>::_clamp(self, S::lift_self(min), S::lift_other(max))
Expr::<T>::_clamp(self, S::lift_self(min), U::lift_self(max))
}
}
impl<T, S> EqExpr<S> for T
Expand Down
6 changes: 6 additions & 0 deletions luisa_compute/src/lang/ops/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,12 @@ ops_trait!(MinMaxExpr<T>[MinMaxThis] {
fn min_[_min_](self, other: T);
});

ops_trait!(RemEuclidExpr<T>[RemEuclidThis] {
type Output;

fn rem_euclid[_rem_euclid](self, other: T);
});

ops_trait!(ClampExpr<A, B>[ClampThis] {
fn clamp[_clamp](self, min: A, max: B);
});
Expand Down
Loading