From 1dbb2e4d38bed79836900d3752364b5d39e14feb Mon Sep 17 00:00:00 2001 From: ReversedCausality Date: Fri, 27 Oct 2023 18:45:19 +0100 Subject: [PATCH] Added `IoTexel` derive for simple structs. --- luisa_compute/examples/custom_op.rs | 7 +++ luisa_compute_derive/src/lib.rs | 8 +++ luisa_compute_derive_impl/src/lib.rs | 83 +++++++++++++++++++++++----- 3 files changed, 83 insertions(+), 15 deletions(-) diff --git a/luisa_compute/examples/custom_op.rs b/luisa_compute/examples/custom_op.rs index 6383550..08be4bc 100644 --- a/luisa_compute/examples/custom_op.rs +++ b/luisa_compute/examples/custom_op.rs @@ -3,6 +3,13 @@ use std::env::current_exe; use luisa::lang::external::CpuFn; use luisa::prelude::*; use luisa_compute as luisa; + +#[derive(Clone, Copy, Value, IoTexel, Debug)] +#[repr(transparent)] +struct Foo { + x: u32, +} + #[derive(Clone, Copy, Value, Debug)] #[repr(C)] #[value_new(pub)] diff --git a/luisa_compute_derive/src/lib.rs b/luisa_compute_derive/src/lib.rs index 41e4ee2..a18c41d 100644 --- a/luisa_compute_derive/src/lib.rs +++ b/luisa_compute_derive/src/lib.rs @@ -3,6 +3,14 @@ use syn::__private::quote::quote; use syn::parse::{Parse, ParseStream}; use syn::spanned::Spanned; +/// Derives the `IoTexel` trait for a `#[repr(transparent)]` struct and a `Value` impl. +#[proc_macro_derive(IoTexel)] +pub fn derive_iotexel(item: TokenStream) -> TokenStream { + let item: syn::Item = syn::parse(item).unwrap(); + let compiler = luisa_compute_derive_impl::Compiler; + compiler.derive_iotexel(&item).into() +} + #[proc_macro_derive(Value, attributes(value_new))] pub fn derive_value(item: TokenStream) -> TokenStream { let item: syn::Item = syn::parse(item).unwrap(); diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index ed9add1..60408fa 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -37,6 +37,9 @@ impl Compiler { fn runtime_path(&self) -> TokenStream { quote!(::luisa_compute::runtime) } + fn resource_path(&self) -> TokenStream { + quote!(::luisa_compute::resource) + } fn value_attributes(&self, attribtes: &Vec) -> Option { let mut has_repr_c = false; let mut ordering = None; @@ -85,25 +88,18 @@ impl Compiler { let fields: Vec<_> = struct_ .fields .iter() - .map(|f| f) .filter(|f| { let attrs = &f.attrs; for attr in attrs { let meta = &attr.meta; - match meta { - syn::Meta::List(list) => { - for tok in list.tokens.clone().into_iter() { - match tok { - TokenTree::Ident(ident) => { - if ident == "exclude" || ident == "ignore" { - return false; - } - } - _ => {} + if let syn::Meta::List(list) = meta { + for tok in list.tokens.clone().into_iter() { + if let TokenTree::Ident(ident) = tok { + if ident == "exclude" || ident == "ignore" { + return false; } } } - _ => {} } } true @@ -147,7 +143,7 @@ impl Compiler { let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let name = &struct_.ident; let vis = &struct_.vis; - let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect(); + let fields: Vec<_> = struct_.fields.iter().collect(); let field_vis: Vec<_> = fields.iter().map(|f| &f.vis).collect(); let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect(); let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); @@ -209,6 +205,63 @@ impl Compiler { } ) } + pub fn derive_iotexel(&self, item: &Item) -> TokenStream { + match item { + Item::Struct(struct_) => self.derive_iotexel_for_struct(struct_), + _ => todo!(), + } + } + pub fn derive_iotexel_for_struct(&self, struct_: &ItemStruct) -> TokenStream { + let span = struct_.span(); + let resource_path = self.resource_path(); + let lang_path = self.lang_path(); + // Make sure that the struct has repr(transparent). + let mut has_repr_transparent = false; + for Attribute { meta, .. } in &struct_.attrs { + if let syn::Meta::List(list) = meta { + let path = &list.path; + if path.is_ident("repr") { + for tok in list.tokens.clone().into_iter() { + if let TokenTree::Ident(ident) = tok { + if ident == "transparent" { + has_repr_transparent = true; + } + } + } + } + } + } + if !has_repr_transparent { + panic!("Struct must have #[repr(transparent)] attribute"); + } + let generics = &struct_.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + let syn::Fields::Named(syn::FieldsNamed { named, .. }) = &struct_.fields else { + panic!("IoTexel derive currently only supports named fields.") + }; + assert_eq!(named.len(), 1); + let syn::Field { ident, ty, .. } = &named[0]; + let ident = ident.as_ref().unwrap(); + let struct_name = &struct_.ident; + let struct_comps_name = + syn::Ident::new(&format!("{}Comps", struct_name), struct_name.span()); + quote_spanned! {span=> + impl #impl_generics #resource_path::IoTexel for #struct_name<#ty_generics> #where_clause { + type RwType = <#ty as #resource_path::IoTexel>::RwType; + fn pixel_format(storage: #resource_path::PixelStorage) -> #resource_path::PixelFormat { + <#ty as #resource_path::IoTexel>::pixel_format(storage) + } + fn convert_from_read(texel: #lang_path::types::Expr) -> #lang_path::types::Expr { + #struct_name::from_comps_expr(#struct_comps_name { + #ident: <#ty as #resource_path::IoTexel>::convert_from_read(texel), + }) + } + fn convert_to_write(value: #lang_path::types::Expr) -> #lang_path::types::Expr { + <#ty as #resource_path::IoTexel>::convert_to_write(value.#ident) + } + } + } + } pub fn derive_value(&self, item: &Item) -> TokenStream { match item { Item::Struct(struct_) => self.derive_value_for_struct(struct_), @@ -318,7 +371,7 @@ impl Compiler { let marker_args = quote!(#(#marker_args),*); let name = &struct_.ident; let vis = &struct_.vis; - let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect(); + let fields: Vec<_> = struct_.fields.iter().collect(); let field_vis: Vec<_> = fields.iter().map(|f| &f.vis).collect(); let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect(); let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); @@ -555,7 +608,7 @@ impl Compiler { let span = struct_.span(); let lang_path = self.lang_path(); let name = &struct_.ident; - let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect(); + let fields: Vec<_> = struct_.fields.iter().collect(); let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect(); let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect(); quote_spanned!(span=>