Skip to content

Commit

Permalink
Add support for rank-reduced vector.transfer_read/write (Xilinx#645)
Browse files Browse the repository at this point in the history
  • Loading branch information
erwei-xilinx authored Jul 3, 2024
1 parent c6d171d commit 0fb668c
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 4 deletions.
16 changes: 12 additions & 4 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1282,9 +1282,13 @@ air::writeAccessPattern(mlir::vector::TransferReadOp readOp) {
std::get<0>(pattern), std::get<1>(pattern),
std::get<2>(pattern));
// Update wraps based on vector shape and vector access patterns.
for (unsigned i = 0; i < std::get<1>(pattern).size(); i++)
unsigned rankOffset =
vectorTy.getShape().size() >= std::get<1>(pattern).size()
? 0
: std::get<1>(pattern).size() - vectorTy.getShape().size();
for (unsigned i = rankOffset; i < std::get<1>(pattern).size(); i++)
std::get<1>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), vectorTy.getShape()[i]);
builder.getUnknownLoc(), vectorTy.getShape()[i - rankOffset]);
updateAccessPatternByScfForNest(pattern, readOp.getIndices(), builder);
return pattern;
}
Expand All @@ -1303,9 +1307,13 @@ air::writeAccessPattern(mlir::vector::TransferWriteOp writeOp) {
std::get<0>(pattern), std::get<1>(pattern),
std::get<2>(pattern));
// Update wraps based on vector shape and vector access patterns.
for (unsigned i = 0; i < std::get<1>(pattern).size(); i++)
unsigned rankOffset =
vectorTy.getShape().size() >= std::get<1>(pattern).size()
? 0
: std::get<1>(pattern).size() - vectorTy.getShape().size();
for (unsigned i = rankOffset; i < std::get<1>(pattern).size(); i++)
std::get<1>(pattern)[i] = builder.create<arith::ConstantIndexOp>(
builder.getUnknownLoc(), vectorTy.getShape()[i]);
builder.getUnknownLoc(), vectorTy.getShape()[i - rankOffset]);
updateAccessPatternByScfForNest(pattern, writeOp.getIndices(), builder);
return pattern;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,3 +581,46 @@ func.func @func6(%arg0: memref<512x512xi32>, %arg1: memref<512x512xi32>, %arg2:
}
return
}

// Rank-reduced vector transferRead and transferWrite.

// CHECK-LABEL: func.func @func7
// CHECK-DAG: %[[CST0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[CSTBF16:.*]] = arith.constant 0.000000e+00 : bf16
// CHECK: vector.transfer_read %{{.*}}[%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]], %[[CSTBF16]] {in_bounds = [true, true, true]} : memref<1x1x4x8xbf16, 2 : i32>, vector<1x4x8xbf16>
// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] {in_bounds = [true, true, true]} : vector<1x4x8xbf16>, memref<1x1x4x8xbf16, 2 : i32>

func.func @func7() {
%0 = air.launch async () in () {
%1 = air.segment @segment_0 async {
%c4 = arith.constant 4 : index
%c2 = arith.constant 2 : index
%2 = air.herd @herd_0 async tile (%arg0, %arg1) in (%arg2=%c2, %arg3=%c4) {
%cst = arith.constant 0.000000e+00 : bf16
%c0 = arith.constant 0 : index
%async_token, %results = air.execute -> (memref<1x1x4x8xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x4x8xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x4x8xbf16, 2 : i32>
}
%async_token_0, %results_1 = air.execute -> (memref<1x1x4x8xbf16, 2 : i32>) {
%alloc = memref.alloc() : memref<1x1x4x8xbf16, 2 : i32>
air.execute_terminator %alloc : memref<1x1x4x8xbf16, 2 : i32>
}
%async_token_2, %results_3 = air.execute [%async_token_0] -> (vector<1x4x8xbf16>) {
%3 = vector.transfer_read %results_1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x4x8xbf16, 2 : i32>, vector<1x4x8xbf16>
air.execute_terminator %3 : vector<1x4x8xbf16>
}
%async_token_4 = air.execute [%async_token_2] {
vector.transfer_write %results_3, %results[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x8xbf16>, memref<1x1x4x8xbf16, 2 : i32>
}
%async_token_5 = air.execute [%async_token_2] {
memref.dealloc %results_1 : memref<1x1x4x8xbf16, 2 : i32>
}
%async_token_6 = air.execute [%async_token_4] {
memref.dealloc %results : memref<1x1x4x8xbf16, 2 : i32>
}
}
}
}
return
}

0 comments on commit 0fb668c

Please sign in to comment.