From e5981d5a776b96eb6bc3e2120d1b95e77a2c7b24 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Thu, 2 Nov 2023 01:58:33 -0400 Subject: [PATCH] added external memory --- luisa_compute/src/runtime.rs | 48 ++++++++++++++++------------------ luisa_compute_sys/LuisaCompute | 2 +- 2 files changed, 23 insertions(+), 27 deletions(-) diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 81b59e9..4756c6d 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -1,4 +1,4 @@ -use std::any::Any; +use std::any::{Any, TypeId}; use std::cell::{Cell, RefCell}; use std::collections::{HashMap, HashSet}; use std::env; @@ -9,6 +9,7 @@ use std::path::PathBuf; use std::rc::Rc; use std::sync::{Arc, Weak}; +use libc::c_void; use parking_lot::lock_api::RawMutex as RawMutexTrait; use parking_lot::{Condvar, Mutex, RawMutex, RwLock}; @@ -177,31 +178,9 @@ impl Device { } /// Creates an **unintialized** buffer of `len` bytes. + /// Alias of [`Device::create_buffer::`]. pub fn create_byte_buffer(&self, len: usize) -> Buffer { - let name = self.name(); - if name == "dx" { - assert!( - len < u32::MAX as usize, - "numer of bytes must be less than u32::MAX on dx" - ); - } - let buffer = self.inner.create_buffer(&Type::void(), len); - let handle = Arc::new(BufferHandle { - device: self.clone(), - handle: api::Buffer(buffer.resource.handle), - native_handle: buffer.resource.native_handle, - }); - let buffer = Buffer { - handle: handle.clone(), - full_view: BufferView { - device: self.clone(), - handle: Arc::downgrade(&handle), - offset: 0, - len, - _marker: PhantomData, - }, - }; - buffer + self.create_buffer(len) } /// Creates an **unintialized** buffer of `count` elements of type `T` in SOA layout. @@ -233,6 +212,9 @@ impl Device { /// Creates an **unintialized** buffer of `count` elements of type `T`. pub fn create_buffer(&self, count: usize) -> Buffer { + self._create_buffer(std::ptr::null_mut(), count) + } + fn _create_buffer(&self, ext_mem: *mut c_void, count: usize) -> Buffer { let name = self.name(); if name == "dx" { assert!( @@ -248,7 +230,12 @@ impl Device { std::mem::size_of::() > 0, "size of T must be greater than 0" ); - let buffer = self.inner.create_buffer(&T::type_(), count); + let ty = if TypeId::of::() == TypeId::of::() { + Type::void() + } else { + ::type_() + }; + let buffer = self.inner.create_buffer(&ty, count, ext_mem); let handle = Arc::new(BufferHandle { device: self.clone(), handle: api::Buffer(buffer.resource.handle), @@ -266,6 +253,15 @@ impl Device { }; buffer } + + /// Imports an external buffer of `count` elements of type `T`. + pub unsafe fn create_buffer_from_external_memory( + &self, + data: *mut T, + count: usize, + ) -> Buffer { + self._create_buffer(data as *mut c_void, count) + } pub fn create_buffer_from_slice(&self, data: &[T]) -> Buffer { let buffer = self.create_buffer(data.len()); buffer.view(..).copy_from(data); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index ad49600..0554132 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit ad4960095abf6db31e2352a1ff5c06b5f5a2b472 +Subproject commit 0554132b41bceeb86d0f0d5aa27af97462e30998