Skip to content

Commit

Permalink
use expand/collapse_shape to do rank alter
Browse files Browse the repository at this point in the history
  • Loading branch information
zhczhong committed Aug 7, 2024
1 parent 892be10 commit 304dcde
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
32 changes: 20 additions & 12 deletions lib/gc/Transforms/DeepTileContractionNamedOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,19 +275,22 @@ static void setStaticSizeForExtractSliceOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> mixedOffsets = extractSlice.getMixedOffsets();
SmallVector<OpFoldResult> mixedSizes = extractSlice.getMixedSizes();
SmallVector<OpFoldResult> mixedStrides = extractSlice.getMixedStrides();
auto targetTensor = mlir::RankedTensorType::get(
SmallVector<int64_t>(size.begin() + shrinDimNum, size.end()),
extractSlice.getResult().getType().getElementType());
for (auto &&[i, s] : llvm::enumerate(size))
mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s);
if (shrinDimNum > 0)
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSlice,
mlir::RankedTensorType::get(
SmallVector<int64_t>(size.begin() + shrinDimNum, size.end()),
extractSlice.getResult().getType().getElementType()),
extractSlice.getSource(), mixedOffsets, mixedSizes, mixedStrides);
else
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
extractSlice, extractSlice.getSource(), mixedOffsets, mixedSizes,
mixedStrides);
Operation *newExtractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
extractSlice->getLoc(), extractSlice.getSource(), mixedOffsets,
mixedSizes, mixedStrides);
if (shrinDimNum > 0) {
rewriter.setInsertionPointAfter(newExtractSliceOp);
Value viewResult = tensorViewRankedTensor(
rewriter, targetTensor, newExtractSliceOp->getResult(0));
rewriter.replaceOp(extractSlice, viewResult);
} else {
rewriter.replaceOp(extractSlice, newExtractSliceOp);
}
}
}

Expand All @@ -304,8 +307,13 @@ static void setStaticSizeForInsertSliceOp(RewriterBase &rewriter, Operation *op,
SmallVector<OpFoldResult> mixedStrides = insertSlice.getMixedStrides();
for (auto &&[i, s] : llvm::enumerate(size))
mixedSizes[i] = getAsIndexOpFoldResult(rewriter.getContext(), s);
auto targetTensor = mlir::RankedTensorType::get(
size,
insertSlice.getDest().getType().getElementType());
Value viewResult = tensorViewRankedTensor(
rewriter, targetTensor, source);
rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
insertSlice, source, insertSlice.getDest(), mixedOffsets, mixedSizes,
insertSlice, viewResult, insertSlice.getDest(), mixedOffsets, mixedSizes,
mixedStrides);
}
}
Expand Down
6 changes: 5 additions & 1 deletion test/gc/Transform/deepTileContractionNamedOp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,13 @@ func.func @matmul_4Dx4D_bf16(%arg0: tensor<128x128x32x32xbf16>, %arg1: tensor<12
// CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1]
// CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1]
// CHECK: scf.for
// CHECK: tensor.extract_slice {{.*}} [1, 8, 32, 32] [1, 1, 1, 1]
// CHECK: tensor.collapse_shape {{.*}} tensor<1x8x32x32xbf16> into tensor<8x32x32xbf16>
// CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1]
// CHECK: tensor.collapse_shape {{.*}} tensor<1x8x16x32x2xbf16> into tensor<8x16x32x2xbf16>
// CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1]
// CHECK: tensor.collapse_shape {{.*}} tensor<1x1x32x32xf32> into tensor<32x32xf32>
// CHECK: tensor.extract_slice {{.*}} [1, 1, 32, 32] [1, 1, 1, 1]
// CHECK: tensor.collapse_shape {{.*}} tensor<1x1x32x32xbf16> into tensor<32x32xbf16>
// CHECK: scf.if
// CHECK: linalg.fill
// CHECK: linalgx.batch_reduce_matmul_vnni
Expand Down Expand Up @@ -92,6 +95,7 @@ func.func @matmul_2Dx4D_bf16(%arg0: tensor<4096x4096xbf16>, %arg1: tensor<128x12
// CHECK: tensor.extract_slice {{.*}} [32, 256] [1, 1]
// CHECK: scf.for
// CHECK: tensor.extract_slice {{.*}} [1, 8, 16, 32, 2] [1, 1, 1, 1, 1]
// CHECK: tensor.collapse_shape {{.*}} tensor<1x8x16x32x2xbf16> into tensor<8x16x32x2xbf16>
// CHECK: tensor.extract_slice {{.*}} [32, 32] [1, 1]
// CHECK: linalg.transpose {{.*}} permutation = [1, 0, 2]
// CHECK: scf.if
Expand Down

0 comments on commit 304dcde

Please sign in to comment.