diff --git a/mlir/lib/Conversion/AIRRtToNpuPass.cpp b/mlir/lib/Conversion/AIRRtToNpuPass.cpp index 6a6147303..7f69f3585 100644 --- a/mlir/lib/Conversion/AIRRtToNpuPass.cpp +++ b/mlir/lib/Conversion/AIRRtToNpuPass.cpp @@ -578,15 +578,15 @@ void tileIllegalWrapDim(airrt::DmaMemcpyNdOp memcpy_op) { auto const_stride = *getConstantIntValue(strides[i]); if (const_wrap >= AIE2_WRAP_UPPER_BOUND) { // Found dimension with illegal wrap. Tiling. - int inner_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1); - int new_wrap = mlir::ceilDiv(const_wrap, inner_wrap); + int outer_wrap = findLargestFactor(const_wrap, AIE2_WRAP_UPPER_BOUND - 1); + int inner_wrap = mlir::ceilDiv(const_wrap, outer_wrap); wraps[i] = builder.create( loc, builder.getI64Type(), IntegerAttr::get(builder.getI64Type(), inner_wrap)); wraps.insert(wraps.begin() + i, builder.create( loc, builder.getI64Type(), - IntegerAttr::get(builder.getI64Type(), new_wrap))); + IntegerAttr::get(builder.getI64Type(), outer_wrap))); auto new_const_stride = (const_stride * inner_wrap) % air::getTensorVolume( @@ -1130,56 +1130,71 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase { } std::optional - getAllocOpForSymbol(AIE::DeviceOp dev, StringRef sym_name) { - auto sym = dev.lookupSymbol(sym_name); - if (!sym) - return std::nullopt; - - auto uses = SymbolTable::getSymbolUses(sym, dev); - for (auto use : *uses) - if (auto infoOp = dyn_cast(use.getUser())) - return infoOp; - + getAllocOpForSymbol(SmallVector shimDmaAllocOps, + StringRef sym_name) { + for (auto shimDmaAllocOp : shimDmaAllocOps) + if (shimDmaAllocOp.getSymName() == sym_name) + return shimDmaAllocOp; return std::nullopt; } - std::optional - getObjectFifoCreateOpForSymbol(AIE::DeviceOp dev, StringRef sym_name) { - auto sym = dev.lookupSymbol(sym_name); - if (!sym) - return std::nullopt; - - for (auto objFifoCreateOp : dev.getOps()) { - if (objFifoCreateOp.getSymName().str() == sym_name.str()) - return objFifoCreateOp; - } - + std::optional getObjectFifoCreateOpForSymbol( + SmallVector objectFifoCreateOps, + StringRef sym_name) { + for (auto objectFifoCreateOp : objectFifoCreateOps) + if (objectFifoCreateOp.getSymName().str() == sym_name.str()) + return objectFifoCreateOp; return std::nullopt; } void insertNpuSyncOpForResults(ModuleOp module) { - module.walk([&](mlir::func::FuncOp f) { + SmallVector funcOps; + module.walk([&](mlir::func::FuncOp f) { funcOps.push_back(f); }); + for (auto f : funcOps) { SmallVector dmas; f.walk([&](AIEX::NpuDmaMemcpyNdOp dma) { dmas.push_back(dma); }); auto d = f->getParentOfType(); + + SmallVector shimDmaAllocOps; + if (d) + d.walk([&](AIE::ShimDMAAllocationOp shimDmaAllocOp) { + shimDmaAllocOps.push_back(shimDmaAllocOp); + }); + // Performance optimization: instead of repeating calls to + // getAllocOpForSymbol with the same symbol name, cache the result of the + // first call and use the cache for subsequent calls. This dramatically + // improves compile time for some designs. + llvm::DenseMap> + allocationCache; + auto getAllocOpForSymbolWithCaching = [&](StringRef sym_name) { + auto iter = allocationCache.find(sym_name); + if (iter != allocationCache.end()) { + return iter->second; + } + auto infaOp = getAllocOpForSymbol(shimDmaAllocOps, sym_name); + allocationCache.insert({sym_name, infaOp}); + return infaOp; + }; + if (!d) - return; + continue; + OpBuilder builder(f); for (auto dma : dmas) { - if (auto infoOp = getAllocOpForSymbol(d, dma.getMetadata())) { - if (infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM) { - // Found dma op copying results to host - OpBuilder builder(dma); - auto col = builder.getI32IntegerAttr(infoOp->getCol()); - auto row = builder.getI32IntegerAttr(0); - auto dir = builder.getI32IntegerAttr(0); - auto chan = builder.getI32IntegerAttr(infoOp->getChannelIndex()); - auto col_num = builder.getI32IntegerAttr(1); - auto row_num = builder.getI32IntegerAttr(1); - builder.setInsertionPointAfter(dma); - builder.create(dma->getLoc(), col, row, dir, chan, - col_num, row_num); - } - } + auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata()); + if (!infoOp) + continue; + if (infoOp->getChannelDir() != AIE::DMAChannelDir::S2MM) + continue; + // Found dma op copying results to host + auto col = builder.getI32IntegerAttr(infoOp->getCol()); + auto row = builder.getI32IntegerAttr(0); + auto dir = builder.getI32IntegerAttr(0); + auto chan = builder.getI32IntegerAttr(infoOp->getChannelIndex()); + auto col_num = builder.getI32IntegerAttr(1); + auto row_num = builder.getI32IntegerAttr(1); + builder.setInsertionPointAfter(dma); + builder.create(dma->getLoc(), col, row, dir, chan, + col_num, row_num); } // Attempt to make npu.sync ops contiguous if they are not operating on @@ -1189,19 +1204,20 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase { if (auto sync = dyn_cast(op)) previsouSyncs.push_back(sync); else if (auto dma = dyn_cast(op)) { - auto infoOp = getAllocOpForSymbol(d, dma.getMetadata()); - if (infoOp && infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM && - !previsouSyncs.empty()) { + auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata()); + if (!infoOp) + return; + if (previsouSyncs.empty()) + return; + if (infoOp->getChannelDir() == AIE::DMAChannelDir::S2MM) { for (auto prevSync : previsouSyncs) prevSync->moveAfter(op); - } else if (infoOp && - infoOp->getChannelDir() == AIE::DMAChannelDir::MM2S && - !previsouSyncs.empty()) { + } else if (infoOp->getChannelDir() == AIE::DMAChannelDir::MM2S) { previsouSyncs.clear(); } } }); - }); + } } // Renumber aiex.npu.dma_memcpy_nd ops per column of AIEs. @@ -1209,34 +1225,65 @@ struct AIRRtToNpuPass : public impl::AIRRtToNpuBase { std::map chanToIdMap; AIE::DeviceOp d = nullptr; blk->walk([&](AIE::DeviceOp op) { d = op; }); + SmallVector shimDmaAllocOps; + if (d) + d.walk([&](AIE::ShimDMAAllocationOp shimDmaAllocOp) { + shimDmaAllocOps.push_back(shimDmaAllocOp); + }); + // Performance optimization: instead of repeating calls to + // getAllocOpForSymbol with the same symbol name, cache the result of the + // first call and use the cache for subsequent calls. This dramatically + // improves compile time for some designs. + llvm::DenseMap> + allocationCache; + auto getAllocOpForSymbolWithCaching = [&](StringRef sym_name) { + auto iter = allocationCache.find(sym_name); + if (iter != allocationCache.end()) { + return iter->second; + } + auto infaOp = getAllocOpForSymbol(shimDmaAllocOps, sym_name); + allocationCache.insert({sym_name, infaOp}); + return infaOp; + }; + SmallVector objectFifoCreateOps; + if (d) + d.walk([&](AIE::ObjectFifoCreateOp objectFifoCreateOp) { + objectFifoCreateOps.push_back(objectFifoCreateOp); + }); + OpBuilder builder(blk->getParentOp()); blk->walk([&](Operation *op) { - if (auto dma = dyn_cast(op)) { - OpBuilder builder(dma); - int col = -1; - if (d) { - if (auto infoOp = getAllocOpForSymbol(d, dma.getMetadata())) { - col = infoOp->getCol(); - } else if (auto objFifoCreateOp = - getObjectFifoCreateOpForSymbol(d, dma.getMetadata())) { - auto prodTileOp = - objFifoCreateOp->getProducerTile().getDefiningOp(); - if (prodTileOp.isShimTile()) - col = prodTileOp.colIndex(); - for (auto consumerTileOp : objFifoCreateOp->getConsumerTiles()) { - auto consTileOp = consumerTileOp.getDefiningOp(); - if (consTileOp.isShimTile()) { - col = consTileOp.colIndex(); - } + auto dma = dyn_cast(op); + auto sync = dyn_cast(op); + if (sync) { + chanToIdMap.clear(); + return; + } + if (!dma) + return; + builder.setInsertionPoint(dma); + int col = -1; + if (d) { + if (auto infoOp = getAllocOpForSymbolWithCaching(dma.getMetadata())) { + col = infoOp->getCol(); + } else if (auto objFifoCreateOp = getObjectFifoCreateOpForSymbol( + objectFifoCreateOps, dma.getMetadata())) { + auto prodTileOp = + objFifoCreateOp->getProducerTile().getDefiningOp(); + if (prodTileOp.isShimTile()) + col = prodTileOp.colIndex(); + for (auto consumerTileOp : objFifoCreateOp->getConsumerTiles()) { + auto consTileOp = consumerTileOp.getDefiningOp(); + if (consTileOp.isShimTile()) { + col = consTileOp.colIndex(); } } } - if (!chanToIdMap.count(col)) - chanToIdMap[col] = 0; - dma->setAttr("id", mlir::IntegerAttr::get( - mlir::IntegerType::get(dma->getContext(), 64), - chanToIdMap[col]++)); - } else if (isa(op)) - chanToIdMap.clear(); + } + if (!chanToIdMap.count(col)) + chanToIdMap[col] = 0; + dma->setAttr("id", mlir::IntegerAttr::get( + mlir::IntegerType::get(dma->getContext(), 64), + chanToIdMap[col]++)); }); } diff --git a/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir b/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir index 4d79d63f6..30c6b9f00 100644 --- a/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir +++ b/mlir/test/Conversion/AIRRtToNpu/airrt_to_npu.mlir @@ -455,10 +455,10 @@ module { // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 64, 0][4, 8, 64, 256][0, 256, 2048]) {id = 1 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 128, 0][4, 8, 64, 256][0, 256, 2048]) {id = 2 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 192, 0][4, 8, 64, 256][0, 256, 2048]) {id = 3 : i64, metadata = @airMemcpyId20} : memref<2048x2048xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 4, 512, 64][64, 1048576, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 4 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 5 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 6 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG1]][0, 0, 0, 0][4, 512, 4, 64][64, 8192, 2048]) {id = 7 : i64, metadata = @airMemcpyId21} : memref<2048x2048xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG2]][0, 0, 0, 0][4, 4, 64, 64][131072, 64, 2048]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2048x2048xi32> #map = affine_map<()[s0] -> (s0 * 64)> @@ -521,9 +521,9 @@ module { // CHECK-LABEL: aie.device(npu) // CHECK: func.func @func10(%[[ARG0:.*]]: memref<2654208xi32>) -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 0 : i64, metadata = @airMemcpyId21} : memref<2654208xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 1 : i64, metadata = @airMemcpyId21} : memref<2654208xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 3, 768, 32][128, 884736, 1152]) {id = 2 : i64, metadata = @airMemcpyId21} : memref<2654208xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 0 : i64, metadata = @airMemcpyId21} : memref<2654208xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 1 : i64, metadata = @airMemcpyId21} : memref<2654208xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[ARG0]][0, 0, 0, 0][3, 768, 3, 32][128, 3456, 1152]) {id = 2 : i64, metadata = @airMemcpyId21} : memref<2654208xi32> #map = affine_map<()[s0] -> (s0 * 64)> module { @@ -701,8 +701,8 @@ module { // CHECK-SAME: %[[VAL_0:.*]]: memref<262144xi32>, %[[VAL_1:.*]]: memref<262144xi32>, %[[VAL_2:.*]]: memref<131072xi32>) { // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 0][2, 4, 256, 128][0, 128, 512]) {id = 0 : i64, metadata = @airMemcpyId7} : memref<262144xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][2, 4, 256, 128][0, 128, 512]) {id = 1 : i64, metadata = @airMemcpyId7} : memref<262144xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 2, 512, 128][128, 131072, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 2 : i64, metadata = @airMemcpyId12} : memref<262144xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][2, 512, 2, 128][128, 512, 256]) {id = 3 : i64, metadata = @airMemcpyId12} : memref<262144xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][2, 2, 64, 128][65536, 128, 256]) {id = 4 : i64, metadata = @airMemcpyId45} : memref<131072xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 16384][2, 2, 64, 128][65536, 128, 256]) {id = 5 : i64, metadata = @airMemcpyId46} : memref<131072xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 32768][2, 2, 64, 128][65536, 128, 256]) {id = 0 : i64, metadata = @airMemcpyId47} : memref<131072xi32> diff --git a/mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir b/mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir index a6532b892..e85681f73 100644 --- a/mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir +++ b/mlir/test/Conversion/AIRRtToNpu/buffer_memref_to_args.mlir @@ -122,10 +122,10 @@ module { // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 131072][4, 8, 128, 128][0, 128, 1024]) {id = 1 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 262144][4, 8, 128, 128][0, 128, 1024]) {id = 2 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_0]][0, 0, 0, 393216][4, 8, 128, 128][0, 128, 1024]) {id = 3 : i64, metadata = @airMemcpyId10} : memref<2097152xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> -// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 4, 512, 64][64, 524288, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 4 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 5 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 6 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> +// CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_1]][0, 0, 0, 0][4, 512, 4, 64][64, 4096, 1024]) {id = 7 : i64, metadata = @airMemcpyId13} : memref<2097152xi32> // CHECK: aiex.npu.dma_memcpy_nd(0, 0, %[[VAL_2]][0, 0, 0, 0][4, 4, 128, 64][131072, 64, 1024]) {id = 8 : i64, metadata = @airMemcpyId26} : memref<2097152xi32> module {