Skip to content

Commit

Permalink
Merge pull request #27 from iMplode-nZ/main
Browse files Browse the repository at this point in the history
Added `repr(u32)` enum Value derives.
  • Loading branch information
shiinamiyuki authored Oct 20, 2023
2 parents dd04749 + 4ea0e1e commit 65a1a9a
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 11 deletions.
6 changes: 3 additions & 3 deletions luisa_compute/src/lang/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ pub trait SoaValue: Value {
pub trait ExprProxy: Copy + 'static {
type Value: Value<Expr = Self>;

fn from_expr(expr: Expr<Self::Value>) -> Self;
fn as_expr_from_proxy(&self) -> &Expr<Self::Value>;
fn from_expr(expr: Expr<Self::Value>) -> Self;
}

/// A trait for implementing remote impls on top of an [`Var`] using [`Deref`].
Expand All @@ -71,15 +71,15 @@ pub trait ExprProxy: Copy + 'static {
/// impls.
pub trait VarProxy: Copy + 'static + Deref<Target = Expr<Self::Value>> {
type Value: Value<Var = Self>;
fn as_var_from_proxy(&self) -> &Var<Self::Value>;

fn as_var_from_proxy(&self) -> &Var<Self::Value>;
fn from_var(expr: Var<Self::Value>) -> Self;
}

pub unsafe trait AtomicRefProxy: Copy + 'static {
type Value: Value<AtomicRef = Self>;
fn as_atomic_ref_from_proxy(&self) -> &AtomicRef<Self::Value>;

fn as_atomic_ref_from_proxy(&self) -> &AtomicRef<Self::Value>;
fn from_atomic_ref(expr: AtomicRef<Self::Value>) -> Self;
}

Expand Down
13 changes: 8 additions & 5 deletions luisa_compute/src/lang/types/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ impl<T: Primitive> Value for T {
type AtomicRef = PrimitiveAtomicRef<T>;

fn expr(self) -> Expr<Self> {
let node = __current_scope(|s| -> NodeRef { s.const_(self.const_()) });
let node = __current_scope(|s| s.const_(self.const_()));
Expr::<Self>::from_node(node.into())
}
}
Expand All @@ -288,7 +288,10 @@ macro_rules! impl_atomic {
lower_atomic_ref(
self.node().get(),
Func::AtomicCompareExchange,
&[expected.as_expr().node().get(), desired.as_expr().node().get()],
&[
expected.as_expr().node().get(),
desired.as_expr().node().get(),
],
)
}
pub fn exchange(&self, operand: impl AsExpr<Value = $t>) -> Expr<$t> {
Expand Down Expand Up @@ -375,9 +378,9 @@ fn lower_atomic_ref<T: Value>(node: NodeRef, op: Func, args: &[NodeRef]) -> Expr
.chain(args.iter())
.map(|n| *n)
.collect::<Vec<_>>();
Expr::<T>::from_node(__current_scope(|b| {
b.call(op, &new_args, <T as TypeOf>::type_())
}).into())
Expr::<T>::from_node(
__current_scope(|b| b.call(op, &new_args, <T as TypeOf>::type_())).into(),
)
}
_ => unreachable!("{:?}", inst),
},
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute_derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use syn::spanned::Spanned;

#[proc_macro_derive(Value, attributes(value_new))]
pub fn derive_value(item: TokenStream) -> TokenStream {
let item: syn::ItemStruct = syn::parse(item).unwrap();
let item: syn::Item = syn::parse(item).unwrap();
let compiler = luisa_compute_derive_impl::Compiler;
compiler.derive_value(&item).into()
}
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute_derive_impl/src/bin/derive-debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ fn main() {
)
.unwrap();
println!("{:?}", item.to_token_stream());
let out = compiler.derive_value(&item);
let out = compiler.derive_value_for_struct(&item);
println!("{:?}", out.to_string());
}
84 changes: 83 additions & 1 deletion luisa_compute_derive_impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,89 @@ impl Compiler {
}
)
}
pub fn derive_value(&self, struct_: &ItemStruct) -> TokenStream {
pub fn derive_value(&self, item: &Item) -> TokenStream {
match item {
Item::Struct(struct_) => self.derive_value_for_struct(struct_),
Item::Enum(enum_) => self.derive_value_for_enum(enum_),
_ => todo!(),
}
}
pub fn derive_value_for_enum(&self, enum_: &ItemEnum) -> TokenStream {
let repr = enum_
.attrs
.iter()
.find_map(|attr| {
let meta = &attr.meta;
match meta {
syn::Meta::List(list) => {
let path = &list.path;
if path.is_ident("repr") {
list.parse_args::<Ident>().ok()
} else {
None
}
}
_ => None,
}
})
.expect("Enum must have repr attribute.");
let span = enum_.span();
let lang_path = self.lang_path();
let name = &enum_.ident;
let expr_proxy_name = syn::Ident::new(&format!("{}Expr", name), name.span());
let var_proxy_name = syn::Ident::new(&format!("{}Var", name), name.span());
let atomic_ref_proxy_name = syn::Ident::new(&format!("{}AtomicRef", name), name.span());
let as_repr = syn::Ident::new(&format!("as_{}", repr), repr.span());
if !(["bool", "u8", "u16", "u32", "u64", "i8", "i16", "i32", "i64"]
.contains(&&*repr.to_string()))
{
panic!("Enum repr must be one of bool, u8, u16, u32, u64, i8, i16, i32, i64");
}
quote_spanned! {span=>
impl #lang_path::types::Value for #name {
type Expr = #expr_proxy_name;
type Var = #var_proxy_name;
type AtomicRef = #atomic_ref_proxy_name;

fn expr(self) -> Expr<Self> {
let node = #lang_path::__current_scope(|s| s.const_(<#repr as #lang_path::types::core::Primitive>::const_(&(self as #repr))));
<Expr::<Self> as #lang_path::FromNode>::from_node(node.into())
}
}
impl #lang_path::ir::TypeOf for #name {
fn type_() -> #lang_path::ir::CArc<#lang_path::ir::Type> {
<#repr as #lang_path::ir::TypeOf>::type_()
}
}

::luisa_compute::impl_simple_expr_proxy!(#expr_proxy_name for #name);
::luisa_compute::impl_simple_var_proxy!(#var_proxy_name for #name);
::luisa_compute::impl_simple_atomic_ref_proxy!(#atomic_ref_proxy_name for #name);

impl #expr_proxy_name {
pub fn #as_repr(&self) -> #lang_path::types::Expr<#repr> {
use #lang_path::ToNode;
use #lang_path::types::ExprProxy;
#lang_path::FromNode::from_node(self.as_expr_from_proxy().node())
}
}
impl #var_proxy_name {
pub fn #as_repr(&self) -> #lang_path::types::Var<#repr> {
use #lang_path::ToNode;
use #lang_path::types::VarProxy;
#lang_path::FromNode::from_node(self.as_var_from_proxy().node())
}
}
impl #atomic_ref_proxy_name {
pub fn #as_repr(&self) -> #lang_path::types::AtomicRef<#repr> {
use #lang_path::ToNode;
use #lang_path::types::AtomicRefProxy;
#lang_path::FromNode::from_node(self.as_atomic_ref_from_proxy().node())
}
}
}
}
pub fn derive_value_for_struct(&self, struct_: &ItemStruct) -> TokenStream {
let ordering = self.value_attributes(&struct_.attrs);
let span = struct_.span();
let lang_path = self.lang_path();
Expand Down

0 comments on commit 65a1a9a

Please sign in to comment.