From 75485314e6e0f41e715da7dacdb48ecf9dd67b75 Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 31 Jul 2024 10:37:07 -0400 Subject: [PATCH 1/2] Apply the mapping after the inputs --- crates/cubecl-macros/src/codegen_function/launch.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/crates/cubecl-macros/src/codegen_function/launch.rs b/crates/cubecl-macros/src/codegen_function/launch.rs index 660e4f32..57ce23b0 100644 --- a/crates/cubecl-macros/src/codegen_function/launch.rs +++ b/crates/cubecl-macros/src/codegen_function/launch.rs @@ -299,19 +299,13 @@ impl Codegen { let mut inputs: std::collections::BTreeMap> = std::collections::BTreeMap::new(); let mut outputs: std::collections::BTreeMap> = std::collections::BTreeMap::new(); - for mapping in self.settings.mappings.iter() { - if !inputs.contains_key(&mapping.pos_input) { - inputs.insert( - mapping.pos_input, - #register_input_call(&mut builder, &self.settings, mapping.pos_input), - ); - } + #register_input + for mapping in self.settings.mappings.iter() { let input = inputs.get(&mapping.pos_input).unwrap(); outputs.insert(mapping.pos_output, input.clone()); } - #register_input #register_output }; From de615cbf8627956e72c28032545afc733c93661d Mon Sep 17 00:00:00 2001 From: nathaniel Date: Wed, 31 Jul 2024 12:22:50 -0400 Subject: [PATCH 2/2] Fix inplace --- crates/cubecl-macros/src/codegen_function/launch.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/crates/cubecl-macros/src/codegen_function/launch.rs b/crates/cubecl-macros/src/codegen_function/launch.rs index 57ce23b0..49ed4dcb 100644 --- a/crates/cubecl-macros/src/codegen_function/launch.rs +++ b/crates/cubecl-macros/src/codegen_function/launch.rs @@ -300,12 +300,6 @@ impl Codegen { let mut outputs: std::collections::BTreeMap> = std::collections::BTreeMap::new(); #register_input - - for mapping in self.settings.mappings.iter() { - let input = inputs.get(&mapping.pos_input).unwrap(); - outputs.insert(mapping.pos_output, input.clone()); - } - #register_output }; @@ -319,6 +313,13 @@ impl Codegen { }); } + tokens.extend(quote::quote! { + for mapping in self.settings.mappings.iter() { + let input = inputs.get(&mapping.pos_input).unwrap(); + outputs.insert(mapping.pos_output, input.clone()); + } + }); + if num_outputs > 0 { tokens.extend(quote::quote! { for i in 0..#num_outputs {