diff --git a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp index 0ebaeeada893..4252b95ef54d 100644 --- a/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp @@ -28,8 +28,10 @@ void lowerDistributedToShared( // NV path because for non-KContig tensor their blocked and shared layout // still have the same order. if (auto blocked = dyn_cast(srcTy.getEncoding())) { + auto rank = blocked.getOrder().size(); auto inOrd = blocked.getOrder(); - crossGrain = inOrd[0] != outOrd[0]; + // it has to be 2D and blocked's and shared's order mismatch + crossGrain = (rank == 2) && (inOrd[0] != outOrd[0]); } assert(srcTy.getShape().size() <= 2 || (srcTy.getShape().size() == 3 && outOrd[2] == 0) && diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp index 405fef881411..d0d93e6e7afe 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/inThreadTranspose.cpp @@ -95,6 +95,7 @@ SmallVector getLoadInsts(Operation *op) { } bool needCvtToThreadRaked(Value operand) { + auto opTensorTy = cast(operand.getType()); auto opEnc = opTensorTy.getEncoding(); auto opDotOpEnc = dyn_cast(opEnc); @@ -116,6 +117,11 @@ bool needCvtToThreadRaked(Value operand) { // make sure it is converted from blocked layout if (!blockedEnc) return false; + auto rank = blockedEnc.getOrder().size(); + if (rank != 2) { + LDBG("inThreadRake only supports 2D case right now"); + return false; + } // check whether it's contiguous on K dimension int kDimNum = opDotOpEnc.getOpIdx() == 0 ? 1 : 0; auto order = blockedEnc.getOrder();