Skip to content

Commit

Permalink
1D blocktiling matmul works, set as the default
Browse files Browse the repository at this point in the history
  • Loading branch information
austinvhuang committed Jul 9, 2024
1 parent a256079 commit dda3934
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 68 deletions.
160 changes: 97 additions & 63 deletions examples/matmul/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,11 @@ inline ShaderCode createMatmul1(const char *shaderTemplate, const size_t M,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
std::string codeString(shaderTemplate);

replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
{"{{precision}}", toString(precision)},
{"{{M}}", toString(M)},
{"{{K}}", toString(K)},
{"{{N}}", toString(N)}});

return ShaderCode{codeString, workgroupSize};
}

Expand Down Expand Up @@ -130,74 +128,66 @@ inline ShaderCode createMatmul2(const char *shaderTemplate, const size_t M,
*
*/
static const char *kShaderMatmul3 = R"(
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
@group(0) @binding(0) var<storage, read_write> a: array<{{precision}}>;
@group(0) @binding(1) var<storage, read_write> b: array<{{precision}}>;
@group(0) @binding(2) var<storage, read_write> c: array<{{precision}}>;
var<workgroup> tileA: array<{{precision}}, {{BM}} * {{BK}}>;
var<workgroup> tileB: array<{{precision}}, {{BK}} * {{BN}}>;
var<workgroup> tileB: array<{{precision}}, {{BN}} * {{BK}}>;
@compute @workgroup_size({{workgroupSize}})
fn main(
@builtin(global_invocation_id) globalID : vec3<u32>,
@builtin(local_invocation_id) localID : vec3<u32>,
@builtin(local_invocation_index) localIdx : u32,
@builtin(workgroup_id) groupID : vec3<u32>) {
@builtin(workgroup_id) groupid : vec3<u32>) {
var threadResults: array<{{precision}}, {{TM}}>;
let cRow: u32 = groupID.x;
let cCol: u32 = groupID.y;
let cRow: u32 = groupid.x;
let cCol: u32 = groupid.y;
// Position of the first C element computed by the thread
let threadRow: u32 = localID.x / {{BN}};
// position of the first c element computed by the thread
let threadRow: u32 = localID.x / {{BN}} * {{TM}};
let threadCol: u32 = localID.x % {{BN}};
// Value of A to cache in As
// value of a to cache in as
// value of b to cache in bs (b is stored as b^t)
// Both tiles are of width BK
let loadColA = localID.x % {{BK}};
let loadRowA = localID.x / {{BK}};
let loadColB = loadColA;
let loadRowB = loadRowA;
// Value of B to cache in Bs (B is stored as B^T)
let loadColB = localID.x % {{BK}};
let loadRowB = localID.x / {{BK}};
// aPtr and bPtr are the starting positions of the tiles in A and B,
// incremented in the bkIdx loop.
// cPtr is the starting position of the tile in C which is fixed.
// aPtr and bPtr are the starting positions of the tiles in a and b,
// incremented in the bkidx loop.
// cPtr is the starting position of the tile in c which is fixed.
var aPtr = cRow * {{BM}} * {{K}};
var bPtr = (cCol * {{BN}}) // cCol corresponds to the row in B^T
* {{K}}; // K columns per row (column-major)
var cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}};
var bPtr = (cCol * {{BN}}) * {{K}};
let cPtr = cRow * {{BM}} * {{N}} + cCol * {{BN}};
for (var bkIdx = 0; bkIdx < {{K}}; bkIdx += {{BK}}) {
tileA[loadRowA * {{BK}} + loadColA] = A[aPtr + loadRowA * {{K}} + loadColA];
tileB[loadRowB * {{BK}} + loadColB] = B[bPtr + loadRowB * {{K}} + loadColB];
for (var bkidx = 0; bkidx < {{K}}; bkidx += {{BK}}) {
// Load tile
tileA[loadRowA * {{BK}} + loadColA] = a[aPtr + loadRowA * {{K}} + loadColA];
tileB[loadRowB * {{BK}} + loadColB] = b[bPtr + loadRowB * {{K}} + loadColB];
aPtr += {{BK}};
bPtr += {{BK}};
workgroupBarrier();
// Compute tile
for (var dotIdx: u32 = 0; dotIdx < {{BK}}; dotIdx = dotIdx + 1) {
let tmp = tileB[threadCol * {{BK}} + dotIdx];
for (var resIdx: u32 = 0; resIdx < {{TM}}; resIdx = resIdx + 1) {
let mask = {{precision}}(threadRow * {{TM}} + resIdx < {{BM}}
&& threadCol < {{BN}}
&& threadRow * {{TM}} + resIdx < {{M}}
&& cCol * {{BN}} + threadCol < {{N}}
&& cRow * {{BM}} + threadRow < {{M}}
);
threadResults[resIdx] += mask * tileA[(threadRow * {{TM}} + resIdx) * {{BK}} + dotIdx] * tmp;
for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) {
threadResults[residx] += tileA[(threadRow + residx) * {{BK}} + dotIdx] * tmp;
}
}
workgroupBarrier();
}
for (var resIdx: u32 = 0; resIdx < {{TM}}; resIdx = resIdx + 1) {
C[cPtr + (threadRow * {{TM}} + resIdx) * {{N}} + threadCol] = threadResults[resIdx];
for (var residx: u32 = 0; residx < {{TM}}; residx = residx + 1) {
c[cPtr + (threadRow + residx) * {{N}} + threadCol] = threadResults[residx];
}
}
Expand All @@ -209,6 +199,13 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M,
const size_t TM,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
assert(BM % TM == 0);
assert(K % BK == 0);
assert(M % BM == 0);
assert(N % BN == 0);
// # threads = tile A size == tile B size == # threads for computing C
assert(/* tile A size */ BM * BK == /* tile B size */ BK * BN);
assert(/* tile A size */ BM * BK == /* # of threads for C */ BM * BN / TM);
std::string codeString(shaderTemplate);
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
{"{{precision}}", toString(precision)},
Expand All @@ -222,11 +219,35 @@ inline ShaderCode createMatmul3(const char *shaderTemplate, const size_t M,
return ShaderCode{codeString, workgroupSize};
}

/**
* @brief No-Op shader with matmul bindings for performance testing
*/
static const char *kShaderNoOp = R"(
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
@group(0) @binding(2) var<storage, read_write> C: array<{{precision}}>;
@compute @workgroup_size({{workgroupSize}})
fn main(
@builtin(global_invocation_id) globalID : vec3<u32>) {
}
)";

inline ShaderCode createNoOp(const char *shaderTemplate,
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32) {
std::string codeString(shaderTemplate);
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
{"{{precision}}", toString(precision)}});
return ShaderCode{codeString, workgroupSize};
}

void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
std::unique_ptr<float[]> &weightsPtr) {
std::mt19937 gen(314159);
randn(inputPtr.get(), M * K, gen);
randn(weightsPtr.get(), N * K, gen);
// randint(inputPtr.get(), M * K, gen, 1, 2);
// randint(weightsPtr.get(), N * K, gen, 1, 2);
LOG(kDefLog, kInfo, "%s", show<float>(inputPtr.get(), M, K, "Input").c_str());
LOG(kDefLog, kInfo, "%s",
show<float>(weightsPtr.get(), N, K, "Weights").c_str());
Expand All @@ -239,10 +260,11 @@ void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
std::unique_ptr<float[]> outputRefPtr = std::make_unique<float[]>(M * N);
ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
nullptr, 1, M, K, N);
// LOG(kDefLog, kInfo, "Reference Output: %s",
// show<float>(outputRefPtr.get(), M, N, "Output (Reference)").c_str());
LOG(kDefLog, kInfo, "Reference Output: %s",
show<float>(outputRefPtr.get(), M, N, "Output (Reference)").c_str());
LOG(kDefLog, kInfo,
isclose(outputPtr.get(), outputRefPtr.get(), M * N) ? "PASS" : "FAIL");
isclose(outputPtr.get(), outputRefPtr.get(), M * N) ? "CPU Check: PASS"
: "CPU Check: FAIL");
}

void runTest(int version, size_t M, size_t K, size_t N,
Expand Down Expand Up @@ -275,26 +297,37 @@ void runTest(int version, size_t M, size_t K, size_t N,
createKernel(ctx, matmul, Bindings{input, weights, output},
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
} else if (version == 3) {
// TODO(avh): fails for larger block dimensions
static constexpr size_t BM = 4; // 32;
static constexpr size_t BK = 4; // 8;
static constexpr size_t BN = 4; // 32;
static constexpr size_t TM = 1; // 8;
// BM * BN values per workgroup, TM rows per thread => BM * BN / TM threads
static constexpr size_t BM = 64;
static constexpr size_t BK = 4;
static constexpr size_t BN = BM;
static constexpr size_t TM =
BN / BK; // BM * BN / TM == BM * BK, therefore TM == BN / BK

Shape wgSize = {BM * BN / TM, 1,
1}; // BM * BN values per workgroup, TM values per thread
Shape nWorkgroups = {cdiv(M, BM), cdiv(N, BN), 1};
LOG(kDefLog, kInfo, "M: %d, K: %d, N: %d", M, K, N);
LOG(kDefLog, kInfo, "BM: %d, BK: %d, BN: %d, TM: %d", BM, BK, BN, TM);
LOG(kDefLog, kInfo, "wgSize: ( %s )", toString(wgSize).c_str());
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
ShaderCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
/*wgSize*/ {BM * BN / TM, 1, 1});
kernel =
createKernel(ctx, matmul, Bindings{input, weights, output},
/*nWorkgroups*/ {cdiv(cdiv(M, BM), TM), cdiv(N, BN), 1});
// /*nWorkgroups*/ cdiv({M, N, 1}, {BM, BN, 1}));
/*wgSize*/ wgSize);
kernel = createKernel(ctx, matmul, Bindings{input, weights, output},
/*nWorkgroups*/ nWorkgroups);
} else if (version == 4) {
Shape wgSize = {256, 1, 1};
Shape nWorkgroups = cdiv({M, N, 1}, {16, 16, 1});
ShaderCode matmul = createNoOp(kShaderNoOp, /*wgsize*/ wgSize);
kernel = createKernel(ctx, matmul, Bindings{input, weights, output},
/*nWorkgroups*/ nWorkgroups);
}

// Dispatch kernel execution
LOG(kDefLog, kInfo, "Dispatching + waiting");

// pre-allocate promises and futures for async dispatch
// TODO(avh): implement a pooling mechanism for promises/futures in gpu.h
constexpr size_t nIter = 4;
constexpr size_t nIter = 10;
std::array<std::promise<void>, nIter> promises;
std::array<std::future<void>, nIter> futures;
for (int i = 0; i < nIter; i++) {
Expand Down Expand Up @@ -329,13 +362,17 @@ void runTest(int version, size_t M, size_t K, size_t N,
}

int main() {
static constexpr int kTestSize = 1;
size_t M, K, N;
int version = 3; // 1 == naive matmul
// 2 == tiling
// 3 == 1D blocktiling
// 4 == No-Op
size_t M, K, N; // Matrix dimensions
static constexpr int kTestSize = 2;
if constexpr (kTestSize == 0) {
// Tiny test
M = 16;
K = 4;
N = 8;
M = 32;
K = 32;
N = 32;
} else if constexpr (kTestSize == 1) {
// Small test
M = 256;
Expand All @@ -347,9 +384,6 @@ int main() {
K = 4096;
N = 2 * 4096;
}
int version = 3; // 1 == naive
// 2 == tiling
// 3 == 1D blocktiling (WIP)

std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * K);
std::unique_ptr<float[]> weightsPtr = std::make_unique<float[]>(N * K);
Expand All @@ -358,7 +392,7 @@ int main() {
initData(M, K, N, inputPtr, weightsPtr);
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);

if constexpr (kTestSize <= 0) {
if constexpr (kTestSize <= 1) {
// Check result with CPU reference implementation for tiny/small tests
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
}
Expand Down
3 changes: 0 additions & 3 deletions gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -828,7 +828,6 @@ inline Kernel createKernel(Context &ctx, const ShaderCode &shader,
op.buffers = std::make_unique<WGPUBuffer[]>(numBindings);
op.bufferSizes = std::make_unique<size_t[]>(numBindings);
op.numBindings = numBindings;
LOG(kDefLog, kInfo, "Create the bind group layout");
std::vector<WGPUBindGroupLayoutEntry> bgLayoutEntries(numBindings);
// Create layout entries for input buffers
for (size_t i = 0; i < numTensors; ++i) {
Expand All @@ -855,7 +854,6 @@ inline Kernel createKernel(Context &ctx, const ShaderCode &shader,
},
};
}
LOG(kDefLog, kInfo, "Create the bind group layout descriptor");
WGPUBindGroupLayoutDescriptor bgLayoutDesc = {
.entryCount = static_cast<uint32_t>(bgLayoutEntries.size()),
.entries = bgLayoutEntries.data(),
Expand All @@ -881,7 +879,6 @@ inline Kernel createKernel(Context &ctx, const ShaderCode &shader,
} else {
LOG(kDefLog, kInfo, "No params buffer needed");
}
LOG(kDefLog, kInfo, "Create the bind group");
std::vector<WGPUBindGroupEntry> bindGroupEntries(numBindings);
for (size_t i = 0; i < numTensors; ++i) {
bindGroupEntries[i] = WGPUBindGroupEntry{
Expand Down
4 changes: 2 additions & 2 deletions utils/array_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
namespace gpu {

static constexpr int kShowMaxRows = 32;
static constexpr int kShowMaxCols = 8;
static constexpr int kShowMaxCols = 10;

template <typename numtype>
std::string show(const numtype *a, size_t rows, size_t cols,
Expand Down Expand Up @@ -68,7 +68,7 @@ std::string show(const numtype *a, size_t rows, size_t cols,
a[i * cols + j] == 0.0) {
sprintf(buffer, "%8.2f", a[i * cols + j]);
} else
sprintf(buffer, "%8.2e", a[i * cols + j]);
sprintf(buffer, "%10.2e", a[i * cols + j]);
} else {
throw std::runtime_error("Unsupported number type for show()");
}
Expand Down

0 comments on commit dda3934

Please sign in to comment.