From c3d317e18c99371c4248c12e193ee51315c268f4 Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Tue, 19 Sep 2023 08:38:36 -0400 Subject: [PATCH] update submod --- luisa_compute/Cargo.toml | 2 ++ luisa_compute/examples/autodiff.rs | 8 ++++---- luisa_compute/src/lang.rs | 1 + luisa_compute/src/lang/soa.rs | 9 +++++++++ luisa_compute/src/runtime.rs | 5 +++++ luisa_compute_derive_impl/src/lib.rs | 1 + luisa_compute_sys/LuisaCompute | 2 +- luisa_compute_track/src/lib.rs | 4 +++- 8 files changed, 26 insertions(+), 6 deletions(-) create mode 100644 luisa_compute/src/lang/soa.rs diff --git a/luisa_compute/Cargo.toml b/luisa_compute/Cargo.toml index dbf69d2..39baf2b 100644 --- a/luisa_compute/Cargo.toml +++ b/luisa_compute/Cargo.toml @@ -43,3 +43,5 @@ dx = ["luisa_compute_sys/dx"] strict = ["luisa_compute_sys/strict"] remote = ["luisa_compute_sys/remote"] cpu = ["luisa_compute_sys/cpu"] + + diff --git a/luisa_compute/examples/autodiff.rs b/luisa_compute/examples/autodiff.rs index 720e858..c0ba51e 100644 --- a/luisa_compute/examples/autodiff.rs +++ b/luisa_compute/examples/autodiff.rs @@ -33,13 +33,13 @@ fn main() { let buf_y = y.var(); let x = buf_x.read(tid); let y = buf_y.read(tid); - let f = track!(|x: Expr, y: Expr| { - if x > y { + let f = |x: Expr, y: Expr| { + track!(if x > y { x * y } else { y * x + (x / 32.0 * PI).sin() - } - }); + }) + }; autodiff(|| { requires_grad(x); requires_grad(y); diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index dae2d7e..bfc46c9 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -34,6 +34,7 @@ pub mod ops; pub mod poly; pub mod swizzle; pub mod types; +pub mod soa; #[allow(dead_code)] pub(crate) static KERNEL_ID: AtomicUsize = AtomicUsize::new(0); diff --git a/luisa_compute/src/lang/soa.rs b/luisa_compute/src/lang/soa.rs new file mode 100644 index 0000000..d7d4e62 --- /dev/null +++ b/luisa_compute/src/lang/soa.rs @@ -0,0 +1,9 @@ +use luisa_compute_ir::{ir::Type, CArc}; + +use crate::prelude::*; +/** A buffer with SOA layout. + */ +pub struct SoaBuffer { + inner: ByteBuffer, + _marker: std::marker::PhantomData, +} diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 90dc02b..cf2c604 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -16,6 +16,7 @@ use raw_window_handle::HasRawWindowHandle; use winit::window::Window; use crate::internal_prelude::*; +use crate::lang::soa::SoaBuffer; use ir::{ CallableModule, CallableModuleRef, Capture, CpuCustomOp, KernelModule, Module, ModuleFlags, ModuleKind, ModulePools, @@ -180,6 +181,10 @@ impl Device { }; buffer } + pub fn create_soa_buffer(&self, count: usize) -> SoaBuffer { + // let inner = self.create_byte_buffer(len) + todo!() + } pub fn create_buffer(&self, count: usize) -> Buffer { assert!( std::mem::size_of::() > 0, diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index eca0e83..ca83977 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -109,6 +109,7 @@ impl Compiler { } ) } + pub fn derive_value(&self, struct_: &ItemStruct) -> TokenStream { self.check_repr_c(&struct_.attrs); let span = struct_.span(); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index b8b284f..f4e8dab 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit b8b284f5c8b8ee4470298f362a1cd7f9a7c79698 +Subproject commit f4e8dabfd1c5d8ce0923b1f81f2dc8ef3ea8f68e diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index 3129809..09ab1ad 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -171,7 +171,9 @@ impl VisitMut for TraceVisitor { #[proc_macro] pub fn track(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - track_impl(parse_macro_input!(input as Expr)).into() + let input: TokenStream = input.into(); + let block_input: proc_macro::TokenStream = quote!({ #input }).into(); + track_impl(parse_macro_input!(block_input as Expr)).into() } fn track_impl(mut ast: Expr) -> TokenStream {