diff --git a/crates/cubecl-lac/src/matmul/cmma/compute_loop.rs b/crates/cubecl-lac/src/matmul/cmma/compute_loop.rs index 2932b4d1..38d0ad62 100644 --- a/crates/cubecl-lac/src/matmul/cmma/compute_loop.rs +++ b/crates/cubecl-lac/src/matmul/cmma/compute_loop.rs @@ -117,7 +117,7 @@ pub mod tests { ) { let mut lhs = SharedMemory::::new(Comptime::get(m * k)); let mut rhs = SharedMemory::::new(Comptime::get(k * n)); - let mut accumulate = SharedMemory::::new(Comptime::get(m * n)); + for i in range(0u32, Comptime::get(m * k), Comptime::new(false)) { lhs[i] = lhs_tensor[i]; } @@ -125,7 +125,7 @@ pub mod tests { rhs[i] = rhs_tensor[i]; } for i in range(0u32, Comptime::get(m * n), Comptime::new(false)) { - accumulate[i] = F::new(0.); + accumulate_array[i] = F::new(0.); } let shared_memories = SharedMemories { lhs, rhs }; @@ -134,17 +134,17 @@ pub mod tests { compute_loop(shared_memories, accumulators, config); let offset = UNIT_POS_Y * UInt::new(512); - let slice = accumulate_array.slice_mut(offset, offset + UInt::new(256)); + let slice_0 = accumulate_array.slice_mut(offset, offset + UInt::new(256)); cmma::store::( - slice, + slice_0, &accumulators.first, UInt::new(16), cmma::MatrixLayout::RowMajor, ); - let slice = accumulate_array.slice_mut(offset + UInt::new(256), offset + UInt::new(512)); + let slice_1 = accumulate_array.slice_mut(offset + UInt::new(256), offset + UInt::new(512)); cmma::store::( - slice, + slice_1, &accumulators.second, UInt::new(16), cmma::MatrixLayout::RowMajor, @@ -233,16 +233,19 @@ pub mod tests { return; } - let lhs = range_tensor_f16::(16, 32, device); - let rhs = range_tensor_f16::(32, 32, device); - let results = create_empty::(16, 32, device); + let m = 16; + let k = 32; + let n = 32; + let lhs = range_tensor_f16::(m, k, device); + let rhs = range_tensor_f16::(k, n, device); + let results = create_empty::(m, n, device); let cube_dim = CubeDim::new(32, 1, 1); let cube_count = CubeCount::Static(1, 1, 1); let config = CmmaConfig { - block_size_m: UInt::new(16), - block_size_k: UInt::new(32), - block_size_n: UInt::new(32), + block_size_m: UInt::new(m as u32), + block_size_k: UInt::new(k as u32), + block_size_n: UInt::new(n as u32), tile_size: UInt::new(16), check_m_bounds: false, check_k_bounds: false, @@ -256,10 +259,10 @@ pub mod tests { cube_dim, TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape), TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape), - ArrayArg::new(&results, 512), - UInt::new(16), - UInt::new(32), - UInt::new(32), + ArrayArg::new(&results, m * n), + UInt::new(m as u32), + UInt::new(k as u32), + UInt::new(n as u32), config, );