diff --git a/luisa_compute/examples/external_callable.rs b/luisa_compute/examples/external_callable.rs new file mode 100644 index 0000000..6096b58 --- /dev/null +++ b/luisa_compute/examples/external_callable.rs @@ -0,0 +1,36 @@ +use luisa::{prelude::*, runtime::ExternalCallable}; +use luisa_compute as luisa; +use std::env::current_exe; + +fn main() { + luisa::init_logger(); + let ctx = Context::new(current_exe().unwrap()); + let device = ctx.create_device("cuda"); + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + let time = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as f32); + y.view(..).fill_fn(|i| 1000.0 * i as f32); + let clock64 = ExternalCallable:: Expr>::new("clock64"); + let kernel = Kernel::)>::new( + &device, + &track!(|buf_z| { + let buf_x = x.var(); + let buf_y = y.var(); + let tid = dispatch_id().x; + let t0 = clock64.call(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + + buf_z.write(tid, x + y); + let t1 = clock64.call(); + time.write(tid, t1 - t0); + }), + ); + kernel.dispatch([1024, 1, 1], &z); + let z_data = z.view(..).copy_to_vec(); + println!("{:?}", &z_data[0..16]); + let time = time.copy_to_vec().iter().sum::() as f64 / 1024.0; + println!("avg time: {}", time); +} diff --git a/luisa_compute/src/lang/types/dynamic.rs b/luisa_compute/src/lang/types/dynamic.rs index e0b15a8..462010c 100644 --- a/luisa_compute/src/lang/types/dynamic.rs +++ b/luisa_compute/src/lang/types/dynamic.rs @@ -121,6 +121,9 @@ unsafe impl CallableRet for DynExpr { fn _from_return(node: NodeRef) -> Self { Self::from_node(node.into()) } + fn _return_type() -> CArc { + panic!("should not be called") + } } impl Aggregate for DynVar { diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index e0feb2f..68b623e 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -5,7 +5,6 @@ extern crate self as luisa_compute; use std::any::Any; use std::backtrace::Backtrace; use std::path::Path; -use std::ptr::null; use std::sync::Arc; pub mod lang; diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 8ac785a..d9e475e 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -741,6 +741,8 @@ impl Device { ) -> Kernel { let name = options.name.unwrap_or("".to_string()); let name = Arc::new(CString::new(name).unwrap()); + let native_include = options.native_include.unwrap_or("".to_string()); + let native_include = Arc::new(CString::new(native_include).unwrap()); let shader_options = api::ShaderOption { enable_cache: options.enable_cache, enable_fast_math: options.enable_fast_math, @@ -749,6 +751,7 @@ impl Device { max_registers: options.max_registers, compile_only: false, name: name.as_ptr(), + native_include: native_include.as_ptr(), }; let module = k.inner.module.clone(); let artifact = if options.async_compile { @@ -757,6 +760,7 @@ impl Device { module.clone(), shader_options, name, + native_include, )) } else { ShaderArtifact::Sync(self.inner.create_shader(&module, &shader_options)) @@ -1164,6 +1168,8 @@ pub(crate) struct AsyncShaderArtifact { // strange naming, huh? #[allow(dead_code)] name: Arc, + #[allow(dead_code)] + native_include: Arc, } pub(crate) enum ShaderArtifact { @@ -1177,9 +1183,14 @@ impl AsyncShaderArtifact { kernel: CArc, options: api::ShaderOption, name: Arc, + native_include: Arc, ) -> Arc<(Mutex, Condvar)> { let artifact = Arc::new(( - Mutex::new(AsyncShaderArtifact { shader: None, name }), + Mutex::new(AsyncShaderArtifact { + shader: None, + name, + native_include, + }), Condvar::new(), )); { @@ -1496,6 +1507,21 @@ impl RawKernel { } } +/// A callable written in native shader language. +pub struct ExternalCallable { + pub(crate) name: CBoxedSlice, + pub(crate) _marker: PhantomData, +} +impl ExternalCallable { + pub fn new(name: impl Into) -> Self { + let name: String = name.into(); + Self { + name: CBoxedSlice::new(name.into_bytes()), + _marker: PhantomData, + } + } +} + pub struct Callable { #[allow(dead_code)] pub(crate) inner: RawCallable, @@ -1777,6 +1803,21 @@ macro_rules! impl_call_for_callable { self.call_impl(std::rc::Rc::new(($($Ts,)*)), &encoder.args) } } + impl ExternalCallableR> { + #[allow(non_snake_case)] + #[allow(unused_mut)] + pub fn call(&self, $($Ts:$Ts),*) -> R { + let mut encoder = CallableArgEncoder::new(); + $($Ts.encode(&mut encoder);)* + CallableRet::_from_return(__current_scope(|b| { + b.call( + Func::External(self.name.clone()), + &encoder.args, + R::_return_type(), + ) + })) + } + } }; } impl_call_for_callable!(); diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 94e0175..0ca9c07 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -543,6 +543,11 @@ pub struct KernelBuildOptions { /// measure time spent during compilation pub time_trace: bool, pub name: Option, + /// Include code written in the native shading language. + /// If provided, backend will include this string into the generated + /// shader code. This field is useful for interoperation with external callables. + /// see also [`ExternalCallable`] + pub native_include: Option, } impl Default for KernelBuildOptions { @@ -560,6 +565,7 @@ impl Default for KernelBuildOptions { max_registers: 0, time_trace: false, name: None, + native_include: None, } } } @@ -574,6 +580,7 @@ pub trait StaticCallableBuildFn: CallableBuildFn {} pub unsafe trait CallableRet { fn _return(&self) -> CArc; fn _from_return(node: NodeRef) -> Self; + fn _return_type() -> CArc; } unsafe impl CallableRet for () { @@ -581,6 +588,9 @@ unsafe impl CallableRet for () { Type::void() } fn _from_return(_: NodeRef) -> Self {} + fn _return_type() -> CArc { + Type::void() + } } unsafe impl CallableRet for Expr { @@ -594,6 +604,9 @@ unsafe impl CallableRet for Expr { fn _from_return(node: NodeRef) -> Self { Self::from_node(node.into()) } + fn _return_type() -> CArc { + V::type_() + } } pub trait CallableSignature { diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 4c4abab..8a01e22 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 4c4abab9f787e5c8350bee4ff99794e3533bd6a6 +Subproject commit 8a01e22a8b202f749fa5a17241a38dd63027d03a