diff --git a/mlir/lib/Conversion/AIRLoweringPass.cpp b/mlir/lib/Conversion/AIRLoweringPass.cpp index a3babad0f..3589f8f43 100644 --- a/mlir/lib/Conversion/AIRLoweringPass.cpp +++ b/mlir/lib/Conversion/AIRLoweringPass.cpp @@ -506,8 +506,15 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder, } else { opers.push_back(builder.create( loc, IntegerType::get(ctx, 64), launch.getInductionVars()[0])); - opers.push_back(builder.create( - loc, IntegerType::get(ctx, 64), launch.getInductionVars()[1])); + if (launch.getNumLoops() == 2) + opers.push_back(builder.create( + loc, IntegerType::get(ctx, 64), launch.getInductionVars()[1])); + else if (launch.getNumLoops() == 1) + opers.push_back(builder.create( + loc, i64Ty, IntegerAttr::get(i64Ty, 0))); + else + assert(false && "lowering of air.launch with more than 2 dimensions is " + "currently unsupported"); } opers.push_back(thisOp.getMemref()); diff --git a/mlir/test/Conversion/AIRLowering/air_channel_get_put.mlir b/mlir/test/Conversion/AIRLowering/air_channel_get_put.mlir index d44e595a4..3e7fbf4a9 100644 --- a/mlir/test/Conversion/AIRLowering/air_channel_get_put.mlir +++ b/mlir/test/Conversion/AIRLowering/air_channel_get_put.mlir @@ -163,3 +163,49 @@ func.func @par_with_for_put_get(%arg0: memref<32x16xi32>, %arg1: memref<32x16xi3 } return } + +// CHECK-LABEL: func.func @one_d_scf_parallel +// CHECK: affine.for +// CHECK: airrt.dma_memcpy_nd(%{{.*}}, %{{.*}}, %{{.*}}, %arg0[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}], [%{{.*}}, %{{.*}}, %{{.*}}]) : (i32, i64, i64, memref<128xf32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event +// CHECK: } {affine_opt_label = "tiling"} + +#map = affine_map<()[s0] -> (s0 * 64)> +air.channel @channel_6 [1, 1] +func.func @one_d_scf_parallel(%arg0: memref<128xf32>, %arg1: memref<128xf32>) { + %c2 = arith.constant 2 : index + %0 = air.launch async (%arg2) in (%arg3=%c2) args(%arg4=%arg0) : memref<128xf32> attributes {id = 1 : i32} { + %c64 = arith.constant 64 : index + %c1 = arith.constant 1 : index + %async_token, %results = air.execute -> (index) { + %3 = affine.apply #map()[%arg2] + air.execute_terminator %3 : index + } + %1 = air.channel.put async [%async_token] @channel_6[] (%arg4[%results] [%c64] [%c1]) {id = 1 : i32} : (memref<128xf32>) + %2 = air.segment @segment_0 async attributes {id = 2 : i32, x_loc = 0 : i64, x_size = 1 : i64, y_loc = 2 : i64, y_size = 4 : i64} { + %c1_0 = arith.constant 1 : index + %c2_1 = arith.constant 2 : index + %3 = air.wait_all async + %async_token_2, %results_3 = air.execute -> (memref<64xf32, 1>) { + %alloc = memref.alloc() : memref<64xf32, 1> + air.execute_terminator %alloc : memref<64xf32, 1> + } + %4 = air.channel.get async [%3, %async_token_2] @channel_6[] (%results_3[] [] []) {id = 3 : i32} : (memref<64xf32, 1>) + %5 = air.herd @herd_0 async [%4] tile (%arg5, %arg6) in (%arg7=%c1_0, %arg8=%c2_1) attributes {id = 3 : i32, x_loc = 0 : i64, y_loc = 2 : i64} { + %async_token_5, %results_6 = air.execute -> (memref<32xf32, 2>) { + %alloc = memref.alloc() : memref<32xf32, 2> + air.execute_terminator %alloc : memref<32xf32, 2> + } + %async_token_7 = air.execute [%async_token_5] { + memref.dealloc %results_6 : memref<32xf32, 2> + } + air.herd_terminator + } + %async_token_4 = air.execute [%4] { + memref.dealloc %results_3 : memref<64xf32, 1> + } + air.segment_terminator + } + air.launch_terminator + } + return +}