Skip to content

Commit

Permalink
fix comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhczhong committed Aug 1, 2024
1 parent 8f27e8f commit 23dfa97
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 88 deletions.
57 changes: 38 additions & 19 deletions lib/gc/Analysis/MatmulConfigAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,9 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,

template <typename T>
static llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
std::vector<T> arry) {
std::vector<T> array) {
ss << "[";
for (auto [idx, a] : llvm::enumerate(arry)) {
if (idx != 0) {
ss << ", ";
}
ss << a;
}
llvm::interleaveComma(array, ss);
ss << "]";
return ss;
}
Expand Down Expand Up @@ -174,24 +169,23 @@ std::vector<MatmulConfig>
filterConfigByCostModel(ArrayRef<MatmulConfig> configs,
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape,
SystemDesc &sysDesc, const CostModelFn &costModel,
float eliminationRatio = 0.5, float threshold = -1) {
float preserveRatio = 0.5, float threshold = -1) {
std::vector<MatmulConfig> result;
std::vector<float> costs;
std::vector<size_t> idx;
for (auto [i, config] : llvm::enumerate(configs)) {
for (auto &&[i, config] : llvm::enumerate(configs)) {
costs.push_back(costModel(linalgOp, shape, config, sysDesc));
idx.push_back(i);
}
std::stable_sort(idx.begin(), idx.end(), [&costs](size_t i1, size_t i2) {
return costs[i1] < costs[i2];
});
double thresholdCost =
costs[idx[(size_t)(eliminationRatio * configs.size())]];
double thresholdCost = costs[idx[(size_t)(preserveRatio * configs.size())]];
thresholdCost =
threshold < thresholdCost && threshold > 0 ? threshold : thresholdCost;
for (size_t i = 0; i < configs.size(); i++) {
if (costs[idx[i]] <= thresholdCost) {
result.push_back(configs[idx[i]]);
for (const auto &i : idx) {
if (costs[i] <= thresholdCost) {
result.push_back(configs[i]);
}
}
LLVM_DEBUG(llvm::dbgs() << "thresholdCost is: " << thresholdCost
Expand All @@ -210,6 +204,11 @@ std::vector<MatmulConfig>
prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
ArrayRef<uint32_t> shape,
ArrayRef<uint32_t> givenInnermostBlock) {
if (shape.size() < 3) {
LLVM_DEBUG(llvm::dbgs()
<< "The shape is invalid, no candidate is generated\n");
return {};
}
std::vector<MatmulConfig> configs;
uint32_t threads = sysDesc.getNumThreads();
std::vector<uint32_t> MThreadsCandidates =
Expand Down Expand Up @@ -290,10 +289,25 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
return configs;
}

bool validateConfig(const MatmulConfig &cfg) {
if (cfg.MThreads <= 0 || cfg.NThreads <= 0 || cfg.KThreads <= 0 ||
cfg.MBlock <= 0 || cfg.NBlock <= 0 || cfg.KBlock <= 0 ||
cfg.innerMostMBlock <= 0 || cfg.innerMostNBlock <= 0 ||
cfg.innerMostKBlock <= 0) {
return false;
}
if (cfg.MBlock % cfg.innerMostMBlock != 0 ||
cfg.NBlock % cfg.innerMostNBlock != 0 ||
cfg.KBlock % cfg.innerMostKBlock != 0) {
return false;
}
return true;
}

// read the config from the attributes for tuning
bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
size_t cfgItemCnt = 0;
for (auto &attr : attrs) {
for (const auto &attr : attrs) {
if (attr.getName() == "KBlock") {
config.KBlock = cast<IntegerAttr>(attr.getValue()).getInt();
cfgItemCnt++;
Expand Down Expand Up @@ -323,7 +337,12 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
cfgItemCnt++;
}
}
return cfgItemCnt == 9;
if (validateConfig(config)) {
return cfgItemCnt == 9;
} else {
LLVM_DEBUG(llvm::dbgs() << "The predefined config is invalid\n");
return false;
}
}

// Analyze the workload and system description to generate the default config
Expand All @@ -350,14 +369,14 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
SmallVector<unsigned> NDimTypeIdx =
extractDimTypeIdx(oprandDimType[1], DimType::N);
uint32_t M = 1U, N = 1U, K = 1U;
for (auto [s, dimType] :
for (auto &&[s, dimType] :
llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(0)),
oprandDimType[0])) {
if (dimType == DimType::M) {
M *= s;
}
}
for (auto [s, dimType] :
for (auto &&[s, dimType] :
llvm::zip(linalgOp.getShape(linalgOp.getDpsInputOperand(1)),
oprandDimType[1])) {
if (dimType == DimType::N) {
Expand Down Expand Up @@ -425,7 +444,7 @@ MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
SmallVector<uint32_t> shape = {M, N, K};
std::vector<MatmulConfig> configCandidates =
prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock);
for (auto [fn, name, threshold] : costModelList) {
for (auto &&[fn, name, threshold] : costModelList) {
configCandidates = filterConfigByCostModel(
configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold);
}
Expand Down
50 changes: 25 additions & 25 deletions lib/gc/Transforms/DeepTileContractionNamedOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ static void buildLinalgRegion(Operation *op, bool createTemporaryOp = false) {
// Check if the linalgOp need to be legalized to f32 accumulation type
static bool needToLegalizeDtype(linalg::LinalgOp linalgOp) {
mlir::Type dataType =
dyn_cast<mlir::RankedTensorType>(linalgOp.getDpsInputs()[0].getType())
dyn_cast<mlir::ShapedType>(linalgOp.getDpsInputs()[0].getType())
.getElementType();
mlir::Type resultType =
dyn_cast<mlir::RankedTensorType>(linalgOp.getDpsInits()[0].getType())
dyn_cast<mlir::ShapedType>(linalgOp.getDpsInits()[0].getType())
.getElementType();
return (dataType.isBF16() || dataType.isF16()) && dataType == resultType;
}
Expand Down Expand Up @@ -372,7 +372,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
linalg::LinalgOp currentOp = linalgOp;

bool hasFullResult = !option.isPartialResult;
for (auto [i, loopType] : llvm::enumerate(loopType)) {
for (auto &&[i, loopType] : llvm::enumerate(loopType)) {
ArrayRef<size_t> currentDim = loopDim[i];
ArrayRef<size_t> currentTileSize = nestedTileSizes[i];
if (loopType == OuterLoopGenerationOption::LoopType::ForOp) {
Expand Down Expand Up @@ -420,7 +420,7 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
cast<TilingInterface>(currentOp.getOperation()).getIterationDomain(b);
currentOp.getReductionDims(reductionDims);
bool tileOnReduction = false;
for (auto [d, tile] : llvm::zip(currentDim, currentTileSize)) {
for (auto &&[d, tile] : llvm::zip(currentDim, currentTileSize)) {
if (llvm::find(reductionDims, d) != reductionDims.end() && tile != 0 &&
(!getConstantIntValue(loopRanges[d].size) ||
tile != static_cast<size_t>(
Expand All @@ -438,22 +438,23 @@ generateOuterLoop(RewriterBase &b, linalg::LinalgOp linalgOp,
OpBuilder::InsertionGuard guard(b);
b.setInsertionPoint(currentOp);
if (tileOnReduction) {
for (auto [idx, tile] : llvm::enumerate(tileSizes)) {
for (auto &&[idx, tile] : llvm::enumerate(tileSizes)) {
if (isConstantIntValue(tile, 0) &&
llvm::find(reductionDims, idx) != reductionDims.end()) {
tileSizes[idx] = loopRanges[idx].size;
}
}
SmallVector<OpFoldResult> newParallelDims;
for (size_t i = 0UL; i < reductionDims.size(); i++) {
newParallelDims.push_back(getAsIndexOpFoldResult(b.getContext(), i));
for (auto iter : llvm::enumerate(reductionDims)) {
newParallelDims.push_back(
getAsIndexOpFoldResult(b.getContext(), iter.index()));
}
FailureOr<linalg::ForallReductionTilingResult> tilingResult =
linalgX::tileReductionUsingForall(
b, cast<PartialReductionOpInterface>(currentOp.getOperation()),
{}, tileSizes, newParallelDims, std::nullopt);
if (failed(tilingResult) &&
tilingResult->parallelTiledOps.size() == 1UL)
llvm::hasSingleElement(tilingResult->parallelTiledOps))
return failure();
currentOp =
dyn_cast<linalg::LinalgOp>(tilingResult->parallelTiledOps.back());
Expand Down Expand Up @@ -585,7 +586,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
: cfg.NBlock;

// Outer loop tile size
for (auto [tile, dim] :
for (auto &&[tile, dim] :
llvm::zip(SmallVector<size_t>{KParallelBlockSize, MParallelBlockSize,
NParallelBlockSize},
SmallVector<size_t>{KDimPos[0], MDimPos[0], NDimPos[0]})) {
Expand All @@ -596,27 +597,27 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
}

// Middle loop tile size
for (auto [tile, dim] :
for (auto &&[tile, dim] :
llvm::zip(SmallVector<size_t>{MOuterBlockSize, NOuterBlockSize,
KOuterBlockSize},
SmallVector<size_t>{MDimPos[0], NDimPos[0], KDimPos[0]})) {
option.nestedTileSizes.emplace_back(SmallVector<size_t>{tile});
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
option.loopDim.emplace_back(SmallVector<size_t>{dim});
}
if (KDimPos.size() == 1) {
if (llvm::hasSingleElement(KDimPos)) {
option.nestedTileSizes.emplace_back(SmallVector<size_t>{cfg.KBlock});
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
option.loopDim.emplace_back(SmallVector<size_t>{KDimPos.back()});
}
// Inner loop tile size
if (MDimPos.size() == 1) {
if (llvm::hasSingleElement(MDimPos)) {
option.nestedTileSizes.emplace_back(
SmallVector<size_t>{cfg.innerMostMBlock});
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
option.loopDim.emplace_back(SmallVector<size_t>{MDimPos.back()});
}
if (NDimPos.size() == 1) {
if (llvm::hasSingleElement(NDimPos)) {
option.nestedTileSizes.emplace_back(
SmallVector<size_t>{cfg.innerMostNBlock});
option.loopType.emplace_back(OuterLoopGenerationOption::LoopType::ForOp);
Expand Down Expand Up @@ -656,7 +657,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
const linalg::ForallReductionTilingResult &result)
-> FailureOr<linalg::LinalgOp> {
ArrayRef<Value> initValue = result.initialValues;
if (initValue.size() == 1 &&
if (llvm::hasSingleElement(initValue) &&
isa<linalg::FillOp>(initValue[0].getDefiningOp())) {
rewriter.replaceOp(initValue[0].getDefiningOp(),
dyn_cast<DestinationStyleOpInterface>(
Expand Down Expand Up @@ -706,7 +707,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
SmallVector<int64_t> AInnermostDims, BInnermostDims, CInnermostDims;
bool firstM = true, firstK = true, firstN = true;
if (MDimNum > 1) {
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[0])) {
for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[0])) {
if (iter == DimType::M && firstM) {
AInnermostDims.push_back(1);
firstM = false;
Expand All @@ -721,7 +722,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
}
firstM = true;
firstN = true;
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[2])) {
for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[2])) {
if (iter == DimType::M && firstM) {
CInnermostDims.push_back(1);
firstM = false;
Expand All @@ -745,7 +746,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
if (NDimNum > 1) {
firstN = true;
firstK = true;
for (auto [idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
for (auto &&[idx, iter] : llvm::enumerate((*operandDimTypes)[1])) {
if (iter == DimType::N && firstN) {
BInnermostDims.push_back(1);
firstN = false;
Expand All @@ -768,13 +769,13 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(currentOp);
mlir::Type dataType =
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[0].getType())
dyn_cast<mlir::ShapedType>(currentOp.getDpsInputs()[0].getType())
.getElementType();
mlir::Type weightType =
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInputs()[1].getType())
dyn_cast<mlir::ShapedType>(currentOp.getDpsInputs()[1].getType())
.getElementType();
mlir::Type resultType =
dyn_cast<mlir::RankedTensorType>(currentOp.getDpsInits()[0].getType())
dyn_cast<mlir::ShapedType>(currentOp.getDpsInits()[0].getType())
.getElementType();

// update the extractSlice to static size, replace it with
Expand Down Expand Up @@ -821,9 +822,8 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
currentOp.getDpsInits()[0]);
// Create the brgemm op and replace the origin linalg op
linalg::LinalgOp matmul;
if (dyn_cast<mlir::RankedTensorType>(weightOprand.getType())
.getShape()
.size() == 3) {
if (dyn_cast<mlir::ShapedType>(weightOprand.getType()).getShape().size() ==
3) {
matmul = rewriter.create<linalg::BatchReduceMatmulOp>(
loc, resultOprand.getType(), ValueRange{dataOprand, weightOprand},
resultOprand);
Expand All @@ -843,7 +843,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
// fuse the low precision cast to the innermost body
rewriter.setInsertionPointAfter(currentOp);
Value cond;
for (LoopLikeOpInterface loop : option.KLoopHandles) {
for (LoopLikeOpInterface &loop : option.KLoopHandles) {
Value induceVar = turnOpFoldResultIntoValue(
rewriter, loc, *loop.getSingleInductionVar());
Value upBound = turnOpFoldResultIntoValue(rewriter, loc,
Expand Down Expand Up @@ -903,7 +903,7 @@ struct DeepTileMatmul : public OpInterfaceRewritePattern<linalg::LinalgOp> {
Value cond;
arith::ConstantIndexOp zeroConst =
rewriter.create<arith::ConstantIndexOp>(loc, 0);
for (LoopLikeOpInterface loop : option.KLoopHandles) {
for (LoopLikeOpInterface &loop : option.KLoopHandles) {
Value induceVar = loop.getLoopRegions().front()->front().getArgument(0);
Value currentCond = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, induceVar, zeroConst);
Expand Down
1 change: 0 additions & 1 deletion src/gc-opt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ set(gc_opt_libs
${conversion_libs}
${MLIR_LINK_COMPONENTS}
GCPasses
GCGPUPasses
GCAnalysis)

if(GC_USE_GPU)
Expand Down
Loading

0 comments on commit 23dfa97

Please sign in to comment.