Skip to content

Commit

Permalink
support dlti
Browse files Browse the repository at this point in the history
  • Loading branch information
zhczhong committed Aug 1, 2024
1 parent ae819fd commit 8f27e8f
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 31 deletions.
68 changes: 43 additions & 25 deletions include/gc/Analysis/MatmulConfigAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,64 +10,82 @@
#define MLIR_ANALYSIS_MATMULCONFIGANALYSIS_H

#include "gc/Dialect/Linalgx/LinalgxOps.h"
#include "mlir/Dialect/DLTI/DLTI.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include <cstring>
#include "mlir/Interfaces/DataLayoutInterfaces.h"

namespace mlir {
namespace gc {

using namespace mlir;

// A mock for the taget information
// TODO: replace it with upstream hardware description model
struct SystemDesc {

static int getPositiveIntFromStr(char *str, int defaultValue = 1) {
if (!str || strlen(str) == 0 || str[0] > '9' || str[0] < '0') {
return defaultValue;
}
auto val = std::stoi(str);
return val > 0 ? val : defaultValue;
}

// get runtime OMP_NUM_THREADS
uint32_t getNumThreads() {
char *numThreads = getenv("OMP_NUM_THREADS");
return getPositiveIntFromStr(numThreads, 1);
std::optional<Attribute> numThreads = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("num_threads"));
if (numThreads && isa<IntegerAttr>(*numThreads)) {
return dyn_cast<IntegerAttr>(*numThreads).getInt();
}
return 1;
}
// get cache size by cacheLevel
size_t getCacheSize(uint8_t cacheLevel) {
if (cacheLevel == 1) {
char *cacheSize = getenv("L1_CACHE_SIZE");
return getPositiveIntFromStr(cacheSize, 0);
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("L1_cache_size_in_bytes"));
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
}
} else if (cacheLevel == 2) {
char *cacheSize = getenv("L2_CACHE_SIZE");
return getPositiveIntFromStr(cacheSize, 0);
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("L2_cache_size_in_bytes"));
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
}
} else if (cacheLevel == 3) {
char *cacheSize = getenv("L3_CACHE_SIZE");
return getPositiveIntFromStr(cacheSize, 0);
std::optional<Attribute> cacheSize = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("L3_cache_size_in_bytes"));
if (cacheSize && isa<IntegerAttr>(*cacheSize)) {
return dyn_cast<IntegerAttr>(*cacheSize).getInt();
}
}
return 0;
}

// get the maximum vector length in bits
size_t getMaxVectorLength() {
char *maxVectorLanes = getenv("MAX_VECTOR_LENGTH");
return getPositiveIntFromStr(maxVectorLanes, 512);
std::optional<Attribute> maxVectorLength = layout.getDevicePropertyValue(
Builder(ctx).getStringAttr("CPU" /* device ID*/),
Builder(ctx).getStringAttr("max_vector_width"));
if (maxVectorLength && isa<IntegerAttr>(*maxVectorLength)) {
return dyn_cast<IntegerAttr>(*maxVectorLength).getInt();
}
return 512;
}

SystemDesc(ModuleOp m) : layout(m), ctx(m->getContext()) {}

private:
DataLayout layout;
MLIRContext *ctx;
};

// The configuration for matmul tiling
// TODO: support batch matmul
struct MatmulConfig {
// The number of threads distributed to M, N, K
uint32_t MThreads, NThreads, KThreads;
// The innermost block size for M, N, K which will be directly converted to
// brgemm.
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
// The outer block size for M, N, K which will be used to decide the loop tile
// size in single thread
uint32_t MBlock, NBlock, KBlock;
// The innermost block size for M, N, K which will be directly converted to
// brgemm.
uint32_t innerMostMBlock, innerMostNBlock, innerMostKBlock;
};

enum DimType { Batch, M, N, K };
Expand Down
11 changes: 6 additions & 5 deletions lib/gc/Analysis/MatmulConfigAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ double vectorRegEfficiencyCost(linalg::LinalgOp &linalgOp,
size_t dtypeSize = DataLayout().getTypeSizeInBits(
ShapeAdaptor(linalgOp.getDpsInputs()[1].getType()).getElementType());
size_t maxVectorLength = sysDesc.getMaxVectorLength() / dtypeSize;
// TODO: take matrix register like amx into account
double cost = (maxVectorLength - config.innerMostMBlock % maxVectorLength) %
maxVectorLength * 1.0 / config.innerMostMBlock +
(maxVectorLength - config.innerMostKBlock % maxVectorLength) %
Expand Down Expand Up @@ -270,8 +271,8 @@ prepareConfigCandidates(Operation *root, SystemDesc &sysDesc,
continue;
}
MatmulConfig config{
MBlock, NBlock, KBlock,
MThreads, NThreads, KThreads,
MBlock, NBlock, KBlock,
innerMostMBlock, innerMostNBlock, innerMostKBlock};
configs.push_back(config);
}
Expand Down Expand Up @@ -311,13 +312,13 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
} else if (attr.getName() == "MThreads") {
config.MThreads = cast<IntegerAttr>(attr.getValue()).getInt();
cfgItemCnt++;
} else if (attr.getName() == "innerMostMBlock") {
} else if (attr.getName() == "innermostMBlock") {
config.innerMostMBlock = cast<IntegerAttr>(attr.getValue()).getInt();
cfgItemCnt++;
} else if (attr.getName() == "innerMostNBlock") {
} else if (attr.getName() == "innermostNBlock") {
config.innerMostNBlock = cast<IntegerAttr>(attr.getValue()).getInt();
cfgItemCnt++;
} else if (attr.getName() == "innerMostKBlock") {
} else if (attr.getName() == "innermostKBlock") {
config.innerMostKBlock = cast<IntegerAttr>(attr.getValue()).getInt();
cfgItemCnt++;
}
Expand All @@ -338,7 +339,7 @@ bool readConfigFromAttrs(MatmulConfig &config, ArrayRef<NamedAttribute> attrs) {
// previous matmul
MatmulConfigAnalysis::MatmulConfigAnalysis(Operation *root) {
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(root)) {
SystemDesc sysDesc;
SystemDesc sysDesc(root->getParentOfType<ModuleOp>());
SmallVector<SmallVector<DimType>> oprandDimType =
*getOprandDimType(linalgOp);
// get the origin M,N,K size
Expand Down
2 changes: 1 addition & 1 deletion lib/gc/Transforms/DeepTileContractionNamedOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ static Operation *findParentFillOp(Value val) {
llvm::find(skipOpList, currentOp->getName().getStringRef()) !=
skipOpList.end() &&
!isa<linalg::FillOp>(currentOp)) {
currentOp = currentOp->getResult(0).getDefiningOp();
currentOp = currentOp->getOperand(0).getDefiningOp();
}
if (currentOp && isa<linalg::FillOp>(currentOp)) {
return currentOp;
Expand Down
2 changes: 2 additions & 0 deletions lib/gc/Transforms/TilingUtil.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
namespace mlir {
namespace linalgX {

// An enahncement for the upstream pass to support tiling reduction for MKmk
// like cases(with multiple reduction iterators).
FailureOr<linalg::ForallReductionTilingResult> tileReductionUsingForall(
RewriterBase &b, PartialReductionOpInterface op,
ArrayRef<OpFoldResult> threadNums, ArrayRef<OpFoldResult> tileSizes,
Expand Down
47 changes: 47 additions & 0 deletions test/gc/Transform/deepTileContractionNamedOp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,50 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12
return %2 : tensor<4096x4096xbf16>
}

// -----

module attributes {
dlti.target_system_spec = #dlti.target_system_spec<
"CPU": #dlti.target_device_spec<
#dlti.dl_entry<"L1_cache_size_in_bytes", 49152 : i32>,
#dlti.dl_entry<"L2_cache_size_in_bytes", 2097152 : i32>,
#dlti.dl_entry<"L3_cache_size_in_bytes", 110100480 : i32>,
#dlti.dl_entry<"num_threads", 56 : i32>,
#dlti.dl_entry<"max_vector_width", 512 : i32>>
>} {
/// CHECK-LABEL: @matmul_2Dx4D_bf16_with_dlti
func.func @matmul_2Dx4D_bf16_with_dlti(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x128x16x32x2xbf16>) -> tensor<4096x4096xbf16> {
%cst_0 = arith.constant 0.000000e+00 : bf16
%0 = tensor.empty() : tensor<4096x4096xbf16>
%1 = linalg.fill ins(%cst_0 : bf16) outs(%0 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
// CHECK: scf.forall
// CHECK: tensor.extract_slice
// CHECK: scf.forall
// CHECK: tensor.extract_slice
// CHECK: scf.forall
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK: scf.for
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: scf.for
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: linalg.transpose
// CHECK: scf.if
// CHECK: linalg.fill
// CHECK: linalgx.batch_reduce_matmul_vnni
// CHECK: else
// CHECK: linalgx.batch_reduce_matmul_vnni
// CHECK: scf.forall.in_parallel
// CHECK: scf.forall.in_parallel
// CHECK: scf.forall.in_parallel
// CHECK: linalg.reduce
// CHECK: linalg.copy
%2 = linalgx.mm2d_vnni ins(%arg0, %arg1 : tensor<4096x4096xbf16>, tensor<128x128x16x32x2xbf16>) outs(%1 : tensor<4096x4096xbf16>) -> tensor<4096x4096xbf16>
return %2 : tensor<4096x4096xbf16>
}

}

0 comments on commit 8f27e8f

Please sign in to comment.