From dc0f7290760ec7536560f56ad771c56c5ae713ac Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sun, 24 Sep 2023 10:00:22 -0400 Subject: [PATCH] BIG HACK here: redirect Buffer -> ByteBuffer; revert if causes trouble --- luisa_compute/src/resource.rs | 460 +++++++++++++++------------- luisa_compute/src/runtime.rs | 37 +-- luisa_compute/src/runtime/kernel.rs | 45 +-- luisa_compute/tests/autodiff.rs | 6 + luisa_compute/tests/misc.rs | 16 + luisa_compute_sys/LuisaCompute | 2 +- luisa_compute_track/src/lib.rs | 17 +- 7 files changed, 332 insertions(+), 251 deletions(-) diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index ae822f9..08a579f 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -14,196 +14,234 @@ use crate::runtime::*; use api::{BufferDownloadCommand, BufferUploadCommand, INVALID_RESOURCE_HANDLE}; use libc::c_void; -pub struct ByteBuffer { - pub(crate) device: Device, - pub(crate) handle: Arc, - pub(crate) len: usize, -} -impl ByteBuffer { - pub fn len(&self) -> usize { - self.len - } - #[inline] - pub fn handle(&self) -> api::Buffer { - self.handle.handle - } - #[inline] - pub fn native_handle(&self) -> *mut c_void { - self.handle.native_handle - } - #[inline] - pub fn copy_from(&self, data: &[u8]) { - self.view(..).copy_from(data); - } - #[inline] - pub fn copy_from_async<'a>(&self, data: &[u8]) -> Command<'_> { - self.view(..).copy_from_async(data) - } - #[inline] - pub fn copy_to(&self, data: &mut [u8]) { - self.view(..).copy_to(data); - } - #[inline] - pub fn copy_to_async<'a>(&self, data: &'a mut [u8]) -> Command<'a> { - self.view(..).copy_to_async(data) - } - #[inline] - pub fn copy_to_vec(&self) -> Vec { - self.view(..).copy_to_vec() - } - #[inline] - pub fn copy_to_buffer(&self, dst: &ByteBuffer) { - self.view(..).copy_to_buffer(dst.view(..)); - } - #[inline] - pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a ByteBuffer) -> Command<'a> { - self.view(..).copy_to_buffer_async(dst.view(..)) - } - #[inline] - pub fn fill_fn u8>(&self, f: F) { - self.view(..).fill_fn(f); - } - #[inline] - pub fn fill(&self, value: u8) { - self.view(..).fill(value); - } - pub fn view>(&self, range: S) -> ByteBufferView<'_> { - let lower = range.start_bound(); - let upper = range.end_bound(); - let lower = match lower { - std::ops::Bound::Included(&x) => x, - std::ops::Bound::Excluded(&x) => x + 1, - std::ops::Bound::Unbounded => 0, - }; - let upper = match upper { - std::ops::Bound::Included(&x) => x + 1, - std::ops::Bound::Excluded(&x) => x, - std::ops::Bound::Unbounded => self.len, - }; - assert!(lower <= upper); - assert!(upper <= self.len); - ByteBufferView { - buffer: self, - offset: lower, - len: upper - lower, - } - } - pub fn var(&self) -> ByteBufferVar { - ByteBufferVar::new(&self.view(..)) - } -} -pub struct ByteBufferView<'a> { - pub(crate) buffer: &'a ByteBuffer, - pub(crate) offset: usize, - pub(crate) len: usize, -} -impl<'a> ByteBufferView<'a> { - pub fn handle(&self) -> api::Buffer { - self.buffer.handle() - } - pub fn copy_to_async<'b>(&'a self, data: &'b mut [u8]) -> Command<'b> { - assert_eq!(data.len(), self.len); - let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); - Command { - inner: api::Command::BufferDownload(BufferDownloadCommand { - buffer: self.handle(), - offset: self.offset, - size: data.len(), - data: data.as_mut_ptr() as *mut u8, - }), - marker: PhantomData, - resource_tracker: rt, - callback: None, - } - } - pub fn copy_to_vec(&self) -> Vec { - let mut data = Vec::with_capacity(self.len); - unsafe { - let slice = std::slice::from_raw_parts_mut(data.as_mut_ptr(), self.len); - self.copy_to(slice); - data.set_len(self.len); - } - data - } - pub fn copy_to(&self, data: &mut [u8]) { - unsafe { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_async(data)]); - } - } - pub fn copy_from_async<'b>(&'a self, data: &'b [u8]) -> Command<'static> { - assert_eq!(data.len(), self.len); - let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); - Command { - inner: api::Command::BufferUpload(BufferUploadCommand { - buffer: self.handle(), - offset: self.offset, - size: data.len(), - data: data.as_ptr() as *const u8, - }), - marker: PhantomData, - resource_tracker: rt, - callback: None, - } - } - pub fn copy_from(&self, data: &[u8]) { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_async(data)]); - } - pub fn fill_fn u8>(&self, f: F) { - self.copy_from(&(0..self.len).map(f).collect::>()); - } - pub fn fill(&self, value: u8) { - self.fill_fn(|_| value); - } - pub fn copy_to_buffer_async(&self, dst: ByteBufferView<'a>) -> Command<'static> { - assert_eq!(self.len, dst.len); - let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); - rt.add(dst.buffer.handle.clone()); - Command { - inner: api::Command::BufferCopy(api::BufferCopyCommand { - src: self.handle(), - src_offset: self.offset, - dst: dst.handle(), - dst_offset: dst.offset, - size: self.len, - }), - marker: PhantomData, - resource_tracker: rt, - callback: None, - } - } - pub fn copy_to_buffer(&self, dst: ByteBufferView<'a>) { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(dst)]); - } -} -#[derive(Clone)] -pub struct ByteBufferVar { - #[allow(dead_code)] - pub(crate) handle: Option>, - pub(crate) node: NodeRef, -} -impl ByteBufferVar { - pub fn new(buffer: &ByteBufferView<'_>) -> Self { - let node = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - assert!(r.lock, "BufferVar must be created from within a kernel"); - let binding = Binding::Buffer(BufferBinding { - handle: buffer.handle().0, - size: buffer.len, - offset: buffer.offset as u64, - }); - r.capture_or_get(binding, &buffer.buffer.handle, || { - Node::new(CArc::new(Instruction::Buffer), Type::void()) - }) - }); - Self { - node, - handle: Some(buffer.buffer.handle.clone()), - } - } +pub type ByteBuffer = Buffer; +pub type ByteBufferView<'a> = BufferView<'a, u8>; +pub type ByteBufferVar = BufferVar; + +// Uncomment if the alias blowup again... +// pub struct ByteBuffer { +// pub(crate) device: Device, +// pub(crate) handle: Arc, +// pub(crate) len: usize, +// } +// impl ByteBuffer { +// pub fn len(&self) -> usize { +// self.len +// } +// #[inline] +// pub fn handle(&self) -> api::Buffer { +// self.handle.handle +// } +// #[inline] +// pub fn native_handle(&self) -> *mut c_void { +// self.handle.native_handle +// } +// #[inline] +// pub fn copy_from(&self, data: &[u8]) { +// self.view(..).copy_from(data); +// } +// #[inline] +// pub fn copy_from_async<'a>(&self, data: &[u8]) -> Command<'_> { +// self.view(..).copy_from_async(data) +// } +// #[inline] +// pub fn copy_to(&self, data: &mut [u8]) { +// self.view(..).copy_to(data); +// } +// #[inline] +// pub fn copy_to_async<'a>(&self, data: &'a mut [u8]) -> Command<'a> { +// self.view(..).copy_to_async(data) +// } +// #[inline] +// pub fn copy_to_vec(&self) -> Vec { +// self.view(..).copy_to_vec() +// } +// #[inline] +// pub fn copy_to_buffer(&self, dst: &ByteBuffer) { +// self.view(..).copy_to_buffer(dst.view(..)); +// } +// #[inline] +// pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a ByteBuffer) -> Command<'a> { +// self.view(..).copy_to_buffer_async(dst.view(..)) +// } +// #[inline] +// pub fn fill_fn u8>(&self, f: F) { +// self.view(..).fill_fn(f); +// } +// #[inline] +// pub fn fill(&self, value: u8) { +// self.view(..).fill(value); +// } +// pub fn view>(&self, range: S) -> ByteBufferView<'_> { +// let lower = range.start_bound(); +// let upper = range.end_bound(); +// let lower = match lower { +// std::ops::Bound::Included(&x) => x, +// std::ops::Bound::Excluded(&x) => x + 1, +// std::ops::Bound::Unbounded => 0, +// }; +// let upper = match upper { +// std::ops::Bound::Included(&x) => x + 1, +// std::ops::Bound::Excluded(&x) => x, +// std::ops::Bound::Unbounded => self.len, +// }; +// assert!(lower <= upper); +// assert!(upper <= self.len); +// ByteBufferView { +// buffer: self, +// offset: lower, +// len: upper - lower, +// } +// } +// pub fn var(&self) -> ByteBufferVar { +// ByteBufferVar::new(&self.view(..)) +// } +// } +// pub struct ByteBufferView<'a> { +// pub(crate) buffer: &'a ByteBuffer, +// pub(crate) offset: usize, +// pub(crate) len: usize, +// } +// impl<'a> ByteBufferView<'a> { +// pub fn handle(&self) -> api::Buffer { +// self.buffer.handle() +// } +// pub fn copy_to_async<'b>(&'a self, data: &'b mut [u8]) -> Command<'b> { +// assert_eq!(data.len(), self.len); +// let mut rt = ResourceTracker::new(); +// rt.add(self.buffer.handle.clone()); +// Command { +// inner: api::Command::BufferDownload(BufferDownloadCommand { +// buffer: self.handle(), +// offset: self.offset, +// size: data.len(), +// data: data.as_mut_ptr() as *mut u8, +// }), +// marker: PhantomData, +// resource_tracker: rt, +// callback: None, +// } +// } +// pub fn copy_to_vec(&self) -> Vec { +// let mut data = Vec::with_capacity(self.len); +// unsafe { +// let slice = std::slice::from_raw_parts_mut(data.as_mut_ptr(), self.len); +// self.copy_to(slice); +// data.set_len(self.len); +// } +// data +// } +// pub fn copy_to(&self, data: &mut [u8]) { +// unsafe { +// submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_async(data)]); +// } +// } + +// pub fn copy_from_async<'b>(&'a self, data: &'b [u8]) -> Command<'static> { +// assert_eq!(data.len(), self.len); +// let mut rt = ResourceTracker::new(); +// rt.add(self.buffer.handle.clone()); +// Command { +// inner: api::Command::BufferUpload(BufferUploadCommand { +// buffer: self.handle(), +// offset: self.offset, +// size: data.len(), +// data: data.as_ptr() as *const u8, +// }), +// marker: PhantomData, +// resource_tracker: rt, +// callback: None, +// } +// } +// pub fn copy_from(&self, data: &[u8]) { +// submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_async(data)]); +// } +// pub fn fill_fn u8>(&self, f: F) { +// self.copy_from(&(0..self.len).map(f).collect::>()); +// } +// pub fn fill(&self, value: u8) { +// self.fill_fn(|_| value); +// } +// pub fn copy_to_buffer_async(&self, dst: ByteBufferView<'a>) -> Command<'static> { +// assert_eq!(self.len, dst.len); +// let mut rt = ResourceTracker::new(); +// rt.add(self.buffer.handle.clone()); +// rt.add(dst.buffer.handle.clone()); +// Command { +// inner: api::Command::BufferCopy(api::BufferCopyCommand { +// src: self.handle(), +// src_offset: self.offset, +// dst: dst.handle(), +// dst_offset: dst.offset, +// size: self.len, +// }), +// marker: PhantomData, +// resource_tracker: rt, +// callback: None, +// } +// } +// pub fn copy_to_buffer(&self, dst: ByteBufferView<'a>) { +// submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(dst)]); +// } +// } +// #[derive(Clone)] +// pub struct ByteBufferVar { +// #[allow(dead_code)] +// pub(crate) handle: Option>, +// pub(crate) node: NodeRef, +// } +// impl ByteBufferVar { +// pub fn new(buffer: &ByteBufferView<'_>) -> Self { +// let node = RECORDER.with(|r| { +// let mut r = r.borrow_mut(); +// assert!(r.lock, "BufferVar must be created from within a kernel"); +// let binding = Binding::Buffer(BufferBinding { +// handle: buffer.handle().0, +// size: buffer.len, +// offset: buffer.offset as u64, +// }); +// r.capture_or_get(binding, &buffer.buffer.handle, || { +// Node::new(CArc::new(Instruction::Buffer), Type::void()) +// }) +// }); +// Self { +// node, +// handle: Some(buffer.buffer.handle.clone()), +// } +// } +// pub unsafe fn read_as(&self, index_bytes: impl IntoIndex) -> Expr { +// let i = index_bytes.to_u64(); +// Expr::::from_node(__current_scope(|b| { +// b.call( +// Func::ByteBufferRead, +// &[self.node, i.node], +// ::type_(), +// ) +// })) +// } +// pub fn len_bytes_expr(&self) -> Expr { +// Expr::::from_node(__current_scope(|b| { +// b.call(Func::ByteBufferSize, &[self.node], ::type_()) +// })) +// } +// pub unsafe fn write_as( +// &self, +// index_bytes: impl IntoIndex, +// value: impl Into>, +// ) { +// let i = index_bytes.to_u64(); +// let value: Expr = value.into(); +// __current_scope(|b| { +// b.call( +// Func::ByteBufferWrite, +// &[self.node, i.node, value.node()], +// Type::void(), +// ) +// }); +// } +// } +impl BufferVar { pub unsafe fn read_as(&self, index_bytes: impl IntoIndex) -> Expr { let i = index_bytes.to_u64(); Expr::::from_node(__current_scope(|b| { @@ -235,7 +273,6 @@ impl ByteBufferVar { }); } } - pub struct Buffer { pub(crate) device: Device, pub(crate) handle: Arc, @@ -569,24 +606,31 @@ impl BindlessArray { index: usize, bufferview: &ByteBufferView<'a>, ) { - self.lock(); - self.modifications - .borrow_mut() - .push(api::BindlessArrayUpdateModification { - slot: index, - buffer: api::BindlessArrayUpdateBuffer { - op: api::BindlessArrayUpdateOperation::Emplace, - handle: bufferview.handle(), - offset: bufferview.offset, - }, - tex2d: api::BindlessArrayUpdateTexture::default(), - tex3d: api::BindlessArrayUpdateTexture::default(), - }); - self.make_pending_slots(); - let mut pending = self.pending_slots.borrow_mut(); - pending[index].buffer = Some(bufferview.buffer.handle.clone()); - self.unlock(); - } + self.emplace_buffer_view_async(index, bufferview) + } + // pub fn emplace_byte_buffer_view_async<'a>( + // &self, + // index: usize, + // bufferview: &ByteBufferView<'a>, + // ) { + // self.lock(); + // self.modifications + // .borrow_mut() + // .push(api::BindlessArrayUpdateModification { + // slot: index, + // buffer: api::BindlessArrayUpdateBuffer { + // op: api::BindlessArrayUpdateOperation::Emplace, + // handle: bufferview.handle(), + // offset: bufferview.offset, + // }, + // tex2d: api::BindlessArrayUpdateTexture::default(), + // tex3d: api::BindlessArrayUpdateTexture::default(), + // }); + // self.make_pending_slots(); + // let mut pending = self.pending_slots.borrow_mut(); + // pending[index].buffer = Some(bufferview.buffer.handle.clone()); + // self.unlock(); + // } pub fn emplace_buffer_async(&self, index: usize, buffer: &Buffer) { self.emplace_buffer_view_async(index, &buffer.view(..)) } diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index ba5bf16..c550134 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -174,9 +174,9 @@ impl Device { }; swapchain } - pub fn create_byte_buffer(&self, len: usize) -> ByteBuffer { + pub fn create_byte_buffer(&self, len: usize) -> Buffer { let buffer = self.inner.create_buffer(&Type::void(), len); - let buffer = ByteBuffer { + let buffer = Buffer { device: self.clone(), handle: Arc::new(BufferHandle { device: self.clone(), @@ -184,6 +184,7 @@ impl Device { native_handle: buffer.resource.native_handle, }), len, + _marker: PhantomData, }; buffer } @@ -1060,18 +1061,18 @@ impl KernelArg for Buffer { encoder.buffer(self); } } -impl KernelArg for ByteBuffer { - type Parameter = ByteBufferVar; - fn encode(&self, encoder: &mut KernelArgEncoder) { - encoder.byte_buffer(self); - } -} -impl<'a> KernelArg for ByteBufferView<'a> { - type Parameter = ByteBufferVar; - fn encode(&self, encoder: &mut KernelArgEncoder) { - encoder.byte_buffer_view(self); - } -} +// impl KernelArg for ByteBuffer { +// type Parameter = ByteBufferVar; +// fn encode(&self, encoder: &mut KernelArgEncoder) { +// encoder.byte_buffer(self); +// } +// } +// impl<'a> KernelArg for ByteBufferView<'a> { +// type Parameter = ByteBufferVar; +// fn encode(&self, encoder: &mut KernelArgEncoder) { +// encoder.byte_buffer_view(self); +// } +// } impl KernelArg for T { type Parameter = Expr; fn encode(&self, encoder: &mut KernelArgEncoder) { @@ -1367,13 +1368,13 @@ impl<'a, T: Value> AsKernelArg> for BufferView<'a, T> {} impl<'a, T: Value> AsKernelArg> for Buffer {} -impl AsKernelArg for ByteBuffer {} +// impl AsKernelArg for ByteBuffer {} -impl<'a> AsKernelArg for ByteBufferView<'a> {} +// impl<'a> AsKernelArg for ByteBufferView<'a> {} -impl<'a> AsKernelArg> for ByteBufferView<'a> {} +// impl<'a> AsKernelArg> for ByteBufferView<'a> {} -impl<'a> AsKernelArg> for ByteBuffer {} +// impl<'a> AsKernelArg> for ByteBuffer {} impl<'a, T: IoTexel> AsKernelArg> for Tex2dView<'a, T> {} diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 7b0c6c5..a39d9f3 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -70,14 +70,14 @@ impl CallableParameter for BufferVar { encoder.buffer(self) } } -impl CallableParameter for ByteBufferVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.byte_buffer() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.byte_buffer(self) - } -} +// impl CallableParameter for ByteBufferVar { +// fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { +// builder.byte_buffer() +// } +// fn encode(&self, encoder: &mut CallableArgEncoder) { +// encoder.byte_buffer(self) +// } +// } impl CallableParameter for Tex2dVar { fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { builder.tex2d() @@ -129,12 +129,13 @@ impl KernelParameter for BufferVar { builder.buffer() } } -impl KernelParameter for ByteBufferVar { - type Arg = ByteBuffer; - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.byte_buffer() - } -} + +// impl KernelParameter for ByteBufferVar { +// type Arg = ByteBuffer; +// fn def_param(builder: &mut KernelBuilder) -> Self { +// builder.byte_buffer() +// } +// } impl KernelParameter for Tex2dVar { type Arg = Tex2d; @@ -223,14 +224,14 @@ impl KernelBuilder { self.args.push(node); FromNode::from_node(node) } - pub fn byte_buffer(&mut self) -> ByteBufferVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Buffer), Type::void()), - ); - self.args.push(node); - ByteBufferVar { node, handle: None } - } + // pub fn byte_buffer(&mut self) -> ByteBufferVar { + // let node = new_node( + // __module_pools(), + // Node::new(CArc::new(Instruction::Buffer), Type::void()), + // ); + // self.args.push(node); + // ByteBufferVar { node, handle: None } + // } pub fn buffer(&mut self) -> BufferVar { let node = new_node( __module_pools(), diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 8fe8dd1..5f272fa 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -37,6 +37,12 @@ fn autodiff_helper]) -> Expr>( f: F, ) { let device = get_device(); + if device.name() == "dx" { + // DX has limit on writable buffers + if n_inputs > 8 { + return; + } + } let inputs = (0..n_inputs) .map(|_| device.create_buffer::(repeats)) .collect::>(); diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index 8cd3cbf..6558625 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -880,6 +880,22 @@ struct Big { a: [f32; 32], } #[test] +fn buffer_u8() { + let device = get_device(); + if device.name() == "dx" { + return; + } + let buf = device.create_buffer::(1024); + let kernel = Kernel::::new( + &device, + &track!(|| { + let tid = dispatch_id().x; + buf.write(tid, (tid & 0xff).as_u8()); + }), + ); + kernel.dispatch([1024, 1, 1]); +} +#[test] fn byte_buffer() { let device = get_device(); let buf = device.create_byte_buffer(1024); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 8283e71..1961bee 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 8283e71d4f9fcede23b16b6c65ffc44dc480ff24 +Subproject commit 1961bee345934a04bdfc668441d06d108cd8de66 diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index 4f130a3..93fca1b 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -75,6 +75,7 @@ impl VisitMut for TraceVisitor { let flow_path = &self.flow_path; let trait_path = &self.trait_path; let span = node.span(); + match node { Expr::Assign(expr) => { let left = &expr.left; @@ -126,9 +127,21 @@ impl VisitMut for TraceVisitor { let body = &expr.body; let expr = &expr.expr; if let Expr::Range(range) = &**expr { - *node = parse_quote_spanned! {span=> - #flow_path::for_range(#range, |#pat| #body) + let attrs = &range.attrs; + // check if #[unroll] is present + let unroll = attrs.iter().any(|attr| { + attr.path().is_ident("unroll") + }); + if unroll { + *node = parse_quote_spanned! {span=> + #range.for_each(|#pat| #body) + } + } else { + *node = parse_quote_spanned! {span=> + #flow_path::for_range(#range, |#pat| #body) + } } + } } // Expr::Unary(op) => {