Skip to content

Commit

Permalink
enforce inThreadTranspose to 2D to fix test_dot3d
Browse files Browse the repository at this point in the history
  • Loading branch information
jtang10 committed Nov 26, 2024
1 parent d07d944 commit ccc6197
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
4 changes: 3 additions & 1 deletion lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ void lowerDistributedToShared(
// NV path because for non-KContig tensor their blocked and shared layout
// still have the same order.
if (auto blocked = dyn_cast<BlockedEncodingAttr>(srcTy.getEncoding())) {
auto rank = blocked.getOrder().size();
auto inOrd = blocked.getOrder();
crossGrain = inOrd[0] != outOrd[0];
// it has to be 2D and blocked's and shared's order mismatch
crossGrain = (rank == 2) && (inOrd[0] != outOrd[0]);
}
assert(srcTy.getShape().size() <= 2 ||
(srcTy.getShape().size() == 3 && outOrd[2] == 0) &&
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ SmallVector<Operation *> getLoadInsts(Operation *op) {
}

bool needCvtToThreadRaked(Value operand) {

auto opTensorTy = cast<RankedTensorType>(operand.getType());
auto opEnc = opTensorTy.getEncoding();
auto opDotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(opEnc);
Expand All @@ -116,6 +117,11 @@ bool needCvtToThreadRaked(Value operand) {
// make sure it is converted from blocked layout
if (!blockedEnc)
return false;
auto rank = blockedEnc.getOrder().size();
if (rank != 2) {
LDBG("inThreadRake only supports 2D case right now");
return false;
}
// check whether it's contiguous on K dimension
int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0;
auto order = blockedEnc.getOrder();
Expand Down

0 comments on commit ccc6197

Please sign in to comment.