Skip to content

Commit

Permalink
trying to fix flaky test
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd committed Jul 18, 2024
1 parent b8639e8 commit 7deaccb
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions crates/cubecl-lac/src/matmul/cmma/compute_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,15 @@ pub mod tests {
) {
let mut lhs = SharedMemory::<FC>::new(Comptime::get(m * k));
let mut rhs = SharedMemory::<FC>::new(Comptime::get(k * n));
let mut accumulate = SharedMemory::<F>::new(Comptime::get(m * n));

for i in range(0u32, Comptime::get(m * k), Comptime::new(false)) {
lhs[i] = lhs_tensor[i];
}
for i in range(0u32, Comptime::get(k * n), Comptime::new(false)) {
rhs[i] = rhs_tensor[i];
}
for i in range(0u32, Comptime::get(m * n), Comptime::new(false)) {
accumulate[i] = F::new(0.);
accumulate_array[i] = F::new(0.);
}

let shared_memories = SharedMemories { lhs, rhs };
Expand All @@ -134,17 +134,17 @@ pub mod tests {
compute_loop(shared_memories, accumulators, config);

let offset = UNIT_POS_Y * UInt::new(512);
let slice = accumulate_array.slice_mut(offset, offset + UInt::new(256));
let slice_0 = accumulate_array.slice_mut(offset, offset + UInt::new(256));
cmma::store::<F>(
slice,
slice_0,
&accumulators.first,
UInt::new(16),
cmma::MatrixLayout::RowMajor,
);

let slice = accumulate_array.slice_mut(offset + UInt::new(256), offset + UInt::new(512));
let slice_1 = accumulate_array.slice_mut(offset + UInt::new(256), offset + UInt::new(512));
cmma::store::<F>(
slice,
slice_1,
&accumulators.second,
UInt::new(16),
cmma::MatrixLayout::RowMajor,
Expand Down Expand Up @@ -233,16 +233,19 @@ pub mod tests {
return;
}

let lhs = range_tensor_f16::<R>(16, 32, device);
let rhs = range_tensor_f16::<R>(32, 32, device);
let results = create_empty::<R>(16, 32, device);
let m = 16;
let k = 32;
let n = 32;
let lhs = range_tensor_f16::<R>(m, k, device);
let rhs = range_tensor_f16::<R>(k, n, device);
let results = create_empty::<R>(m, n, device);
let cube_dim = CubeDim::new(32, 1, 1);
let cube_count = CubeCount::Static(1, 1, 1);

let config = CmmaConfig {
block_size_m: UInt::new(16),
block_size_k: UInt::new(32),
block_size_n: UInt::new(32),
block_size_m: UInt::new(m as u32),
block_size_k: UInt::new(k as u32),
block_size_n: UInt::new(n as u32),
tile_size: UInt::new(16),
check_m_bounds: false,
check_k_bounds: false,
Expand All @@ -256,10 +259,10 @@ pub mod tests {
cube_dim,
TensorArg::new(&lhs.handle, &lhs.strides, &lhs.shape),
TensorArg::new(&rhs.handle, &rhs.strides, &rhs.shape),
ArrayArg::new(&results, 512),
UInt::new(16),
UInt::new(32),
UInt::new(32),
ArrayArg::new(&results, m * n),
UInt::new(m as u32),
UInt::new(k as u32),
UInt::new(n as u32),
config,
);

Expand Down

0 comments on commit 7deaccb

Please sign in to comment.