diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 68d9810..467a116 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1379,45 +1379,46 @@ impl Kernel { } } -pub trait AsKernelArg: KernelArg {} - -impl AsKernelArg for T {} - -impl AsKernelArg> for Buffer {} - -impl<'a, T: Value> AsKernelArg> for BufferView<'a, T> {} - -impl<'a, T: Value> AsKernelArg> for BufferView<'a, T> {} - -impl<'a, T: Value> AsKernelArg> for Buffer {} - -// impl AsKernelArg for ByteBuffer {} - -// impl<'a> AsKernelArg for ByteBufferView<'a> {} - -// impl<'a> AsKernelArg> for ByteBufferView<'a> {} - -// impl<'a> AsKernelArg> for ByteBuffer {} - -impl<'a, T: IoTexel> AsKernelArg> for Tex2dView<'a, T> {} +// A trait signifying that this argument can be used in place of an argument of type `Self::T`. +pub trait AsKernelArg: KernelArg { + type Output: KernelArg + 'static; +} -impl<'a, T: IoTexel> AsKernelArg> for Tex3dView<'a, T> {} +impl AsKernelArg for T { + type Output = T; +} -impl<'a, T: IoTexel> AsKernelArg> for Tex2dView<'a, T> {} +impl AsKernelArg for Buffer { + type Output = T; +} -impl<'a, T: IoTexel> AsKernelArg> for Tex3dView<'a, T> {} +impl<'a, T: Value> AsKernelArg for BufferView<'a, T> { + type Output = Buffer; +} -impl<'a, T: IoTexel> AsKernelArg> for Tex2d {} +impl<'a, T: IoTexel> AsKernelArg for Tex2dView<'a, T> { + type Output = Tex2d; +} -impl<'a, T: IoTexel> AsKernelArg> for Tex3d {} +impl<'a, T: IoTexel> AsKernelArg for Tex3dView<'a, T> { + type Output = Tex3d; +} -impl AsKernelArg> for Tex2d {} +impl AsKernelArg for Tex2d { + type Output = Tex2d; +} -impl AsKernelArg> for Tex3d {} +impl AsKernelArg for Tex3d { + type Output = Tex3d; +} -impl AsKernelArg for BindlessArray {} +impl AsKernelArg for BindlessArray { + type Output = BindlessArray; +} -impl AsKernelArg for Accel {} +impl AsKernelArg for Accel { + type Output = Accel; +} macro_rules! impl_call_for_callable { ( $($Ts:ident)*) => { @@ -1467,7 +1468,7 @@ macro_rules! impl_dispatch_for_kernel { impl <$($Ts: KernelArg+'static),*> Kernel { #[allow(non_snake_case)] #[allow(unused_mut)] - pub fn dispatch(&self, dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg<$Ts>),*) { + pub fn dispatch(&self, dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg),*) { let mut encoder = KernelArgEncoder::new(); $($Ts.encode(&mut encoder);)* self.inner.dispatch(encoder, dispatch_size) @@ -1476,7 +1477,7 @@ macro_rules! impl_dispatch_for_kernel { #[allow(unused_mut)] pub fn dispatch_async( &self, - dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg<$Ts>),* + dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg),* ) -> Command<'static, 'static> { let mut encoder = KernelArgEncoder::new(); $($Ts.encode(&mut encoder);)* diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index 3791fd2..c40e21c 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -141,7 +141,8 @@ impl Compiler { #(self.#field_names.encode(encoder);)* } } - impl #impl_generics #runtime_path::AsKernelArg<#name #ty_generics> for #name #ty_generics #where_clause { + impl #impl_generics #runtime_path::AsKernelArg for #name #ty_generics #where_clause { + type Output = #name #ty_generics; } ) }