Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add loop unrolling for matmul #11

Merged
merged 3 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 46 additions & 31 deletions examples/matmul/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "llmc/reference_impls.h" // for CPU reference implementation
#include "utils/array_utils.h" // show, isclose, randn, randint
#include "utils/logging.h" // LOG
#include "experimental/wgsl.h" // loopUnrolling

using namespace gpu;

Expand Down Expand Up @@ -180,15 +181,14 @@ fn main(
// Compute tile
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
let tmp = tileB[threadCol * {{BK}} + dotIdx];
for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) {

for (var residx: u32 = 0; residx < {{TM}}; residx++) {
threadResults[residx] += tileA[(threadRow + residx) * {{BK}} + dotIdx] * tmp;
}
}
workgroupBarrier();
}

for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) {
for (var residx: u32 = 0; residx < {{TM}}; residx++) {
c[cPtr + (threadRow + residx) * {{N}} + threadCol] = threadResults[residx];
}

Expand All @@ -200,7 +200,8 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M,
const size_t BK, const size_t BN,
const size_t TM,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
NumType precision = kf32,
bool unrolling = false) {
assert(BM % TM == 0);
assert(K % BK == 0);
assert(M % BM == 0);
Expand All @@ -218,7 +219,13 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M,
{"{{BK}}", toString(BK)},
{"{{BN}}", toString(BN)},
{"{{TM}}", toString(TM)}});
return ShaderCode{codeString, workgroupSize};
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return ShaderCode{unrolledCode, workgroupSize};
} else {
return ShaderCode{codeString, workgroupSize};
}
}

/**
Expand Down Expand Up @@ -262,9 +269,6 @@ fn main(
let threadRow: u32 = (localID.x / ({{BN}} / {{TN}})) * {{TM}};
let threadCol: u32 = (localID.x % ({{BN}} / {{TN}})) * {{TN}};

let numIterA: u32 = {{BM}} * {{BK}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}}));
let numIterB: u32 = {{BK}} * {{BN}} / ({{BM}} * {{BN}} / ({{TM}} * {{TN}}));

// aPtr and bPtr are the starting positions of the tiles in a and b,
// incremented in the bkidx loop.
// cPtr is the starting position of the tile in c which is fixed.
Expand All @@ -278,17 +282,13 @@ fn main(
// Load tile
// Load BM x BK by numThread(BM * BN / (TM * TN))
// The number of iteration == BM * BK / (BM * BN / (TM * TN))
for (var i: u32 = 0; i < numIterA; i++) {
let loadColA: u32 = (localID.x + i * numThread) % {{BK}};
let loadRowA: u32 = (localID.x + i * numThread) / {{BK}};
tileA[loadRowA * {{BK}} + loadColA] = a[aPtr + loadRowA * {{K}} + loadColA];
for (var idx: u32 = 0; idx < {{NUM_TILEA}}; idx++) {
tileA[localID.x + idx * numThread] = a[aPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + (localID.x + idx * numThread) % {{BK}}];
}
// Load BK x BN by numThread(BM * BN / (TM * TN))
// The number of iteration == BK * BN / (BM * BN / (TM * TN))
for (var i: u32 = 0; i < numIterB; i++) {
let loadColB: u32 = (localID.x + i * numThread) % {{BK}};
let loadRowB: u32 = (localID.x + i * numThread) / {{BK}};
tileB[loadRowB * {{BK}} + loadColB] = b[bPtr + loadRowB * {{K}} + loadColB];
for (var idx: u32 = 0; idx < {{NUM_TILEB}}; idx++) {
tileB[localID.x + idx * numThread] = b[bPtr + ((localID.x + idx * numThread) / {{BK}}) * {{K}} + ((localID.x + idx * numThread) % {{BK}})];
}

aPtr += {{BK}};
Expand All @@ -297,11 +297,11 @@ fn main(
workgroupBarrier();
// Compute tile
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
for (var i: u32 = 0; i < {{TM}}; i++) {
localM[i] = tileA[(threadRow + i) * {{BK}} + dotIdx];
for (var idx: u32 = 0; idx < {{TM}}; idx++) {
localM[idx] = tileA[(threadRow + idx) * {{BK}} + dotIdx];
}
for (var i: u32 = 0; i < {{TN}}; i++) {
localN[i] = tileB[(threadCol + i) * {{BK}} + dotIdx];
for (var idx: u32 = 0; idx < {{TN}}; idx++) {
localN[idx] = tileB[(threadCol + idx) * {{BK}} + dotIdx];
}
for (var resIdxM: u32 = 0; resIdxM < {{TM}}; resIdxM++) {
for (var resIdxN: u32 = 0; resIdxN < {{TN}}; resIdxN++) {
Expand All @@ -325,15 +325,15 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M,
const size_t BK, const size_t BN,
const size_t TM, const size_t TN,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
NumType precision = kf32,
bool unrolling = false) {
assert(BM % TM == 0);
assert(BN % TN == 0);
assert(K % BK == 0);
assert(M % BM == 0);
assert(N % BN == 0);
// # threads = tile A size == tile B size == # threads for computing C
//assert(/* tile A size */ BM * BK == /* tile B size */ BK * BN);
//assert(/* tile A size */ BM * BK == /* # of threads for C */ BM * BN / (TM * TN));
int num_threads = BM * BN / (TM * TN);
std::string codeString(shaderTemplate);
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
{"{{precision}}", toString(precision)},
Expand All @@ -344,8 +344,17 @@ inline ShaderCode createMatmul4(const char *shaderTemplate, const size_t M,
{"{{BK}}", toString(BK)},
{"{{BN}}", toString(BN)},
{"{{TM}}", toString(TM)},
{"{{TN}}", toString(TN)}});
return ShaderCode{codeString, workgroupSize};
{"{{TN}}", toString(TN)},
{"{{NUM_TILEA}}", toString(BM * BK / num_threads)},
{"{{NUM_TILEB}}", toString(BN * BK / num_threads)}
});
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return ShaderCode{unrolledCode, workgroupSize};
} else {
return ShaderCode{codeString, workgroupSize};
}
}

inline ShaderCode createNoOp(const char *shaderTemplate,
Expand Down Expand Up @@ -401,7 +410,7 @@ Kernel selectMatmul(Context &ctx, int version,
kernel =
createKernel(ctx, matmul, bindings,
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
} else if (version == 3) {
} else if (version == 3 || version == 5) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 4;
static constexpr size_t BN = BM;
Expand All @@ -415,10 +424,12 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
ShaderCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
/*wgSize*/ wgSize);
/*wgSize*/ wgSize,
kf32,
/*Loop unrolling*/ version == 5 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
} else if (version == 4) {
} else if (version == 4 || version == 6) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 16;
static constexpr size_t BN = 64;
Expand All @@ -431,10 +442,12 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
ShaderCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
/*wgSize*/ wgSize);
/*wgSize*/ wgSize,
kf32,
/*Loop unrolling*/ version == 6 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
} else if (version == 5) {
} else if (version == 7) {
Shape wgSize = {256, 1, 1};
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
ShaderCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
Expand Down Expand Up @@ -513,7 +526,9 @@ int main() {
// 2 == tiling
// 3 == 1D blocktiling
// 4 == 2D blocktiling
// 5 == No-Op
// 5 == 1D blocktiling with loop unrolling
// 6 == 2D blocktiling with loop unrolling
// 7 == No-Op

size_t M, K, N; // Matrix dimensions
static constexpr int kTestSize = 2;
Expand Down
84 changes: 84 additions & 0 deletions experimental/wgsl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#ifndef GPU_CPP_WGSL_H
#define GPU_CPP_WGSL_H

#include <string>
#include <regex>
#include "utils/logging.h" // LOG

namespace gpu {

// Loop-unrolling optimization with regex
//
// Note: Be cautious, as it does not correctly recognize comments or lexical tokens.
std::string loopUnrolling(const std::string& code, int threshold = 32) {
// This regex pattern matches a for loop with the following structure:
// for (var <varName>: u32 = <start>; <varName> < <end>; <varName>++) { <loopBody> }
std::regex forLoopPattern(R"(for\s*\(\s*var\s+(\w+):\s*u32\s*=\s*(\d+)\s*;\s*\1\s*<\s*(\d+)\s*;\s*\1\+\+\s*\)\s*\{\s*([^{}]*)\})");
// Explanation of the regex:
// for\s*\( : Matches 'for (' with optional whitespace
// \s*var\s+ : Matches 'var ' with optional whitespace
// (\w+) : Captures the variable name (alphanumeric characters and underscores)
// :\s*u32\s*=\s* : Matches ': u32 = ' with optional whitespace
// (\d+) : Captures the start value (one or more digits)
// \s*;\s* : Matches ';' with optional whitespace
// \1\s*<\s* : Matches the captured variable name followed by '<' with optional whitespace
// (\d+) : Captures the end value (one or more digits)
// \s*;\s* : Matches ';' with optional whitespace
// \1\+\+\s* : Matches the captured variable name followed by '++' with optional whitespace
// \)\s*\{\s* : Matches ')' followed by '{' with optional whitespace
// ([^{}]*) : Captures the loop body (anything except '{' or '}')
// \} : Matches the closing '}'

// Example:
//
// Input code:
// for (var i: u32 = 0; i < 3; i++) { std::cout << i << std::endl; }
//
// Matches:
// varName = "i"
// start = "0"
// end = "3"
// loopBody = "std::cout << i << std::endl;"
//
// Unrolled:
// std::cout << 0 << std::endl;
// std::cout << 1 << std::endl;
// std::cout << 2 << std::endl;
//
std::smatch match;
std::string unrolledCode = code;
while (std::regex_search(unrolledCode, match, forLoopPattern)) {
std::string varName = match[1];
int start = std::stoi(match[2]);
int end = std::stoi(match[3]);
std::string loopBody = match[4];

if (end - start > threshold ) {
std::string skippedLoop =
"for (var " +
std::string(match[1]) + ": u32 = " + std::string(match[2]) + ";"+
std::string(match[1]) + " < " + std::string(match[3]) + ";"+
std::string(match[1]) + "++) /* Skipped */ {"+
std::string(match[4]) +
"}";
LOG(kDefLog, kInfo, "Roll loop:%s", skippedLoop.c_str());
unrolledCode = unrolledCode.substr(0, match.position()) + skippedLoop + unrolledCode.substr(match.position() + match.length());
} else {
LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str());
std::string unrolledLoop;
for (int i = start; i < end; ++i) {
std::string unrolledIteration = loopBody;
std::regex varPattern(varName);
unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i));
unrolledLoop += unrolledIteration;
}
unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length());
}
}

return unrolledCode;
}

} // namespace gpu

#endif
Loading