diff --git a/luisa_compute/examples/vecadd.rs b/luisa_compute/examples/vecadd.rs index 0c1b0bb2..448023d5 100644 --- a/luisa_compute/examples/vecadd.rs +++ b/luisa_compute/examples/vecadd.rs @@ -24,7 +24,6 @@ fn main() { let z = device.create_buffer::(1024); x.view(..).fill_fn(|i| i as f32); y.view(..).fill_fn(|i| 1000.0 * i as f32); - let a = RefCell::new(1.0f32); let kernel = device.create_kernel::)>(track!(&|buf_z| { // z is pass by arg let buf_x = x.var(); // x and y are captured diff --git a/luisa_compute/src/printer.rs b/luisa_compute/src/printer.rs index cde50f67..eab2eb67 100644 --- a/luisa_compute/src/printer.rs +++ b/luisa_compute/src/printer.rs @@ -130,8 +130,8 @@ impl Printer { if_!( offset - .lt(data.len().cast::()) - .bitand((offset.add(1 + args.count as u32)).le(data.len().cast::())), + .lt(data.len_expr().cast::()) + .bitand((offset.add(1 + args.count as u32)).le(data.len_expr().cast::())), { data.atomic_fetch_add(0, 1); data.write(offset, item_id); diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 5288036e..386f1017 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -15,206 +15,12 @@ use crate::runtime::*; use api::{BufferDownloadCommand, BufferUploadCommand, INVALID_RESOURCE_HANDLE}; use libc::c_void; -pub struct ByteBuffer { - pub(crate) device: Device, - pub(crate) handle: Arc, - pub(crate) len: usize, -} -impl ByteBuffer { - pub fn len(&self) -> usize { - self.len - } - #[inline] - pub fn handle(&self) -> api::Buffer { - self.handle.handle - } - #[inline] - pub fn native_handle(&self) -> *mut c_void { - self.handle.native_handle - } - #[inline] - pub fn copy_from(&self, data: &[u8]) { - self.view(..).copy_from(data); - } - #[inline] - pub fn copy_from_async<'a>(&self, data: &[u8]) -> Command<'_> { - self.view(..).copy_from_async(data) - } - #[inline] - pub fn copy_to(&self, data: &mut [u8]) { - self.view(..).copy_to(data); - } - #[inline] - pub fn copy_to_async<'a>(&self, data: &'a mut [u8]) -> Command<'a> { - self.view(..).copy_to_async(data) - } - #[inline] - pub fn copy_to_vec(&self) -> Vec { - self.view(..).copy_to_vec() - } - #[inline] - pub fn copy_to_buffer(&self, dst: &ByteBuffer) { - self.view(..).copy_to_buffer(dst.view(..)); - } - #[inline] - pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a ByteBuffer) -> Command<'a> { - self.view(..).copy_to_buffer_async(dst.view(..)) - } - #[inline] - pub fn fill_fn u8>(&self, f: F) { - self.view(..).fill_fn(f); - } - #[inline] - pub fn fill(&self, value: u8) { - self.view(..).fill(value); - } - pub fn view>(&self, range: S) -> ByteBufferView<'_> { - let lower = range.start_bound(); - let upper = range.end_bound(); - let lower = match lower { - std::ops::Bound::Included(&x) => x, - std::ops::Bound::Excluded(&x) => x + 1, - std::ops::Bound::Unbounded => 0, - }; - let upper = match upper { - std::ops::Bound::Included(&x) => x + 1, - std::ops::Bound::Excluded(&x) => x, - std::ops::Bound::Unbounded => self.len, - }; - assert!(lower <= upper); - assert!(upper <= self.len); - ByteBufferView { - buffer: self, - offset: lower, - len: upper - lower, - } - } - pub fn var(&self) -> ByteBufferVar { - ByteBufferVar::new(&self.view(..)) - } -} -pub struct ByteBufferView<'a> { - pub(crate) buffer: &'a ByteBuffer, - pub(crate) offset: usize, - pub(crate) len: usize, -} -impl<'a> ByteBufferView<'a> { - pub fn handle(&self) -> api::Buffer { - self.buffer.handle() - } - pub fn copy_to_async<'b>(&'a self, data: &'b mut [u8]) -> Command<'b> { - assert_eq!(data.len(), self.len); - let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); - Command { - inner: api::Command::BufferDownload(BufferDownloadCommand { - buffer: self.handle(), - offset: self.offset, - size: data.len(), - data: data.as_mut_ptr() as *mut u8, - }), - marker: PhantomData, - resource_tracker: rt, - callback: None, - } - } - pub fn copy_to_vec(&self) -> Vec { - let mut data = Vec::with_capacity(self.len); - unsafe { - let slice = std::slice::from_raw_parts_mut(data.as_mut_ptr(), self.len); - self.copy_to(slice); - data.set_len(self.len); - } - data - } - pub fn copy_to(&self, data: &mut [u8]) { - unsafe { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_async(data)]); - } - } +pub type ByteBuffer = Buffer; +pub type ByteBufferView<'a> = BufferView<'a, u8>; +pub type ByteBufferVar = BufferVar; - pub fn copy_from_async<'b>(&'a self, data: &'b [u8]) -> Command<'static> { - assert_eq!(data.len(), self.len); - let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); - Command { - inner: api::Command::BufferUpload(BufferUploadCommand { - buffer: self.handle(), - offset: self.offset, - size: data.len(), - data: data.as_ptr() as *const u8, - }), - marker: PhantomData, - resource_tracker: rt, - callback: None, - } - } - pub fn copy_from(&self, data: &[u8]) { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_async(data)]); - } - pub fn fill_fn u8>(&self, f: F) { - self.copy_from(&(0..self.len).map(f).collect::>()); - } - pub fn fill(&self, value: u8) { - self.fill_fn(|_| value); - } - pub fn copy_to_buffer_async(&self, dst: ByteBufferView<'a>) -> Command<'static> { - assert_eq!(self.len, dst.len); - let mut rt = ResourceTracker::new(); - rt.add(self.buffer.handle.clone()); - rt.add(dst.buffer.handle.clone()); - Command { - inner: api::Command::BufferCopy(api::BufferCopyCommand { - src: self.handle(), - src_offset: self.offset, - dst: dst.handle(), - dst_offset: dst.offset, - size: self.len, - }), - marker: PhantomData, - resource_tracker: rt, - callback: None, - } - } - pub fn copy_to_buffer(&self, dst: ByteBufferView<'a>) { - submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(dst)]); - } -} -#[derive(Clone)] -pub struct ByteBufferVar { - #[allow(dead_code)] - pub(crate) handle: Option>, - pub(crate) node: NodeRef, -} impl ByteBufferVar { - pub fn new(buffer: &ByteBufferView<'_>) -> Self { - let node = RECORDER.with(|r| { - let mut r = r.borrow_mut(); - assert!(r.lock, "BufferVar must be created from within a kernel"); - let binding = Binding::Buffer(BufferBinding { - handle: buffer.handle().0, - size: buffer.len, - offset: buffer.offset as u64, - }); - if let Some((_, node, _, _)) = r.captured_buffer.get(&binding) { - *node - } else { - let node = new_node( - r.pools.as_ref().unwrap(), - Node::new(CArc::new(Instruction::Buffer), Type::void()), - ); - let i = r.captured_buffer.len(); - r.captured_buffer - .insert(binding, (i, node, binding, buffer.buffer.handle.clone())); - node - } - }); - Self { - node, - handle: Some(buffer.buffer.handle.clone()), - } - } - pub fn read(&self, index_bytes: impl IntoIndex) -> Expr { + pub unsafe fn read_as(&self, index_bytes: impl IntoIndex) -> Expr { let i = index_bytes.to_u64(); Expr::::from_node(__current_scope(|b| { b.call( @@ -224,12 +30,16 @@ impl ByteBufferVar { ) })) } - pub fn len(&self) -> Expr { + pub fn len_bytes_expr(&self) -> Expr { Expr::::from_node(__current_scope(|b| { b.call(Func::ByteBufferSize, &[self.node], ::type_()) })) } - pub fn write(&self, index_bytes: impl IntoIndex, value: impl Into>) { + pub unsafe fn write_as( + &self, + index_bytes: impl IntoIndex, + value: impl Into>, + ) { let i = index_bytes.to_u64(); let value: Expr = value.into(); __current_scope(|b| { @@ -241,11 +51,14 @@ impl ByteBufferVar { }); } } + pub struct Buffer { pub(crate) device: Device, pub(crate) handle: Arc, pub(crate) len: usize, pub(crate) _marker: PhantomData, + // big hack here + pub(crate) _is_byte_buffer: bool, } pub(crate) struct BufferHandle { pub(crate) device: Device, @@ -363,6 +176,7 @@ impl Buffer { handle: self.handle.clone(), len: self.len, _marker: PhantomData, + _is_byte_buffer: self._is_byte_buffer, } } #[inline] @@ -574,23 +388,7 @@ impl BindlessArray { index: usize, bufferview: &ByteBufferView<'a>, ) { - self.lock(); - self.modifications - .borrow_mut() - .push(api::BindlessArrayUpdateModification { - slot: index, - buffer: api::BindlessArrayUpdateBuffer { - op: api::BindlessArrayUpdateOperation::Emplace, - handle: bufferview.handle(), - offset: bufferview.offset, - }, - tex2d: api::BindlessArrayUpdateTexture::default(), - tex3d: api::BindlessArrayUpdateTexture::default(), - }); - self.make_pending_slots(); - let mut pending = self.pending_slots.borrow_mut(); - pending[index].buffer = Some(bufferview.buffer.handle.clone()); - self.unlock(); + self.emplace_buffer_view_async(index, bufferview) } pub fn emplace_buffer_async(&self, index: usize, buffer: &Buffer) { self.emplace_buffer_view_async(index, &buffer.view(..)) @@ -1346,7 +1144,7 @@ impl IndexRead for BindlessBufferVar { fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::::from_node(__current_scope(|b| { @@ -1359,7 +1157,7 @@ impl IndexRead for BindlessBufferVar { } } impl BindlessBufferVar { - pub fn len(&self) -> Expr { + pub fn len_expr(&self) -> Expr { let stride = (T::type_().size() as u64).expr(); Expr::::from_node(__current_scope(|b| { b.call( @@ -1390,7 +1188,7 @@ impl ToNode for BindlessByteBufferVar { } } impl BindlessByteBufferVar { - pub fn read(&self, index_bytes: impl IntoIndex) -> Expr { + pub unsafe fn read_as(&self, index_bytes: impl IntoIndex) -> Expr { let i = index_bytes.to_u64(); Expr::::from_node(__current_scope(|b| { b.call( @@ -1694,7 +1492,7 @@ impl IndexRead for BufferVar { fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::::from_node(__current_scope(|b| { b.call(Func::BufferRead, &[self.node, ToNode::node(&i)], T::type_()) @@ -1706,7 +1504,7 @@ impl IndexWrite for BufferVar { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } __current_scope(|b| { b.call( @@ -1732,7 +1530,10 @@ impl BufferVar { } else { let node = new_node( r.pools.as_ref().unwrap(), - Node::new(CArc::new(Instruction::Buffer), T::type_()), + Node::new( + CArc::new(Instruction::Buffer), + T::type_() + ), ); let i = r.captured_buffer.len(); r.captured_buffer @@ -1752,7 +1553,7 @@ impl BufferVar { b.call_no_append(Func::AtomicRef, &[self.node, i.node()], T::type_()) })) } - pub fn len(&self) -> Expr { + pub fn len_expr(&self) -> Expr { FromNode::from_node( __current_scope(|b| b.call(Func::BufferSize, &[self.node], u64::type_())).into(), ) @@ -1770,7 +1571,7 @@ macro_rules! impl_atomic { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1794,7 +1595,7 @@ macro_rules! impl_atomic { let expected = expected.as_expr(); let desired = desired.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1812,7 +1613,7 @@ macro_rules! impl_atomic { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1830,7 +1631,7 @@ macro_rules! impl_atomic { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1848,7 +1649,7 @@ macro_rules! impl_atomic { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1866,7 +1667,7 @@ macro_rules! impl_atomic { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1890,7 +1691,7 @@ macro_rules! impl_atomic_bit { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1908,7 +1709,7 @@ macro_rules! impl_atomic_bit { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1926,7 +1727,7 @@ macro_rules! impl_atomic_bit { let i = i.to_u64(); let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.lt(self.len())); + lc_assert!(i.lt(self.len_expr())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 5593ba9b..218afa93 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -170,7 +170,7 @@ impl Device { } pub fn create_byte_buffer(&self, len: usize) -> ByteBuffer { let buffer = self.inner.create_buffer(&Type::void(), len); - let buffer = ByteBuffer { + let buffer = Buffer:: { device: self.clone(), handle: Arc::new(BufferHandle { device: self.clone(), @@ -178,6 +178,8 @@ impl Device { native_handle: buffer.resource.native_handle, }), len, + _marker: PhantomData {}, + _is_byte_buffer: true, }; buffer } @@ -200,6 +202,7 @@ impl Device { }), _marker: PhantomData {}, len: count, + _is_byte_buffer:false, }; buffer } @@ -951,9 +954,6 @@ impl CallableArgEncoder { pub fn buffer(&mut self, buffer: &BufferVar) { self.args.push(buffer.node); } - pub fn byte_buffer(&mut self, buffer: &ByteBufferVar) { - self.args.push(buffer.node); - } pub fn tex2d(&mut self, tex2d: &Tex2dVar) { self.args.push(tex2d.node); } @@ -1017,20 +1017,6 @@ impl KernelArgEncoder { size: buffer.len * std::mem::size_of::(), })); } - pub fn byte_buffer(&mut self, buffer: &ByteBuffer) { - self.args.push(api::Argument::Buffer(api::BufferArgument { - buffer: buffer.handle.handle, - offset: 0, - size: buffer.len, - })); - } - pub fn byte_buffer_view(&mut self, buffer: &ByteBufferView) { - self.args.push(api::Argument::Buffer(api::BufferArgument { - buffer: buffer.handle(), - offset: buffer.offset, - size: buffer.len, - })); - } pub fn tex2d(&mut self, tex: &Tex2dView) { self.args.push(api::Argument::Texture(api::TextureArgument { texture: tex.handle(), @@ -1063,18 +1049,6 @@ impl KernelArg for Buffer { encoder.buffer(self); } } -impl KernelArg for ByteBuffer { - type Parameter = ByteBufferVar; - fn encode(&self, encoder: &mut KernelArgEncoder) { - encoder.byte_buffer(self); - } -} -impl<'a> KernelArg for ByteBufferView<'a> { - type Parameter = ByteBufferVar; - fn encode(&self, encoder: &mut KernelArgEncoder) { - encoder.byte_buffer_view(self); - } -} impl KernelArg for T { type Parameter = Expr; fn encode(&self, encoder: &mut KernelArgEncoder) { @@ -1309,14 +1283,6 @@ impl<'a, T: Value> AsKernelArg> for BufferView<'a, T> {} impl<'a, T: Value> AsKernelArg> for Buffer {} -impl AsKernelArg for ByteBuffer {} - -impl<'a> AsKernelArg for ByteBufferView<'a> {} - -impl<'a> AsKernelArg> for ByteBufferView<'a> {} - -impl<'a> AsKernelArg> for ByteBuffer {} - impl<'a, T: IoTexel> AsKernelArg> for Tex2dView<'a, T> {} impl<'a, T: IoTexel> AsKernelArg> for Tex3dView<'a, T> {} diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index ed2f41ab..ea118b9c 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -70,14 +70,6 @@ impl CallableParameter for BufferVar { encoder.buffer(self) } } -impl CallableParameter for ByteBufferVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.byte_buffer() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.byte_buffer(self) - } -} impl CallableParameter for Tex2dVar { fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { builder.tex2d() @@ -119,11 +111,7 @@ impl KernelParameter for Expr { builder.uniform::() } } -impl KernelParameter for ByteBufferVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.byte_buffer() - } -} + impl KernelParameter for BufferVar { fn def_param(builder: &mut KernelBuilder) -> Self { builder.buffer() @@ -212,14 +200,6 @@ impl KernelBuilder { self.args.push(node); FromNode::from_node(node) } - pub fn byte_buffer(&mut self) -> ByteBufferVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Buffer), Type::void()), - ); - self.args.push(node); - ByteBufferVar { node, handle: None } - } pub fn buffer(&mut self) -> BufferVar { let node = new_node( __module_pools(), diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index a0cbf139..f10b83a4 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -278,7 +278,7 @@ fn test_var_replace() { let kernel = device.create_kernel::(&track!(|| { let tid = dispatch_id().x; let x = xs.var().read(tid).var(); - *x = Int4::expr(1,2,3,4); + *x = Int4::expr(1, 2, 3, 4); let y = **x; *x.y = 10; *x.z = 20; @@ -758,26 +758,26 @@ fn byte_buffer() { let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); device - .create_kernel::(&track!(|| { + .create_kernel::(&track!(|| unsafe { let buf = buf.var(); let i0 = i0 as u64; let i1 = i1 as u64; let i2 = i2 as u64; let i3 = i3 as u64; - let v0 = buf.read::(i0).var(); - let v1 = buf.read::(i1).var(); - let v2 = buf.read::(i2).var(); - let v3 = buf.read::(i3).var(); + let v0 = buf.read_as::(i0).var(); + let v1 = buf.read_as::(i1).var(); + let v2 = buf.read_as::(i2).var(); + let v3 = buf.read_as::(i3).var(); *v0 = Float3::expr(1.0, 2.0, 3.0); for_range(0u32..32u32, |i| { v1.a.write(i, i.as_f32() * 2.0); }); *v2 = 1i32.expr(); *v3 = 2.0.expr(); - buf.write::(i0, v0.load()); - buf.write::(i1, v1.load()); - buf.write::(i2, v2.load()); - buf.write::(i3, v3.load()); + buf.write_as::(i0, v0.load()); + buf.write_as::(i1, v1.load()); + buf.write_as::(i2, v2.load()); + buf.write_as::(i3, v3.load()); })) .dispatch([1, 1, 1]); let data = buf.copy_to_vec(); @@ -833,27 +833,27 @@ fn bindless_byte_buffer() { let i2 = push!(i32, 0i32); let i3 = push!(f32, 1f32); device - .create_kernel::(&track!(|out: ByteBufferVar| { + .create_kernel::(&track!(|out: ByteBufferVar| unsafe { let heap = heap.var(); let buf = heap.byte_address_buffer(0u32); let i0 = i0 as u64; let i1 = i1 as u64; let i2 = i2 as u64; let i3 = i3 as u64; - let v0 = buf.read::(i0).var(); - let v1 = buf.read::(i1).var(); - let v2 = buf.read::(i2).var(); - let v3 = buf.read::(i3).var(); + let v0 = buf.read_as::(i0).var(); + let v1 = buf.read_as::(i1).var(); + let v2 = buf.read_as::(i2).var(); + let v3 = buf.read_as::(i3).var(); *v0 = Float3::expr(1.0, 2.0, 3.0); for_range(0u32..32u32, |i| { v1.a.write(i, i.as_f32() * 2.0); }); *v2 = 1i32.expr(); *v3 = 2.0.expr(); - out.write::(i0, v0.load()); - out.write::(i1, v1.load()); - out.write::(i2, v2.load()); - out.write::(i3, v3.load()); + out.write_as::(i0, v0.load()); + out.write_as::(i1, v1.load()); + out.write_as::(i2, v2.load()); + out.write_as::(i3, v3.load()); })) .dispatch([1, 1, 1], &out); let data = out.copy_to_vec(); @@ -952,5 +952,4 @@ fn atomic() { } assert_eq!(foo_max, expected_foo_max); assert_eq!(foo_min, expected_foo_min); - } diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 1272ab89..d55d75f4 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 1272ab89c8cddb73189f439ef5fcd6c1737b45d2 +Subproject commit d55d75f4ce7419ecd482d2375a2a1317ddc6a9b1