diff --git a/luisa_compute/examples/custom_op.rs b/luisa_compute/examples/custom_op.rs index 1513bfe..6383550 100644 --- a/luisa_compute/examples/custom_op.rs +++ b/luisa_compute/examples/custom_op.rs @@ -1,6 +1,6 @@ use std::env::current_exe; -use luisa::lang::debug::CpuFn; +use luisa::lang::external::CpuFn; use luisa::prelude::*; use luisa_compute as luisa; #[derive(Clone, Copy, Value, Debug)] diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 61e87fe..cd071fe 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -35,6 +35,7 @@ pub mod ops; pub mod poly; pub mod soa; pub mod types; +pub mod external; pub(crate) trait CallFuncTrait { fn call(self, x: Expr) -> Expr; diff --git a/luisa_compute/src/lang/debug.rs b/luisa_compute/src/lang/debug.rs index 5b8f5da..20e2c4c 100644 --- a/luisa_compute/src/lang/debug.rs +++ b/luisa_compute/src/lang/debug.rs @@ -1,85 +1,8 @@ -use ir::CpuCustomOp; use std::ffi::CString; use std::fmt::Debug; -use std::sync::Arc; use crate::internal_prelude::*; -pub struct CpuFn { - op: CArc, - _marker: PhantomData, -} - -/* -Interestingly, Box::into_raw(Box) does not give a valid pointer. -*/ -struct ClosureContainer { - f: Arc, -} - -impl CpuFn { - pub fn new(f: F) -> Self { - let f_ptr = Box::into_raw(Box::new(ClosureContainer:: { f: Arc::new(f) })); - let op = CpuCustomOp { - data: f_ptr as *mut u8, - func: _trampoline::, - destructor: _drop::, - arg_type: T::type_(), - }; - Self { - op: CArc::new(op), - _marker: PhantomData, - } - } - pub fn call(&self, arg: impl AsExpr) -> Expr { - RECORDER.with(|r| { - let mut r = r.borrow_mut(); - assert!(r.lock); - assert_eq!( - r.device - .as_ref() - .unwrap() - .upgrade() - .unwrap() - .inner - .query("device_name") - .unwrap(), - "cpu", - "CpuFn can only be used in cpu backend" - ); - let addr = CArc::as_ptr(&self.op) as u64; - if let Some((_, op)) = r.cpu_custom_ops.get(&addr) { - assert_eq!(CArc::as_ptr(op), CArc::as_ptr(&self.op)); - } else { - let i = r.cpu_custom_ops.len(); - r.cpu_custom_ops.insert(addr, (i, self.op.clone())); - } - }); - Expr::::from_node(__current_scope(|b| { - b.call( - Func::CpuCustomOp(self.op.clone()), - &[arg.as_expr().node()], - T::type_(), - ) - })) - } -} - -extern "C" fn _trampoline(data: *mut u8, args: *mut u8) { - unsafe { - let container = &*(data as *const ClosureContainer); - let f = &container.f; - let args = &mut *(args as *mut T); - f(args); - } -} - -extern "C" fn _drop(data: *mut u8) { - unsafe { - let _ = Box::from_raw(data as *mut T); - } -} - #[macro_export] macro_rules! cpu_dbg { ($arg:expr) => {{ diff --git a/luisa_compute/src/lang/external.rs b/luisa_compute/src/lang/external.rs new file mode 100644 index 0000000..c26b138 --- /dev/null +++ b/luisa_compute/src/lang/external.rs @@ -0,0 +1,80 @@ +use std::sync::Arc; + +use luisa_compute_ir::ir::CpuCustomOp; + +use crate::internal_prelude::*; +pub struct CpuFn { + op: CArc, + _marker: PhantomData, +} + +/* +Interestingly, Box::into_raw(Box) does not give a valid pointer. +*/ +struct ClosureContainer { + f: Arc, +} + +/// A custom function that can be called inside a cpu kernel +impl CpuFn { + pub fn new(f: F) -> Self { + let f_ptr = Box::into_raw(Box::new(ClosureContainer:: { f: Arc::new(f) })); + let op = CpuCustomOp { + data: f_ptr as *mut u8, + func: _trampoline::, + destructor: _drop::, + arg_type: T::type_(), + }; + Self { + op: CArc::new(op), + _marker: PhantomData, + } + } + pub fn call(&self, arg: impl AsExpr) -> Expr { + RECORDER.with(|r| { + let mut r = r.borrow_mut(); + assert!(r.lock); + assert_eq!( + r.device + .as_ref() + .unwrap() + .upgrade() + .unwrap() + .inner + .query("device_name") + .unwrap(), + "cpu", + "CpuFn can only be used in cpu backend" + ); + let addr = CArc::as_ptr(&self.op) as u64; + if let Some((_, op)) = r.cpu_custom_ops.get(&addr) { + assert_eq!(CArc::as_ptr(op), CArc::as_ptr(&self.op)); + } else { + let i = r.cpu_custom_ops.len(); + r.cpu_custom_ops.insert(addr, (i, self.op.clone())); + } + }); + Expr::::from_node(__current_scope(|b| { + b.call( + Func::CpuCustomOp(self.op.clone()), + &[arg.as_expr().node()], + T::type_(), + ) + })) + } +} + +extern "C" fn _trampoline(data: *mut u8, args: *mut u8) { + unsafe { + let container = &*(data as *const ClosureContainer); + let f = &container.f; + let args = &mut *(args as *mut T); + f(args); + } +} + +extern "C" fn _drop(data: *mut u8) { + unsafe { + let _ = Box::from_raw(data as *mut T); + } +} diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index ade93e6..86a89ae 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -35,7 +35,7 @@ pub mod prelude { }; pub use crate::lang::types::vector::swizzle::*; pub use crate::lang::types::vector::VectorExprProxy; - pub use crate::lang::types::{AsExpr, Expr, Value, Var, SoaValue}; + pub use crate::lang::types::{AsExpr, Expr, SoaValue, Value, Var}; pub use crate::lang::Aggregate; pub use crate::resource::{IoTexel, StorageTexel, *}; pub use crate::runtime::api::StreamTag; @@ -52,7 +52,8 @@ pub mod prelude { } mod internal_prelude { - pub(crate) use crate::lang::debug::{__env_need_backtrace, is_cpu_backend, CpuFn}; + pub(crate) use crate::lang::debug::{__env_need_backtrace, is_cpu_backend}; + pub(crate) use crate::lang::external::CpuFn; pub(crate) use crate::lang::ir::ffi::*; pub(crate) use crate::lang::ir::{ new_node, register_type, BasicBlock, Const, Func, Instruction, IrBuilder, Node, @@ -60,7 +61,7 @@ mod internal_prelude { }; pub(crate) use crate::lang::ops::Linear; pub(crate) use crate::lang::types::vector::alias::*; - pub(crate) use crate::lang::types::{SoaBufferProxy, vector::*}; + pub(crate) use crate::lang::types::{vector::*, SoaBufferProxy}; #[allow(unused_imports)] pub(crate) use crate::lang::{ check_index_lt_usize, ir, CallFuncTrait, Recorder, __compose, __extract, __insert, diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index ecc1268..3015047 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit ecc1268f1a294b90becb7222f00256232fc7e140 +Subproject commit 30150473b3f772d59f416cb2fb5cd3e796c085a9