From 15b8c510f61a2e8874ee22fd98f25103bfcebd01 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sun, 24 Sep 2023 08:20:29 -0400 Subject: [PATCH] revert bytebuffer due to dx limitation --- luisa_compute/src/lang/ops/impls.rs | 23 +-- luisa_compute/src/resource.rs | 213 ++++++++++++++++++++++++++-- luisa_compute/src/runtime.rs | 49 ++++++- luisa_compute/src/runtime/kernel.rs | 22 +++ luisa_compute/tests/misc.rs | 9 ++ luisa_compute_sys/LuisaCompute | 2 +- 6 files changed, 295 insertions(+), 23 deletions(-) diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 27ebb8b..e8799e8 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -209,7 +209,7 @@ where log10 => Log10 } fn is_finite(&self) -> Self::Bool { - !self.is_infinite().bitand(!self.is_nan()) + (!self.is_infinite()).bitand(!self.is_nan()) } fn is_infinite(&self) -> Self::Bool { Func::IsInf.call(self.clone()) @@ -233,11 +233,11 @@ where (self.sin(), self.cos()) } } -impl NormExpr for Expr +impl NormExpr for Expr> where - X::Scalar: Floating, + X: vector::VectorAlign, { - type Output = Expr; + type Output = Expr; impl_simple_fns! { Self::Output, norm => Length, @@ -268,8 +268,11 @@ impl OuterProductExpr for Expr { Func::OuterProduct.call2(self.clone(), other.as_expr()) } } -impl ReduceExpr for Expr { - type Output = Expr; +impl ReduceExpr for Expr> +where + X: vector::VectorAlign, +{ + type Output = Expr; impl_simple_fns! { Self::Output, reduce_max=>ReduceMax, @@ -278,12 +281,12 @@ impl ReduceExpr for Expr { reduce_sum=>ReduceSum } } -impl DotExpr for Expr +impl DotExpr for Expr> where - X::Scalar: Floating, + X: vector::VectorAlign, { - type Value = X; - type Output = Expr; + type Value = Vector; + type Output = Expr; fn dot(&self, other: impl AsExpr) -> Self::Output { Func::Dot.call2(self.clone(), other.as_expr()) } diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 1322076..ae822f9 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -14,12 +14,196 @@ use crate::runtime::*; use api::{BufferDownloadCommand, BufferUploadCommand, INVALID_RESOURCE_HANDLE}; use libc::c_void; +pub struct ByteBuffer { + pub(crate) device: Device, + pub(crate) handle: Arc, + pub(crate) len: usize, +} +impl ByteBuffer { + pub fn len(&self) -> usize { + self.len + } + #[inline] + pub fn handle(&self) -> api::Buffer { + self.handle.handle + } + #[inline] + pub fn native_handle(&self) -> *mut c_void { + self.handle.native_handle + } + #[inline] + pub fn copy_from(&self, data: &[u8]) { + self.view(..).copy_from(data); + } + #[inline] + pub fn copy_from_async<'a>(&self, data: &[u8]) -> Command<'_> { + self.view(..).copy_from_async(data) + } + #[inline] + pub fn copy_to(&self, data: &mut [u8]) { + self.view(..).copy_to(data); + } + #[inline] + pub fn copy_to_async<'a>(&self, data: &'a mut [u8]) -> Command<'a> { + self.view(..).copy_to_async(data) + } + #[inline] + pub fn copy_to_vec(&self) -> Vec { + self.view(..).copy_to_vec() + } + #[inline] + pub fn copy_to_buffer(&self, dst: &ByteBuffer) { + self.view(..).copy_to_buffer(dst.view(..)); + } + #[inline] + pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a ByteBuffer) -> Command<'a> { + self.view(..).copy_to_buffer_async(dst.view(..)) + } + #[inline] + pub fn fill_fn u8>(&self, f: F) { + self.view(..).fill_fn(f); + } + #[inline] + pub fn fill(&self, value: u8) { + self.view(..).fill(value); + } + pub fn view>(&self, range: S) -> ByteBufferView<'_> { + let lower = range.start_bound(); + let upper = range.end_bound(); + let lower = match lower { + std::ops::Bound::Included(&x) => x, + std::ops::Bound::Excluded(&x) => x + 1, + std::ops::Bound::Unbounded => 0, + }; + let upper = match upper { + std::ops::Bound::Included(&x) => x + 1, + std::ops::Bound::Excluded(&x) => x, + std::ops::Bound::Unbounded => self.len, + }; + assert!(lower <= upper); + assert!(upper <= self.len); + ByteBufferView { + buffer: self, + offset: lower, + len: upper - lower, + } + } + pub fn var(&self) -> ByteBufferVar { + ByteBufferVar::new(&self.view(..)) + } +} +pub struct ByteBufferView<'a> { + pub(crate) buffer: &'a ByteBuffer, + pub(crate) offset: usize, + pub(crate) len: usize, +} +impl<'a> ByteBufferView<'a> { + pub fn handle(&self) -> api::Buffer { + self.buffer.handle() + } + pub fn copy_to_async<'b>(&'a self, data: &'b mut [u8]) -> Command<'b> { + assert_eq!(data.len(), self.len); + let mut rt = ResourceTracker::new(); + rt.add(self.buffer.handle.clone()); + Command { + inner: api::Command::BufferDownload(BufferDownloadCommand { + buffer: self.handle(), + offset: self.offset, + size: data.len(), + data: data.as_mut_ptr() as *mut u8, + }), + marker: PhantomData, + resource_tracker: rt, + callback: None, + } + } + pub fn copy_to_vec(&self) -> Vec { + let mut data = Vec::with_capacity(self.len); + unsafe { + let slice = std::slice::from_raw_parts_mut(data.as_mut_ptr(), self.len); + self.copy_to(slice); + data.set_len(self.len); + } + data + } + pub fn copy_to(&self, data: &mut [u8]) { + unsafe { + submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_async(data)]); + } + } -pub type ByteBuffer = Buffer; -pub type ByteBufferView<'a> = BufferView<'a, u8>; -pub type ByteBufferVar = BufferVar; - + pub fn copy_from_async<'b>(&'a self, data: &'b [u8]) -> Command<'static> { + assert_eq!(data.len(), self.len); + let mut rt = ResourceTracker::new(); + rt.add(self.buffer.handle.clone()); + Command { + inner: api::Command::BufferUpload(BufferUploadCommand { + buffer: self.handle(), + offset: self.offset, + size: data.len(), + data: data.as_ptr() as *const u8, + }), + marker: PhantomData, + resource_tracker: rt, + callback: None, + } + } + pub fn copy_from(&self, data: &[u8]) { + submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_async(data)]); + } + pub fn fill_fn u8>(&self, f: F) { + self.copy_from(&(0..self.len).map(f).collect::>()); + } + pub fn fill(&self, value: u8) { + self.fill_fn(|_| value); + } + pub fn copy_to_buffer_async(&self, dst: ByteBufferView<'a>) -> Command<'static> { + assert_eq!(self.len, dst.len); + let mut rt = ResourceTracker::new(); + rt.add(self.buffer.handle.clone()); + rt.add(dst.buffer.handle.clone()); + Command { + inner: api::Command::BufferCopy(api::BufferCopyCommand { + src: self.handle(), + src_offset: self.offset, + dst: dst.handle(), + dst_offset: dst.offset, + size: self.len, + }), + marker: PhantomData, + resource_tracker: rt, + callback: None, + } + } + pub fn copy_to_buffer(&self, dst: ByteBufferView<'a>) { + submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(dst)]); + } +} +#[derive(Clone)] +pub struct ByteBufferVar { + #[allow(dead_code)] + pub(crate) handle: Option>, + pub(crate) node: NodeRef, +} impl ByteBufferVar { + pub fn new(buffer: &ByteBufferView<'_>) -> Self { + let node = RECORDER.with(|r| { + let mut r = r.borrow_mut(); + assert!(r.lock, "BufferVar must be created from within a kernel"); + let binding = Binding::Buffer(BufferBinding { + handle: buffer.handle().0, + size: buffer.len, + offset: buffer.offset as u64, + }); + r.capture_or_get(binding, &buffer.buffer.handle, || { + Node::new(CArc::new(Instruction::Buffer), Type::void()) + }) + }); + Self { + node, + handle: Some(buffer.buffer.handle.clone()), + } + } pub unsafe fn read_as(&self, index_bytes: impl IntoIndex) -> Expr { let i = index_bytes.to_u64(); Expr::::from_node(__current_scope(|b| { @@ -57,8 +241,6 @@ pub struct Buffer { pub(crate) handle: Arc, pub(crate) len: usize, pub(crate) _marker: PhantomData, - // big hack here - pub(crate) _is_byte_buffer: bool, } pub(crate) struct BufferHandle { pub(crate) device: Device, @@ -176,7 +358,6 @@ impl Buffer { handle: self.handle.clone(), len: self.len, _marker: PhantomData, - _is_byte_buffer: self._is_byte_buffer, } } #[inline] @@ -388,7 +569,23 @@ impl BindlessArray { index: usize, bufferview: &ByteBufferView<'a>, ) { - self.emplace_buffer_view_async(index, bufferview) + self.lock(); + self.modifications + .borrow_mut() + .push(api::BindlessArrayUpdateModification { + slot: index, + buffer: api::BindlessArrayUpdateBuffer { + op: api::BindlessArrayUpdateOperation::Emplace, + handle: bufferview.handle(), + offset: bufferview.offset, + }, + tex2d: api::BindlessArrayUpdateTexture::default(), + tex3d: api::BindlessArrayUpdateTexture::default(), + }); + self.make_pending_slots(); + let mut pending = self.pending_slots.borrow_mut(); + pending[index].buffer = Some(bufferview.buffer.handle.clone()); + self.unlock(); } pub fn emplace_buffer_async(&self, index: usize, buffer: &Buffer) { self.emplace_buffer_view_async(index, &buffer.view(..)) diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index f83aecf..ba5bf16 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -176,7 +176,7 @@ impl Device { } pub fn create_byte_buffer(&self, len: usize) -> ByteBuffer { let buffer = self.inner.create_buffer(&Type::void(), len); - let buffer = Buffer:: { + let buffer = ByteBuffer { device: self.clone(), handle: Arc::new(BufferHandle { device: self.clone(), @@ -184,8 +184,6 @@ impl Device { native_handle: buffer.resource.native_handle, }), len, - _marker: PhantomData {}, - _is_byte_buffer: true, }; buffer } @@ -194,6 +192,13 @@ impl Device { todo!() } pub fn create_buffer(&self, count: usize) -> Buffer { + let name = self.name(); + if name == "dx" { + assert!( + std::mem::align_of::() >= 4, + "T must be aligned to 4 bytes on dx" + ); + } assert!( std::mem::size_of::() > 0, "size of T must be greater than 0" @@ -208,7 +213,6 @@ impl Device { }), _marker: PhantomData {}, len: count, - _is_byte_buffer: false, }; buffer } @@ -944,6 +948,9 @@ impl CallableArgEncoder { pub fn buffer(&mut self, buffer: &BufferVar) { self.args.push(buffer.node); } + pub fn byte_buffer(&mut self, buffer: &ByteBufferVar) { + self.args.push(buffer.node); + } pub fn tex2d(&mut self, tex2d: &Tex2dVar) { self.args.push(tex2d.node); } @@ -1007,6 +1014,20 @@ impl KernelArgEncoder { size: buffer.len * std::mem::size_of::(), })); } + pub fn byte_buffer(&mut self, buffer: &ByteBuffer) { + self.args.push(api::Argument::Buffer(api::BufferArgument { + buffer: buffer.handle.handle, + offset: 0, + size: buffer.len, + })); + } + pub fn byte_buffer_view(&mut self, buffer: &ByteBufferView) { + self.args.push(api::Argument::Buffer(api::BufferArgument { + buffer: buffer.handle(), + offset: buffer.offset, + size: buffer.len, + })); + } pub fn tex2d(&mut self, tex: &Tex2dView) { self.args.push(api::Argument::Texture(api::TextureArgument { texture: tex.handle(), @@ -1039,6 +1060,18 @@ impl KernelArg for Buffer { encoder.buffer(self); } } +impl KernelArg for ByteBuffer { + type Parameter = ByteBufferVar; + fn encode(&self, encoder: &mut KernelArgEncoder) { + encoder.byte_buffer(self); + } +} +impl<'a> KernelArg for ByteBufferView<'a> { + type Parameter = ByteBufferVar; + fn encode(&self, encoder: &mut KernelArgEncoder) { + encoder.byte_buffer_view(self); + } +} impl KernelArg for T { type Parameter = Expr; fn encode(&self, encoder: &mut KernelArgEncoder) { @@ -1334,6 +1367,14 @@ impl<'a, T: Value> AsKernelArg> for BufferView<'a, T> {} impl<'a, T: Value> AsKernelArg> for Buffer {} +impl AsKernelArg for ByteBuffer {} + +impl<'a> AsKernelArg for ByteBufferView<'a> {} + +impl<'a> AsKernelArg> for ByteBufferView<'a> {} + +impl<'a> AsKernelArg> for ByteBuffer {} + impl<'a, T: IoTexel> AsKernelArg> for Tex2dView<'a, T> {} impl<'a, T: IoTexel> AsKernelArg> for Tex3dView<'a, T> {} diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 1e6fe1b..7b0c6c5 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -70,6 +70,14 @@ impl CallableParameter for BufferVar { encoder.buffer(self) } } +impl CallableParameter for ByteBufferVar { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.byte_buffer() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.byte_buffer(self) + } +} impl CallableParameter for Tex2dVar { fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { builder.tex2d() @@ -121,6 +129,12 @@ impl KernelParameter for BufferVar { builder.buffer() } } +impl KernelParameter for ByteBufferVar { + type Arg = ByteBuffer; + fn def_param(builder: &mut KernelBuilder) -> Self { + builder.byte_buffer() + } +} impl KernelParameter for Tex2dVar { type Arg = Tex2d; @@ -209,6 +223,14 @@ impl KernelBuilder { self.args.push(node); FromNode::from_node(node) } + pub fn byte_buffer(&mut self) -> ByteBufferVar { + let node = new_node( + __module_pools(), + Node::new(CArc::new(Instruction::Buffer), Type::void()), + ); + self.args.push(node); + ByteBufferVar { node, handle: None } + } pub fn buffer(&mut self) -> BufferVar { let node = new_node( __module_pools(), diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index e7733c3..8cd3cbf 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -285,6 +285,9 @@ fn vec_cast() { #[test] fn bool_op() { let device = get_device(); + if device.name() == "dx" { + return; + } let x: Buffer = device.create_buffer(1024); let y: Buffer = device.create_buffer(1024); let and: Buffer = device.create_buffer(1024); @@ -329,6 +332,9 @@ fn bool_op() { #[test] fn bvec_op() { let device = get_device(); + if device.name() == "dx" { + return; + } let x: Buffer = device.create_buffer(1024); let y: Buffer = device.create_buffer(1024); let and: Buffer = device.create_buffer(1024); @@ -505,6 +511,9 @@ fn vec_permute() { #[test] fn if_phi() { let device = get_device(); + if device.name() == "dx" { + return; + } let x: Buffer = device.create_buffer(1024); let even: Buffer = device.create_buffer(1024); x.view(..).fill_fn(|i| i as i32); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index c92380d..8283e71 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit c92380d9d69a814383ad8a80b14fcacbffd24ef8 +Subproject commit 8283e71d4f9fcede23b16b6c65ffc44dc480ff24