Skip to content

Commit

Permalink
Merge pull request #11 from junjihashimoto/feature/unrolling
Browse files Browse the repository at this point in the history
Add loop unrolling for matmul
  • Loading branch information
austinvhuang authored Jul 16, 2024
2 parents 3a3d4e7 + 91b572d commit af6e1e0
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 31 deletions.
77 changes: 46 additions & 31 deletions examples/matmul/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llmc/reference_impls.h" // for CPU reference implementation
#include "utils/array_utils.h" // show, isclose, randn, randint
#include "utils/logging.h" // LOG
#include "experimental/wgsl.h" // loopUnrolling

using namespace gpu;

Expand Down Expand Up @@ -180,15 +181,14 @@ fn main(
// Compute tile
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
let tmp = tileB[threadCol * {{BK}} + dotIdx];
for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) {
for (var residx: u32 = 0; residx < {{TM}}; residx++) {
threadResults[residx] += tileA[(threadRow + residx) * {{BK}} + dotIdx] * tmp;
}
}
workgroupBarrier();
}
for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) {
for (var residx: u32 = 0; residx < {{TM}}; residx++) {
c[cPtr + (threadRow + residx) * {{N}} + threadCol] = threadResults[residx];
}
Expand All @@ -200,7 +200,8 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M,
const size_t BK, const size_t BN,
const size_t TM,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
NumType precision = kf32,
bool unrolling = false) {
assert(BM % TM == 0);
assert(K % BK == 0);
assert(M % BM == 0);
Expand All @@ -218,7 +219,13 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M,
{"{{BK}}", toString(BK)},
{"{{BN}}", toString(BN)},
{"{{TM}}", toString(TM)}});
return ShaderCode{codeString, workgroupSize};
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return ShaderCode{unrolledCode, workgroupSize};
} else {
return ShaderCode{codeString, workgroupSize};
}
}

/**
Expand Down Expand Up @@ -262,9 +269,6 @@ fn main(
let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};
let numIterA: u32 = {{BM}} * {{BK}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}}));
let numIterB: u32 = {{BK}} * {{BN}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}}));
// aPtr and bPtr are the starting positions of the tiles in a and b,
// incremented in the bkidx loop.
// cPtr is the starting position of the tile in c which is fixed.
Expand All @@ -278,17 +282,13 @@ fn main(
// Load tile
// Load BM x BK by numThread(BM * BN / (TM * TN))
// The number of iteration == BM * BK / (BM * BN / (TM * TN))
for (var i: u32 = 0; i < numIterA; i++) {
let loadColA: u32 = (localID.x + i * numThread) % {{BK}};
let loadRowA: u32 = (localID.x + i * numThread) / {{BK}};
tileA[loadRowA * {{BK}} + loadColA] = a[aPtr + loadRowA * {{K}} + loadColA];
for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
}
// Load BK x BN by numThread(BM * BN / (TM * TN))
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
for (var i: u32 = 0; i < numIterB; i++) {
let loadColB: u32 = (localID.x + i * numThread) % {{BK}};
let loadRowB: u32 = (localID.x + i * numThread) / {{BK}};
tileB[loadRowB * {{BK}} + loadColB] = b[bPtr + loadRowB * {{K}} + loadColB];
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
}
aPtr += {{BK}};
Expand All @@ -297,11 +297,11 @@ fn main(
workgroupBarrier();
// Compute tile
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
for (var i: u32 = 0; i < {{TM}}; i++) {
localM[i] = tileA[(threadRow + i) * {{BK}} + dotIdx];
for (var idx: u32 = 0; idx < {{TM}}; idx++) {
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
}
for (var i: u32 = 0; i < {{TN}}; i++) {
localN[i] = tileB[(threadCol + i) * {{BK}} + dotIdx];
for (var idx: u32 = 0; idx < {{TN}}; idx++) {
localN[idx] = tileB[(threadCol + idx) * {{BK}} + dotIdx];
}
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) {
Expand All @@ -325,15 +325,15 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M,
const size_t BK, const size_t BN,
const size_t TM, const size_t TN,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
NumType precision = kf32,
bool unrolling = false) {
assert(BM % TM == 0);
assert(BN % TN == 0);
assert(K % BK == 0);
assert(M % BM == 0);
assert(N % BN == 0);
// # threads = tile A size == tile B size == # threads for computing C
//assert(/* tile A size */ BM * BK == /* tile B size */ BK * BN);
//assert(/* tile A size */ BM * BK == /* # of threads for C */ BM * BN / (TM * TN));
int num_threads = BM * BN / (TM * TN);
std::string codeString(shaderTemplate);
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
{"{{precision}}", toString(precision)},
Expand All @@ -344,8 +344,17 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M,
{"{{BK}}", toString(BK)},
{"{{BN}}", toString(BN)},
{"{{TM}}", toString(TM)},
{"{{TN}}", toString(TN)}});
return ShaderCode{codeString, workgroupSize};
{"{{TN}}", toString(TN)},
{"{{NUM_TILEA}}", toString(BM * BK / num_threads)},
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)}
});
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return ShaderCode{unrolledCode, workgroupSize};
} else {
return ShaderCode{codeString, workgroupSize};
}
}

inline ShaderCode createNoOp(const char *shaderTemplate,
Expand Down Expand Up @@ -401,7 +410,7 @@ Kernel selectMatmul(Context &ctx, int version,
kernel =
createKernel(ctx, matmul, bindings,
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
} else if (version == 3) {
} else if (version == 3 || version == 5) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 4;
static constexpr size_t BN = BM;
Expand All @@ -415,10 +424,12 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
ShaderCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
/*wgSize*/ wgSize);
/*wgSize*/ wgSize,
kf32,
/*Loop unrolling*/ version == 5 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
} else if (version == 4) {
} else if (version == 4 || version == 6) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 16;
static constexpr size_t BN = 64;
Expand All @@ -431,10 +442,12 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
ShaderCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
/*wgSize*/ wgSize);
/*wgSize*/ wgSize,
kf32,
/*Loop unrolling*/ version == 6 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
} else if (version == 5) {
} else if (version == 7) {
Shape wgSize = {256, 1, 1};
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
ShaderCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
Expand Down Expand Up @@ -513,7 +526,9 @@ int main() {
// 2 == tiling
// 3 == 1D blocktiling
// 4 == 2D blocktiling
// 5 == No-Op
// 5 == 1D blocktiling with loop unrolling
// 6 == 2D blocktiling with loop unrolling
// 7 == No-Op

size_t M, K, N; // Matrix dimensions
static constexpr int kTestSize = 2;
Expand Down
84 changes: 84 additions & 0 deletions experimental/wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#ifndef GPU_CPP_WGSL_H
#define GPU_CPP_WGSL_H

#include <string>
#include <regex>
#include "utils/logging.h" // LOG

namespace gpu {

// Loop-unrolling optimization with regex
//
// Note: Be cautious, as it does not correctly recognize comments or lexical tokens.
std::string loopUnrolling(const std::string& code, int threshold = 32) {
// This regex pattern matches a for loop with the following structure:
// for (var <varName>: u32 = <start>; <varName> < <end>; <varName>++) { <loopBody> }
std::regex forLoopPattern(R"(for\s*\(\s*var\s+(\w+):\s*u32\s*=\s*(\d+)\s*;\s*\1\s*<\s*(\d+)\s*;\s*\1\+\+\s*\)\s*\{\s*([^{}]*)\})");
// Explanation of the regex:
// for\s*\( : Matches 'for (' with optional whitespace
// \s*var\s+ : Matches 'var ' with optional whitespace
// (\w+) : Captures the variable name (alphanumeric characters and underscores)
// :\s*u32\s*=\s* : Matches ': u32 = ' with optional whitespace
// (\d+) : Captures the start value (one or more digits)
// \s*;\s* : Matches ';' with optional whitespace
// \1\s*<\s* : Matches the captured variable name followed by '<' with optional whitespace
// (\d+) : Captures the end value (one or more digits)
// \s*;\s* : Matches ';' with optional whitespace
// \1\+\+\s* : Matches the captured variable name followed by '++' with optional whitespace
// \)\s*\{\s* : Matches ')' followed by '{' with optional whitespace
// ([^{}]*) : Captures the loop body (anything except '{' or '}')
// \} : Matches the closing '}'

// Example:
//
// Input code:
// for (var i: u32 = 0; i < 3; i++) { std::cout << i << std::endl; }
//
// Matches:
// varName = "i"
// start = "0"
// end = "3"
// loopBody = "std::cout << i << std::endl;"
//
// Unrolled:
// std::cout << 0 << std::endl;
// std::cout << 1 << std::endl;
// std::cout << 2 << std::endl;
//
std::smatch match;
std::string unrolledCode = code;
while (std::regex_search(unrolledCode, match, forLoopPattern)) {
std::string varName = match[1];
int start = std::stoi(match[2]);
int end = std::stoi(match[3]);
std::string loopBody = match[4];

if (end - start > threshold ) {
std::string skippedLoop =
"for (var " +
std::string(match[1]) + ": u32 = " + std::string(match[2]) + ";"+
std::string(match[1]) + " < " + std::string(match[3]) + ";"+
std::string(match[1]) + "++) /* Skipped */ {"+
std::string(match[4]) +
"}";
LOG(kDefLog, kInfo, "Roll loop:%s", skippedLoop.c_str());
unrolledCode = unrolledCode.substr(0, match.position()) + skippedLoop + unrolledCode.substr(match.position() + match.length());
} else {
LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str());
std::string unrolledLoop;
for (int i = start; i < end; ++i) {
std::string unrolledIteration = loopBody;
std::regex varPattern(varName);
unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i));
unrolledLoop += unrolledIteration;
}
unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length());
}
}

return unrolledCode;
}

} // namespace gpu

#endif

0 comments on commit af6e1e0

Please sign in to comment.