diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index cc2d866e1..4e84788b4 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -2459,21 +2459,6 @@ struct BroadcastDetection { } } } - // If not variant wrt herd, then check for fixed row-wise or col-wise - // offset. - int src_memspace = llvm::cast(dma_op.getSrcMemref().getType()) - .getMemorySpaceAsInt(); - auto externalOffsets = src_memspace == (int)air::MemorySpace::L1 - ? dma_op.getDstOffsets() - : dma_op.getSrcOffsets(); - if (!hl_op && externalOffsets.size() == - dma_op->getParentOfType().getNumDims()) { - hl_op = dma_op->getParentOfType(); - if (getConstantIntValue(externalOffsets[0])) - isVariantWrtHerdRows = true; - if (getConstantIntValue(externalOffsets[1])) - isVariantWrtHerdCols = true; - } if (!hl_op) { // If dma op is completely independent of the parent herd induction diff --git a/mlir/lib/Util/Dependency.cpp b/mlir/lib/Util/Dependency.cpp index 77fa58bf3..00f5037b5 100644 --- a/mlir/lib/Util/Dependency.cpp +++ b/mlir/lib/Util/Dependency.cpp @@ -53,6 +53,8 @@ void traceDependentInductionVar(SmallVector candidate_scalar_operands, SmallVector &loop_dep_history, std::vector &op_history) { for (auto operand : candidate_scalar_operands) { + if (!llvm::isa(operand.getType())) + continue; // Only tracing scalar operands // If parent loop op is an scf.for if (auto for_op = mlir::scf::getForInductionVarOwner(operand)) { loop_dep_history.push_back(for_op.getInductionVar()); @@ -76,27 +78,26 @@ void traceDependentInductionVar(SmallVector candidate_scalar_operands, // Recursively trace dependency to loop induction vars for (auto operand : candidate_scalar_operands) { - if (operand && llvm::isa( - operand.getType())) { // Only tracing scalar operands - if (operand.getDefiningOp() && - mlir::dyn_cast(operand.getDefiningOp())) { - auto ancestor_async_op = - dyn_cast(operand.getDefiningOp()); - op_history.push_back(ancestor_async_op.getOperation()); - traceDependentInductionVar(ancestor_async_op, loop_dep_history, - op_history); - } else { - // Trace dependency through a for loop - if (auto for_op = getForRegionIterArgsOwner(operand)) { - for (auto iter_arg : for_op.getInitArgs()) { - if (operand == iter_arg) { - loop_dep_history.push_back(iter_arg); - } + if (!llvm::isa(operand.getType())) + continue; // Only tracing scalar operands + if (operand.getDefiningOp() && + mlir::dyn_cast(operand.getDefiningOp())) { + auto ancestor_async_op = + dyn_cast(operand.getDefiningOp()); + op_history.push_back(ancestor_async_op.getOperation()); + traceDependentInductionVar(ancestor_async_op, loop_dep_history, + op_history); + } else { + // Trace dependency through a for loop + if (auto for_op = getForRegionIterArgsOwner(operand)) { + for (auto iter_arg : for_op.getInitArgs()) { + if (operand == iter_arg) { + loop_dep_history.push_back(iter_arg); } } - // Trace dependency through a parallel loop - // TODO: decide if parallel should exist in herd launch } + // Trace dependency through a parallel loop + // TODO: decide if parallel should exist in herd launch } } } diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir index 75ced9b48..781ccbdbc 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/broadcast_detection.mlir @@ -64,7 +64,7 @@ func.func @matmul(%arg0: memref<512x512xbf16>, %arg1: memref<512x512xbf16>, %arg // ----- -// CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 3 >= 0, d1 - s0 == 0, s0 >= 0, -s0 + 3 >= 0)> +// CHECK: [[$SET0:#set[0-9]*]] = affine_set<(d0, d1)[s0] : (d0 >= 0, -d0 + 3 >= 0, d1 >= 0, -d1 + 3 >= 0, s0 >= 0, -s0 >= 0)> // CHECK-LABEL: func.func @func0 // CHECK: %[[EVENT0:.*]] = air.dma_memcpy_nd {{.*}} {id = 1 : i32} : (memref<256x64xbf16, 1>, memref<1024x256xbf16>) // CHECK: %[[EVENT1:.*]] = air.dma_memcpy_nd {{.*}}broadcast_pattern = [[$SET0]]{{.*}}