Skip to content

Commit

Permalink
Merge pull request #39 from junjihashimoto/feature/matmul-f16
Browse files Browse the repository at this point in the history
Add matmul with float16
  • Loading branch information
austinvhuang authored Aug 8, 2024
2 parents 305c25a + 23dd96e commit 228084f
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 77 deletions.
216 changes: 143 additions & 73 deletions examples/matmul/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,35 @@
#include "utils/array_utils.h" // show, isclose, randn, randint
#include "utils/logging.h" // LOG
#include "experimental/wgsl.h" // loopUnrolling
#include "numeric_types/half.h"

using namespace gpu;

const std::string versionToStr(int version);

void matmulf16_forward_cpu(half* out,
const half* inp, const half* weight, const half* bias,
int B, int T, int C, int OC) {
// OC is short for "output channels"
// inp is (B,T,C), weight is (OC, C)
// out will be (B,T,OC)
#pragma omp parallel for collapse(2)
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
half* out_bt = out + b * T * OC + t * OC;
const half* inp_bt = inp + b * T * C + t * C;
for (int o = 0; o < OC; o++) {
float val = (bias != NULL) ? halfToFloat(bias[o]) : 0.0f;
const half* wrow = weight + o*C;
for (int i = 0; i < C; i++) {
val += halfToFloat(inp_bt[i]) * halfToFloat(wrow[i]);
}
out_bt[o] = val;
}
}
}
}

static const char *kShaderMatmul1 = R"(
@group(0) @binding(0) var<storage, read_write> A: array<{{precision}}>;
@group(0) @binding(1) var<storage, read_write> B: array<{{precision}}>;
Expand Down Expand Up @@ -47,7 +71,7 @@ inline KernelCode createMatmul1(const char *shaderTemplate, const size_t M,
{"{{M}}", toString(M)},
{"{{K}}", toString(K)},
{"{{N}}", toString(N)}});
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}

// Shared memory cache-blocking
Expand Down Expand Up @@ -108,7 +132,7 @@ inline KernelCode createMatmul2(const char *shaderTemplate, const size_t M,
{"{{N}}", toString(N)},
{"{{tileSize}}",
toString(static_cast<size_t>(sqrt(workgroupSize[0])))}});
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}

/* 1D block-tiling
Expand Down Expand Up @@ -224,9 +248,9 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M,
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return {unrolledCode, workgroupSize};
return {unrolledCode, workgroupSize, precision};
} else {
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}
}

Expand Down Expand Up @@ -340,9 +364,9 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M,
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return {unrolledCode, workgroupSize};
return {unrolledCode, workgroupSize, precision};
} else {
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}
}

Expand Down Expand Up @@ -462,9 +486,9 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons
if (unrolling) {
std::string unrolledCode = loopUnrolling(codeString);
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return {unrolledCode, workgroupSize};
return {unrolledCode, workgroupSize, precision};
} else {
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}
}

Expand Down Expand Up @@ -582,7 +606,7 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si
});
std::string unrolledCode = loopUnrolling(codeString);
// LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str());
return {unrolledCode, workgroupSize};
return {unrolledCode, workgroupSize, precision};
}

/**
Expand All @@ -604,7 +628,7 @@ inline KernelCode createNoOp(const char *shaderTemplate,
std::string codeString(shaderTemplate);
replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)},
{"{{precision}}", toString(precision)}});
return {codeString, workgroupSize};
return {codeString, workgroupSize, precision};
}

void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
Expand All @@ -619,23 +643,41 @@ void initData(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
show<float>(weightsPtr.get(), N, K, "Weights").c_str());
}

void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr<float[]> &inputPtr,
std::unique_ptr<float[]> &weightsPtr,
std::unique_ptr<float[]> &outputPtr) {
void initData(size_t M, size_t K, size_t N, std::unique_ptr<half[]> &inputPtr,
std::unique_ptr<half[]> &weightsPtr) {
std::mt19937 gen(314159);
randn(inputPtr.get(), M * K, gen);
randn(weightsPtr.get(), N * K, gen);
// randint(inputPtr.get(), M * K, gen, 1, 2);
// randint(weightsPtr.get(), N * K, gen, 1, 2);
LOG(kDefLog, kInfo, "%s", show<half>(inputPtr.get(), M, K, "Input").c_str());
LOG(kDefLog, kInfo, "%s",
show<half>(weightsPtr.get(), N, K, "Weights").c_str());
}

template<class precision=float>
void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr<precision[]> &inputPtr,
std::unique_ptr<precision[]> &weightsPtr,
std::unique_ptr<precision[]> &outputPtr) {
LOG(kDefLog, kInfo, "Computing CPU reference implementation");
std::unique_ptr<float[]> outputRefPtr = std::make_unique<float[]>(M * N);
ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
nullptr, 1, M, K, N);
std::unique_ptr<precision[]> outputRefPtr = std::make_unique<precision[]>(M * N);
if constexpr (std::is_same<precision, float>::value) {
ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
nullptr, 1, M, K, N);
} else if constexpr (std::is_same<precision, half>::value) {
matmulf16_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(),
nullptr, 1, M, K, N);
}
LOG(kDefLog, kInfo, "Reference Output: %s",
show<float>(outputRefPtr.get(), M, N, "Output (Reference)").c_str());
show<precision>(outputRefPtr.get(), M, N, "Output (Reference)").c_str());
LOG(kDefLog, kInfo,
isclose(outputPtr.get(), outputRefPtr.get(), M * N) ? "CPU Check: PASS"
: "CPU Check: FAIL");
}

Kernel selectMatmul(Context &ctx, int version,
const Bindings</* input, weights, output */ 3> &bindings,
size_t M, size_t K, size_t N) {
size_t M, size_t K, size_t N, NumType numtype) {
Kernel kernel;
if (version == 1) {
Shape wgSize = {256, 1, 1};
Expand All @@ -647,13 +689,13 @@ Kernel selectMatmul(Context &ctx, int version,
Shape wgSize = {16, 16, 1};
LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str());
KernelCode matmul =
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize);
createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ cdiv({M, N, 1}, wgSize));
} else if (version == 3) {
static constexpr size_t tileSize = 16;
KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N,
/*wgSize*/ {tileSize * tileSize, 1, 1});
/*wgSize*/ {tileSize * tileSize, 1, 1}, numtype);
kernel =
createKernel(ctx, matmul, bindings,
/* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1}));
Expand All @@ -672,7 +714,7 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM,
/*wgSize*/ wgSize,
kf32,
numtype,
/*Loop unrolling*/ version == 6 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
Expand All @@ -690,11 +732,11 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN,
/*wgSize*/ wgSize,
kf32,
numtype,
/*Loop unrolling*/ version == 7 ? true: false);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
} else if (version == 8) {
} else if (version == 8 || version == 10) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
static constexpr size_t BN = 64;
Expand All @@ -708,11 +750,11 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmulWithVectorization(kShaderMatmulWithVectorization, M, K, N, BM, BK, BN, TM, TN,
/*wgSize*/ wgSize,
kf32,
numtype,
/*Loop unrolling*/ true);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
} else if (version == 9) {
} else if (version == 9 || version == 11) {
static constexpr size_t BM = 64;
static constexpr size_t BK = 8;
static constexpr size_t BN = 64;
Expand All @@ -726,23 +768,36 @@ Kernel selectMatmul(Context &ctx, int version,
LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str());
KernelCode matmul = createMatmulWithTranspose(kShaderMatmulWithTranspose, M, K, N, BM, BK, BN, TM, TN,
/*wgSize*/ wgSize,
kf32);
numtype);
kernel = createKernel(ctx, matmul, bindings,
/*nWorkgroups*/ nWorkgroups);
}
return kernel;
}

template<class precision=float>
void runTest(int version, size_t M, size_t K, size_t N,
std::unique_ptr<float[]> &inputPtr,
std::unique_ptr<float[]> &weightsPtr,
std::unique_ptr<float[]> &outputPtr) {
std::unique_ptr<precision[]> &inputPtr,
std::unique_ptr<precision[]> &weightsPtr,
std::unique_ptr<precision[]> &outputPtr,
NumType numtype) {
if constexpr (std::is_same<precision, float>::value) {
assert(numtype == kf32);
} else if constexpr (std::is_same<precision, half>::value) {
assert(numtype == kf16);
}

// Allocate GPU buffers and copy data
Context ctx = createContext();
Tensor input = createTensor(ctx, Shape{M, K}, kf32, inputPtr.get());
Tensor weights =
createTensor(ctx, Shape{N, K}, kf32, weightsPtr.get()); // column-major
Context ctx = createContext(
{}, {},
/*device descriptor, enabling f16 in WGSL*/
{
.requiredFeatureCount = 1,
.requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data(),
});

Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get());
Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major

constexpr size_t nIter = 30;

Expand All @@ -756,8 +811,8 @@ void runTest(int version, size_t M, size_t K, size_t N,
std::array<Tensor, nIter> outputs;
for (int i = 0; i < nIter; i++) {
futures[i] = promises[i].get_future();
outputs[i] = createTensor(ctx, Shape{M, N}, kf32);
kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N);
outputs[i] = createTensor(ctx, Shape{M, N}, numtype);
kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N, numtype);
}

printf("[ Press enter to start tests ... ]\n");
Expand Down Expand Up @@ -785,9 +840,9 @@ void runTest(int version, size_t M, size_t K, size_t N,
1000000000.0 * static_cast<float>(nIter);

LOG(kDefLog, kInfo, "Copying result to CPU");
toCPU(ctx, outputs[0], outputPtr.get(), M * N * sizeof(float));
toCPU(ctx, outputs[0], outputPtr.get(), M * N * sizeof(precision));
LOG(kDefLog, kInfo, "%s",
show<float>(outputPtr.get(), M, N, "Output[0]").c_str());
show<precision>(outputPtr.get(), M, N, "Output[0]").c_str());

LOG(kDefLog, kInfo, "\n\n===================================================================="
"============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations "
Expand All @@ -798,33 +853,62 @@ void runTest(int version, size_t M, size_t K, size_t N,
M, K, N, nIter, duration.count() / static_cast<double>(nIter) / 1000.0 /* us -> ms */, gflops);
}

template<class precision=float>
void runTestWithCheck(int version, size_t M, size_t K, size_t N,
bool transposedInput, int kTestSize, NumType numtype) {
std::unique_ptr<precision[]> inputPtr = std::make_unique<precision[]>(M * K);
std::unique_ptr<precision[]> weightsPtr = std::make_unique<precision[]>(N * K);
std::unique_ptr<precision[]> outputPtr = std::make_unique<precision[]>(M * N);

initData(M, K, N, inputPtr, weightsPtr);
if (transposedInput) {
std::unique_ptr<precision[]> transposedWeightPtr = std::make_unique<precision[]>(K * N);
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr, numtype);
} else {
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr, numtype);
}

if (kTestSize <= 1) {
// Check result with CPU reference implementation for tiny/small tests
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
}
}

const std::string versionToStr(int version){
switch (version) {
case 1: return "No-Op";
case 2: return "naive matmul";
case 3: return "tiling";
case 4: return "1D blocktiling";
case 5: return "2D blocktiling";
case 6: return "1D blocktiling with loop unrolling";
case 7: return "2D blocktiling with loop unrolling";
case 8: return "2D blocktiling with loop unrolling and vectorization";
case 9: return "2D blocktiling with loop unrolling, vectorization and transpose";
case 1: return "f32: No-Op";
case 2: return "f32: naive matmul";
case 3: return "f32: tiling";
case 4: return "f32: 1D blocktiling";
case 5: return "f32: 2D blocktiling";
case 6: return "f32: 1D blocktiling with loop unrolling";
case 7: return "f32: 2D blocktiling with loop unrolling";
case 8: return "f32: 2D blocktiling with loop unrolling and vectorization";
case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose";
case 10: return "f16: 2D blocktiling with loop unrolling and vectorization";
case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose";
default: return "Not specified";
}
}

int main() {
char* version_str = getenv("MATMUL_VERSION");
int version = version_str == NULL ? 9 : atoi(version_str);
// 1 == No-Op
// 2 == naive matmul
// 3 == tiling
// 4 == 1D blocktiling
// 5 == 2D blocktiling
// 6 == 1D blocktiling with loop unrolling
// 7 == 2D blocktiling with loop unrolling
// 8 == 2D blocktiling with loop unrolling and vectorization
// 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default)
int version = version_str == NULL ? 10 : atoi(version_str);
// 1 == f32: No-Op
// 2 == f32: naive matmul
// 3 == f32: tiling
// 4 == f32: 1D blocktiling
// 5 == f32: 2D blocktiling
// 6 == f32: 1D blocktiling with loop unrolling
// 7 == f32: 2D blocktiling with loop unrolling
// 8 == f32: 2D blocktiling with loop unrolling and vectorization
// 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose
// 10 == f16: 2D blocktiling with loop unrolling and vectorization (default)
// 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose
bool enableF16 = version == 10 || version ==11;
bool transposedInput = version == 9 || version == 11;
NumType numtype = enableF16 ? kf16 : kf32;

size_t M, K, N; // Matrix dimensions
char* kTestSize_str = getenv("MATMUL_SIZE");
Expand All @@ -846,24 +930,10 @@ int main() {
N = 2 * 4096;
}

std::unique_ptr<float[]> inputPtr = std::make_unique<float[]>(M * K);
std::unique_ptr<float[]> weightsPtr = std::make_unique<float[]>(N * K);
std::unique_ptr<float[]> outputPtr = std::make_unique<float[]>(M * N);
bool transposedInput = version == 9;

initData(M, K, N, inputPtr, weightsPtr);
if (transposedInput) {
std::unique_ptr<float[]> transposedWeightPtr = std::make_unique<float[]>(K * N);
transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K);
runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr);
if (enableF16) {
runTestWithCheck<half>(version, M, K, N, transposedInput, kTestSize, numtype);
} else {
runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr);
}


if (kTestSize <= 1) {
// Check result with CPU reference implementation for tiny/small tests
checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr);
runTestWithCheck<float>(version, M, K, N, transposedInput, kTestSize, numtype);
}

LOG(kDefLog, kInfo, "Done.");
Expand Down
3 changes: 3 additions & 0 deletions gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ struct KernelCode {
const Shape &workgroupSize = {256, 1, 1},
NumType precision = kf32)
: data(pData), workgroupSize(workgroupSize), precision(precision) {
if (precision == kf16) {
data = "enable f16;\n" + data;
}
replaceAll(data, "{{workgroupSize}}", toString(workgroupSize));
replaceAll(data, "{{precision}}", toString(precision));
LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str());
Expand Down
Loading

0 comments on commit 228084f

Please sign in to comment.