From d52401daf24c4ae8374ff28ac19ee74dde8c9670 Mon Sep 17 00:00:00 2001 From: Thomas Faingnaert Date: Thu, 7 Dec 2023 14:09:18 +0100 Subject: [PATCH] Check size limits of LocalArray Extracted from #179 --- src/matmul.jl | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/matmul.jl b/src/matmul.jl index 3c4bde0d..f3dfd34f 100644 --- a/src/matmul.jl +++ b/src/matmul.jl @@ -34,6 +34,30 @@ function matmul(conf::Config, a, b, c, d; conf.block_shape.K ≥ 2 * conf.compute_op_shape.K || throw(ConfigError("Need at least two stages to use a pipelined kernel, i.e. BLOCK_K ≥ 2 * OPERATOR_K")) end + # Check LocalArray size limit of 32 elements. + if kernel == Kernel.matmul_singlestage + num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M + num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N + + num_fragments_m * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + end + + if kernel == Kernel.matmul_pipelined + num_fragments_m = conf.compute_warp.M ÷ conf.compute_op_shape.M + num_fragments_n = conf.compute_warp.N ÷ conf.compute_op_shape.N + + a_frag_i = (conf.block_shape.M * conf.block_shape.K) ÷ (conf.mem_a_warp.M * conf.mem_a_warp.K * conf.warps_per_block) + a_frag_j = (conf.mem_a_warp.M * conf.mem_a_warp.K) ÷ (conf.mem_a_thread.M * conf.mem_a_thread.K * 32) + b_frag_i = (conf.block_shape.K * conf.block_shape.N) ÷ (conf.mem_b_warp.K * conf.mem_b_warp.N * conf.warps_per_block) + b_frag_j = (conf.mem_b_warp.K * conf.mem_b_warp.N) ÷ (conf.mem_b_thread.K * conf.mem_b_thread.N * 32) + + num_fragments_m * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + a_frag_i * a_frag_j < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + b_frag_i * b_frag_j < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + 2 * num_fragments_m < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + 2 * num_fragments_n < 32 || throw(ConfigError("Config exceeds LocalArray size limit of 32 elements!")) + end + hostkernel = @cuda launch=false kernel(args...) attributes(hostkernel.fun)[CUDA.FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES] = shmem