Skip to content

Commit

Permalink
revert bytebuffer due to dx limitation
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 24, 2023
1 parent c1ca0f8 commit 15b8c51
Show file tree
Hide file tree
Showing 6 changed files with 295 additions and 23 deletions.
23 changes: 13 additions & 10 deletions luisa_compute/src/lang/ops/impls.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ where
log10 => Log10
}
fn is_finite(&self) -> Self::Bool {
!self.is_infinite().bitand(!self.is_nan())
(!self.is_infinite()).bitand(!self.is_nan())
}
fn is_infinite(&self) -> Self::Bool {
Func::IsInf.call(self.clone())
Expand All @@ -233,11 +233,11 @@ where
(self.sin(), self.cos())
}
}
impl<X: Linear> NormExpr for Expr<X>
impl<const N: usize, X: Floating> NormExpr for Expr<Vector<X, N>>
where
X::Scalar: Floating,
X: vector::VectorAlign<N>,
{
type Output = Expr<X::Scalar>;
type Output = Expr<X>;
impl_simple_fns! {
Self::Output,
norm => Length,
Expand Down Expand Up @@ -268,8 +268,11 @@ impl OuterProductExpr for Expr<Float4> {
Func::OuterProduct.call2(self.clone(), other.as_expr())
}
}
impl<X: Linear> ReduceExpr for Expr<X> {
type Output = Expr<X::Scalar>;
impl<const N: usize, X: VectorElement> ReduceExpr for Expr<Vector<X, N>>
where
X: vector::VectorAlign<N>,
{
type Output = Expr<X>;
impl_simple_fns! {
Self::Output,
reduce_max=>ReduceMax,
Expand All @@ -278,12 +281,12 @@ impl<X: Linear> ReduceExpr for Expr<X> {
reduce_sum=>ReduceSum
}
}
impl<X: Linear> DotExpr for Expr<X>
impl<const N: usize, X: Floating> DotExpr for Expr<Vector<X, N>>
where
X::Scalar: Floating,
X: vector::VectorAlign<N>,
{
type Value = X;
type Output = Expr<X::Scalar>;
type Value = Vector<X, N>;
type Output = Expr<X>;
fn dot(&self, other: impl AsExpr<Value = Self::Value>) -> Self::Output {
Func::Dot.call2(self.clone(), other.as_expr())
}
Expand Down
213 changes: 205 additions & 8 deletions luisa_compute/src/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,196 @@ use crate::runtime::*;

use api::{BufferDownloadCommand, BufferUploadCommand, INVALID_RESOURCE_HANDLE};
use libc::c_void;
pub struct ByteBuffer {
pub(crate) device: Device,
pub(crate) handle: Arc<BufferHandle>,
pub(crate) len: usize,
}
impl ByteBuffer {
pub fn len(&self) -> usize {
self.len
}
#[inline]
pub fn handle(&self) -> api::Buffer {
self.handle.handle
}
#[inline]
pub fn native_handle(&self) -> *mut c_void {
self.handle.native_handle
}
#[inline]
pub fn copy_from(&self, data: &[u8]) {
self.view(..).copy_from(data);
}
#[inline]
pub fn copy_from_async<'a>(&self, data: &[u8]) -> Command<'_> {
self.view(..).copy_from_async(data)
}
#[inline]
pub fn copy_to(&self, data: &mut [u8]) {
self.view(..).copy_to(data);
}
#[inline]
pub fn copy_to_async<'a>(&self, data: &'a mut [u8]) -> Command<'a> {
self.view(..).copy_to_async(data)
}
#[inline]
pub fn copy_to_vec(&self) -> Vec<u8> {
self.view(..).copy_to_vec()
}
#[inline]
pub fn copy_to_buffer(&self, dst: &ByteBuffer) {
self.view(..).copy_to_buffer(dst.view(..));
}
#[inline]
pub fn copy_to_buffer_async<'a>(&'a self, dst: &'a ByteBuffer) -> Command<'a> {
self.view(..).copy_to_buffer_async(dst.view(..))
}
#[inline]
pub fn fill_fn<F: FnMut(usize) -> u8>(&self, f: F) {
self.view(..).fill_fn(f);
}
#[inline]
pub fn fill(&self, value: u8) {
self.view(..).fill(value);
}
pub fn view<S: RangeBounds<usize>>(&self, range: S) -> ByteBufferView<'_> {
let lower = range.start_bound();
let upper = range.end_bound();
let lower = match lower {
std::ops::Bound::Included(&x) => x,
std::ops::Bound::Excluded(&x) => x + 1,
std::ops::Bound::Unbounded => 0,
};
let upper = match upper {
std::ops::Bound::Included(&x) => x + 1,
std::ops::Bound::Excluded(&x) => x,
std::ops::Bound::Unbounded => self.len,
};
assert!(lower <= upper);
assert!(upper <= self.len);
ByteBufferView {
buffer: self,
offset: lower,
len: upper - lower,
}
}
pub fn var(&self) -> ByteBufferVar {
ByteBufferVar::new(&self.view(..))
}
}
pub struct ByteBufferView<'a> {
pub(crate) buffer: &'a ByteBuffer,
pub(crate) offset: usize,
pub(crate) len: usize,
}
impl<'a> ByteBufferView<'a> {
pub fn handle(&self) -> api::Buffer {
self.buffer.handle()
}
pub fn copy_to_async<'b>(&'a self, data: &'b mut [u8]) -> Command<'b> {
assert_eq!(data.len(), self.len);
let mut rt = ResourceTracker::new();
rt.add(self.buffer.handle.clone());
Command {
inner: api::Command::BufferDownload(BufferDownloadCommand {
buffer: self.handle(),
offset: self.offset,
size: data.len(),
data: data.as_mut_ptr() as *mut u8,
}),
marker: PhantomData,
resource_tracker: rt,
callback: None,
}
}
pub fn copy_to_vec(&self) -> Vec<u8> {
let mut data = Vec::with_capacity(self.len);
unsafe {
let slice = std::slice::from_raw_parts_mut(data.as_mut_ptr(), self.len);
self.copy_to(slice);
data.set_len(self.len);
}
data
}
pub fn copy_to(&self, data: &mut [u8]) {
unsafe {
submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_async(data)]);
}
}

pub type ByteBuffer = Buffer<u8>;
pub type ByteBufferView<'a> = BufferView<'a, u8>;
pub type ByteBufferVar = BufferVar<u8>;

pub fn copy_from_async<'b>(&'a self, data: &'b [u8]) -> Command<'static> {
assert_eq!(data.len(), self.len);
let mut rt = ResourceTracker::new();
rt.add(self.buffer.handle.clone());
Command {
inner: api::Command::BufferUpload(BufferUploadCommand {
buffer: self.handle(),
offset: self.offset,
size: data.len(),
data: data.as_ptr() as *const u8,
}),
marker: PhantomData,
resource_tracker: rt,
callback: None,
}
}
pub fn copy_from(&self, data: &[u8]) {
submit_default_stream_and_sync(&self.buffer.device, [self.copy_from_async(data)]);
}
pub fn fill_fn<F: FnMut(usize) -> u8>(&self, f: F) {
self.copy_from(&(0..self.len).map(f).collect::<Vec<_>>());
}
pub fn fill(&self, value: u8) {
self.fill_fn(|_| value);
}
pub fn copy_to_buffer_async(&self, dst: ByteBufferView<'a>) -> Command<'static> {
assert_eq!(self.len, dst.len);
let mut rt = ResourceTracker::new();
rt.add(self.buffer.handle.clone());
rt.add(dst.buffer.handle.clone());
Command {
inner: api::Command::BufferCopy(api::BufferCopyCommand {
src: self.handle(),
src_offset: self.offset,
dst: dst.handle(),
dst_offset: dst.offset,
size: self.len,
}),
marker: PhantomData,
resource_tracker: rt,
callback: None,
}
}
pub fn copy_to_buffer(&self, dst: ByteBufferView<'a>) {
submit_default_stream_and_sync(&self.buffer.device, [self.copy_to_buffer_async(dst)]);
}
}
#[derive(Clone)]
pub struct ByteBufferVar {
#[allow(dead_code)]
pub(crate) handle: Option<Arc<BufferHandle>>,
pub(crate) node: NodeRef,
}
impl ByteBufferVar {
pub fn new(buffer: &ByteBufferView<'_>) -> Self {
let node = RECORDER.with(|r| {
let mut r = r.borrow_mut();
assert!(r.lock, "BufferVar must be created from within a kernel");
let binding = Binding::Buffer(BufferBinding {
handle: buffer.handle().0,
size: buffer.len,
offset: buffer.offset as u64,
});
r.capture_or_get(binding, &buffer.buffer.handle, || {
Node::new(CArc::new(Instruction::Buffer), Type::void())
})
});
Self {
node,
handle: Some(buffer.buffer.handle.clone()),
}
}
pub unsafe fn read_as<T: Value>(&self, index_bytes: impl IntoIndex) -> Expr<T> {
let i = index_bytes.to_u64();
Expr::<T>::from_node(__current_scope(|b| {
Expand Down Expand Up @@ -57,8 +241,6 @@ pub struct Buffer<T: Value> {
pub(crate) handle: Arc<BufferHandle>,
pub(crate) len: usize,
pub(crate) _marker: PhantomData<T>,
// big hack here
pub(crate) _is_byte_buffer: bool,
}
pub(crate) struct BufferHandle {
pub(crate) device: Device,
Expand Down Expand Up @@ -176,7 +358,6 @@ impl<T: Value> Buffer<T> {
handle: self.handle.clone(),
len: self.len,
_marker: PhantomData,
_is_byte_buffer: self._is_byte_buffer,
}
}
#[inline]
Expand Down Expand Up @@ -388,7 +569,23 @@ impl BindlessArray {
index: usize,
bufferview: &ByteBufferView<'a>,
) {
self.emplace_buffer_view_async(index, bufferview)
self.lock();
self.modifications
.borrow_mut()
.push(api::BindlessArrayUpdateModification {
slot: index,
buffer: api::BindlessArrayUpdateBuffer {
op: api::BindlessArrayUpdateOperation::Emplace,
handle: bufferview.handle(),
offset: bufferview.offset,
},
tex2d: api::BindlessArrayUpdateTexture::default(),
tex3d: api::BindlessArrayUpdateTexture::default(),
});
self.make_pending_slots();
let mut pending = self.pending_slots.borrow_mut();
pending[index].buffer = Some(bufferview.buffer.handle.clone());
self.unlock();
}
pub fn emplace_buffer_async<T: Value>(&self, index: usize, buffer: &Buffer<T>) {
self.emplace_buffer_view_async(index, &buffer.view(..))
Expand Down
Loading

0 comments on commit 15b8c51

Please sign in to comment.