diff --git a/examples/matmul/run.cpp b/examples/matmul/run.cpp index 59cb6ba..06e9297 100644 --- a/examples/matmul/run.cpp +++ b/examples/matmul/run.cpp @@ -11,11 +11,35 @@ #include "utils/array_utils.h" // show, isclose, randn, randint #include "utils/logging.h" // LOG #include "experimental/wgsl.h" // loopUnrolling +#include "numeric_types/half.h" using namespace gpu; const std::string versionToStr(int version); +void matmulf16_forward_cpu(half* out, + const half* inp, const half* weight, const half* bias, + int B, int T, int C, int OC) { + // OC is short for "output channels" + // inp is (B,T,C), weight is (OC, C) + // out will be (B,T,OC) +#pragma omp parallel for collapse(2) + for (int b = 0; b < B; b++) { + for (int t = 0; t < T; t++) { + half* out_bt = out + b * T * OC + t * OC; + const half* inp_bt = inp + b * T * C + t * C; + for (int o = 0; o < OC; o++) { + float val = (bias != NULL) ? halfToFloat(bias[o]) : 0.0f; + const half* wrow = weight + o*C; + for (int i = 0; i < C; i++) { + val += halfToFloat(inp_bt[i]) * halfToFloat(wrow[i]); + } + out_bt[o] = val; + } + } + } +} + static const char *kShaderMatmul1 = R"( @group(0) @binding(0) var A: array<{{precision}}>; @group(0) @binding(1) var B: array<{{precision}}>; @@ -47,7 +71,7 @@ inline KernelCode createMatmul1(const char *shaderTemplate, const size_t M, {"{{M}}", toString(M)}, {"{{K}}", toString(K)}, {"{{N}}", toString(N)}}); - return {codeString, workgroupSize}; + return {codeString, workgroupSize, precision}; } // Shared memory cache-blocking @@ -108,7 +132,7 @@ inline KernelCode createMatmul2(const char *shaderTemplate, const size_t M, {"{{N}}", toString(N)}, {"{{tileSize}}", toString(static_cast(sqrt(workgroupSize[0])))}}); - return {codeString, workgroupSize}; + return {codeString, workgroupSize, precision}; } /* 1D block-tiling @@ -224,9 +248,9 @@ inline KernelCode createMatmul3(const char *shaderTemplate, const size_t M, if (unrolling) { std::string unrolledCode = loopUnrolling(codeString); // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); - return {unrolledCode, workgroupSize}; + return {unrolledCode, workgroupSize, precision}; } else { - return {codeString, workgroupSize}; + return {codeString, workgroupSize, precision}; } } @@ -340,9 +364,9 @@ inline KernelCode createMatmul4(const char *shaderTemplate, const size_t M, if (unrolling) { std::string unrolledCode = loopUnrolling(codeString); // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); - return {unrolledCode, workgroupSize}; + return {unrolledCode, workgroupSize, precision}; } else { - return {codeString, workgroupSize}; + return {codeString, workgroupSize, precision}; } } @@ -462,9 +486,9 @@ inline KernelCode createMatmulWithVectorization(const char *shaderTemplate, cons if (unrolling) { std::string unrolledCode = loopUnrolling(codeString); // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); - return {unrolledCode, workgroupSize}; + return {unrolledCode, workgroupSize, precision}; } else { - return {codeString, workgroupSize}; + return {codeString, workgroupSize, precision}; } } @@ -582,7 +606,7 @@ inline KernelCode createMatmulWithTranspose(const char *shaderTemplate, const si }); std::string unrolledCode = loopUnrolling(codeString); // LOG(kDefLog, kInfo, "Unrolled code:\n%s", unrolledCode.c_str()); - return {unrolledCode, workgroupSize}; + return {unrolledCode, workgroupSize, precision}; } /** @@ -604,7 +628,7 @@ inline KernelCode createNoOp(const char *shaderTemplate, std::string codeString(shaderTemplate); replaceAll(codeString, {{"{{workgroupSize}}", toString(workgroupSize)}, {"{{precision}}", toString(precision)}}); - return {codeString, workgroupSize}; + return {codeString, workgroupSize, precision}; } void initData(size_t M, size_t K, size_t N, std::unique_ptr &inputPtr, @@ -619,15 +643,33 @@ void initData(size_t M, size_t K, size_t N, std::unique_ptr &inputPtr, show(weightsPtr.get(), N, K, "Weights").c_str()); } -void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr &inputPtr, - std::unique_ptr &weightsPtr, - std::unique_ptr &outputPtr) { +void initData(size_t M, size_t K, size_t N, std::unique_ptr &inputPtr, + std::unique_ptr &weightsPtr) { + std::mt19937 gen(314159); + randn(inputPtr.get(), M * K, gen); + randn(weightsPtr.get(), N * K, gen); + // randint(inputPtr.get(), M * K, gen, 1, 2); + // randint(weightsPtr.get(), N * K, gen, 1, 2); + LOG(kDefLog, kInfo, "%s", show(inputPtr.get(), M, K, "Input").c_str()); + LOG(kDefLog, kInfo, "%s", + show(weightsPtr.get(), N, K, "Weights").c_str()); +} + +template +void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr &inputPtr, + std::unique_ptr &weightsPtr, + std::unique_ptr &outputPtr) { LOG(kDefLog, kInfo, "Computing CPU reference implementation"); - std::unique_ptr outputRefPtr = std::make_unique(M * N); - ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(), - nullptr, 1, M, K, N); + std::unique_ptr outputRefPtr = std::make_unique(M * N); + if constexpr (std::is_same::value) { + ref::matmul_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(), + nullptr, 1, M, K, N); + } else if constexpr (std::is_same::value) { + matmulf16_forward_cpu(outputRefPtr.get(), inputPtr.get(), weightsPtr.get(), + nullptr, 1, M, K, N); + } LOG(kDefLog, kInfo, "Reference Output: %s", - show(outputRefPtr.get(), M, N, "Output (Reference)").c_str()); + show(outputRefPtr.get(), M, N, "Output (Reference)").c_str()); LOG(kDefLog, kInfo, isclose(outputPtr.get(), outputRefPtr.get(), M * N) ? "CPU Check: PASS" : "CPU Check: FAIL"); @@ -635,7 +677,7 @@ void checkCPU(size_t M, size_t K, size_t N, std::unique_ptr &inputPtr, Kernel selectMatmul(Context &ctx, int version, const Bindings &bindings, - size_t M, size_t K, size_t N) { + size_t M, size_t K, size_t N, NumType numtype) { Kernel kernel; if (version == 1) { Shape wgSize = {256, 1, 1}; @@ -647,13 +689,13 @@ Kernel selectMatmul(Context &ctx, int version, Shape wgSize = {16, 16, 1}; LOG(kDefLog, kInfo, "wgSize: %s", toString(wgSize).c_str()); KernelCode matmul = - createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize); + createMatmul1(kShaderMatmul1, M, K, N, /*wgsize*/ wgSize, numtype); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ cdiv({M, N, 1}, wgSize)); } else if (version == 3) { static constexpr size_t tileSize = 16; KernelCode matmul = createMatmul2(kShaderMatmul2, M, K, N, - /*wgSize*/ {tileSize * tileSize, 1, 1}); + /*wgSize*/ {tileSize * tileSize, 1, 1}, numtype); kernel = createKernel(ctx, matmul, bindings, /* nWorkgroups*/ cdiv({M, N, 1}, {tileSize, tileSize, 1})); @@ -672,7 +714,7 @@ Kernel selectMatmul(Context &ctx, int version, LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); KernelCode matmul = createMatmul3(kShaderMatmul3, M, K, N, BM, BK, BN, TM, /*wgSize*/ wgSize, - kf32, + numtype, /*Loop unrolling*/ version == 6 ? true: false); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); @@ -690,11 +732,11 @@ Kernel selectMatmul(Context &ctx, int version, LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); KernelCode matmul = createMatmul4(kShaderMatmul4, M, K, N, BM, BK, BN, TM, TN, /*wgSize*/ wgSize, - kf32, + numtype, /*Loop unrolling*/ version == 7 ? true: false); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); - } else if (version == 8) { + } else if (version == 8 || version == 10) { static constexpr size_t BM = 64; static constexpr size_t BK = 8; static constexpr size_t BN = 64; @@ -708,11 +750,11 @@ Kernel selectMatmul(Context &ctx, int version, LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); KernelCode matmul = createMatmulWithVectorization(kShaderMatmulWithVectorization, M, K, N, BM, BK, BN, TM, TN, /*wgSize*/ wgSize, - kf32, + numtype, /*Loop unrolling*/ true); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); - } else if (version == 9) { + } else if (version == 9 || version == 11) { static constexpr size_t BM = 64; static constexpr size_t BK = 8; static constexpr size_t BN = 64; @@ -726,23 +768,36 @@ Kernel selectMatmul(Context &ctx, int version, LOG(kDefLog, kInfo, "nWorkgroups: ( %s )", toString(nWorkgroups).c_str()); KernelCode matmul = createMatmulWithTranspose(kShaderMatmulWithTranspose, M, K, N, BM, BK, BN, TM, TN, /*wgSize*/ wgSize, - kf32); + numtype); kernel = createKernel(ctx, matmul, bindings, /*nWorkgroups*/ nWorkgroups); } return kernel; } +template void runTest(int version, size_t M, size_t K, size_t N, - std::unique_ptr &inputPtr, - std::unique_ptr &weightsPtr, - std::unique_ptr &outputPtr) { + std::unique_ptr &inputPtr, + std::unique_ptr &weightsPtr, + std::unique_ptr &outputPtr, + NumType numtype) { + if constexpr (std::is_same::value) { + assert(numtype == kf32); + } else if constexpr (std::is_same::value) { + assert(numtype == kf16); + } // Allocate GPU buffers and copy data - Context ctx = createContext(); - Tensor input = createTensor(ctx, Shape{M, K}, kf32, inputPtr.get()); - Tensor weights = - createTensor(ctx, Shape{N, K}, kf32, weightsPtr.get()); // column-major + Context ctx = createContext( + {}, {}, + /*device descriptor, enabling f16 in WGSL*/ + { + .requiredFeatureCount = 1, + .requiredFeatures = std::array{WGPUFeatureName_ShaderF16}.data(), + }); + + Tensor input = createTensor(ctx, Shape{M, K}, numtype, inputPtr.get()); + Tensor weights = createTensor(ctx, Shape{N, K}, numtype, weightsPtr.get()); // column-major constexpr size_t nIter = 30; @@ -756,8 +811,8 @@ void runTest(int version, size_t M, size_t K, size_t N, std::array outputs; for (int i = 0; i < nIter; i++) { futures[i] = promises[i].get_future(); - outputs[i] = createTensor(ctx, Shape{M, N}, kf32); - kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N); + outputs[i] = createTensor(ctx, Shape{M, N}, numtype); + kernels[i] = selectMatmul(ctx, version, {input, weights, outputs[i]}, M, K, N, numtype); } printf("[ Press enter to start tests ... ]\n"); @@ -785,9 +840,9 @@ void runTest(int version, size_t M, size_t K, size_t N, 1000000000.0 * static_cast(nIter); LOG(kDefLog, kInfo, "Copying result to CPU"); - toCPU(ctx, outputs[0], outputPtr.get(), M * N * sizeof(float)); + toCPU(ctx, outputs[0], outputPtr.get(), M * N * sizeof(precision)); LOG(kDefLog, kInfo, "%s", - show(outputPtr.get(), M, N, "Output[0]").c_str()); + show(outputPtr.get(), M, N, "Output[0]").c_str()); LOG(kDefLog, kInfo, "\n\n====================================================================" "============\nExecution Time: (M = %d, K = %d, N = %d) x %d iterations " @@ -798,33 +853,62 @@ void runTest(int version, size_t M, size_t K, size_t N, M, K, N, nIter, duration.count() / static_cast(nIter) / 1000.0 /* us -> ms */, gflops); } +template +void runTestWithCheck(int version, size_t M, size_t K, size_t N, + bool transposedInput, int kTestSize, NumType numtype) { + std::unique_ptr inputPtr = std::make_unique(M * K); + std::unique_ptr weightsPtr = std::make_unique(N * K); + std::unique_ptr outputPtr = std::make_unique(M * N); + + initData(M, K, N, inputPtr, weightsPtr); + if (transposedInput) { + std::unique_ptr transposedWeightPtr = std::make_unique(K * N); + transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K); + runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr, numtype); + } else { + runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr, numtype); + } + + if (kTestSize <= 1) { + // Check result with CPU reference implementation for tiny/small tests + checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr); + } +} + const std::string versionToStr(int version){ switch (version) { - case 1: return "No-Op"; - case 2: return "naive matmul"; - case 3: return "tiling"; - case 4: return "1D blocktiling"; - case 5: return "2D blocktiling"; - case 6: return "1D blocktiling with loop unrolling"; - case 7: return "2D blocktiling with loop unrolling"; - case 8: return "2D blocktiling with loop unrolling and vectorization"; - case 9: return "2D blocktiling with loop unrolling, vectorization and transpose"; + case 1: return "f32: No-Op"; + case 2: return "f32: naive matmul"; + case 3: return "f32: tiling"; + case 4: return "f32: 1D blocktiling"; + case 5: return "f32: 2D blocktiling"; + case 6: return "f32: 1D blocktiling with loop unrolling"; + case 7: return "f32: 2D blocktiling with loop unrolling"; + case 8: return "f32: 2D blocktiling with loop unrolling and vectorization"; + case 9: return "f32: 2D blocktiling with loop unrolling, vectorization and transpose"; + case 10: return "f16: 2D blocktiling with loop unrolling and vectorization"; + case 11: return "f16: 2D blocktiling with loop unrolling, vectorization and transpose"; default: return "Not specified"; } } int main() { char* version_str = getenv("MATMUL_VERSION"); - int version = version_str == NULL ? 9 : atoi(version_str); - // 1 == No-Op - // 2 == naive matmul - // 3 == tiling - // 4 == 1D blocktiling - // 5 == 2D blocktiling - // 6 == 1D blocktiling with loop unrolling - // 7 == 2D blocktiling with loop unrolling - // 8 == 2D blocktiling with loop unrolling and vectorization - // 9 == 2D blocktiling with loop unrolling, vectorization and transpose (default) + int version = version_str == NULL ? 10 : atoi(version_str); + // 1 == f32: No-Op + // 2 == f32: naive matmul + // 3 == f32: tiling + // 4 == f32: 1D blocktiling + // 5 == f32: 2D blocktiling + // 6 == f32: 1D blocktiling with loop unrolling + // 7 == f32: 2D blocktiling with loop unrolling + // 8 == f32: 2D blocktiling with loop unrolling and vectorization + // 9 == f32: 2D blocktiling with loop unrolling, vectorization and transpose + // 10 == f16: 2D blocktiling with loop unrolling and vectorization (default) + // 11 == f16: 2D blocktiling with loop unrolling, vectorization and transpose + bool enableF16 = version == 10 || version ==11; + bool transposedInput = version == 9 || version == 11; + NumType numtype = enableF16 ? kf16 : kf32; size_t M, K, N; // Matrix dimensions char* kTestSize_str = getenv("MATMUL_SIZE"); @@ -846,24 +930,10 @@ int main() { N = 2 * 4096; } - std::unique_ptr inputPtr = std::make_unique(M * K); - std::unique_ptr weightsPtr = std::make_unique(N * K); - std::unique_ptr outputPtr = std::make_unique(M * N); - bool transposedInput = version == 9; - - initData(M, K, N, inputPtr, weightsPtr); - if (transposedInput) { - std::unique_ptr transposedWeightPtr = std::make_unique(K * N); - transpose(weightsPtr.get(), transposedWeightPtr.get(), N, K); - runTest(version, M, K, N, inputPtr, transposedWeightPtr, outputPtr); + if (enableF16) { + runTestWithCheck(version, M, K, N, transposedInput, kTestSize, numtype); } else { - runTest(version, M, K, N, inputPtr, weightsPtr, outputPtr); - } - - - if (kTestSize <= 1) { - // Check result with CPU reference implementation for tiny/small tests - checkCPU(M, K, N, inputPtr, weightsPtr, outputPtr); + runTestWithCheck(version, M, K, N, transposedInput, kTestSize, numtype); } LOG(kDefLog, kInfo, "Done."); diff --git a/gpu.h b/gpu.h index 3ac4499..0636b0c 100644 --- a/gpu.h +++ b/gpu.h @@ -318,6 +318,9 @@ struct KernelCode { const Shape &workgroupSize = {256, 1, 1}, NumType precision = kf32) : data(pData), workgroupSize(workgroupSize), precision(precision) { + if (precision == kf16) { + data = "enable f16;\n" + data; + } replaceAll(data, "{{workgroupSize}}", toString(workgroupSize)); replaceAll(data, "{{precision}}", toString(precision)); LOG(kDefLog, kInfo, "Shader code:\n%s", data.c_str()); diff --git a/utils/array_utils.h b/utils/array_utils.h index 3b9faad..c199a8a 100644 --- a/utils/array_utils.h +++ b/utils/array_utils.h @@ -20,6 +20,7 @@ #include #include "utils/logging.h" +#include "numeric_types/half.h" namespace gpu { @@ -52,12 +53,14 @@ std::string show(const numtype *a, size_t rows, size_t cols, } // spacing as log10 of max value int spacing = 1; - numtype max = *std::max_element(a, a + rows * cols); if constexpr (std::is_same::value) { + int max = *std::max_element(a, a + rows * cols); spacing = std::max(0, (int)log10(max + .01)) + 2; } else if constexpr (std::is_same::value) { // spacing = std::max(0, (int)log10(max + .01)) + 1; spacing = 8; // scientific notation + } else if constexpr (std::is_same::value) { + spacing = 8; } else { throw std::runtime_error("Unsupported number type for show()"); } @@ -82,6 +85,14 @@ std::string show(const numtype *a, size_t rows, size_t cols, snprintf(buffer, 16, "%9.2f", a[i * cols + j]); } else snprintf(buffer, 16, "%10.2e", a[i * cols + j]); + } else if constexpr (std::is_same::value) { + float tmp = halfToFloat(a[i * cols + j]); + if (std::abs(tmp) < 1000 && + std::abs(tmp) > 0.01 || + tmp == 0.0) { + snprintf(buffer, 16, "%9.2f", tmp); + } else + snprintf(buffer, 16, "%10.2e", tmp); } else { throw std::runtime_error("Unsupported number type for show()"); } @@ -199,14 +210,22 @@ void randint(std::array &a, std::mt19937 &gen, int min = -1, * @param mean The mean of the Gaussian distribution. * @param std The standard deviation of the Gaussian distribution. */ -void randn(float *a, size_t N, std::mt19937 &gen, float mean = 0.0, - float std = 1.0) { +inline void randn(float *a, size_t N, std::mt19937 &gen, float mean = 0.0, + float std = 1.0) { std::normal_distribution dist(mean, std); for (int i = 0; i < N; i++) { a[i] = static_cast(dist(gen)); } } +inline void randn(half *a, size_t N, std::mt19937 &gen, float mean = 0.0, + float std = 1.0) { + std::normal_distribution dist(mean, std); + for (int i = 0; i < N; i++) { + a[i] = halfFromFloat(dist(gen)); + } +} + /** * @brief Overload of `randn()` for std::array. * @param a The array to populate. @@ -254,6 +273,14 @@ inline void transpose(float *input, float *output, size_t M, size_t N) { } } +inline void transpose(half *input, half *output, size_t M, size_t N) { + for (size_t i = 0; i < M; i++) { + for (size_t j = 0; j < N; j++) { + output[j * M + i] = input[i * N + j]; + } + } +} + /** * @brief Flip a matrix horizontally or vertically. * @param a The matrix to flip. @@ -285,7 +312,7 @@ inline void flip(float *a, size_t R, size_t C, bool horizontal = true) { * @param tol The tolerance for closeness. * @return bool True if the arrays are close, false otherwise. */ -bool isclose(float *a, float *b, size_t n, float tol = 1e-3) { +inline bool isclose(float *a, float *b, size_t n, float tol = 1e-3) { for (size_t i = 0; i < n; i++) { if (std::abs(a[i] - b[i]) > tol || std::isnan(a[i]) || std::isnan(b[i])) { LOG(kDefLog, kError, "Mismatch at index %d: %f != %f", i, a[i], b[i]); @@ -295,6 +322,18 @@ bool isclose(float *a, float *b, size_t n, float tol = 1e-3) { return true; } +inline bool isclose(half *a, half *b, size_t n, float tol = 1) { + for (size_t i = 0; i < n; i++) { + float ai = halfToFloat(a[i]); + float bi = halfToFloat(b[i]); + if (std::abs(ai - bi) > tol || std::isnan(ai) || std::isnan(bi)) { + LOG(kDefLog, kError, "Mismatch at index %d: %f != %f", i, ai, bi); + return false; + } + } + return true; +} + } // namespace gpu #endif