From b2c79ef8617c7fd11b213d9cbf019a21ed8fe596 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Mon, 15 Jul 2024 17:40:07 +0900 Subject: [PATCH 1/3] Add loop unrolling for matmul --- examples/matmul/run.cpp | 104 ++++++++++++++++++++++++++++------------ 1 file changed, 73 insertions(+), 31 deletions(-) diff --git a/examples/matmul/run.cpp b/examples/matmul/run.cpp index c5dbe2c..f38a7d7 100644 --- a/examples/matmul/run.cpp +++ b/examples/matmul/run.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include "gpu.h" // createContext, createTensor, createKernel, dispatchKernel, // wait, resetCommandBuffer, toCPU @@ -13,6 +14,33 @@ using namespace gpu; +// For loop unroller function +std::string unroll(const std::string& code) { + std::regex forLoopPattern(R"(for\s*\(\s*var\s+(\w+):\s*u32\s*=\s*(\d+)\s*;\s*\1\s*<\s*(\d+)\s*;\s*\1\+\+\s*\)\s*\{\s*([^{}]*)\})"); + + std::smatch match; + std::string unrolledCode = code; + while (std::regex_search(unrolledCode, match, forLoopPattern)) { + std::string varName = match[1]; + int start = std::stoi(match[2]); + int end = std::stoi(match[3]); + std::string loopBody = match[4]; + LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str()); + + std::string unrolledLoop; + for (int i = start; i < end; ++i) { + std::string unrolledIteration = loopBody; + std::regex varPattern(varName); + unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i)); + unrolledLoop += unrolledIteration; + } + + unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length()); + } + + return unrolledCode; +} + static const char *kShaderMatmul1 = R"( @group(0) @binding(0) var A: array<{{precision}}>; @group(0) @binding(1) var B: array<{{precision}}>; @@ -180,15 +208,14 @@ fn main( // Compute tile for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) { let tmp = tileB[threadCol * {{BK}} + dotIdx]; - for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) { - + for (var residx: u32 = 0; residx < {{TM}}; residx++) { threadResults[residx] += tileA[(threadRow + residx) * {{BK}} + dotIdx] * tmp; } } workgroupBarrier(); } - for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) { + for (var residx: u32 = 0; residx < {{TM}}; residx++) { c[cPtr + (threadRow + residx) * {{N}} + threadCol] = threadResults[residx]; } @@ -200,7 +227,8 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M, const size_t BK, const size_t BN, const size_t TM, const Shape &workgroupSize = {256, 1, 1}, - NumType precision = kf32) { + NumType precision = kf32, + bool unrolling = false) { assert(BM % TM == 0); assert(K % BK == 0); assert(M % BM == 0); @@ -218,7 +246,13 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M, {"{{BK}}", toString(BK)}, {"{{BN}}", toString(BN)}, {"{{TM}}", toString(TM)}}); - return ShaderCode{codeString, workgroupSize}; + if (unrolling) { + std::string unrolledCode = unroll(codeString); + LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); + return ShaderCode{unrolledCode, workgroupSize}; + } else { + return ShaderCode{codeString, workgroupSize}; + } } /** @@ -262,9 +296,6 @@ fn main( let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}}; let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}}; - let numIterA: u32 = {{BM}} * {{BK}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}})); - let numIterB: u32 = {{BK}} * {{BN}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}})); - // aPtr and bPtr are the starting positions of the tiles in a and b, // incremented in the bkidx loop. // cPtr is the starting position of the tile in c which is fixed. @@ -278,17 +309,13 @@ fn main( // Load tile // Load BM x BK by numThread(BM * BN / (TM * TN)) // The number of iteration == BM * BK / (BM * BN / (TM * TN)) - for (var i: u32 = 0; i < numIterA; i++) { - let loadColA: u32 = (localID.x + i * numThread) % {{BK}}; - let loadRowA: u32 = (localID.x + i * numThread) / {{BK}}; - tileA[loadRowA * {{BK}} + loadColA] = a[aPtr + loadRowA * {{K}} + loadColA]; + for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) { + tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}]; } // Load BK x BN by numThread(BM * BN / (TM * TN)) // The number of iteration == BK * BN / (BM * BN / (TM * TN)) - for (var i: u32 = 0; i < numIterB; i++) { - let loadColB: u32 = (localID.x + i * numThread) % {{BK}}; - let loadRowB: u32 = (localID.x + i * numThread) / {{BK}}; - tileB[loadRowB * {{BK}} + loadColB] = b[bPtr + loadRowB * {{K}} + loadColB]; + for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) { + tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})]; } aPtr += {{BK}}; @@ -297,11 +324,11 @@ fn main( workgroupBarrier(); // Compute tile for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) { - for (var i: u32 = 0; i < {{TM}}; i++) { - localM[i] = tileA[(threadRow + i) * {{BK}} + dotIdx]; + for (var idx: u32 = 0; idx < {{TM}}; idx++) { + localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx]; } - for (var i: u32 = 0; i < {{TN}}; i++) { - localN[i] = tileB[(threadCol + i) * {{BK}} + dotIdx]; + for (var idx: u32 = 0; idx < {{TN}}; idx++) { + localN[idx] = tileB[(threadCol + idx) * {{BK}} + dotIdx]; } for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) { for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) { @@ -325,15 +352,15 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M, const size_t BK, const size_t BN, const size_t TM, const size_t TN, const Shape &workgroupSize = {256, 1, 1}, - NumType precision = kf32) { + NumType precision = kf32, + bool unrolling = false) { assert(BM % TM == 0); assert(BN % TN == 0); assert(K % BK == 0); assert(M % BM == 0); assert(N % BN == 0); // # threads = tile A size == tile B size == # threads for computing C - //assert(/* tile A size */ BM * BK == /* tile B size */ BK * BN); - //assert(/* tile A size */ BM * BK == /* # of threads for C */ BM * BN / (TM * TN)); + int num_threads = BM * BN / (TM * TN); std::string codeString(shaderTemplate); replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)}, {"{{precision}}", toString(precision)}, @@ -344,8 +371,17 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M, {"{{BK}}", toString(BK)}, {"{{BN}}", toString(BN)}, {"{{TM}}", toString(TM)}, - {"{{TN}}", toString(TN)}}); - return ShaderCode{codeString, workgroupSize}; + {"{{TN}}", toString(TN)}, + {"{{NUM_TILEA}}", toString(BM * BK / num_threads)}, + {"{{NUM_TILEB}}", toString(BN * BK / num_threads)} + }); + if (unrolling) { + std::string unrolledCode = unroll(codeString); + LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); + return ShaderCode{unrolledCode, workgroupSize}; + } else { + return ShaderCode{codeString, workgroupSize}; + } } inline ShaderCode createNoOp(const char *shaderTemplate, @@ -401,7 +437,7 @@ Kernel selectMatmul(Context &ctx, int version, kernel = createKernel(ctx, matmul, bindings, /* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1})); - } else if (version == 3) { + } else if (version == 3 || version == 5) { static constexpr size_t BM = 64; static constexpr size_t BK = 4; static constexpr size_t BN = BM; @@ -415,10 +451,12 @@ Kernel selectMatmul(Context &ctx, int version, LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str()); LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); ShaderCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM, - /*wgSize*/ wgSize); + /*wgSize*/ wgSize, + kf32, + /*Loop unrolling*/ version == 5 ? true: false); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); - } else if (version == 4) { + } else if (version == 4 || version == 6) { static constexpr size_t BM = 64; static constexpr size_t BK = 16; static constexpr size_t BN = 64; @@ -431,10 +469,12 @@ Kernel selectMatmul(Context &ctx, int version, LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str()); LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); ShaderCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN, - /*wgSize*/ wgSize); + /*wgSize*/ wgSize, + kf32, + /*Loop unrolling*/ version == 6 ? true: false); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); - } else if (version == 5) { + } else if (version == 7) { Shape wgSize = {256, 1, 1}; Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1}); ShaderCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize); @@ -513,7 +553,9 @@ int main() { // 2 == tiling // 3 == 1D blocktiling // 4 == 2D blocktiling - // 5 == No-Op + // 5 == 1D blocktiling with loop unrolling + // 6 == 2D blocktiling with loop unrolling + // 7 == No-Op size_t M, K, N; // Matrix dimensions static constexpr int kTestSize = 2; From 0ae5bed4ffd9baa5af0e1546b50512facb3c3fdf Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Tue, 16 Jul 2024 03:22:47 +0900 Subject: [PATCH 2/3] Move loopUnrolling function to experimental/wgsl.h --- examples/matmul/run.cpp | 33 ++----------------- experimental/wgsl.h | 73 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+), 30 deletions(-) create mode 100644 experimental/wgsl.h diff --git a/examples/matmul/run.cpp b/examples/matmul/run.cpp index f38a7d7..d9c5c1f 100644 --- a/examples/matmul/run.cpp +++ b/examples/matmul/run.cpp @@ -3,7 +3,6 @@ #include #include #include -#include #include "gpu.h" // createContext, createTensor, createKernel, dispatchKernel, // wait, resetCommandBuffer, toCPU @@ -11,36 +10,10 @@ #include "llmc/reference_impls.h" // for CPU reference implementation #include "utils/array_utils.h" // show, isclose, randn, randint #include "utils/logging.h" // LOG +#include "experimental/wgsl.h" // loopUnrolling using namespace gpu; -// For loop unroller function -std::string unroll(const std::string& code) { - std::regex forLoopPattern(R"(for\s*\(\s*var\s+(\w+):\s*u32\s*=\s*(\d+)\s*;\s*\1\s*<\s*(\d+)\s*;\s*\1\+\+\s*\)\s*\{\s*([^{}]*)\})"); - - std::smatch match; - std::string unrolledCode = code; - while (std::regex_search(unrolledCode, match, forLoopPattern)) { - std::string varName = match[1]; - int start = std::stoi(match[2]); - int end = std::stoi(match[3]); - std::string loopBody = match[4]; - LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str()); - - std::string unrolledLoop; - for (int i = start; i < end; ++i) { - std::string unrolledIteration = loopBody; - std::regex varPattern(varName); - unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i)); - unrolledLoop += unrolledIteration; - } - - unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length()); - } - - return unrolledCode; -} - static const char *kShaderMatmul1 = R"( @group(0) @binding(0) var A: array<{{precision}}>; @group(0) @binding(1) var B: array<{{precision}}>; @@ -247,7 +220,7 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M, {"{{BN}}", toString(BN)}, {"{{TM}}", toString(TM)}}); if (unrolling) { - std::string unrolledCode = unroll(codeString); + std::string unrolledCode = loopUnrolling(codeString); LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); return ShaderCode{unrolledCode, workgroupSize}; } else { @@ -376,7 +349,7 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M, {"{{NUM_TILEB}}", toString(BN * BK / num_threads)} }); if (unrolling) { - std::string unrolledCode = unroll(codeString); + std::string unrolledCode = loopUnrolling(codeString); LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); return ShaderCode{unrolledCode, workgroupSize}; } else { diff --git a/experimental/wgsl.h b/experimental/wgsl.h new file mode 100644 index 0000000..47cd7e5 --- /dev/null +++ b/experimental/wgsl.h @@ -0,0 +1,73 @@ +#ifndef GPU_CPP_WGSL_H +#define GPU_CPP_WGSL_H + +#include +#include +#include "utils/logging.h" // LOG + +namespace gpu { + +// Loop-unrolling optimization with regex +// +// Note: Be cautious, as it does not correctly recognize comments or lexical tokens. +std::string loopUnrolling(const std::string& code) { + // This regex pattern matches a for loop with the following structure: + // for (var : u32 = ; < ; ++) { } + std::regex forLoopPattern(R"(for\s*\(\s*var\s+(\w+):\s*u32\s*=\s*(\d+)\s*;\s*\1\s*<\s*(\d+)\s*;\s*\1\+\+\s*\)\s*\{\s*([^{}]*)\})"); + // Explanation of the regex: + // for\s*\( : Matches 'for (' with optional whitespace + // \s*var\s+ : Matches 'var ' with optional whitespace + // (\w+) : Captures the variable name (alphanumeric characters and underscores) + // :\s*u32\s*=\s* : Matches ': u32 = ' with optional whitespace + // (\d+) : Captures the start value (one or more digits) + // \s*;\s* : Matches ';' with optional whitespace + // \1\s*<\s* : Matches the captured variable name followed by '<' with optional whitespace + // (\d+) : Captures the end value (one or more digits) + // \s*;\s* : Matches ';' with optional whitespace + // \1\+\+\s* : Matches the captured variable name followed by '++' with optional whitespace + // \)\s*\{\s* : Matches ')' followed by '{' with optional whitespace + // ([^{}]*) : Captures the loop body (anything except '{' or '}') + // \} : Matches the closing '}' + + // Example: + // + // Input code: + // for (var i: u32 = 0; i < 3; i++) { std::cout << i << std::endl; } + // + // Matches: + // varName = "i" + // start = "0" + // end = "3" + // loopBody = "std::cout << i << std::endl;" + // + // Unrolled: + // std::cout << 0 << std::endl; + // std::cout << 1 << std::endl; + // std::cout << 2 << std::endl; + // + std::smatch match; + std::string unrolledCode = code; + while (std::regex_search(unrolledCode, match, forLoopPattern)) { + std::string varName = match[1]; + int start = std::stoi(match[2]); + int end = std::stoi(match[3]); + std::string loopBody = match[4]; + LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str()); + + std::string unrolledLoop; + for (int i = start; i < end; ++i) { + std::string unrolledIteration = loopBody; + std::regex varPattern(varName); + unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i)); + unrolledLoop += unrolledIteration; + } + + unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length()); + } + + return unrolledCode; +} + +} // namespace gpu + +#endif From 91b572ddcd9674257efe680664c9a79b8c93cf00 Mon Sep 17 00:00:00 2001 From: Junji Hashimoto Date: Tue, 16 Jul 2024 10:19:14 +0900 Subject: [PATCH 3/3] Add the threshold of loop-unrolling --- experimental/wgsl.h | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/experimental/wgsl.h b/experimental/wgsl.h index 47cd7e5..e6521d8 100644 --- a/experimental/wgsl.h +++ b/experimental/wgsl.h @@ -10,7 +10,7 @@ namespace gpu { // Loop-unrolling optimization with regex // // Note: Be cautious, as it does not correctly recognize comments or lexical tokens. -std::string loopUnrolling(const std::string& code) { +std::string loopUnrolling(const std::string& code, int threshold = 32) { // This regex pattern matches a for loop with the following structure: // for (var : u32 = ; < ; ++) { } std::regex forLoopPattern(R"(for\s*\(\s*var\s+(\w+):\s*u32\s*=\s*(\d+)\s*;\s*\1\s*<\s*(\d+)\s*;\s*\1\+\+\s*\)\s*\{\s*([^{}]*)\})"); @@ -52,17 +52,28 @@ std::string loopUnrolling(const std::string& code) { int start = std::stoi(match[2]); int end = std::stoi(match[3]); std::string loopBody = match[4]; - LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str()); - std::string unrolledLoop; - for (int i = start; i < end; ++i) { - std::string unrolledIteration = loopBody; - std::regex varPattern(varName); - unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i)); - unrolledLoop += unrolledIteration; + if (end - start > threshold ) { + std::string skippedLoop = + "for (var " + + std::string(match[1]) + ": u32 = " + std::string(match[2]) + ";"+ + std::string(match[1]) + " < " + std::string(match[3]) + ";"+ + std::string(match[1]) + "++) /* Skipped */ {"+ + std::string(match[4]) + + "}"; + LOG(kDefLog, kInfo, "Roll loop:%s", skippedLoop.c_str()); + unrolledCode = unrolledCode.substr(0, match.position()) + skippedLoop + unrolledCode.substr(match.position() + match.length()); + } else { + LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str()); + std::string unrolledLoop; + for (int i = start; i < end; ++i) { + std::string unrolledIteration = loopBody; + std::regex varPattern(varName); + unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i)); + unrolledLoop += unrolledIteration; + } + unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length()); } - - unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length()); } return unrolledCode;