diff --git a/luisa_compute/examples/polymorphism.rs b/luisa_compute/examples/polymorphism.rs index 4ed1d5d..788b695 100644 --- a/luisa_compute/examples/polymorphism.rs +++ b/luisa_compute/examples/polymorphism.rs @@ -34,7 +34,17 @@ impl Area for SquareExpr { impl_polymorphic!(Area, Square); fn main() { let ctx = Context::new(current_exe().unwrap()); - let device = ctx.create_device("cpu"); + let args: Vec = std::env::args().collect(); + assert!( + args.len() <= 2, + "Usage: {} . : cpu, cuda, dx, metal, remote", + args[0] + ); + let device = ctx.create_device(if args.len() == 2 { + args[1].as_str() + } else { + "cpu" + }); let circles = device.create_buffer::(2); circles .view(..) diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 08a579f..a228d50 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -1397,6 +1397,27 @@ impl IndexRead for BindlessBufferVar { })) } } +impl IndexWrite for BindlessBufferVar { + fn write>(&self, i: I, value: V) { + let i = i.to_u64(); + if need_runtime_check() { + lc_assert!(i.lt(self.len_expr())); + } + let value = value.as_expr(); + __current_scope(|b| { + b.call( + Func::BindlessBufferWrite, + &[ + self.array, + self.buffer_index.node(), + ToNode::node(&i), + value.node(), + ], + Type::void(), + ) + }); + } +} impl BindlessBufferVar { pub fn len_expr(&self) -> Expr { let stride = (T::type_().size() as u64).expr(); @@ -1789,6 +1810,11 @@ impl BufferVar { __current_scope(|b| b.call(Func::BufferSize, &[self.node], u64::type_())).into(), ) } + pub fn len_expr_u32(&self) -> Expr { + FromNode::from_node( + __current_scope(|b| b.call(Func::BufferSize, &[self.node], u32::type_())).into(), + ) + } } macro_rules! impl_atomic { diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 182a085..2ad0783 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -175,6 +175,13 @@ impl Device { swapchain } pub fn create_byte_buffer(&self, len: usize) -> Buffer { + let name = self.name(); + if name == "dx" { + assert!( + len < u32::MAX as usize, + "numer of bytes must be less than u32::MAX on dx" + ); + } let buffer = self.inner.create_buffer(&Type::void(), len); let buffer = Buffer { device: self.clone(), @@ -199,6 +206,10 @@ impl Device { std::mem::align_of::() >= 4, "T must be aligned to 4 bytes on dx" ); + assert!( + count < u32::MAX as usize, + "count must be less than u32::MAX on dx" + ); } assert!( std::mem::size_of::() > 0, diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index dd32337..ceab4f2 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -874,7 +874,7 @@ fn uniform() { let expected = (x.len() as f32 - 1.0) * x.len() as f32 * 0.5 * 6.0; assert!((actual - expected).abs() < 1e-4); } -#[derive(Clone, Copy, Debug, Value)] +#[derive(Clone, Copy, Debug, Value, Default)] #[repr(C)] struct Big { a: [f32; 32], @@ -1279,6 +1279,27 @@ fn dispatch_async() { drop(kernel); } +#[test] +fn buffer_size() { + let device = get_device(); + let x = device.create_buffer::(1024); + x.fill(Big::default()); + let out = device.create_buffer::(1); + out.fill(2); + device + .create_kernel::(&track!(|| { + lc_assert!((x.len() as u64).eq(x.var().len_expr())); + lc_assert!((x.len() as u32).eq(x.var().len_expr().cast_u32())); + let tid = dispatch_id().x; + if tid == 0 { + out.write(0, x.var().len_expr_u32()); + } + })) + .dispatch([1024, 1, 1]); + let out = out.view(..).copy_to_vec(); + assert_eq!(out[0], 1024); +} + #[test] #[tracked] fn test_tracked() { diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index c291315..1b2b323 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit c2913151806bbe6175ef06692283bc85804e469f +Subproject commit 1b2b3238d2596ba0b8337f72e82a7d7640253d0b