From 02f519b842fb65886ccf13280f5b535babadac90 Mon Sep 17 00:00:00 2001 From: "Zhong, Zhicong" Date: Sun, 30 Jun 2024 18:27:21 -0700 Subject: [PATCH] tune config --- lib/gc/Analysis/MatmulConfigAnalysis.cpp | 286 ++++++++++++++--------- lib/gc/CAPI/CMakeLists.txt | 1 + 2 files changed, 172 insertions(+), 115 deletions(-) diff --git a/lib/gc/Analysis/MatmulConfigAnalysis.cpp b/lib/gc/Analysis/MatmulConfigAnalysis.cpp index de2067566..d0147ee2e 100644 --- a/lib/gc/Analysis/MatmulConfigAnalysis.cpp +++ b/lib/gc/Analysis/MatmulConfigAnalysis.cpp @@ -6,6 +6,7 @@ // //===----------------------------------------------------------------------===// +#include #include #include "gc/Analysis/MatmulConfigAnalysis.h" @@ -15,8 +16,6 @@ namespace gc { #define DEBUG_TYPE "matmul-config-analysis" -#define MAX_THREADS (1024U * 1024U) - llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, const MatmulConfig &config) { @@ -29,19 +28,36 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, return ss; } -std::vector getCandidate(uint32_t num, uint32_t floor, - uint32_t ceil) { +template +llvm::raw_ostream &operator<<(llvm::raw_ostream &ss, std::vector arry) { + ss << "["; + for (auto [idx, a] : llvm::enumerate(arry)) { + if (idx != 0) { + ss << ", "; + } + ss << a; + } + ss << "]"; + return ss; +} + +std::vector +getCandidate(uint32_t num, uint32_t floor, + uint32_t ceil = std::numeric_limits::max()) { + // factor std::vector candidates; for (uint32_t i = 1; i <= num; i++) { if (num % i == 0 && i <= ceil && i >= floor) { candidates.push_back(i); } } + // the pow of 2 auto candidate = 1U; while (candidate < num && candidate <= ceil && candidate >= floor) { candidates.push_back(candidate); candidate *= 2; } + std::sort(candidates.begin(), candidates.end()); auto last = std::unique(candidates.begin(), candidates.end()); candidates.erase(last, candidates.end()); return candidates; @@ -53,15 +69,6 @@ bool isValidConfig(const MatmulConfig &config, SystemDesc &sysDesc, config.innerMostKBlock == 0) { return false; } - if (config.MBlock % config.innerMostMBlock != 0 || - config.NBlock % config.innerMostNBlock != 0 || - config.KBlock % config.innerMostKBlock != 0) { - return false; - } - auto threads = sysDesc.getNumThreads(); - if (config.MThreads * config.NThreads * config.KThreads != threads) { - return false; - } if (shape[0] % config.innerMostMBlock != 0 || shape[1] % config.innerMostNBlock != 0 || @@ -72,14 +79,13 @@ bool isValidConfig(const MatmulConfig &config, SystemDesc &sysDesc, return true; } -double threadUtilizationCost(linalg::LinalgOp &linalgOp, - ArrayRef shape, - const MatmulConfig &config, SystemDesc &sysDesc) { - auto threads = sysDesc.getNumThreads(); - auto actualThreads = - (float)(config.MThreads * config.NThreads * config.KThreads); - return threads >= actualThreads ? threads / actualThreads - : actualThreads / threads; +bool validateThreads(ArrayRef threads, SystemDesc &sysDesc) { + auto numThreads = sysDesc.getNumThreads(); + auto actualThreads = 1U; + for (auto t : threads) { + actualThreads *= t; + } + return actualThreads == numThreads; } double hardwareEfficiencyCost(linalg::LinalgOp &linalgOp, @@ -103,9 +109,21 @@ double hardwareEfficiencyCost(linalg::LinalgOp &linalgOp, double workloadBalancedCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, SystemDesc &sysDesc) { - return 1; + auto M = shape[0], N = shape[1], K = shape[2]; + auto MTaskNum = llvm::divideCeil(M, config.MBlock); + auto NTaskNum = llvm::divideCeil(N, config.NBlock); + auto KTaskNum = llvm::divideCeil(K, config.KBlock); + auto cost = (MTaskNum % config.MThreads) * 1.0 / MTaskNum + + (NTaskNum % config.NThreads) * 1.0 / NTaskNum + + (KTaskNum % config.KThreads) * 1.0 / KTaskNum; + if (MTaskNum < config.MThreads || NTaskNum < config.NThreads || + KTaskNum < config.KThreads) { + auto threadNotFulllyUtilizedPenalty = 10.0; + cost *= threadNotFulllyUtilizedPenalty; + } + return cost; } - +constexpr unsigned bitPerByte = 8; double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, @@ -113,30 +131,34 @@ double memoryConsumptionOnThreadCost(linalg::LinalgOp &linalgOp, auto M = shape[0], N = shape[1], K = shape[2]; auto dtypeSize = DataLayout().getTypeSizeInBits( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); - auto penalty = 2.0 * (dtypeSize / 8); + // if use K split, there will be one more final reduce and break the post + // fusion + + auto KSplitPenalty = 8.0 * (dtypeSize / bitPerByte); auto memoryConsumptionPerThread = M * K * 1.0 / config.MThreads / config.KThreads + K * N * 1.0 / config.KThreads / config.NThreads + - M * N * ((config.KThreads - 1) * penalty + 1.0) / config.MThreads / + M * N * ((config.KThreads - 1) * KSplitPenalty + 1.0) / config.MThreads / config.NThreads; return memoryConsumptionPerThread; } -double computationIntensityOnL1Cache(linalg::LinalgOp &linalgOp, +double computationIntensityOnL2Cache(linalg::LinalgOp &linalgOp, ArrayRef shape, const MatmulConfig &config, SystemDesc &sysDesc) { - auto L1Cache = sysDesc.getCacheSize(2); + double simulationPenalty = 0.7; + auto L2Cache = sysDesc.getCacheSize(2); auto dtypeSize = DataLayout().getTypeSizeInBits( ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType()); auto outOfCachePenalty = 1024; - double FLOPS = - 2.0 * config.innerMostMBlock * config.innerMostNBlock * config.KBlock; - double memoryConsumption = config.innerMostMBlock * config.innerMostNBlock + - config.innerMostNBlock * config.KBlock + - config.innerMostMBlock * config.KBlock; + double FLOPS = 2.0 * config.MBlock * config.NBlock * config.KBlock; + double memoryConsumption = config.MBlock * config.NBlock + + config.NBlock * config.KBlock + + config.MBlock * config.KBlock; double computationIntensity = FLOPS / memoryConsumption; - if (memoryConsumption * (dtypeSize / 8) > L1Cache) { + if (memoryConsumption * (dtypeSize / bitPerByte) > + L2Cache * simulationPenalty) { computationIntensity /= outOfCachePenalty; } return 1 / computationIntensity; @@ -149,7 +171,7 @@ using CostModelFn = std::vector filterConfigByCostModel(std::vector configs, linalg::LinalgOp &linalgOp, ArrayRef shape, - SystemDesc &sysDesc, const CostModelFn &costModel, + SystemDesc &sysDesc, CostModelFn costModel, float eliminationRatio = 0.5, float threshold = -1) { std::vector result; std::vector costs; @@ -169,13 +191,13 @@ filterConfigByCostModel(std::vector configs, result.push_back(configs[idx[i]]); } } - llvm::outs() << "thresholdCost is: " << thresholdCost + llvm::errs() << "thresholdCost is: " << thresholdCost << "\nbest with cost: " << costs[idx[0]] << "\n" << configs[idx[0]] << "\n worst with cost: " << costs[idx[configs.size() - 1]] << "\n" << configs[idx[configs.size() - 1]] << "\n"; - return !result.empty() ? result : configs; + return result.size() > 0 ? result : configs; } std::vector @@ -184,19 +206,25 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, ArrayRef givenInnermostBlock) { std::vector configs; auto threads = sysDesc.getNumThreads(); - auto MThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); - auto NThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); - auto KThreadsCandidates = getCandidate((uint32_t)threads, 1U, MAX_THREADS); - auto MBlockCandidates = - getCandidate((uint32_t)shape[0], 1U, (uint32_t)shape[0]); - auto NBlockCandidates = getCandidate((uint32_t)shape[1], 1U, shape[1]); - auto KBlockCandidates = getCandidate((uint32_t)shape[2], 1U, shape[2]); - auto innerMostMBlockCandidates = - getCandidate((uint32_t)shape[0], 1U, (uint32_t)shape[0]); - auto innerMostNBlockCandidates = - getCandidate((uint32_t)shape[1], 1U, (uint32_t)shape[1]); - auto innerMostKBlockCandidates = - getCandidate((uint32_t)shape[2], 1U, (uint32_t)shape[2]); + auto MThreadsCandidates = getCandidate((uint32_t)threads, 1U); + auto NThreadsCandidates = getCandidate((uint32_t)threads, 1U); + auto KThreadsCandidates = getCandidate((uint32_t)threads, 1U); + auto noSmallBlockNeedThreshold = 8 * 8U; + auto MBlockCandidates = getCandidate( + (uint32_t)shape[0], shape[0] > noSmallBlockNeedThreshold ? 8U : 1U, + (uint32_t)shape[0]); + auto NBlockCandidates = + getCandidate((uint32_t)shape[1], + shape[1] > noSmallBlockNeedThreshold ? 8U : 1U, shape[1]); + auto KBlockCandidates = + getCandidate((uint32_t)shape[2], + shape[2] > noSmallBlockNeedThreshold ? 8U : 1U, shape[2]); + auto innerMostMBlockCandidates = getCandidate( + (uint32_t)shape[0], shape[0] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); + auto innerMostNBlockCandidates = getCandidate( + (uint32_t)shape[1], shape[1] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); + auto innerMostKBlockCandidates = getCandidate( + (uint32_t)shape[2], shape[2] > noSmallBlockNeedThreshold ? 8U : 1U, 256U); if (givenInnermostBlock.size() == 3) { innerMostMBlockCandidates = givenInnermostBlock[0] != 0 @@ -211,38 +239,56 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, ? std::vector{givenInnermostBlock[2]} : innerMostKBlockCandidates; } - llvm::outs() << "MThreadsCandidates size: " << MThreadsCandidates.size() + llvm::errs() << "MThreadsCandidates size: " << MThreadsCandidates.size() + << MThreadsCandidates << "\n"; + llvm::errs() << "NThreadsCandidates size: " << NThreadsCandidates.size() + << NThreadsCandidates << "\n"; + llvm::errs() << "KThreadsCandidates size: " << KThreadsCandidates.size() + << KThreadsCandidates << "\n"; + llvm::errs() << "MBlockCandidates size: " << MBlockCandidates.size() + << MBlockCandidates << "\n"; + llvm::errs() << "NBlockCandidates size: " << NBlockCandidates.size() + << NBlockCandidates << "\n"; + llvm::errs() << "KBlockCandidates size: " << KBlockCandidates.size() + << KBlockCandidates << "\n"; + llvm::errs() << "innerMostMBlockCandidates size: " + << innerMostMBlockCandidates.size() << innerMostMBlockCandidates << "\n"; - llvm::outs() << "NThreadsCandidates size: " << NThreadsCandidates.size() + llvm::errs() << "innerMostNBlockCandidates size: " + << innerMostNBlockCandidates.size() << innerMostNBlockCandidates << "\n"; - llvm::outs() << "KThreadsCandidates size: " << KThreadsCandidates.size() + llvm::errs() << "innerMostKBlockCandidates size: " + << innerMostKBlockCandidates.size() << innerMostKBlockCandidates << "\n"; - llvm::outs() << "MBlockCandidates size: " << MBlockCandidates.size() << "\n"; - llvm::outs() << "NBlockCandidates size: " << NBlockCandidates.size() << "\n"; - llvm::outs() << "KBlockCandidates size: " << KBlockCandidates.size() << "\n"; - llvm::outs() << "innerMostMBlockCandidates size: " - << innerMostMBlockCandidates.size() << "\n"; - llvm::outs() << "innerMostNBlockCandidates size: " - << innerMostNBlockCandidates.size() << "\n"; - llvm::outs() << "innerMostKBlockCandidates size: " - << innerMostKBlockCandidates.size() << "\n"; for (auto MThreads : MThreadsCandidates) { for (auto NThreads : NThreadsCandidates) { for (auto KThreads : KThreadsCandidates) { + if (!validateThreads({MThreads, NThreads, KThreads}, sysDesc)) { + continue; + } for (auto MBlock : MBlockCandidates) { - for (auto NBlock : NBlockCandidates) { - for (auto KBlock : KBlockCandidates) { - for (auto innerMostMBlock : innerMostMBlockCandidates) { - for (auto innerMostNBlock : innerMostNBlockCandidates) { + for (auto innerMostMBlock : innerMostMBlockCandidates) { + if (MBlock % innerMostMBlock != 0 || + shape[0] % innerMostMBlock != 0) { + continue; + } + for (auto NBlock : NBlockCandidates) { + for (auto innerMostNBlock : innerMostNBlockCandidates) { + if (NBlock % innerMostNBlock != 0 || + shape[1] % innerMostNBlock != 0) { + continue; + } + for (auto KBlock : KBlockCandidates) { for (auto innerMostKBlock : innerMostKBlockCandidates) { + if (KBlock % innerMostKBlock != 0 || + shape[2] % innerMostKBlock != 0) { + continue; + } MatmulConfig config{ MBlock, NBlock, KBlock, MThreads, NThreads, KThreads, innerMostMBlock, innerMostNBlock, innerMostKBlock}; - - if (isValidConfig(config, sysDesc, shape)) { - configs.push_back(config); - } + configs.push_back(config); } } } @@ -252,9 +298,38 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc, } } } + llvm::errs() << "Finish generating candidates. ConfigCandidates size: " + << configs.size() << "\n"; return configs; } +bool readConfigFromAttrs(MatmulConfig &config, ArrayRef attrs) { + bool hasPredefinedConfig = false; + for (auto attr : attrs) { + if (attr.getName() == "KBlock") { + config.KBlock = cast(attr.getValue()).getInt(); + hasPredefinedConfig = true; + } else if (attr.getName() == "KThreads") { + config.KThreads = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "NBlock") { + config.NBlock = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "NThreads") { + config.NThreads = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "MBlock") { + config.MBlock = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "MThreads") { + config.MThreads = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "innerMostMBlock") { + config.innerMostMBlock = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "innerMostNBlock") { + config.innerMostNBlock = cast(attr.getValue()).getInt(); + } else if (attr.getName() == "innerMostKBlock") { + config.innerMostKBlock = cast(attr.getValue()).getInt(); + } + } + return hasPredefinedConfig; +} + /* thread utilization computation intensity @@ -269,7 +344,6 @@ previous matmul MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { SystemDesc sysDesc; if (auto linalgOp = dyn_cast(root)) { - // TODO: build a more complex heuristic to determine the best tiling auto oprandDimType = *getOprandDimType(linalgOp); // get the origin M,N,K size auto MDimTypeIdx = extractDimTypeIdx(oprandDimType[0], DimType::M); @@ -292,7 +366,6 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { K *= s; } } - // innermost Block, if the layout is blockied layout, the innermost block // will derived from the layout directly auto defaultBlock = 32; @@ -331,56 +404,39 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) { givenInnermostBlock.push_back(0); } - // Number of block - auto MNumBlock = M / config.innerMostMBlock; - auto NNumBlock = N / config.innerMostNBlock; - auto KNumBlock = K / config.innerMostKBlock; - - // Threads - config.MThreads = 32; - config.NThreads = 1; - config.KThreads = 1; - - // Block - config.MBlock = (int)llvm::divideCeil(MNumBlock, config.MThreads) * - config.innerMostMBlock; - config.NBlock = (int)llvm::divideCeil(NNumBlock, config.NThreads) * - config.innerMostNBlock; - config.KBlock = (int)llvm::divideCeil(KNumBlock, config.KThreads) * - config.innerMostKBlock; - config.MBlock = 128; - config.NBlock = 128; - config.KBlock = 128; - config.MThreads = 2; - config.NThreads = 2; - config.KThreads = 1; + llvm::errs() << "M: " << M << ", N: " << N << ", K: " << K << "\n"; - llvm::outs() << "M: " << M << ", N: " << N << ", K: " << K << "\n"; + SmallVector> costModelList = { + {workloadBalancedCost, "workloadBalancedCost", 1}, + {hardwareEfficiencyCost, "hardwareEfficiencyCost", -1}, + {computationIntensityOnL2Cache, "computationIntensityOnL2Cache", -1}, + {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost", -1}}; - SmallVector> costModelList = { - {threadUtilizationCost, "threadUtilizationCost"}, - {hardwareEfficiencyCost, "hardwareEfficiencyCost"}, - {workloadBalancedCost, "workloadBalancedCost"}, - {memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost"}, - {computationIntensityOnL1Cache, "computationIntensityOnL1Cache"}}; + SmallVector attrs(linalgOp->getAttrs()); + bool hasPredefinedConfig = readConfigFromAttrs(config, attrs); - auto configCandidates = - prepareConfigCandidates(root, sysDesc, {M, N, K}, givenInnermostBlock); - - for (auto [fn, name] : costModelList) { - llvm::outs() << name << "\n\n"; - configCandidates = filterConfigByCostModel(configCandidates, linalgOp, - {M, N, K}, sysDesc, fn, 0.5); - llvm::outs() << "ConfigCandidates size: " << configCandidates.size() - << "\n"; - } - - if (!configCandidates.empty()) { - config = configCandidates[0]; + if (!hasPredefinedConfig) { + llvm::errs() << "No predefined config\n"; + auto configCandidates = prepareConfigCandidates(root, sysDesc, {M, N, K}, + givenInnermostBlock); + for (auto [fn, name, threshold] : costModelList) { + llvm::errs() << "\n" << name << "\n"; + configCandidates = filterConfigByCostModel( + configCandidates, linalgOp, {M, N, K}, sysDesc, fn, 0.5, threshold); + llvm::errs() << "ConfigCandidates size: " << configCandidates.size() + << "\n"; + } + if (configCandidates.size() > 0) { + config = configCandidates[0]; + } } - llvm::outs() << "Final config\nNumThreads: " << sysDesc.getNumThreads() + llvm::errs() << "Final config\nNumThreads: " << sysDesc.getNumThreads() << ", MatmulConfig: " << config << "\n"; + for (auto [fn, name, threshold] : costModelList) { + auto cost = fn(linalgOp, {M, N, K}, config, sysDesc); + llvm::errs() << name << ": " << cost << "\n"; + } } } } // namespace gc diff --git a/lib/gc/CAPI/CMakeLists.txt b/lib/gc/CAPI/CMakeLists.txt index 2da458bb5..35f2f54fb 100644 --- a/lib/gc/CAPI/CMakeLists.txt +++ b/lib/gc/CAPI/CMakeLists.txt @@ -5,5 +5,6 @@ add_mlir_public_c_api_library(GcCAPI MLIROneDNNGraph MLIRCPURuntimeDialect GCPasses + GCAnalysis MLIRCPURuntimeTransforms ) \ No newline at end of file