From 89d6aa1a3ed5f7165ac6c6e047fe096d36b7e080 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Fri, 29 Sep 2023 10:12:39 -0400 Subject: [PATCH] refactored command lifetime --- luisa_compute/examples/async.rs | 52 ++++++++++++++++++++++++ luisa_compute/src/resource.rs | 24 +++++------ luisa_compute/src/rtx.rs | 15 ++++--- luisa_compute/src/runtime.rs | 71 ++++++++++++++++----------------- luisa_compute_sys/LuisaCompute | 2 +- 5 files changed, 108 insertions(+), 56 deletions(-) create mode 100644 luisa_compute/examples/async.rs diff --git a/luisa_compute/examples/async.rs b/luisa_compute/examples/async.rs new file mode 100644 index 0000000..4f9842b --- /dev/null +++ b/luisa_compute/examples/async.rs @@ -0,0 +1,52 @@ +use std::env::current_exe; + +use luisa::prelude::*; +use luisa_compute as luisa; + +fn main() { + let ctx = Context::new(current_exe().unwrap()); + let device = ctx.create_device("cpu"); + let x = device.create_buffer::(128); + let y = device.create_buffer::(128); + let x_data = (0..x.len()).map(|i| i as f32).collect::>(); + let stream = device.default_stream(); + + // this should not compile + // stream.with_scope(|s| { + // let cmd = { + // let tmp_data = (0..y.len()).map(|i| i as f32).collect::>(); + // y.copy_from_async(&tmp_data) + // }; + // s.submit([cmd]); + // }); + + // also should not compile + // { + // let s: Scope<'static> = stream.scope(); + // let cmd = { + // let tmp_data = (0..y.len()).map(|i| i as f32).collect::>(); + // y.copy_from_async(&tmp_data) + // }; + // s.submit([cmd]); + // } + + { + let s = stream.scope(); + { + let tmp_data = (0..y.len()).map(|i| i as f32).collect::>(); + // nested lifetime should also be fine + { + let tmp_data = (0..y.len()).map(|i| i as f32).collect::>(); + s.submit([y.copy_from_async(&tmp_data)]); + }; + s.submit([y.copy_from_async(&tmp_data)]); + }; + + s.submit([x.copy_from_async(&x_data)]); + } + + stream.with_scope(|s| { + let tmp_data = (0..y.len()).map(|i| i as f32).collect::>(); + s.submit([x.copy_from_async(&x_data), y.copy_from_async(&tmp_data)]); + }); +} diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 9cfe65f..654d20f 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -333,7 +333,7 @@ impl<'a, T: Value> BufferView<'a, T> { pub fn handle(&self) -> api::Buffer { self.buffer.handle() } - pub fn copy_to_async<'b>(&'a self, data: &'b mut [T]) -> Command<'b> { + pub fn copy_to_async<'b>(&'a self, data: &'b mut [T]) -> Command<'b, 'b> { assert_eq!(data.len(), self.len); let mut rt = ResourceTracker::new(); rt.add(self.buffer.handle.clone()); @@ -364,7 +364,7 @@ impl<'a, T: Value> BufferView<'a, T> { } } - pub fn copy_from_async<'b>(&'a self, data: &'b [T]) -> Command<'static> { + pub fn copy_from_async<'b>(&'a self, data: &'b [T]) -> Command<'b, 'static> { assert_eq!(data.len(), self.len); let mut rt = ResourceTracker::new(); rt.add(self.buffer.handle.clone()); @@ -389,7 +389,7 @@ impl<'a, T: Value> BufferView<'a, T> { pub fn fill(&self, value: T) { self.fill_fn(|_| value); } - pub fn copy_to_buffer_async(&self, dst: BufferView<'a, T>) -> Command<'static> { + pub fn copy_to_buffer_async(&self, dst: BufferView<'a, T>) -> Command<'static, 'static> { assert_eq!(self.len, dst.len); let mut rt = ResourceTracker::new(); rt.add(self.buffer.handle.clone()); @@ -434,7 +434,7 @@ impl Buffer { self.view(..).copy_from(data); } #[inline] - pub fn copy_from_async<'a>(&self, data: &'a [T]) -> Command<'a> { + pub fn copy_from_async<'a>(&self, data: &'a [T]) -> Command<'a, 'static> { self.view(..).copy_from_async(data) } #[inline] @@ -442,7 +442,7 @@ impl Buffer { self.view(..).copy_to(data); } #[inline] - pub fn copy_to_async<'a>(&self, data: &'a mut [T]) -> Command<'a> { + pub fn copy_to_async<'a>(&self, data: &'a mut [T]) -> Command<'a, 'a> { self.view(..).copy_to_async(data) } #[inline] @@ -454,7 +454,7 @@ impl Buffer { self.view(..).copy_to_buffer(dst.view(..)); } #[inline] - pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a Buffer) -> Command<'a> { + pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a Buffer) -> Command<'static, 'static> { self.view(..).copy_to_buffer_async(dst.view(..)) } #[inline] @@ -845,7 +845,7 @@ impl BindlessArray { pub fn update(&self) { submit_default_stream_and_sync(&self.device, [self.update_async()]); } - pub fn update_async<'a>(&'a self) -> Command<'a> { + pub fn update_async<'a>(&'a self) -> Command<'a, 'a> { // What lifetime should this be? self.lock(); let mut rt = ResourceTracker::new(); let modifications = Arc::new(std::mem::replace( @@ -1161,7 +1161,7 @@ impl Tex3d { macro_rules! impl_tex_view { ($name:ident) => { impl<'a, T: IoTexel> $name<'a, T> { - pub fn copy_to_async>(&'a self, data: &'a mut [U]) -> Command<'a> { + pub fn copy_to_async>(&'a self, data: &'a mut [U]) -> Command<'a, 'a> { assert_eq!(data.len(), self.texel_count() as usize); assert_eq!(self.tex.handle.storage, U::pixel_storage()); let mut rt = ResourceTracker::new(); @@ -1192,7 +1192,7 @@ macro_rules! impl_tex_view { self.copy_to(&mut data); data } - pub fn copy_from_async<'b, U: StorageTexel>(&'a self, data: &'b [U]) -> Command<'b> { + pub fn copy_from_async<'b, U: StorageTexel>(&'a self, data: &'b [U]) -> Command<'b, 'static> { assert_eq!(data.len(), self.texel_count() as usize); assert_eq!(self.tex.handle.storage, U::pixel_storage()); let mut rt = ResourceTracker::new(); @@ -1219,7 +1219,7 @@ macro_rules! impl_tex_view { pub fn copy_to_buffer_async<'b, U: StorageTexel + Value>( &'a self, buffer_view: &'b BufferView, - ) -> Command<'static> { + ) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); rt.add(self.tex.handle.clone()); rt.add(buffer_view.buffer.handle.clone()); @@ -1251,7 +1251,7 @@ macro_rules! impl_tex_view { pub fn copy_from_buffer_async<'b, U: StorageTexel + Value>( &'a self, buffer_view: BufferView, - ) -> Command<'static> { + ) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); rt.add(self.tex.handle.clone()); rt.add(buffer_view.buffer.handle.clone()); @@ -1280,7 +1280,7 @@ macro_rules! impl_tex_view { [self.copy_from_buffer_async(buffer_view)], ); } - pub fn copy_to_texture_async(&'a self, other: $name) -> Command<'static> { + pub fn copy_to_texture_async(&'a self, other: $name) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); rt.add(self.tex.handle.clone()); rt.add(other.tex.handle.clone()); diff --git a/luisa_compute/src/rtx.rs b/luisa_compute/src/rtx.rs index 4159fef..e2f2ba1 100644 --- a/luisa_compute/src/rtx.rs +++ b/luisa_compute/src/rtx.rs @@ -73,7 +73,7 @@ impl ProceduralPrimitive { pub fn native_handle(&self) -> *mut std::ffi::c_void { self.handle.native_handle } - pub fn build_async<'a>(&self, request: AccelBuildRequest) -> Command<'a> { + pub fn build_async(&self, request: AccelBuildRequest) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); rt.add(self.handle.clone()); Command { @@ -124,7 +124,7 @@ impl Mesh { pub fn native_handle(&self) -> *mut std::ffi::c_void { self.handle.native_handle } - pub fn build_async<'a>(&self, request: AccelBuildRequest) -> Command<'a> { + pub fn build_async(&self, request: AccelBuildRequest) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); rt.add(self.handle.clone()); Command { @@ -287,7 +287,7 @@ impl Accel { pub fn build(&self, request: api::AccelBuildRequest) { submit_default_stream_and_sync(&self.handle.device, [self.build_async(request)]) } - pub fn build_async<'a>(&'a self, request: api::AccelBuildRequest) -> Command<'a> { + pub fn build_async(&self, request: api::AccelBuildRequest) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); let instance_handles = self.instance_handles.read(); rt.add(self.handle.clone()); @@ -388,7 +388,10 @@ pub enum HitType { Procedural = 2, } -pub fn offset_ray_origin(p: impl AsExpr, n: impl AsExpr) -> Expr { +pub fn offset_ray_origin( + p: impl AsExpr, + n: impl AsExpr, +) -> Expr { lazy_static! { static ref F: Callable, Expr) -> Expr> = Callable::, Expr) -> Expr>::new_static(|p, n| { @@ -402,8 +405,8 @@ pub fn offset_ray_origin(p: impl AsExpr, n: impl AsExpr = p.as_expr(); + let n: Expr = n.as_expr(); F.call(p, n) } pub type Index = [u32; 3]; diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 30fa142..e1be3b4 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -703,21 +703,15 @@ impl<'a> Scope<'a> { self } #[inline] - fn command_list(&self) -> CommandList<'a> { - CommandList::<'a> { - marker: PhantomData {}, - commands: Vec::new(), - } - } - #[inline] - pub fn submit(&self, commands: impl IntoIterator>) -> &Self { + pub fn submit<'cmd>(&self, commands: impl IntoIterator>) -> &Self { self.submit_with_callback(commands, || {}) } - fn submit_impl(&self, commands: Vec>, callback: F) { + fn submit_impl<'cmd, F: FnOnce() + Send + 'static>( + &self, + commands: Vec>, + callback: F, + ) { self.synchronized.set(false); - let mut command_buffer = self.command_list(); - command_buffer.extend(commands); - let commands = command_buffer.commands; let api_commands = commands.iter().map(|c| c.inner).collect::>(); let ctx = CommandCallbackCtx { commands, @@ -725,7 +719,7 @@ impl<'a> Scope<'a> { }; let ptr = Box::into_raw(Box::new(ctx)); extern "C" fn trampoline<'a, F: FnOnce() + Send + 'static>(ptr: *mut u8) { - let ctx = unsafe { *Box::from_raw(ptr as *mut CommandCallbackCtx<'a, F>) }; + let ctx = unsafe { *Box::from_raw(ptr as *mut CommandCallbackCtx<'static, 'a, F>) }; (ctx.f)(); } self.handle.device().dispatch( @@ -735,9 +729,9 @@ impl<'a> Scope<'a> { ) } #[inline] - pub fn submit_with_callback( + pub fn submit_with_callback<'cmd, F: FnOnce() + Send + 'static>( &self, - commands: impl IntoIterator>, + commands: impl IntoIterator>, callback: F, ) -> &Self { let mut iter = commands.into_iter(); @@ -847,28 +841,17 @@ impl Stream { } } -pub(crate) struct CommandList<'a> { - marker: PhantomData<&'a ()>, - commands: Vec>, -} - -struct CommandCallbackCtx<'a, F: FnOnce() + Send + 'static> { +struct CommandCallbackCtx<'cmd, 'scope, F: FnOnce() + Send + 'static> { #[allow(dead_code)] - commands: Vec>, + commands: Vec>, f: F, } -impl<'a> CommandList<'a> { - pub fn extend>>(&mut self, commands: I) { - self.commands.extend(commands); - } - #[allow(dead_code)] - pub fn push(&mut self, command: Command<'a>) { - self.commands.push(command); - } -} - -pub fn submit_default_stream_and_sync<'a, I: IntoIterator>>( +pub fn submit_default_stream_and_sync< + 'cmd, + 'scope, + I: IntoIterator>, +>( device: &Device, commands: I, ) { @@ -879,17 +862,31 @@ pub fn submit_default_stream_and_sync<'a, I: IntoIterator>>( }) } +/// 'from_data is the lifetime of the data that is copied from +/// 'to_data is the lifetime of the data that is copied to. It is also the lifetime of [`Scope<'a>`] +/// Commands are created by resources and submitted to a [`Scope<'a>`] via `scope.submit` and `scope.submit_with_callback`. #[must_use] -pub struct Command<'a> { +pub struct Command<'from_data, 'to_data> { #[allow(dead_code)] pub(crate) inner: api::Command, // is this really necessary? - pub(crate) marker: PhantomData<&'a ()>, + pub(crate) marker: PhantomData<(&'from_data (), &'to_data ())>, pub(crate) callback: Option>, #[allow(dead_code)] pub(crate) resource_tracker: ResourceTracker, } +impl<'cmd, 'scope> Command<'cmd, 'scope> { + pub unsafe fn lift(self) -> Command<'static, 'static> { + Command { + inner: self.inner, + marker: PhantomData {}, + callback: self.callback, + resource_tracker: self.resource_tracker, + } + } +} + pub(crate) struct AsyncShaderArtifact { shader: Option, // strange naming, huh? @@ -1175,7 +1172,7 @@ impl RawKernel { self: &Arc, args: KernelArgEncoder, dispatch_size: [u32; 3], - ) -> Command<'static> { + ) -> Command<'static, 'static> { let mut rt = ResourceTracker::new(); rt.add(Arc::new(args.uniform_data)); rt.add(self.clone()); @@ -1458,7 +1455,7 @@ macro_rules! impl_dispatch_for_kernel { pub fn dispatch_async( &self, dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg<$Ts>),* - ) -> Command<'static> { + ) -> Command<'static, 'static> { let mut encoder = KernelArgEncoder::new(); $($Ts.encode(&mut encoder);)* self.inner.dispatch_async(encoder, dispatch_size) diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index fbb048b..bcc49cd 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit fbb048bd662bab74f373935af495f23e655cab48 +Subproject commit bcc49cdeb6004b62cd7cfc68010902e47682b5af