Skip to content

Commit

Permalink
Add the threshold of loop-unrolling
Browse files Browse the repository at this point in the history
  • Loading branch information
junjihashimoto committed Jul 16, 2024
1 parent 0ae5bed commit 91b572d
Showing 1 changed file with 21 additions and 10 deletions.
31 changes: 21 additions & 10 deletions experimental/wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace gpu {
// Loop-unrolling optimization with regex
//
// Note: Be cautious, as it does not correctly recognize comments or lexical tokens.
std::string loopUnrolling(const std::string& code) {
std::string loopUnrolling(const std::string& code, int threshold = 32) {
// This regex pattern matches a for loop with the following structure:
// for (var <varName>: u32 = <start>; <varName> < <end>; <varName>++) { <loopBody> }
std::regex forLoopPattern(R"(for\s*\(\s*var\s+(\w+):\s*u32\s*=\s*(\d+)\s*;\s*\1\s*<\s*(\d+)\s*;\s*\1\+\+\s*\)\s*\{\s*([^{}]*)\})");
Expand Down Expand Up @@ -52,17 +52,28 @@ std::string loopUnrolling(const std::string& code) {
int start = std::stoi(match[2]);
int end = std::stoi(match[3]);
std::string loopBody = match[4];
LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str());

std::string unrolledLoop;
for (int i = start; i < end; ++i) {
std::string unrolledIteration = loopBody;
std::regex varPattern(varName);
unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i));
unrolledLoop += unrolledIteration;
if (end - start > threshold ) {
std::string skippedLoop =
"for (var " +
std::string(match[1]) + ": u32 = " + std::string(match[2]) + ";"+
std::string(match[1]) + " < " + std::string(match[3]) + ";"+
std::string(match[1]) + "++) /* Skipped */ {"+
std::string(match[4]) +
"}";
LOG(kDefLog, kInfo, "Roll loop:%s", skippedLoop.c_str());
unrolledCode = unrolledCode.substr(0, match.position()) + skippedLoop + unrolledCode.substr(match.position() + match.length());
} else {
LOG(kDefLog, kInfo, "Unroll loop(var: %s, start:%d, end:%d, body:%s)", varName.c_str(), start, end, loopBody.c_str());
std::string unrolledLoop;
for (int i = start; i < end; ++i) {
std::string unrolledIteration = loopBody;
std::regex varPattern(varName);
unrolledIteration = std::regex_replace(unrolledIteration, varPattern, std::to_string(i));
unrolledLoop += unrolledIteration;
}
unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length());
}

unrolledCode = unrolledCode.substr(0, match.position()) + unrolledLoop + unrolledCode.substr(match.position() + match.length());
}

return unrolledCode;
Expand Down

0 comments on commit 91b572d

Please sign in to comment.