Skip to content

Commit

Permalink
Move dma to channel pass from Conversion to Transform (Xilinx#619)
Browse files Browse the repository at this point in the history
  • Loading branch information
fifield authored Jun 25, 2024
1 parent 05284a2 commit bcbfed5
Show file tree
Hide file tree
Showing 10 changed files with 1,752 additions and 1,688 deletions.
1 change: 0 additions & 1 deletion mlir/include/air/Conversion/ConvertToAIRPass.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
createParallelToLaunchPass(const ParallelToLaunchOptions &options);

std::unique_ptr<mlir::Pass> createCopyToDmaPass();
std::unique_ptr<mlir::Pass> createDmaToChannelPass();
std::unique_ptr<mlir::Pass> createInsertEmptyLaunchOverHerdPass();

} // namespace air
Expand Down
1 change: 0 additions & 1 deletion mlir/include/air/Conversion/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ using namespace mlir;
#define GEN_PASS_DEF_AIRTOAIE
#define GEN_PASS_DEF_AIRTOASYNC
#define GEN_PASS_DEF_COPYTODMA
#define GEN_PASS_DEF_DMATOCHANNEL
#define GEN_PASS_DEF_INSERTEMPTYLAUNCHOVERHERD
#define GEN_PASS_DEF_PARALLELTOHERD
#define GEN_PASS_DEF_PARALLELTOLAUNCH
Expand Down
82 changes: 0 additions & 82 deletions mlir/include/air/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -66,88 +66,6 @@ def CopyToDma : Pass<"air-copy-to-dma", "ModuleOp"> {
}];
}

def DmaToChannel : Pass<"air-dma-to-channel", "ModuleOp"> {
let summary = "Convert air.dma_memcpy_nd to air.channel";
let constructor = "xilinx::air::createDmaToChannelPass()";
let description = [{
Transforms direct memory access (DMA) operations into channel-based
communications, consisting of a series of channel put and get operations
via shared channel constructs.

Example:

Input:
```mlir
%0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c4, %arg3=%c4) args(%arg4=%results_5, %arg5=%results, %arg6=%results_2) : memref<512x512xi32>, memref<512x1024xi32>, memref<1024x512xi32> attributes {id = 3 : i32} {
%1 = air.segment @segment_0 async args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<512x512xi32>, memref<512x1024xi32>, memref<1024x512xi32> attributes {id = 2 : i32} {
...
%3 = scf.for %arg12 = %c0_8 to %c1024 step %c256 iter_args(%arg13 = %2) -> (!air.async.token) {
%8 = air.dma_memcpy_nd async [%arg13, %arg13] (%results_14[%c0_8, %arg12] [%c128, %c256] [%c1024, %c1], %arg10[%results_10, %arg12] [%c128, %c256] [%c1024, %c1]) {id = 1 : i32} : (memref<128x1024xi32, 1 : i32>, memref<512x1024xi32>)
...
}
%6 = air.herd @herd_0 async [%async_token_13, %async_token_15, %async_token_17] tile (%arg12, %arg13) in (%arg14=%c4_7, %arg15=%c4_7) args(%arg16=%results_14, %arg17=%results_16, %arg18=%results_18) : memref<128x1024xi32, 1 : i32>, memref<1024x128xi32, 1 : i32>, memref<128x128xi32, 1 : i32> attributes {id = 1 : i32} {
...
%9 = scf.for %arg19 = %c0_23 to %c128_26 step %c4_24 iter_args(%arg20 = %8) -> (!air.async.token) {
...
%16 = air.dma_memcpy_nd async [%async_token_37, %async_token_35, %arg20] (%results_38[%c0_23] [%c1024_22] [%c1_25], %arg16[%c0_44, %c0_43, %results_36] [%c4_24, %c32, %c8] [%c8, %c1024_22, %c1_25]) {broadcast_set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 3 >= 0)>, id = 3 : i32} : (memref<4x8x4x8xi32, 2 : i32>, memref<128x1024xi32, 1 : i32>)
...
}
...
air.herd_terminator
}
...
air.segment_terminator
}
air.launch_terminator
}
```

Output:
```mlir
...
air.channel @channel_8 [1, 1]
...
air.channel @channel_0 [1, 1] {broadcast_shape = [1, 4]}
...
%0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c4, %arg3=%c4) args(%arg4=%results_5, %arg5=%results, %arg6=%results_2) : memref<512x512xi32>, memref<512x1024xi32>, memref<1024x512xi32> attributes {id = 3 : i32} {
...
%2 = scf.for %arg7 = %c0_7 to %c1024 step %c256 iter_args(%arg8 = %1) -> (!air.async.token) {
...
%17 = air.channel.put async [%async_token_8, %arg8] @channel_8[] (%arg5[%results_9, %arg7] [%c128, %c256] [%c1024, %c1]) : (memref<512x1024xi32>)
...
}
...
%16 = air.segment @segment_0 async args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<512x512xi32>, memref<512x1024xi32>, memref<1024x512xi32> attributes {id = 2 : i32} {
...
%18 = scf.for %arg12 = %c0_32 to %c1024_33 step %c256_34 iter_args(%arg13 = %17) -> (!air.async.token) {
%49 = air.channel.get async [%arg13, %arg13] @channel_8[] (%results_40[%c0_32, %arg12] [%c128_30, %c256_34] [%c1024_33, %c1_29]) : (memref<128x1024xi32, 1 : i32>)
...
}
...
%23 = scf.for %arg12 = %c0_47 to %c128_50 step %c4_48 iter_args(%arg13 = %22) -> (!air.async.token) {
...
%49 = air.channel.put async [%async_token_160, %async_token_39, %arg13] @channel_0[] (%results_40[%c0_163, %c0_162, %results_161] [%c4_48, %c32, %c8] [%c8, %c1024_46, %c1_49]) : (memref<128x1024xi32, 1 : i32>)
...
}
...
%47 = air.herd @herd_0 async [%async_token_39, %async_token_41, %async_token_43] tile (%arg12, %arg13) in (%arg14=%c4_31, %arg15=%c4_31) args(%arg16=%results_40, %arg17=%results_42, %arg18=%results_44) : memref<128x1024xi32, 1 : i32>, memref<1024x128xi32, 1 : i32>, memref<128x128xi32, 1 : i32> attributes {id = 1 : i32} {
...
%50 = scf.for %arg19 = %c0_155 to %c128_159 step %c4_156 iter_args(%arg20 = %49) -> (!air.async.token) {
...
%57 = air.channel.get async [%async_token_170, %async_token_168, %arg20] @channel_0[%arg12, %arg13] (%results_171[%c0_155] [%c1024_154] [%c1_158]) : (memref<4x8x4x8xi32, 2 : i32>)
...
}
...
air.herd_terminator
}
air.segment_terminator
}
air.launch_terminator
}
```
}];
}

def AIRToAsync : Pass<"air-to-async", "ModuleOp"> {
let summary = "AIR dialect lowering";
let constructor = "xilinx::air::createAIRToAsyncPass()";
Expand Down
24 changes: 24 additions & 0 deletions mlir/include/air/Transform/AIRDmaToChannel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
//===- AIRDmaToChannel.h ----------------------------------------*- C++ -*-===//
//
// Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
// SPDX-License-Identifier: MIT
//
//===----------------------------------------------------------------------===//

#ifndef AIR_DMA_TO_CHANNEL_H
#define AIR_DMA_TO_CHANNEL_H

#include "air/Transform/PassDetail.h"

#include "mlir/Pass/Pass.h"
#include <memory>

namespace xilinx {
namespace air {

std::unique_ptr<mlir::Pass> createDmaToChannelPass();

} // namespace air
} // namespace xilinx

#endif // AIR_DMA_TO_CHANNEL_H
1 change: 1 addition & 0 deletions mlir/include/air/Transform/PassDetail.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ namespace air {
#define GEN_PASS_DEF_AFFINELOOPOPTPASS
#define GEN_PASS_DEF_AIRSEGMENTLOOPFUSION
#define GEN_PASS_DEF_AIRSPLITL2MEMREFFORBUFFERCONSTRAINTPASS
#define GEN_PASS_DEF_DMATOCHANNEL
#include "air/Transform/Passes.h.inc"

} // namespace air
Expand Down
1 change: 1 addition & 0 deletions mlir/include/air/Transform/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "air/Transform/AIRDependencyCanonicalize.h"
#include "air/Transform/AIRDependencyParseGraph.h"
#include "air/Transform/AIRDependencyScheduleOpt.h"
#include "air/Transform/AIRDmaToChannel.h"
#include "air/Transform/AIRHerdAssignPass.h"
#include "air/Transform/AIRHerdPlacementPass.h"
#include "air/Transform/AIRLinalgCodegen.h"
Expand Down
81 changes: 81 additions & 0 deletions mlir/include/air/Transform/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1380,4 +1380,85 @@ def AIRSplitL2MemrefForBufferConstraintPass : Pass<"air-split-l2-memref", "func:
}];
}

def DmaToChannel : Pass<"air-dma-to-channel", "ModuleOp"> {
let summary = "Convert air.dma_memcpy_nd to air.channel";
let constructor = "xilinx::air::createDmaToChannelPass()";
let description = [{
Transforms direct memory access (DMA) operations into channel-based
communications, consisting of a series of channel put and get operations
via shared channel constructs.

Example:

Input:
```mlir
%0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c4, %arg3=%c4) args(%arg4=%results_5, %arg5=%results, %arg6=%results_2) : memref<512x512xi32>, memref<512x1024xi32>, memref<1024x512xi32> attributes {id = 3 : i32} {
%1 = air.segment @segment_0 async args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<512x512xi32>, memref<512x1024xi32>, memref<1024x512xi32> attributes {id = 2 : i32} {
...
%3 = scf.for %arg12 = %c0_8 to %c1024 step %c256 iter_args(%arg13 = %2) -> (!air.async.token) {
%8 = air.dma_memcpy_nd async [%arg13, %arg13] (%results_14[%c0_8, %arg12] [%c128, %c256] [%c1024, %c1], %arg10[%results_10, %arg12] [%c128, %c256] [%c1024, %c1]) {id = 1 : i32} : (memref<128x1024xi32, 1 : i32>, memref<512x1024xi32>)
...
}
%6 = air.herd @herd_0 async [%async_token_13, %async_token_15, %async_token_17] tile (%arg12, %arg13) in (%arg14=%c4_7, %arg15=%c4_7) args(%arg16=%results_14, %arg17=%results_16, %arg18=%results_18) : memref<128x1024xi32, 1 : i32>, memref<1024x128xi32, 1 : i32>, memref<128x128xi32, 1 : i32> attributes {id = 1 : i32} {
...
%9 = scf.for %arg19 = %c0_23 to %c128_26 step %c4_24 iter_args(%arg20 = %8) -> (!air.async.token) {
...
%16 = air.dma_memcpy_nd async [%async_token_37, %async_token_35, %arg20] (%results_38[%c0_23] [%c1024_22] [%c1_25], %arg16[%c0_44, %c0_43, %results_36] [%c4_24, %c32, %c8] [%c8, %c1024_22, %c1_25]) {broadcast_set = affine_set<()[s0, s1] : (s0 == 0, s1 >= 0, -s1 + 3 >= 0)>, id = 3 : i32} : (memref<4x8x4x8xi32, 2 : i32>, memref<128x1024xi32, 1 : i32>)
...
}
...
air.herd_terminator
}
...
air.segment_terminator
}
air.launch_terminator
}
```

Output:
```mlir
...
air.channel @channel_8 [1, 1]
...
air.channel @channel_0 [1, 1] {broadcast_shape = [1, 4]}
...
%0 = air.launch async [%async_token_0, %async_token_3, %async_token_6] (%arg0, %arg1) in (%arg2=%c4, %arg3=%c4) args(%arg4=%results_5, %arg5=%results, %arg6=%results_2) : memref<512x512xi32>, memref<512x1024xi32>, memref<1024x512xi32> attributes {id = 3 : i32} {
...
%2 = scf.for %arg7 = %c0_7 to %c1024 step %c256 iter_args(%arg8 = %1) -> (!air.async.token) {
...
%17 = air.channel.put async [%async_token_8, %arg8] @channel_8[] (%arg5[%results_9, %arg7] [%c128, %c256] [%c1024, %c1]) : (memref<512x1024xi32>)
...
}
...
%16 = air.segment @segment_0 async args(%arg7=%arg0, %arg8=%arg1, %arg9=%arg4, %arg10=%arg5, %arg11=%arg6) : index, index, memref<512x512xi32>, memref<512x1024xi32>, memref<1024x512xi32> attributes {id = 2 : i32} {
...
%18 = scf.for %arg12 = %c0_32 to %c1024_33 step %c256_34 iter_args(%arg13 = %17) -> (!air.async.token) {
%49 = air.channel.get async [%arg13, %arg13] @channel_8[] (%results_40[%c0_32, %arg12] [%c128_30, %c256_34] [%c1024_33, %c1_29]) : (memref<128x1024xi32, 1 : i32>)
...
}
...
%23 = scf.for %arg12 = %c0_47 to %c128_50 step %c4_48 iter_args(%arg13 = %22) -> (!air.async.token) {
...
%49 = air.channel.put async [%async_token_160, %async_token_39, %arg13] @channel_0[] (%results_40[%c0_163, %c0_162, %results_161] [%c4_48, %c32, %c8] [%c8, %c1024_46, %c1_49]) : (memref<128x1024xi32, 1 : i32>)
...
}
...
%47 = air.herd @herd_0 async [%async_token_39, %async_token_41, %async_token_43] tile (%arg12, %arg13) in (%arg14=%c4_31, %arg15=%c4_31) args(%arg16=%results_40, %arg17=%results_42, %arg18=%results_44) : memref<128x1024xi32, 1 : i32>, memref<1024x128xi32, 1 : i32>, memref<128x128xi32, 1 : i32> attributes {id = 1 : i32} {
...
%50 = scf.for %arg19 = %c0_155 to %c128_159 step %c4_156 iter_args(%arg20 = %49) -> (!air.async.token) {
...
%57 = air.channel.get async [%async_token_170, %async_token_168, %arg20] @channel_0[%arg12, %arg13] (%results_171[%c0_155] [%c1024_154] [%c1_158]) : (memref<4x8x4x8xi32, 2 : i32>)
...
}
...
air.herd_terminator
}
air.segment_terminator
}
air.launch_terminator
}
```
}];
}
#endif // AIR_CONVERSION_PASSES
Loading

0 comments on commit bcbfed5

Please sign in to comment.