diff --git a/examples/matmul/run.cpp b/examples/matmul/run.cpp index c5dbe2c..d9c5c1f 100644 --- a/examples/matmul/run.cpp +++ b/examples/matmul/run.cpp @@ -10,6 +10,7 @@ #include "llmc/reference_impls.h" // for CPU reference implementation #include "utils/array_utils.h" // show, isclose, randn, randint #include "utils/logging.h" // LOG +#include "experimental/wgsl.h" // loopUnrolling using namespace gpu; @@ -180,15 +181,14 @@ fn main( // Compute tile for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) { let tmp = tileB[threadCol * {{BK}} + dotIdx]; - for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) { - + for (var residx: u32 = 0; residx < {{TM}}; residx++) { threadResults[residx] += tileA[(threadRow + residx) * {{BK}} + dotIdx] * tmp; } } workgroupBarrier(); } - for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) { + for (var residx: u32 = 0; residx < {{TM}}; residx++) { c[cPtr + (threadRow + residx) * {{N}} + threadCol] = threadResults[residx]; } @@ -200,7 +200,8 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M, const size_t BK, const size_t BN, const size_t TM, const Shape &workgroupSize = {256, 1, 1}, - NumType precision = kf32) { + NumType precision = kf32, + bool unrolling = false) { assert(BM % TM == 0); assert(K % BK == 0); assert(M % BM == 0); @@ -218,7 +219,13 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M, {"{{BK}}", toString(BK)}, {"{{BN}}", toString(BN)}, {"{{TM}}", toString(TM)}}); - return ShaderCode{codeString, workgroupSize}; + if (unrolling) { + std::string unrolledCode = loopUnrolling(codeString); + LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); + return ShaderCode{unrolledCode, workgroupSize}; + } else { + return ShaderCode{codeString, workgroupSize}; + } } /** @@ -262,9 +269,6 @@ fn main( let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}}; let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}}; - let numIterA: u32 = {{BM}} * {{BK}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}})); - let numIterB: u32 = {{BK}} * {{BN}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}})); - // aPtr and bPtr are the starting positions of the tiles in a and b, // incremented in the bkidx loop. // cPtr is the starting position of the tile in c which is fixed. @@ -278,17 +282,13 @@ fn main( // Load tile // Load BM x BK by numThread(BM * BN / (TM * TN)) // The number of iteration == BM * BK / (BM * BN / (TM * TN)) - for (var i: u32 = 0; i < numIterA; i++) { - let loadColA: u32 = (localID.x + i * numThread) % {{BK}}; - let loadRowA: u32 = (localID.x + i * numThread) / {{BK}}; - tileA[loadRowA * {{BK}} + loadColA] = a[aPtr + loadRowA * {{K}} + loadColA]; + for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) { + tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}]; } // Load BK x BN by numThread(BM * BN / (TM * TN)) // The number of iteration == BK * BN / (BM * BN / (TM * TN)) - for (var i: u32 = 0; i < numIterB; i++) { - let loadColB: u32 = (localID.x + i * numThread) % {{BK}}; - let loadRowB: u32 = (localID.x + i * numThread) / {{BK}}; - tileB[loadRowB * {{BK}} + loadColB] = b[bPtr + loadRowB * {{K}} + loadColB]; + for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) { + tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})]; } aPtr += {{BK}}; @@ -297,11 +297,11 @@ fn main( workgroupBarrier(); // Compute tile for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) { - for (var i: u32 = 0; i < {{TM}}; i++) { - localM[i] = tileA[(threadRow + i) * {{BK}} + dotIdx]; + for (var idx: u32 = 0; idx < {{TM}}; idx++) { + localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx]; } - for (var i: u32 = 0; i < {{TN}}; i++) { - localN[i] = tileB[(threadCol + i) * {{BK}} + dotIdx]; + for (var idx: u32 = 0; idx < {{TN}}; idx++) { + localN[idx] = tileB[(threadCol + idx) * {{BK}} + dotIdx]; } for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) { for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) { @@ -325,15 +325,15 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M, const size_t BK, const size_t BN, const size_t TM, const size_t TN, const Shape &workgroupSize = {256, 1, 1}, - NumType precision = kf32) { + NumType precision = kf32, + bool unrolling = false) { assert(BM % TM == 0); assert(BN % TN == 0); assert(K % BK == 0); assert(M % BM == 0); assert(N % BN == 0); // # threads = tile A size == tile B size == # threads for computing C - //assert(/* tile A size */ BM * BK == /* tile B size */ BK * BN); - //assert(/* tile A size */ BM * BK == /* # of threads for C */ BM * BN / (TM * TN)); + int num_threads = BM * BN / (TM * TN); std::string codeString(shaderTemplate); replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)}, {"{{precision}}", toString(precision)}, @@ -344,8 +344,17 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M, {"{{BK}}", toString(BK)}, {"{{BN}}", toString(BN)}, {"{{TM}}", toString(TM)}, - {"{{TN}}", toString(TN)}}); - return ShaderCode{codeString, workgroupSize}; + {"{{TN}}", toString(TN)}, + {"{{NUM_TILEA}}", toString(BM * BK / num_threads)}, + {"{{NUM_TILEB}}", toString(BN * BK / num_threads)} + }); + if (unrolling) { + std::string unrolledCode = loopUnrolling(codeString); + LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); + return ShaderCode{unrolledCode, workgroupSize}; + } else { + return ShaderCode{codeString, workgroupSize}; + } } inline ShaderCode createNoOp(const char *shaderTemplate, @@ -401,7 +410,7 @@ Kernel selectMatmul(Context &ctx, int version, kernel = createKernel(ctx, matmul, bindings, /* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1})); - } else if (version == 3) { + } else if (version == 3 || version == 5) { static constexpr size_t BM = 64; static constexpr size_t BK = 4; static constexpr size_t BN = BM; @@ -415,10 +424,12 @@ Kernel selectMatmul(Context &ctx, int version, LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str()); LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); ShaderCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM, - /*wgSize*/ wgSize); + /*wgSize*/ wgSize, + kf32, + /*Loop unrolling*/ version == 5 ? true: false); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); - } else if (version == 4) { + } else if (version == 4 || version == 6) { static constexpr size_t BM = 64; static constexpr size_t BK = 16; static constexpr size_t BN = 64; @@ -431,10 +442,12 @@ Kernel selectMatmul(Context &ctx, int version, LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str()); LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); ShaderCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN, - /*wgSize*/ wgSize); + /*wgSize*/ wgSize, + kf32, + /*Loop unrolling*/ version == 6 ? true: false); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); - } else if (version == 5) { + } else if (version == 7) { Shape wgSize = {256, 1, 1}; Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1}); ShaderCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize); @@ -513,7 +526,9 @@ int main() { // 2 == tiling // 3 == 1D blocktiling // 4 == 2D blocktiling - // 5 == No-Op + // 5 == 1D blocktiling with loop unrolling + // 6 == 2D blocktiling with loop unrolling + // 7 == No-Op size_t M, K, N; // Matrix dimensions static constexpr int kTestSize = 2; diff --git a/experimental/wgsl.h b/experimental/wgsl.h new file mode 100644 index 0000000..e6521d8 --- /dev/null +++ b/experimental/wgsl.h @@ -0,0 +1,84 @@ +#ifndef GPU_CPP_WGSL_H +#define GPU_CPP_WGSL_H + +#include +#include +#include "utils/logging.h" // LOG + +namespace gpu { + +// Loop-unrolling optimization with regex +// +// Note: Be cautious, as it does not correctly recognize comments or lexical tokens. +std::string loopUnrolling(const std::string& code, int threshold = 32) { + // This regex pattern matches a for loop with the following structure: + // for (var : u32 = ; < ; ++) { } + std::regex forLoopPattern(R"(for\s*\(\s*var\s+(\w+):\s*u32\s*=\s*(\d+)\s*;\s*\1\s*<\s*(\d+)\s*;\s*\1\+\+\s*\)\s*\{\s*([^{}]*)\})"); + // Explanation of the regex: + // for\s*\( : Matches 'for (' with optional whitespace + // \s*var\s+ : Matches 'var ' with optional whitespace + // (\w+) : Captures the variable name (alphanumeric characters and underscores) + // :\s*u32\s*=\s* : Matches ': u32 = ' with optional whitespace + // (\d+) : Captures the start value (one or more digits) + // \s*;\s* : Matches ';' with optional whitespace + // \1\s*<\s* : Matches the captured variable name followed by '<' with optional whitespace + // (\d+) : Captures the end value (one or more digits) + // \s*;\s* : Matches ';' with optional whitespace + // \1\+\+\s* : Matches the captured variable name followed by '++' with optional whitespace + // \)\s*\{\s* : Matches ')' followed by '{' with optional whitespace + // ([^{}]*) : Captures the loop body (anything except '{' or '}') + // \} : Matches the closing '}' + + // Example: + // + // Input code: + // for (var i: u32 = 0; i < 3; i++) { std::cout << i << std::endl; } + // + // Matches: + // varName = "i" + // start = "0" + // end = "3" + // loopBody = "std::cout << i << std::endl;" + // + // Unrolled: + // std::cout << 0 << std::endl; + // std::cout << 1 << std::endl; + // std::cout << 2 << std::endl; + // + std::smatch match; + std::string unrolledCode = code; + while (std::regex_search(unrolledCode, match, forLoopPattern)) { + std::string varName = match[1]; + int start = std::stoi(match[2]); + int end = std::stoi(match[3]); + std::string loopBody = match[4]; + + if (end - start > threshold ) { + std::string skippedLoop = + "for (var " + + std::string(match[1]) + ": u32 = " + std::string(match[2]) + ";"+ + std::string(match[1]) + " < " + std::string(match[3]) + ";"+ + std::string(match[1]) + "++) /* Skipped */ {"+ + std::string(match[4]) + + "}"; + LOG(kDefLog, kInfo, "Roll loop:%s", skippedLoop.c_str()); + unrolledCode = unrolledCode.substr(0, match.position()) + skippedLoop + unrolledCode.substr(match.position() + match.length()); + } else { + LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str()); + std::string unrolledLoop; + for (int i = start; i < end; ++i) { + std::string unrolledIteration = loopBody; + std::regex varPattern(varName); + unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i)); + unrolledLoop += unrolledIteration; + } + unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length()); + } + } + + return unrolledCode; +} + +} // namespace gpu + +#endif