Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added IoTexel derive for simple structs. #30

Merged
merged 2 commits into from
Oct 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions luisa_compute/examples/custom_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::env::current_exe;
use luisa::lang::external::CpuFn;
use luisa::prelude::*;
use luisa_compute as luisa;

#[derive(Clone, Copy, Value, Debug)]
#[repr(C)]
#[value_new(pub)]
Expand Down
8 changes: 8 additions & 0 deletions luisa_compute_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ use syn::__private::quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;

/// Derives the `IoTexel` trait for a `#[repr(transparent)]` struct and a `Value` impl.
#[proc_macro_derive(IoTexel)]
pub fn derive_iotexel(item: TokenStream) -> TokenStream {
let item: syn::Item = syn::parse(item).unwrap();
let compiler = luisa_compute_derive_impl::Compiler;
compiler.derive_iotexel(&item).into()
}

#[proc_macro_derive(Value, attributes(value_new))]
pub fn derive_value(item: TokenStream) -> TokenStream {
let item: syn::Item = syn::parse(item).unwrap();
Expand Down
83 changes: 68 additions & 15 deletions luisa_compute_derive_impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ impl Compiler {
fn runtime_path(&self) -> TokenStream {
quote!(::luisa_compute::runtime)
}
fn resource_path(&self) -> TokenStream {
quote!(::luisa_compute::resource)
}
fn value_attributes(&self, attribtes: &Vec<Attribute>) -> Option<ValueNewOrdering> {
let mut has_repr_c = false;
let mut ordering = None;
Expand Down Expand Up @@ -85,25 +88,18 @@ impl Compiler {
let fields: Vec<_> = struct_
.fields
.iter()
.map(|f| f)
.filter(|f| {
let attrs = &f.attrs;
for attr in attrs {
let meta = &attr.meta;
match meta {
syn::Meta::List(list) => {
for tok in list.tokens.clone().into_iter() {
match tok {
TokenTree::Ident(ident) => {
if ident == "exclude" || ident == "ignore" {
return false;
}
}
_ => {}
if let syn::Meta::List(list) = meta {
for tok in list.tokens.clone().into_iter() {
if let TokenTree::Ident(ident) = tok {
if ident == "exclude" || ident == "ignore" {
return false;
}
}
}
_ => {}
}
}
true
Expand Down Expand Up @@ -147,7 +143,7 @@ impl Compiler {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let name = &struct_.ident;
let vis = &struct_.vis;
let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect();
let fields: Vec<_> = struct_.fields.iter().collect();
let field_vis: Vec<_> = fields.iter().map(|f| &f.vis).collect();
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
Expand Down Expand Up @@ -209,6 +205,63 @@ impl Compiler {
}
)
}
pub fn derive_iotexel(&self, item: &Item) -> TokenStream {
match item {
Item::Struct(struct_) => self.derive_iotexel_for_struct(struct_),
_ => todo!(),
}
}
pub fn derive_iotexel_for_struct(&self, struct_: &ItemStruct) -> TokenStream {
let span = struct_.span();
let resource_path = self.resource_path();
let lang_path = self.lang_path();
// Make sure that the struct has repr(transparent).
let mut has_repr_transparent = false;
for Attribute { meta, .. } in &struct_.attrs {
if let syn::Meta::List(list) = meta {
let path = &list.path;
if path.is_ident("repr") {
for tok in list.tokens.clone().into_iter() {
if let TokenTree::Ident(ident) = tok {
if ident == "transparent" {
has_repr_transparent = true;
}
}
}
}
}
}
if !has_repr_transparent {
panic!("Struct must have #[repr(transparent)] attribute");
}
let generics = &struct_.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let syn::Fields::Named(syn::FieldsNamed { named, .. }) = &struct_.fields else {
panic!("IoTexel derive currently only supports named fields.")
};
assert_eq!(named.len(), 1);
let syn::Field { ident, ty, .. } = &named[0];
let ident = ident.as_ref().unwrap();
let struct_name = &struct_.ident;
let struct_comps_name =
syn::Ident::new(&format!("{}Comps", struct_name), struct_name.span());
quote_spanned! {span=>
impl #impl_generics #resource_path::IoTexel for #struct_name<#ty_generics> #where_clause {
type RwType = <#ty as #resource_path::IoTexel>::RwType;
fn pixel_format(storage: #resource_path::PixelStorage) -> #resource_path::PixelFormat {
<#ty as #resource_path::IoTexel>::pixel_format(storage)
}
fn convert_from_read(texel: #lang_path::types::Expr<Self::RwType>) -> #lang_path::types::Expr<Self> {
#struct_name::from_comps_expr(#struct_comps_name {
#ident: <#ty as #resource_path::IoTexel>::convert_from_read(texel),
})
}
fn convert_to_write(value: #lang_path::types::Expr<Self>) -> #lang_path::types::Expr<Self::RwType> {
<#ty as #resource_path::IoTexel>::convert_to_write(value.#ident)
}
}
}
}
pub fn derive_value(&self, item: &Item) -> TokenStream {
match item {
Item::Struct(struct_) => self.derive_value_for_struct(struct_),
Expand Down Expand Up @@ -318,7 +371,7 @@ impl Compiler {
let marker_args = quote!(#(#marker_args),*);
let name = &struct_.ident;
let vis = &struct_.vis;
let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect();
let fields: Vec<_> = struct_.fields.iter().collect();
let field_vis: Vec<_> = fields.iter().map(|f| &f.vis).collect();
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
Expand Down Expand Up @@ -555,7 +608,7 @@ impl Compiler {
let span = struct_.span();
let lang_path = self.lang_path();
let name = &struct_.ident;
let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect();
let fields: Vec<_> = struct_.fields.iter().collect();
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
quote_spanned!(span=>
Expand Down
Loading