diff --git a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs index fdc241cf0..d36d85055 100644 --- a/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/cubecl-wgpu/src/compiler/wgsl/instructions.rs @@ -483,16 +483,57 @@ impl Display for Instruction { Instruction::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f), Instruction::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f), Instruction::NotEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "!=", f), - Instruction::Assign { input, out } => { - if input.elem().is_atomic() { - f.write_fmt(format_args!("let {out} = &{input};\n")) - } else if input.item() != out.item() { - let item = out.item(); - f.write_fmt(format_args!("{out} = {item}({input});\n")) - } else { - f.write_fmt(format_args!("{out} = {input};\n")) + Instruction::Assign { input, out } => match out.item() { + Item::Vec4(elem) => { + let input0 = input.index(0); + let input1 = input.index(1); + let input2 = input.index(2); + let input3 = input.index(3); + + f.write_fmt(format_args!( + "{out} = vec4( + {elem}({input0}), + {elem}({input1}), + {elem}({input2}), + {elem}({input3}), +); +" + )) } - } + Item::Vec3(elem) => { + let input0 = input.index(0); + let input1 = input.index(1); + let input2 = input.index(2); + + f.write_fmt(format_args!( + "{out} = vec3( + {elem}({input0}), + {elem}({input1}), + {elem}({input2}), +); +" + )) + } + Item::Vec2(elem) => { + let input0 = input.index(0); + let input1 = input.index(1); + + f.write_fmt(format_args!( + "{out} = vec2( + {elem}({input0}), + {elem}({input1}), +); +" + )) + } + Item::Scalar(elem) => { + if elem.is_atomic() { + f.write_fmt(format_args!("let {out} = &{input};\n")) + } else { + f.write_fmt(format_args!("{out} = {elem}({input});\n")) + } + } + }, Instruction::Stride { dim, position, out } => f.write_fmt(format_args!( "{out} = info[({position}u * rank_2) + {dim} + 1u];\n" )),