Skip to content

Commit

Permalink
Inform users that launch kernels cannot have return types and suggest…
Browse files Browse the repository at this point in the history
… what to do

Signed-off-by: Torstein Grindvik <[email protected]>
  • Loading branch information
Torstein Grindvik committed Nov 17, 2024
1 parent 8f4861e commit 9297e93
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion crates/cubecl-macros/src/parse/kernel.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use crate::{expression::Block, paths::prelude_type, scope::Context, statement::Pattern};
use darling::{ast::NestedMeta, util::Flag, FromMeta};
use proc_macro2::TokenStream;
use quote::ToTokens;
use std::iter;
use syn::{
parse_quote, punctuated::Punctuated, spanned::Spanned, visit_mut::VisitMut, Expr, FnArg,
Generics, Ident, ItemFn, Signature, TraitItemFn, Type, Visibility,
Generics, Ident, ItemFn, ReturnType, Signature, TraitItemFn, Type, Visibility,
};

use super::{desugar::Desugar, helpers::is_comptime_attr, statement::parse_pat};
Expand Down Expand Up @@ -232,6 +233,7 @@ impl KernelFn {
impl Launch {
pub fn from_item_fn(function: ItemFn, args: KernelArgs) -> syn::Result<Self> {
let runtime = prelude_type("Runtime");
let ret = function.sig.output.clone();

let vis = function.vis;
let func = KernelFn::from_sig_and_block(
Expand All @@ -243,6 +245,24 @@ impl Launch {
function.sig,
*function.block,
)?;

// Bail early if the user tries to have a return type in a launch kernel.
if args.is_launch() {
if let ReturnType::Type(arrow, ty) = &ret {
// Span both the arrow and the return type
let mut ts = arrow.to_token_stream();
ts.extend(ty.into_token_stream());

return Err(syn::Error::new_spanned(
ts,
format!(
"This is a launch kernel and cannot have a return type. Remove `-> {}`. Use mutable output arguments instead in order to get values out from kernels.",
ty.into_token_stream()
),
));
}
}

let mut kernel_generics = func.sig.generics.clone();
kernel_generics.params.push(parse_quote![__R: #runtime]);
let mut expand_generics = kernel_generics.clone();
Expand Down

0 comments on commit 9297e93

Please sign in to comment.