Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] Local tests for linalg.packed_matmul/linalg.packing_map implementations #138

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/gc/Dialect/Linalgx/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
set(LLVM_TARGET_DEFINITIONS LinalgxDialect.td)
mlir_tablegen(LinalgxOpsAttributes.h.inc -gen-attrdef-decls -attrdefs-dialect=linalgx)
mlir_tablegen(LinalgxOpsAttributes.cpp.inc -gen-attrdef-defs -attrdefs-dialect=linalgx)

add_mlir_dialect(LinalgxOps linalgx)
set(LLVM_TARGET_DEFINITIONS LinalgxStructuredOps.td)
mlir_tablegen(LinalgxStructuredOps.h.inc -gen-op-decls)
Expand Down
46 changes: 46 additions & 0 deletions include/gc/Dialect/Linalgx/LinalgxDialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@
#define LINALGX_DIALECT

include "mlir/IR/OpBase.td"
include "mlir/IR/EnumAttr.td"
include "mlir/IR/BuiltinTypes.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/IR/CommonTypeConstraints.td"

//===----------------------------------------------------------------------===//
// Linalgx dialect definition.
Expand All @@ -32,6 +37,7 @@ def LinalgxDialect : Dialect {
"tensor::TensorDialect",
];

let useDefaultAttributePrinterParser = 1;
let extraClassDeclaration = [{
/// Attribute name used to memoize indexing maps for named ops.
constexpr const static ::llvm::StringLiteral
Expand All @@ -47,4 +53,44 @@ def LinalgxDialect : Dialect {
}];
}

class Linalgx_Attr<string name, string attrMnemonic, list<Trait> traits = []>
: AttrDef<LinalgxDialect, name, traits> {
let mnemonic = attrMnemonic;
}

def PackingMapAttr : Linalgx_Attr<"PackingMap", "packing_map"> {
let summary = "An Attribute containing a map between index of matmul input/output";
let description = [{
A map between index of matmul input/output.
}];

let cppNamespace = "::mlir::linalgx";
let parameters = (ins ArrayRefParameter<"uint64_t">:$first,
ArrayRefParameter<"uint64_t">:$second);

let assemblyFormat = "`<` `[` $first `]` `->` `[` $second `]` `>`";

let extraClassDeclaration = [{
/// Index first is 0; Index second is 1
unsigned getPackingSrcIndex() {
return getFirst().size() == 1 ? 0 : 1;
}
unsigned getPackingDstIndex() {
return getFirst().size() == 1 ? 1 : 0;
}
/// SrcDims.size() == 1; DstDims.size() >= 1
ArrayRef<uint64_t> getPackingSrcDims() {
return getPackingSrcIndex() == 0 ? getFirst()
: getSecond();
}
ArrayRef<uint64_t> getPackingDstDims() {
return getPackingDstIndex() == 0 ? getFirst()
: getSecond();
}
}];
}

def PackingMapArrayAttr : TypedArrayAttrBase<PackingMapAttr,
"packing_map array attr.">;

#endif // LINALGX_DIALECT
3 changes: 3 additions & 0 deletions include/gc/Dialect/Linalgx/LinalgxOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#define GET_ATTRDEF_CLASSES
#include "gc/Dialect/Linalgx/LinalgxOpsAttributes.h.inc"

#define GET_OP_CLASSES
#include "gc/Dialect/Linalgx/LinalgxOps.h.inc"

Expand Down
56 changes: 56 additions & 0 deletions include/gc/Dialect/Linalgx/LinalgxStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,62 @@ def Linalgx_Mm4DVnniOp
}];
}

def Linalgx_PackedMatmulOp
: LinalgxStructuredBase_Op<"packed_matmul", [AttrSizedOperandSegments]> {
let summary = "matmul with packed data format";
let description = [{
Use m_packing, n_packing and k_packing to define relation shape between C[M, N] = A[M, K] * B[K, N].
}];
let arguments = (ins
Variadic<TensorOrMemref>:$inputs,
Variadic<TensorOrMemref>:$outputs,
PackingMapArrayAttr:$m_packing,
PackingMapArrayAttr:$n_packing,
PackingMapArrayAttr:$k_packing
);
let results = (outs Variadic<TensorOrMemref>:$results);
let regions = (region AnyRegion:$region);

let skipDefaultBuilders = 1;
let builders = [
OpBuilder<
(ins
"TypeRange":$resultTensorTypes,
"ValueRange":$inputs,
"ValueRange":$outputs,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),
[{
buildStructuredOp($_builder, $_state, resultTensorTypes,
inputs, outputs, attributes, PackedMatmulOp::getRegionBuilder());
}]>
];

let hasCustomAssemblyFormat = 1;
let hasFolder = 1;
let hasVerifier = 1;

let extraClassDeclaration = structuredOpsBaseDecls # [{
// Declare functions necessary for LinalgStructuredInterface.
SmallVector<utils::IteratorType> getIteratorTypesArray();
ArrayAttr getIndexingMaps();
static unsigned getNumRegionArgs() { return 3; }
std::string getLibraryCallName() {
return "op_has_no_registered_library_name";
}

// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }

static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
getRegionBuilder() {
return regionBuilder;
}
}];
}

def Linalgx_BatchReduceMatmulVnniOp
: LinalgxStructuredBase_Op<"batch_reduce_matmul_vnni", [AttrSizedOperandSegments]> {
let summary = "Batch reduced matmul with 3d batch input and vnni packed weights";
Expand Down
8 changes: 8 additions & 0 deletions lib/gc/Dialect/Linalgx/LinalgxDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/Parser/Parser.h"
#include "llvm/ADT/TypeSwitch.h"

using namespace mlir;
using namespace mlir::linalgx;

#include "gc/Dialect/Linalgx/LinalgxOpsDialect.cpp.inc"
#define GET_ATTRDEF_CLASSES
#include "gc/Dialect/Linalgx/LinalgxOpsAttributes.cpp.inc"

void LinalgxDialect::initialize() {
addOperations<
Expand All @@ -32,4 +36,8 @@ void LinalgxDialect::initialize() {
#define GET_OP_LIST
#include "gc/Dialect/Linalgx/LinalgxStructuredOps.cpp.inc"
>();
addAttributes<
#define GET_ATTRDEF_LIST
#include "gc/Dialect/Linalgx/LinalgxOpsAttributes.cpp.inc"
>();
}
212 changes: 212 additions & 0 deletions lib/gc/Dialect/Linalgx/LinalgxOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,218 @@ LogicalResult Mm4DVnniOp::verify() {
return success();
}

//===----------------------------------------------------------------------===//
// PackedMatmulOp
//===----------------------------------------------------------------------===//

SmallVector<utils::IteratorType> PackedMatmulOp::getIteratorTypesArray() {
SmallVector<utils::IteratorType> iteratorTypes;
// get packing num for each packing map
auto getPackingIteratorTypes = [&](ArrayAttr packingMaps,
utils::IteratorType iterTy) {
for (auto &attr : packingMaps) {
auto packingNum =
llvm::cast<PackingMapAttr>(attr).getPackingDstDims().size();
iteratorTypes.insert(iteratorTypes.end(), packingNum, iterTy);
}
};
// Process order: m, n, k packing
getPackingIteratorTypes(getMPacking(), utils::IteratorType::parallel);
getPackingIteratorTypes(getNPacking(), utils::IteratorType::parallel);
getPackingIteratorTypes(getKPacking(), utils::IteratorType::reduction);
return iteratorTypes;
}

unsigned getPackingDimsExpr(PackedMatmulOp self,
SmallVector<SmallVector<AffineExpr>> &exprsArr) {
MLIRContext *context = self.getContext();
auto typeA = cast<ShapedType>(self.getDpsInputOperand(0)->get().getType());
auto typeB = cast<ShapedType>(self.getDpsInputOperand(1)->get().getType());
auto typeC = cast<ShapedType>(self.getDpsInitOperand(0)->get().getType());
SmallVector<AffineExpr> exprsA(typeA.getRank());
SmallVector<AffineExpr> exprsB(typeB.getRank());
SmallVector<AffineExpr> exprsC(typeC.getRank());
// dims count from 0
unsigned dims = 0;
//
auto getPackingExprs = [&](ArrayAttr attrArray, ArrayRef<ShapedType> types,
ArrayRef<SmallVector<AffineExpr> *> exprs) {
for (auto &attr : attrArray) {
auto packingMap = cast<PackingMapAttr>(attr);
auto srcIndex = packingMap.getPackingSrcIndex();
auto dstIndex = packingMap.getPackingDstIndex();
auto srcDims = packingMap.getPackingSrcDims();
auto dstDims = packingMap.getPackingDstDims();
auto &dstExprs = *exprs[dstIndex];
auto &srcExprs = *exprs[srcIndex];
auto compound = getAffineConstantExpr(0, context);
for (auto dim : dstDims) {
auto curr = getAffineDimExpr(dims++, context);
auto constant =
getAffineConstantExpr(types[dstIndex].getDimSize(dim), context);
compound = compound * constant + curr;
dstExprs[dim] = curr;
}
srcExprs[srcDims.front()] = compound;
}
};
// Process order: m, n, k packing, kept same as packing iterator types
getPackingExprs(self.getMPacking(), ArrayRef{typeA, typeC},
ArrayRef{&exprsA, &exprsC});
getPackingExprs(self.getNPacking(), ArrayRef{typeB, typeC},
ArrayRef{&exprsB, &exprsC});
getPackingExprs(self.getKPacking(), ArrayRef{typeA, typeB},
ArrayRef{&exprsA, &exprsB});
exprsArr.emplace_back(exprsA);
exprsArr.emplace_back(exprsB);
exprsArr.emplace_back(exprsC);
return dims;
}

ArrayAttr PackedMatmulOp::getIndexingMaps() {
static const char memoizeAttr[] = "linalg.memoized_indexing_maps";
ArrayAttr cached = getOperation()->getAttrOfType<ArrayAttr>(memoizeAttr);
if (cached)
return cached;

SmallVector<SmallVector<AffineExpr>> exprsArr;
auto dims = getPackingDimsExpr(*this, exprsArr);

MLIRContext *context = getContext();
auto mapA = simplifyAffineMap(AffineMap::get(dims, 0, exprsArr[0], context));
auto mapB = simplifyAffineMap(AffineMap::get(dims, 0, exprsArr[1], context));
auto mapC = simplifyAffineMap(AffineMap::get(dims, 0, exprsArr[2], context));

cached = Builder(context).getAffineMapArrayAttr({mapA, mapB, mapC});
getOperation()->setAttr(memoizeAttr, cached);
return cached;
}

void PackedMatmulOp::regionBuilder(ImplicitLocOpBuilder &b, Block &block,
ArrayRef<NamedAttribute> attrs) {
assert(3 > 0 && block.getNumArguments() == 3 &&
"PackedMatmulOp regionBuilder expects 3 (>=0) args");
RegionBuilderHelper helper(b, block);
SmallVector<Value> yields;

Value value1 =
helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
block.getArgument(0));
Value value2 =
helper.buildTypeFn(TypeFn::cast_signed, block.getArgument(2).getType(),
block.getArgument(1));
Value value3 = helper.buildBinaryFn(BinaryFn::mul, value1, value2);
Value value4 =
helper.buildBinaryFn(BinaryFn::add, block.getArgument(2), value3);
yields.push_back(value4);
helper.yieldOutputs(yields);
}

ParseResult PackedMatmulOp::parse(OpAsmParser &parser, OperationState &result) {
return ::parseNamedStructuredOp(parser, result,
PackedMatmulOp::getNumRegionArgs(),
PackedMatmulOp::getRegionBuilder());
}

void PackedMatmulOp::print(OpAsmPrinter &p) {
::printNamedStructuredOp(p, getOperation(), getInputs(), getOutputs());
}

LogicalResult PackedMatmulOp::fold(FoldAdaptor,
SmallVectorImpl<OpFoldResult> &) {
return memref::foldMemRefCast(*this);
}

void PackedMatmulOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (hasPureTensorSemantics())
return;
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
}

LogicalResult PackedMatmulOp::verify() {
// A[M, K]
// B[K, N]
// C[M, N]
// mPacking = A -> C
// nPacking = B -> C
// kPacking = A -> B
auto shapeA = cast<ShapedType>(getDpsInputOperand(0)->get().getType());
auto shapeB = cast<ShapedType>(getDpsInputOperand(1)->get().getType());
auto shapeC = cast<ShapedType>(getDpsInitOperand(0)->get().getType());
auto mPacking = getMPacking();
auto nPacking = getNPacking();
auto kPacking = getKPacking();

// check rank
bool hasRank = shapeA.hasRank() && shapeB.hasRank() && shapeC.hasRank();
if (!hasRank)
return emitOpError() << "input/output shape must have rank.";

// check packing axis
auto getAxisSet = [](ArrayAttr arrayAttr,
llvm::SmallSet<uint64_t, 8> &firstIndexSet,
llvm::SmallSet<uint64_t, 8> &secondIndexSet) {
for (auto &attr : arrayAttr) {
auto packingMap = cast<PackingMapAttr>(attr);
auto firstDims = packingMap.getFirst();
firstIndexSet.insert(firstDims.begin(), firstDims.end());
auto secondDims = packingMap.getSecond();
secondIndexSet.insert(secondDims.begin(), secondDims.end());
}
};
llvm::SmallSet<uint64_t, 8> indexSetA;
llvm::SmallSet<uint64_t, 8> indexSetB;
llvm::SmallSet<uint64_t, 8> indexSetC;
getAxisSet(mPacking, indexSetA, indexSetC);
getAxisSet(nPacking, indexSetB, indexSetC);
getAxisSet(kPacking, indexSetA, indexSetB);
bool checkAxis = (shapeA.getRank() == (int64_t)indexSetA.size()) &&
(shapeB.getRank() == (int64_t)indexSetB.size()) &&
(shapeC.getRank() == (int64_t)indexSetC.size());
if (!checkAxis)
return emitOpError() << "input/output must match packing axis.";

// check packing dims match
auto matchDims = [](ArrayAttr arrayAttr, ShapedType firstShape,
ShapedType secondShape) {
for (auto &attr : arrayAttr) {
auto packingMap = cast<PackingMapAttr>(attr);
bool isDynamic = false;
int64_t firstSize = 1;
auto firstDims = packingMap.getFirst();
for (auto dim : firstDims) {
auto size = firstShape.getDimSize(dim);
if (size == ShapedType::kDynamic)
isDynamic = true;
firstSize *= size;
}
int64_t secondSize = 1;
auto secondDims = packingMap.getSecond();
for (auto dim : secondDims) {
auto size = secondShape.getDimSize(dim);
if (size == ShapedType::kDynamic)
isDynamic = true;
secondSize *= size;
}
if (isDynamic)
continue;
if (firstSize != secondSize)
return false;
}
return true;
};
bool matchM = matchDims(mPacking, shapeA, shapeC);
bool matchN = matchDims(nPacking, shapeB, shapeC);
bool matchK = matchDims(kPacking, shapeA, shapeB);
bool checkMatch = matchM && matchN && matchK;
if (!checkMatch)
return emitOpError() << "input/output must match packing dim size.";

return success();
}

//===----------------------------------------------------------------------===//
// BatchReduceMatmulVnniOp
//===----------------------------------------------------------------------===//
Expand Down
Loading
Loading