Skip to content

Commit

Permalink
Merge pull request #30 from iMplode-nZ/main
Browse files Browse the repository at this point in the history
Added `IoTexel` derive for simple structs.
  • Loading branch information
shiinamiyuki authored Oct 27, 2023
2 parents 4dea07a + 21bdd1c commit e2f661c
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 15 deletions.
1 change: 1 addition & 0 deletions luisa_compute/examples/custom_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use std::env::current_exe;
use luisa::lang::external::CpuFn;
use luisa::prelude::*;
use luisa_compute as luisa;

#[derive(Clone, Copy, Value, Debug)]
#[repr(C)]
#[value_new(pub)]
Expand Down
8 changes: 8 additions & 0 deletions luisa_compute_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,14 @@ use syn::__private::quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::spanned::Spanned;

/// Derives the `IoTexel` trait for a `#[repr(transparent)]` struct and a `Value` impl.
#[proc_macro_derive(IoTexel)]
pub fn derive_iotexel(item: TokenStream) -> TokenStream {
let item: syn::Item = syn::parse(item).unwrap();
let compiler = luisa_compute_derive_impl::Compiler;
compiler.derive_iotexel(&item).into()
}

#[proc_macro_derive(Value, attributes(value_new))]
pub fn derive_value(item: TokenStream) -> TokenStream {
let item: syn::Item = syn::parse(item).unwrap();
Expand Down
83 changes: 68 additions & 15 deletions luisa_compute_derive_impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ impl Compiler {
fn runtime_path(&self) -> TokenStream {
quote!(::luisa_compute::runtime)
}
fn resource_path(&self) -> TokenStream {
quote!(::luisa_compute::resource)
}
fn value_attributes(&self, attribtes: &Vec<Attribute>) -> Option<ValueNewOrdering> {
let mut has_repr_c = false;
let mut ordering = None;
Expand Down Expand Up @@ -85,25 +88,18 @@ impl Compiler {
let fields: Vec<_> = struct_
.fields
.iter()
.map(|f| f)
.filter(|f| {
let attrs = &f.attrs;
for attr in attrs {
let meta = &attr.meta;
match meta {
syn::Meta::List(list) => {
for tok in list.tokens.clone().into_iter() {
match tok {
TokenTree::Ident(ident) => {
if ident == "exclude" || ident == "ignore" {
return false;
}
}
_ => {}
if let syn::Meta::List(list) = meta {
for tok in list.tokens.clone().into_iter() {
if let TokenTree::Ident(ident) = tok {
if ident == "exclude" || ident == "ignore" {
return false;
}
}
}
_ => {}
}
}
true
Expand Down Expand Up @@ -147,7 +143,7 @@ impl Compiler {
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let name = &struct_.ident;
let vis = &struct_.vis;
let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect();
let fields: Vec<_> = struct_.fields.iter().collect();
let field_vis: Vec<_> = fields.iter().map(|f| &f.vis).collect();
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
Expand Down Expand Up @@ -209,6 +205,63 @@ impl Compiler {
}
)
}
pub fn derive_iotexel(&self, item: &Item) -> TokenStream {
match item {
Item::Struct(struct_) => self.derive_iotexel_for_struct(struct_),
_ => todo!(),
}
}
pub fn derive_iotexel_for_struct(&self, struct_: &ItemStruct) -> TokenStream {
let span = struct_.span();
let resource_path = self.resource_path();
let lang_path = self.lang_path();
// Make sure that the struct has repr(transparent).
let mut has_repr_transparent = false;
for Attribute { meta, .. } in &struct_.attrs {
if let syn::Meta::List(list) = meta {
let path = &list.path;
if path.is_ident("repr") {
for tok in list.tokens.clone().into_iter() {
if let TokenTree::Ident(ident) = tok {
if ident == "transparent" {
has_repr_transparent = true;
}
}
}
}
}
}
if !has_repr_transparent {
panic!("Struct must have #[repr(transparent)] attribute");
}
let generics = &struct_.generics;
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
let syn::Fields::Named(syn::FieldsNamed { named, .. }) = &struct_.fields else {
panic!("IoTexel derive currently only supports named fields.")
};
assert_eq!(named.len(), 1);
let syn::Field { ident, ty, .. } = &named[0];
let ident = ident.as_ref().unwrap();
let struct_name = &struct_.ident;
let struct_comps_name =
syn::Ident::new(&format!("{}Comps", struct_name), struct_name.span());
quote_spanned! {span=>
impl #impl_generics #resource_path::IoTexel for #struct_name<#ty_generics> #where_clause {
type RwType = <#ty as #resource_path::IoTexel>::RwType;
fn pixel_format(storage: #resource_path::PixelStorage) -> #resource_path::PixelFormat {
<#ty as #resource_path::IoTexel>::pixel_format(storage)
}
fn convert_from_read(texel: #lang_path::types::Expr<Self::RwType>) -> #lang_path::types::Expr<Self> {
#struct_name::from_comps_expr(#struct_comps_name {
#ident: <#ty as #resource_path::IoTexel>::convert_from_read(texel),
})
}
fn convert_to_write(value: #lang_path::types::Expr<Self>) -> #lang_path::types::Expr<Self::RwType> {
<#ty as #resource_path::IoTexel>::convert_to_write(value.#ident)
}
}
}
}
pub fn derive_value(&self, item: &Item) -> TokenStream {
match item {
Item::Struct(struct_) => self.derive_value_for_struct(struct_),
Expand Down Expand Up @@ -318,7 +371,7 @@ impl Compiler {
let marker_args = quote!(#(#marker_args),*);
let name = &struct_.ident;
let vis = &struct_.vis;
let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect();
let fields: Vec<_> = struct_.fields.iter().collect();
let field_vis: Vec<_> = fields.iter().map(|f| &f.vis).collect();
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
Expand Down Expand Up @@ -555,7 +608,7 @@ impl Compiler {
let span = struct_.span();
let lang_path = self.lang_path();
let name = &struct_.ident;
let fields: Vec<_> = struct_.fields.iter().map(|f| f).collect();
let fields: Vec<_> = struct_.fields.iter().collect();
let field_types: Vec<_> = fields.iter().map(|f| &f.ty).collect();
let field_names: Vec<_> = fields.iter().map(|f| f.ident.as_ref().unwrap()).collect();
quote_spanned!(span=>
Expand Down

0 comments on commit e2f661c

Please sign in to comment.