diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index c0c7720..e0d90d2 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -97,40 +97,47 @@ impl CallableParameter for BindlessArrayVar { } } impl KernelParameter for rtx::AccelVar { + type Arg = rtx::Accel; fn def_param(builder: &mut KernelBuilder) -> Self { builder.accel() } } pub trait KernelParameter { + type Arg: KernelArg; fn def_param(builder: &mut KernelBuilder) -> Self; } impl KernelParameter for Expr { + type Arg = T; fn def_param(builder: &mut KernelBuilder) -> Self { builder.uniform::() } } impl KernelParameter for BufferVar { + type Arg = Buffer; fn def_param(builder: &mut KernelBuilder) -> Self { builder.buffer() } } impl KernelParameter for Tex2dVar { + type Arg = Tex2d; fn def_param(builder: &mut KernelBuilder) -> Self { builder.tex2d() } } impl KernelParameter for Tex3dVar { + type Arg = Tex3d; fn def_param(builder: &mut KernelBuilder) -> Self { builder.tex3d() } } impl KernelParameter for BindlessArrayVar { + type Arg = BindlessArray; fn def_param(builder: &mut KernelBuilder) -> Self { builder.bindless_array() } @@ -139,6 +146,7 @@ impl KernelParameter for BindlessArrayVar { macro_rules! impl_kernel_param_for_tuple { ($first:ident $($rest:ident)*) => { impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelParameter for ($first, $($rest,)*) { + type Arg = ($first::Arg, $($rest::Arg),*); #[allow(non_snake_case)] fn def_param(builder: &mut KernelBuilder) -> Self { ($first::def_param(builder), $($rest::def_param(builder)),*) @@ -148,6 +156,7 @@ macro_rules! impl_kernel_param_for_tuple { }; ()=>{ impl KernelParameter for () { + type Arg = (); fn def_param(_: &mut KernelBuilder) -> Self { } }