diff --git a/src/Compiler/CompilerPasses.cpp b/src/Compiler/CompilerPasses.cpp index fc58920ed0..c63aad0b43 100644 --- a/src/Compiler/CompilerPasses.cpp +++ b/src/Compiler/CompilerPasses.cpp @@ -186,13 +186,15 @@ void addONNXToKrnlPasses(mlir::PassManager &pm, int optLevel, bool enableCSE, if (enableInstrumentONNXSignature) pm.addNestedPass( onnx_mlir::createInstrumentONNXSignaturePass()); - pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3, - /*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel, - /*opsToCall*/ opsForCall)); - // An additional pass of canonicalization is helpful because lowering - // from ONNX dialect to Standard dialect exposes additional canonicalization - // opportunities. - pm.addPass(mlir::createCanonicalizerPass()); + for (unsigned i = 0; i < 2; i++) { + pm.addPass(onnx_mlir::createLowerToKrnlPass(/*enableTiling*/ optLevel >= 3, + /*enableSIMD*/ optLevel >= 3 && !disableSimdOption, enableParallel, + /*opsToCall*/ opsForCall)); + // An additional pass of canonicalization is helpful because lowering + // from ONNX dialect to Standard dialect exposes additional canonicalization + // opportunities. + pm.addPass(mlir::createCanonicalizerPass()); + } } void addKrnlToAffinePasses(mlir::PassManager &pm) { diff --git a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp index 1ef6b0467a..c46eecf491 100644 --- a/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp +++ b/src/Conversion/KrnlToAffine/ConvertKrnlToAffine.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/LoopUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Types.h" diff --git a/src/Conversion/ONNXToKrnl/Additional/Parallel.cpp b/src/Conversion/ONNXToKrnl/Additional/Parallel.cpp new file mode 100644 index 0000000000..63f3a038db --- /dev/null +++ b/src/Conversion/ONNXToKrnl/Additional/Parallel.cpp @@ -0,0 +1,292 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===-------------------- Parallel.cpp - Lowering Parallel Op and Fork Op +//---------------------===// +// +// Copyright 2019-2023 The IBM Research Authors. +// +// ============================================================================= +// +// This file lowers the ONNX Parallel and Fork Operators to Krnl dialect. +// +//===----------------------------------------------------------------------===// + +#include "src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp" +#include "src/Dialect/Krnl/DialectBuilder.hpp" + +#include + +#define DEBUG_TYPE "lowering-parallelop-to-krnl" + +using namespace mlir; + +namespace onnx_mlir { + +//===----------------------------------------------------------------------===// +// Helper function +// Return true if `a` happens before `b`, i.e., `a` or one of its ancestors +// properly dominates `b` and `b` is not inside `a`. +// Reference: llvm-project/mlir/lib/Dialect/Transform/IR/TransformInterfaces.cpp +//===----------------------------------------------------------------------===// + +static bool happensBefore(Operation *a, Operation *b) { + do { + if (a->isProperAncestor(b)) + return false; + if (Operation *bAncestor = a->getBlock()->findAncestorOpInBlock(*b)) { + return a->isBeforeInBlock(bAncestor); + } + } while ((a = a->getParentOp())); + return false; +} + +void moveAllocOpOperands(SmallVector &opsToMove, + SmallVector &globalOpsToMove, Operation *returnValOp, + Operation *parentOp) { + if (opsToMove.size() == 0) + return; + + SmallVector nextOpsToMove; + for (Operation *op : opsToMove) { + LLVM_DEBUG(llvm::dbgs() << "@@START opsToMove op: = " << *op << "\n"); + // Added the op in ops list to move if it is still not added. + if (llvm::find(globalOpsToMove, op) == globalOpsToMove.end()) { + globalOpsToMove.push_back(op); + LLVM_DEBUG(llvm::dbgs() << "Added in opsToMove : " << *op << "\n"); + } + + Region &parentOpRegion = parentOp->getRegions().front(); + Block &parentOpBlock = parentOpRegion.getBlocks().front(); + + // AllocOp: If allocated value is used in KrnlStoreOp, the KrnlStoreOp need + // to be move. + if (op != returnValOp) { + if (auto allocOp = dyn_cast(op)) { + for (Operation *user : allocOp.getResult().getUsers()) { + if (auto krnlStoreOp = dyn_cast(user)) { + if (user->getBlock() == &parentOpBlock) { + if ((llvm::find(nextOpsToMove, user) == nextOpsToMove.end()) and + (llvm::find(globalOpsToMove, user) == + globalOpsToMove.end())) { + LLVM_DEBUG(llvm::dbgs() + << "Added in nextOpsTomove (Single KrnlStore) = " + << *user << "\n"); + nextOpsToMove.push_back(user); + } + } else { + if ((llvm::find(nextOpsToMove, user->getParentOp()) == + nextOpsToMove.end()) and + (llvm::find(globalOpsToMove, user->getParentOp()) == + globalOpsToMove.end())) { + LLVM_DEBUG(llvm::dbgs() + << "Added in nextOpsTomove (KrnlIterateOp): " + << *(user->getParentOp()) << "\n"); + nextOpsToMove.push_back(user->getParentOp()); + } + } + } + } + } + } + // KrnlIterateOp: Operations in the region of KrnlIterateOp are already + // added in opsToMove. So, add operands of operations in the region of + // KrnlIterateOp. + if (auto iterateOp = dyn_cast(op)) { + Block &iterationBlock = iterateOp.getBodyRegion().front(); + for (Operation &iop : iterationBlock.getOperations()) { + LLVM_DEBUG(llvm::dbgs() << "Ops in krnlIterateOp: " << *(&iop) << "\n"); + for (unsigned i = 0; i < iop.getNumOperands(); ++i) { + Value oprd = iop.getOperand(i); + if (isa(oprd)) + continue; + Operation *oprdDefOp = oprd.getDefiningOp(); + if (oprdDefOp->getBlock() != &iterationBlock and + oprdDefOp->getBlock() == &parentOpBlock) { + if ((llvm::find(nextOpsToMove, oprdDefOp) == + nextOpsToMove.end()) and + (llvm::find(globalOpsToMove, oprdDefOp) == + globalOpsToMove.end())) { + + LLVM_DEBUG(llvm::dbgs() + << "Added in nextOpsTomove operand in KrnlIterateOp: " + << *oprdDefOp << "\n"); + nextOpsToMove.push_back(oprdDefOp); + } + } + } + } + } + + // Check if operands need to be moved. Need to move if defining op for the + // operand exists in block in parentOp. + for (unsigned i = 0; i < op->getNumOperands(); ++i) { + Value oprd = op->getOperand(i); + if (isa(oprd)) + continue; + Operation *oprdDefOp = oprd.getDefiningOp(); + if (oprdDefOp->getBlock() == &parentOpBlock) { + if ((llvm::find(nextOpsToMove, oprdDefOp) == nextOpsToMove.end()) and + (llvm::find(globalOpsToMove, oprdDefOp) == globalOpsToMove.end())) { + LLVM_DEBUG(llvm::dbgs() << "Added in nextOpsTomove operand for op: " + << i << " = " << *oprdDefOp << "\n"); + nextOpsToMove.push_back(oprdDefOp); + } + } + } + LLVM_DEBUG(llvm::dbgs() << "@@END\n"); + } + // Check if operands of operandDefOp need to be moved recursively + moveAllocOpOperands(nextOpsToMove, globalOpsToMove, returnValOp, parentOp); +} + +LogicalResult moveAllocOpBeforeAndReplaceAllUses( + ConversionPatternRewriter &rewriter, Operation *op, Operation *yieldOp) { + SmallVector globalOpsToMove; + for (unsigned ii = 0; ii < yieldOp->getNumOperands(); ++ii) { + // Check the return value of the block to check if operations in the + // block is already lowered to memref-level IR such as KrnlIR. Assume + // the block is still not lowered if the return value is still Tensor + // type. Actual return value of ONNXYieldOp is conveted into tensor by + // unrealized_conversion_cast. So, check the operand of previous + // operation. + Value returnVal = yieldOp->getOperands()[ii]; + if (isa(returnVal.getDefiningOp())) + returnVal = returnVal.getDefiningOp()->getOperands()[0]; + if (isa(returnVal.getType())) + return failure(); + + // Move allocOps for results before op + Operation *allocOpForReturnVal = returnVal.getDefiningOp(); + SmallVector opsToMove; + opsToMove.push_back(allocOpForReturnVal); + moveAllocOpOperands(opsToMove, globalOpsToMove, allocOpForReturnVal, op); + rewriter.replaceAllUsesWith( + op->getResults()[ii], allocOpForReturnVal->getResult(0)); + } + + llvm::sort(globalOpsToMove, + [](Operation *a, Operation *b) { return !happensBefore(a, b); }); + Operation *justMovedOp = op; + for (Operation *gop : globalOpsToMove) { + gop->moveBefore(justMovedOp); + justMovedOp = gop; + } + return success(); +} + +void insertRegionInIterateOp( + ConversionPatternRewriter &rewriter, Location &loc, Block &block) { + for (auto iterateOp : block.getOps()) { + // Currently KrnlRegionOp does not support input and output to the region. + // So, if KrnlIterateOp has output (by krnl.yield), can't use KrnlRegionOp. + if (iterateOp.getNumResults() == 0) { + KrnlRegionOp regionOp = rewriter.create(loc); + Block ®ionBlock = regionOp.getBodyRegion().front(); + Block &iterateBlock = iterateOp.getRegion().back(); + insertRegionInIterateOp(rewriter, loc, iterateBlock); // recursive call + rewriter.eraseOp(iterateBlock.getTerminator()); + regionBlock.getOperations().splice( + regionBlock.end(), iterateBlock.getOperations()); + rewriter.setInsertionPointToStart(&iterateBlock); + KrnlYieldOp krnlYieldOp = rewriter.create(loc); + rewriter.moveOpBefore(regionOp, krnlYieldOp); + } + } +} + +struct ONNXParallelOpLowering : public OpConversionPattern { + explicit ONNXParallelOpLowering( + TypeConverter &typeConverter, MLIRContext *ctx) + : OpConversionPattern(typeConverter, ctx) {} + + LogicalResult matchAndRewrite(ONNXParallelOp parallelOp, + ONNXParallelOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Operation *op = parallelOp.getOperation(); + Location loc = ONNXLoc(op); + IndexExprScope ieScope(&rewriter, loc); + MultiDialectBuilder + create(rewriter, loc); + + auto onnxParallelOp = dyn_cast(op); + // Get the parallel region. + Region ¶llelBody = onnxParallelOp.getBody(); + // Make sure the region has only one block. + if (!parallelBody.hasOneBlock()) + return success(); + // Get YieldOp of the body block. + Block &bodyBlock = parallelBody.front(); + Operation *yieldOp = bodyBlock.getTerminator(); + if (!isa(yieldOp)) + return failure(); + + // Move alloc ops included in ForkOps + SmallVector forkOps; + for (Operation &bOp : bodyBlock.getOperations()) { + if (auto forkOp = dyn_cast(bOp)) { + forkOps.push_back(forkOp); + Operation *forkYieldOp = forkOp.getBody().front().getTerminator(); + if (!isa(forkYieldOp)) + return failure(); + + if (failed(moveAllocOpBeforeAndReplaceAllUses( + rewriter, &bOp, forkYieldOp))) + return failure(); + } + } + + // Move allocOp included in ParallelOp + if (failed(moveAllocOpBeforeAndReplaceAllUses(rewriter, op, yieldOp))) + return failure(); + + // Create KrnlIterateOp and replace ParallelOp with it. + rewriter.setInsertionPoint(op); + std::vector loop; + defineLoops(rewriter, loc, loop, 1); + krnl::KrnlIterateOperandPack pack(rewriter, loop); + pack.pushConstantBound(0); + pack.pushConstantBound(forkOps.size()); + KrnlBuilder createKrnl(rewriter, loc); + createKrnl.parallel(loop); + KrnlIterateOp iterateOp = createKrnl.iterate(pack); + Block &iterationBlock = iterateOp.getBodyRegion().back(); + rewriter.setInsertionPointToStart(&iterationBlock); + ValueRange indices = createKrnl.getInductionVarValue({loop[0]}); + rewriter.eraseOp(yieldOp); + rewriter.inlineBlockBefore(&bodyBlock, iterationBlock.getTerminator()); + + // Create SCFIfOp and replace ForkOp with it. + int64_t id = 0; + for (auto forkOp : forkOps) { + rewriter.setInsertionPoint(forkOp); + // Insert scf::IfOp + Value forkId = create.math.constantIndex(id); + Value eq = create.math.eq(forkId, indices[0]); + scf::IfOp ifOp = rewriter.create(loc, eq, /*else=*/false); + Block &ifBlock = ifOp.getThenRegion().back(); + rewriter.setInsertionPointToStart(&ifBlock); + // Insert KrnlRegionOp in every KrnlIterateOps. This needs to avoid + // errors in convertKrnlToAffinePass. + Block &forkBlock = forkOp.getRegion().back(); + insertRegionInIterateOp(rewriter, loc, forkBlock); + Operation *forkYieldOp = forkBlock.getTerminator(); + rewriter.eraseOp(forkYieldOp); + rewriter.inlineBlockBefore(&forkBlock, ifBlock.getTerminator()); + rewriter.eraseOp(forkOp); + id++; + } + + rewriter.eraseOp(op); + return success(); + } +}; + +void populateLoweringONNXParallelOpPattern(RewritePatternSet &patterns, + TypeConverter &typeConverter, MLIRContext *ctx) { + patterns.insert(typeConverter, ctx); +} + +} // namespace onnx_mlir diff --git a/src/Conversion/ONNXToKrnl/CMakeLists.txt b/src/Conversion/ONNXToKrnl/CMakeLists.txt index f5faedf2a5..6cca5dcc8b 100644 --- a/src/Conversion/ONNXToKrnl/CMakeLists.txt +++ b/src/Conversion/ONNXToKrnl/CMakeLists.txt @@ -8,6 +8,7 @@ add_onnx_mlir_library(OMONNXToKrnl Additional/Custom.cpp Additional/LayoutTransform.cpp Additional/ShapeTransform.cpp + Additional/Parallel.cpp ControlFlow/If.cpp ControlFlow/Loop.cpp ControlFlow/Scan.cpp diff --git a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp index 3a450c4e65..6c52626bad 100644 --- a/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp +++ b/src/Conversion/ONNXToKrnl/ConvertONNXToKrnl.cpp @@ -286,6 +286,7 @@ void populateONNXToKrnlConversionPattern(RewritePatternSet &patterns, populateLoweringONNXCustomOpPattern(patterns, typeConverter, ctx); populateLoweringONNXLayoutTransformOpPattern(patterns, typeConverter, ctx, enableParallel); populateLoweringONNXShapeTransformOpPattern(patterns, typeConverter, ctx); + populateLoweringONNXParallelOpPattern(patterns, typeConverter, ctx); // clang-format on } diff --git a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp index 82e10ba141..8c2f0ad8e1 100644 --- a/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp +++ b/src/Conversion/ONNXToKrnl/ONNXToKrnlCommon.hpp @@ -24,6 +24,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -470,6 +471,9 @@ void populateLoweringONNXShapeTransformOpPattern( void populateLoweringONNXCustomOpPattern( mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); +void populateLoweringONNXParallelOpPattern( + mlir::RewritePatternSet &, mlir::TypeConverter &, mlir::MLIRContext *); + // Utilities for generating krnl.call for ONNX Ops // Create allocate based on COMPUTED shapeHelper. diff --git a/src/Dialect/ONNX/AdditionalONNXOps.td b/src/Dialect/ONNX/AdditionalONNXOps.td index a72af4d7c2..608801a558 100644 --- a/src/Dialect/ONNX/AdditionalONNXOps.td +++ b/src/Dialect/ONNX/AdditionalONNXOps.td @@ -620,3 +620,95 @@ def ONNXRMSLayerNormalizationOp:ONNX_Op<"RMSLayerNormalization", }]; let hasVerifier = 1; } + +//===----------------------------------------------------------------------===// +// ONNXForkOp +def ONNXForkOp:ONNX_Op<"Fork", + [Pure, HasParent<"ONNXParallelOp">, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + OpInterface<"mlir::HasOnnxSubgraphOpInterface"> +]> { + let summary = "ONNX operation for forking threads."; + let description = [{ + + }]; + let arguments = (ins Variadic, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[UI32]>, + TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>]>>:$inputs, + DefaultValuedAttr:$id); + let results = (outs Variadic, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[UI32]>, + TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>]>>:$results); + let regions = (region SizedRegion<1>:$body); + let skipDefaultBuilders = 1; + let extraClassDeclaration = [{ + int64_t getSubgraphRegionIdx(const std::string& name) { + if (name == "body") return 0; + llvm_unreachable("region with the specified name does not exist"); + } + using BodyBuilderFn = + llvm::function_ref; + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ONNXForkOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ONNXForkOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; +} + }]; + let builders = [ + OpBuilder<(ins "mlir::TypeRange":$resultTypes, + "mlir::ValueRange":$operands, + CArg<"llvm::function_ref", + "nullptr">:$bodyBuilder)> + ]; +} + +//===----------------------------------------------------------------------===// +// ONNXParallelOp +def ONNXParallelOp:ONNX_Op<"Parallel", + [Pure, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + OpInterface<"mlir::HasOnnxSubgraphOpInterface">]> { + let summary = "ONNX operation to specify paralell region."; + let description = [{ + + }]; + let arguments = (ins Variadic, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[UI32]>, + TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>]>>:$inputs); + let results = (outs Variadic, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[UI32]>, + TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>]>>:$results); + let regions = (region SizedRegion<1>:$body); + + let skipDefaultBuilders = 1; + let extraClassDeclaration = [{ + static int getNumberOfOperands() { + return -1; + } + static int getNumberOfResults() { + return -1; + } + int64_t getSubgraphRegionIdx(const std::string& name) { + if (name == "body") return 0; + llvm_unreachable("region with the specified name does not exist"); + } + using BodyBuilderFn = + llvm::function_ref; + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ONNXParallelOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ONNXParallelOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; + let builders = [ + OpBuilder<(ins "mlir::TypeRange":$resultTypes, + "mlir::ValueRange":$operands, + CArg<"llvm::function_ref", + "nullptr">:$bodyBuilder)> + ]; +} diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 3917c94aa4..f15506fa1d 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -31,9 +31,11 @@ add_onnx_mlir_library(OMONNXOps ONNXOps/Additional/Custom.cpp ONNXOps/Additional/Dim.cpp ONNXOps/Additional/EntryPoint.cpp + ONNXOps/Additional/Fork.cpp ONNXOps/Additional/Return.cpp ONNXOps/Additional/LayoutTransform.cpp ONNXOps/Additional/None.cpp + ONNXOps/Additional/Parallel.cpp ONNXOps/Additional/ShapeTransform.cpp ONNXOps/ControlFlow/If.cpp ONNXOps/ControlFlow/Loop.cpp diff --git a/src/Dialect/ONNX/ONNXDimAnalysis.cpp b/src/Dialect/ONNX/ONNXDimAnalysis.cpp index c1b06859d8..1c31999beb 100644 --- a/src/Dialect/ONNX/ONNXDimAnalysis.cpp +++ b/src/Dialect/ONNX/ONNXDimAnalysis.cpp @@ -875,6 +875,24 @@ void DimAnalysis::visitDim( return; } + // ParallelOp or ForkOp + // The result tensors are the same with the operand tensors for yieldOp in the + // region. + if (isa(op) || isa(op)) { + for (unsigned i = 0; i < op->getNumResults(); ++i) { + if (tensor == op->getResults()[i]) { + Operation *yieldOp = + op->getRegions().front().getBlocks().front().getTerminator(); + DimAnalysis::DimT newSameDim(yieldOp->getOperands()[i], dimIndex); + sameDims.insert(newSameDim); + LLVM_DEBUG(llvm::dbgs() + << " - Added a new dim(" << yieldOp->getOperands()[i] + << ", " << dimIndex << ")\n"); + } + } + return; + } + // All dimensions in the analysis must be dynamic. If not, something really // wrong happened. ShapedType ty = tensor.getType().cast(); diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Fork.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Fork.cpp new file mode 100644 index 0000000000..45486cf442 --- /dev/null +++ b/src/Dialect/ONNX/ONNXOps/Additional/Fork.cpp @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Fork.cpp - ONNX Operations -------------------------===// +// +// Copyright 2019-2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file provides definition of ONNX dialect Fork operation. +// +//===----------------------------------------------------------------------===// + +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +//===----------------------------------------------------------------------===// +// ShapeHelper +//===----------------------------------------------------------------------===// + +template <> +LogicalResult ONNXForkOpShapeHelper::computeShape() { + ONNXForkOp forkOp = llvm::cast(op); + (void)forkOp.inferShapes([](Region ®ion) {}); + Operation *yieldOp = forkOp.getBody().front().getTerminator(); + for (unsigned i = 0; i < yieldOp->getNumOperands(); ++i) { + DimsExpr outputDims; + Value returnVal = yieldOp->getOperands()[i]; + int64_t outRank = returnVal.getType().cast().getRank(); + for (int64_t j = 0; j < outRank; ++j) + outputDims.emplace_back(createIE->getShapeAsDim(returnVal, j)); + setOutputDims(outputDims, i); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Type Inference +//===----------------------------------------------------------------------===// + +std::vector ONNXForkOp::resultTypeInference() { + Operation *terminator = getRegion().back().getTerminator(); + auto bodyOutputTys = terminator->getOperandTypes(); + + // // assert is checked in verify() + // assert(getNumResults() == thenResultTypes.size() && + // getNumResults() == elseResultTypes.size() && + // "if #results and branches #results differ"); + std::vector resultTypes; + for (auto [i, ty] : llvm::enumerate(bodyOutputTys)) { + resultTypes.push_back(ty); + } + return resultTypes; +} + +//===----------------------------------------------------------------------===// +// Shape Inference +//===----------------------------------------------------------------------===// + +LogicalResult ONNXForkOp::inferShapes( + std::function doShapeInference) { + doShapeInference(getRegion()); + for (auto [i, ty] : llvm::enumerate(resultTypeInference())) + getResult(i).setType(ty); + return success(); +} + +//===----------------------------------------------------------------------===// +// Builder: Refer to Async ExecuteOp +//===----------------------------------------------------------------------===// +void ONNXForkOp::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, ValueRange operands, BodyBuilderFn bodyBuilder) { + + result.addOperands(operands); + result.addTypes(resultTypes); + + // Add a body region with block arguments + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + for (Value operand : operands) { + bodyBlock.addArgument(operand.getType(), operand.getLoc()); + } + + // Create the default terminator if the builder is not provided and if the + // expected result is empty. Otherwise, leave this to the caller + // because we don't know which values to return from the execute op. + if (resultTypes.empty() && !bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + builder.create(result.location, ValueRange()); + } else if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuilder(builder, result.location, bodyBlock.getArguments()); + } +} diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Parallel.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Parallel.cpp new file mode 100644 index 0000000000..6ae36ff08b --- /dev/null +++ b/src/Dialect/ONNX/ONNXOps/Additional/Parallel.cpp @@ -0,0 +1,100 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Fork.cpp - ONNX Operations -------------------------===// +// +// Copyright 2019-2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file provides definition of ONNX dialect Fork operation. +// +//===----------------------------------------------------------------------===// + +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +//===----------------------------------------------------------------------===// +// ShapeHelper +//===----------------------------------------------------------------------===// + +template <> +LogicalResult ONNXParallelOpShapeHelper::computeShape() { + ONNXParallelOp parallelOp = llvm::cast(op); + (void)parallelOp.inferShapes([](Region ®ion) {}); + Operation *yieldOp = parallelOp.getBody().front().getTerminator(); + for (unsigned i = 0; i < yieldOp->getNumOperands(); ++i) { + DimsExpr outputDims; + Value returnVal = yieldOp->getOperands()[i]; + int64_t outRank = returnVal.getType().cast().getRank(); + for (int64_t j = 0; j < outRank; ++j) + outputDims.emplace_back(createIE->getShapeAsDim(returnVal, j)); + setOutputDims(outputDims, i); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Type Inference +//===----------------------------------------------------------------------===// + +std::vector ONNXParallelOp::resultTypeInference() { + Operation *terminator = getRegion().back().getTerminator(); + auto bodyOutputTys = terminator->getOperandTypes(); + + // // assert is checked in verify() + // assert(getNumResults() == thenResultTypes.size() && + // getNumResults() == elseResultTypes.size() && + // "if #results and branches #results differ"); + std::vector resultTypes; + for (auto [i, ty] : llvm::enumerate(bodyOutputTys)) { + resultTypes.push_back(ty); + } + return resultTypes; +} + +//===----------------------------------------------------------------------===// +// Shape Inference +//===----------------------------------------------------------------------===// + +LogicalResult ONNXParallelOp::inferShapes( + std::function doShapeInference) { + doShapeInference(getRegion()); + for (auto [i, ty] : llvm::enumerate(resultTypeInference())) + getResult(i).setType(ty); + return success(); +} + +//===----------------------------------------------------------------------===// +// Builder: Refer to Async ExecuteOp +//===----------------------------------------------------------------------===// +void ONNXParallelOp::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, ValueRange operands, BodyBuilderFn bodyBuilder) { + + result.addOperands(operands); + result.addTypes(resultTypes); + + // Add a body region with block arguments + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + for (Value operand : operands) { + bodyBlock.addArgument(operand.getType(), operand.getLoc()); + } + + // Create the default terminator if the builder is not provided and if the + // expected result is empty. Otherwise, leave this to the caller + // because we don't know which values to return from the execute op. + if (resultTypes.empty() && !bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + builder.create(result.location, ValueRange()); + } else if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuilder(builder, result.location, bodyBlock.getArguments()); + } +} diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp index a58f9572f7..aefd3b272f 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp @@ -885,6 +885,8 @@ using ONNXTileOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXTopKOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXTransposeOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXUpsampleOpShapeHelper = ONNXNonSpecificOpShapeHelper; +using ONNXForkOpShapeHelper = ONNXNonSpecificOpShapeHelper; +using ONNXParallelOpShapeHelper = ONNXNonSpecificOpShapeHelper; // clang-format on //===----------------------------------------------------------------------===// diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 07889a0abb..94e0fef19a 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -3801,3 +3801,47 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5 // CHECK: } } +// ----- + +//===----------------------------------------------------------------------===// +/// Test shape inference for Parallel and Fork. +//===----------------------------------------------------------------------===// + +func.func @test_parallel_fork_1(%arg0: tensor<8x64x32xf32>, %arg1: tensor<32x32xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %c0 = onnx.Constant dense<1.0> : tensor<32x32xf32> + %c1 = onnx.Constant dense<1.0> : tensor<32xf32> + %c2 = onnx.Constant dense<1.0> : tensor<32x32xf32> + + %0:2 = "onnx.Parallel"() ({ + %00 = "onnx.Fork"() ({ + %01 = "onnx.MatMul"(%arg0, %c0) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<*xf32> + onnx.Yield %01 : tensor<*xf32> + }) {id = 0 : si64} : () -> tensor<*xf32> + %01 = "onnx.Fork"() ({ + %01 = "onnx.MatMul"(%arg0, %c2) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<*xf32> + onnx.Yield %01 : tensor<*xf32> + }) {id = 1 : si64} : () -> tensor<*xf32> + "onnx.Yield"(%00, %01) : (tensor<*xf32>, tensor<*xf32>) -> () + }) : () -> (tensor<*xf32>, tensor<*xf32>) + "onnx.Return"(%0#0,%0#1): (tensor<*xf32>, tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_parallel_fork_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x64x32xf32>, [[PARAM_1_:%.+]]: tensor<32x32xf32>) -> (tensor<8x64x32xf32>, tensor<8x64x32xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<32x32xf32> +// CHECK-DAG: [[VAR_1_:%.+]]:2 = "onnx.Parallel"() ({ +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Fork"() ({ +// CHECK: [[VAR_4_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<8x64x32xf32> +// CHECK: onnx.Yield [[VAR_4_]] : tensor<8x64x32xf32> +// CHECK: }) {id = 0 : si64} : () -> tensor<8x64x32xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Fork"() ({ +// CHECK-DAG: [[VAR_4_1_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<8x64x32xf32> +// CHECK: onnx.Yield [[VAR_4_1_]] : tensor<8x64x32xf32> +// CHECK: }) {id = 1 : si64} : () -> tensor<8x64x32xf32> +// CHECK: onnx.Yield [[VAR_2_]], [[VAR_3_]] : tensor<8x64x32xf32>, tensor<8x64x32xf32> +// CHECK: }) : () -> (tensor<8x64x32xf32>, tensor<8x64x32xf32>) +// CHECK: onnx.Return [[VAR_1_]]#0, [[VAR_1_]]#1 : tensor<8x64x32xf32>, tensor<8x64x32xf32> +// CHECK: } +} + +// ----- + diff --git a/utils/OpLevelParallel.py b/utils/OpLevelParallel.py new file mode 100755 index 0000000000..2aa17cf0b8 --- /dev/null +++ b/utils/OpLevelParallel.py @@ -0,0 +1,664 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: Apache-2.0 + +################# OpLevelParallel.py ########################################### +# +# Copyright 2019-2024 The IBM Research Authors. +# +################################################################################ +# +# This code analyzes to the profile data and compiled MLIR code, +# (1) Detects sets of operations to be parallelized by fork and join, and +# (2) Genearts parallelized MLIR code +# +################################################################################ + +import os +import sys +import re +import networkx as nx +import argparse + + +num_in_str = lambda str: [int(c) if c.isdigit() else c for c in re.split(r"(\d+)", str)] + +MAX_NODE_NUM_IN_BLOCK = 100 +MIN_PARALLEL_NUM = 2 +MIN_EXECUTION_TIME_IN_BLOCK = -1.0 +KEY_OPERATIONS = ["onnx.Conv", "onnx.MatMul", "onnx.LSTM"] +NO_PARENT_OPERATIONS = ["onnx.Constant", "onnx.NoValue"] +INST_OPERATIONS = [ + "onnx.Conv", + "onnx.MatMul", + "onnx.LSTM", + "onnx.Softplus", + "onnx.Erf", + "onnx.Add", + "onnx.Div", + "onnx.Mul", +] + +LINE_ONNXPARALLEL = '{} = "onnx.Parallel"() ({{' # endvar +LINE_ONNXPARALLELYIELD = "onnx.Yield {} : {}" # endvar, endvars_shape +LINE_ONNXPARALLELEND = "}}) : () -> ({})" # endvar_shape +LINE_ONNXYIELD = '"onnx.Yield"({}):({}) -> ()' # endvars, endvars_shape +LINE_ONNXFORK = '{} = "onnx.Fork"() ({{' # forkvar +LINE_ONNXFORKEND = "}}) {{id={}:si64}}:() -> {}" # th_id, endvar_shape + +LINE_INSTRUMENT = '"krnl.runtime_instrument"() {{nodeName = "{}", opName = "{}", tag = {} : i64 }}: () -> ()' # nodeName, opName, tag +INSTRUMENT_BEFORE_OP = 0x1 +INSTRUMENT_AFTER_OP = 0x2 +INSTRUMENT_REPORT_TIME = 0x4 +INSTRUMENT_REPORT_MEMORY = 0x8 +INSTRUMENT_INIT = 0x10 +INST_START = INSTRUMENT_BEFORE_OP | INSTRUMENT_REPORT_TIME +INST_FINISH = INSTRUMENT_AFTER_OP | INSTRUMENT_REPORT_TIME + + +def valid_onnx_input(fname): + valid_exts = ["mlir"] # ["onnx", "mlir", "onnxtext"] + ext = os.path.splitext(fname)[1][1:] + + if ext not in valid_exts: + parser.error( + "Only accept an input model with one of extensions {}".format(valid_exts) + ) + return fname + + +# Read profile file, and return dictionay for execution time from node +def read_profile(profile): + time_dict = {} + count_dict = {} + with open(profile) as f: + for line in f: + columns = line.split(",") + first_column = columns[0].strip() + if first_column != "==PERF-REPORT==": + continue + nodes = columns[2].strip() + before_after = columns[3].strip() + time = float(columns[4].strip()) + nodes_split = nodes.split("-") + num_nodes = len(list(nodes_split)) + for node in nodes_split: + if before_after == "before": + node = "BEFORE" + if not node in time_dict: + time_dict[node] = 0.0 + count_dict[node] = 0 + time_dict[node] += time / num_nodes + count_dict[node] += 1 + return time_dict, count_dict + + +NODENAME_STR = "onnx_node_name = " +NONEOP_STR = "NONE" + + +# Parse line in MLIR file, and return outvar, invars, operation, nodename and outvar's shape +def parse_line_in_model(line): + columns = line.strip().split(" ") + if (columns[0] == "module") or (columns[0] == "func.func") or (columns[0] == "}"): + return ("", [], "", "", "") + # get outvar + vars_list = line[: line.find("=") - 1].strip().split(",") + for vars in vars_list: + if vars.strip()[0] != "%": + return ("", [], "", "", "") + # XXXXX TODO: The current version processes the first output only. And + # second and futher outputs are ignored. it is O.K for ONNXIR, but not + # for other dialects (e.g. KernelIR). + outvar = vars_list[0] + # get operation + opstr = line[line.find("=") + 1 :] # remove heading str until '=' if found + opstr = opstr[opstr.find('"') + 1 :] # remove heading str until '"' if found + opstr = opstr[: opstr.find('"')] # remove tailing str from ')' if found + operation = opstr.strip().split(" ")[0] + # get invars + s = line[line.find("=") + 1 :] # remove heading str until "=" if found + s = s[s.find("(") + 1 :] # remove heading str until "(" if found + s = s[: s.find(")") + 1] # remove tailing str from ")" if found + s = s[: opstr.find(":")] # remove tailing str from ")" if found + invars = s.replace(" ", "").split(",") if s else [] + # get nodename + index = line.find(NODENAME_STR) + nodename = NONEOP_STR + if index >= 0: + index_start = index + len(NODENAME_STR) + 1 + index_end = index_start + line[index_start:].find('"') + nodename = line[index_start:index_end].strip() + # Get operation shape + opshape = line[line.rfind(":") + 1 :] # remove heading str until last ':' + opshape = opshape.replace(" ", "").strip() + return (outvar, invars, operation, nodename, opshape) + + +# Generate networkx graph from the model file in MLIR +def generate_model_graph(model, profile_dict): + model_graph = nx.DiGraph() + key_operations = [] + with open(model) as f: + for line in f: + outvar, invars, operation, nodename, opshape = parse_line_in_model(line) + if not outvar: + continue + time = profile_dict[nodename] if (nodename in profile_dict) else 0.0 + # print("GENMODEL: outvar={}, invars={}, operation={}, nodename={}, opshape={}, time={}: {}".format(outvar, invars, operation, nodename, opshape, time, line.strip())) + model_graph.add_node( + outvar, + invars=invars, + operation=operation, + nodename=nodename, + opshape=opshape, + time=time, + line=line, + ) + for invar in invars: + model_graph.add_edge(invar, outvar) + if operation in KEY_OPERATIONS: + key_operations.append(outvar) + return (model_graph, key_operations) + + +def get_operation_from_node(node, model_graph): + nodeattr = model_graph.nodes[node] + operation = nodeattr["operation"] if ("operation" in nodeattr) else 0.0 + return operation + + +def get_time_from_node(node, model_graph): + nodeattr = model_graph.nodes[node] + time = nodeattr["time"] if ("time" in nodeattr) else 0.0 + return float(time) + + +def get_node_str( + node, + model_graph, + get_node=False, + get_nodename=True, + get_operation=True, + get_opshape=True, + key_opshape_only=True, + get_time=False, +): + nodeattr = model_graph.nodes[node] + nodename = nodeattr["nodename"] if ("nodename" in nodeattr) else "" + operation = nodeattr["operation"] if ("operation" in nodeattr) else "" + opshape = nodeattr["opshape"] if ("opshape" in nodeattr) else "" + time = float(nodeattr["time"] if ("time" in nodeattr) else 0.0) + nodestr = "" + sep = "" + if get_node: + nodestr += sep + node + sep = ":" + if get_nodename: + nodestr += sep + nodename + sep = ":" + if get_operation: + nodestr += sep + operation + sep = ":" + if get_opshape and (not key_opshape_only or (operation in KEY_OPERATIONS)): + nodestr += sep + opshape + sep = ":" + if get_time: + nodestr += sep + "{:.3f}:".format(time * 1000) + sep = ":" + return nodestr + + +def print_graph(model_graph): + for outvar, attr in model_graph.nodes(data=True): + print("PRINTGRAPH: outvar={}, attr={}".format(outvar, attr)) + invars = attr["invars"] if ("invars" in attr) else "" + operation = attr["operation"] if ("operation" in attr) else "" + nodename = attr["nodename"] if ("nodename" in attr) else "" + time = attr["time"] if ("time" in attr) else "" + print( + "PRINTGRAPH: outvar={}, invars={}, operation={}, nodename={}, time={}".format( + outvar, invars, operation, nodename, time + ) + ) + + +def has_inputs_from_outside(node, subgraph, model_graph, parent=None): + pred_list = list(model_graph.predecessors(node)) + new_node_list = list(set(pred_list) - set(subgraph) - set([node])) + for new_node in new_node_list: + # new node should be parent node or no input node + if new_node != parent and list(model_graph.predecessors(new_node)): + return True + return False + + +def has_multiple_outputs_to_outside(node, subgraph, model_graph): + succ_list = list(model_graph.successors(node)) + new_node_list = list(set(succ_list) - set(subgraph) - set([node])) + return True if len(new_node_list) > 1 else False + + +def get_no_input_node_set(node, model_graph): + pred_list = list(model_graph.predecessors(node)) + no_input_node_set = set() + for pred_node in pred_list: + for input_node in list(model_graph.predecessors(pred_node)): + if not list(model_graph.predecessors(input_node)): + no_input_node_set = no_input_node_set | set([input_node]) + return no_input_node_set + + +def get_successor_subgraph_ending_one_node(parent, node, model_graph): + work_list = [node] # list of outputs of nodes + pending_list = [] # list of pending nodes + subgraph = [] # list of nodes in the current subgraph + no_input_node_set = set() # set of no input node subgraph + total_time = 0.0 # total time in the current subgraph + best_candidate = ([], {}, 0.0) # current best successor subgraph + while work_list and len(subgraph) < args.max_node_num_in_block: + # Get the current node from work_list + curr = work_list.pop(0) + # If the current node has inputs from outside of subgraph, + # append the current node in pending_list and continue + if has_inputs_from_outside( + curr, subgraph, model_graph, parent + ) or has_multiple_outputs_to_outside(curr, subgraph, model_graph): + pending_list.append(curr) + continue + # Append the current node to the subgraph + subgraph.append(curr) + no_input_node_set |= get_no_input_node_set(curr, model_graph) + total_time += get_time_from_node(curr, model_graph) + # Remove the current node from pending_list if exist + if curr in pending_list: + pending_list.remove(curr) + # Get new outputs of the current node not in the subgraph and work_list + curr_succ_list = list(model_graph.successors(curr)) + new_outputs = list(set(curr_succ_list) - set(subgraph) - set(work_list)) + # Add the new output nodes to the work_list + work_list = sorted(new_outputs + work_list) + # Remove the new output nodes from pending_list + pending_list = list(set(pending_list) - set(new_outputs)) + # Update the best_candidate if no nodes in working_list and pending_list + if len(work_list) <= 1 and not pending_list: + best_candidate = (subgraph, no_input_node_set, total_time) + # sort the candidate + subgraph, no_input_node_set, total_time = best_candidate + sorted_no_input_node_list = sorted(list(best_candidate[1]), key=num_in_str) + sorted_best_candidate = (subgraph, sorted_no_input_node_list, total_time) + return sorted_best_candidate + + +def has_key_operation(subgraph, model_graph): + for node in subgraph: + operation = get_operation_from_node(node, model_graph) + if operation in KEY_OPERATIONS: + return True + return False + + +def node_in_candidate_dict(node, candidate_dict): + for outvar in candidate_dict: + for block in candidate_dict[outvar]: + subgraph, _, _ = block + if node in subgraph: + return True + + +def get_candidates(model_graph): + candidate_dict = {} + time_list = [] + for outvar, attr in model_graph.nodes(data=True): + # skip outvar if outvar is already in candidate_dict + if node_in_candidate_dict(outvar, candidate_dict): + continue + invars = attr["invars"] if ("invars" in attr) else "" + operation = attr["operation"] if ("operation" in attr) else "" + if not operation or (operation in NO_PARENT_OPERATIONS): + continue + nodename = attr["nodename"] if ("nodename" in attr) else "" + time = attr["time"] if ("time" in attr) else "" + succ_list = list(model_graph.successors(outvar)) + block_list = [] + if len(succ_list) >= args.min_parallel_num: + for succ in succ_list: + subgraph, no_input_nodes, time = get_successor_subgraph_ending_one_node( + outvar, succ, model_graph + ) + if ( + has_key_operation(subgraph, model_graph) + and time >= args.min_execution_time_in_block + ): + block_list.append((subgraph, no_input_nodes, time)) + if len(block_list) >= args.min_parallel_num: + candidate_dict[outvar] = block_list + return candidate_dict + + +def print_key_operations(key_operations, model_graph): + print("KEYOPERATIONS: {} [ ".format(len(key_operations), end="")) + sep = "" + for node in key_operations: + print("{}{}".format(sep, get_node_str(node, model_graph)), end="") + sep = ", " + print(" ]") + + +def print_block(index, block, model_graph, full=True): + nodes = block[0] + no_input_nodes = block[1] + time = block[2] + print(" {}: BLOCK {} {} [ ".format(index, no_input_nodes, len(nodes)), end="") + if args.profile: + print( + " time={:.3f} msec, ".format(time * 1000), + end="", + ) + if full: + print("\n", end="") + for node in nodes: + print(" {}".format(model_graph.nodes[node]["line"]), end="") + else: + sep = "" + for node in nodes: + print("{}{}".format(sep, get_node_str(node, model_graph)), end="") + sep = ", " + print(" ]") + + +def print_candidate(index, parent, blocks, model_graph, full=True): + if full: + print(" {}: PARENT {} [".format(index, len(blocks))) + print("{}".format(model_graph.nodes[parent]["line"]), end="") + else: + print( + " {}: PARENT {} {} [".format( + index, get_node_str(parent, model_graph), len(blocks) + ) + ) + for idx, block in enumerate(blocks): + print_block(idx, block, model_graph, full) + print(" ]") + + +def print_candidates(candidates, model_graph): + print("CANDIDATES: {} [".format(len(candidates))) + for idx, parent in enumerate(candidates.keys()): + print_candidate(idx, parent, candidates[parent], model_graph) + print("]") + + +def print_line(indent, line, file, end="\n"): + indent_str = " " * indent + print("{}{}".format(indent_str, line), file=file, end=end) + + +def instrument(indent, parent, thread_id, op, tag, file): + if not args.instrument: + return + node = "P" + parent.replace("%", "") + if thread_id >= 0: + node = node + "_" + str(thread_id) + print_line(indent, LINE_INSTRUMENT.format(node, op, tag), file) + + +# print a line in input code +def print_input_line(indent, line, file, parent="NONE", th_id=-1): + _, _, operation, _, _ = parse_line_in_model(line) + if operation in INST_OPERATIONS: + inst_indent = indent + (len(line) - len(line.lstrip())) + instrument(inst_indent, parent, th_id, operation, INST_START, file) + print_line(indent, line, file, end="") + instrument(inst_indent, parent, th_id, operation, INST_FINISH, file) + else: + print_line(indent, line, file, end="") + + +def get_var_shape(var, outvar_lineno_dict, lines_list): + line = lines_list[outvar_lineno_dict[var]] + _, _, _, _, opshape = parse_line_in_model(line) + var_shape = opshape[opshape.find("->") + 2 :] + return var_shape + + +def get_endvar_list_from_candidate(candidate, lineno_dict, lines_list): + endvar_list = [] + endvar_shape_list = [] + for cand in candidate: + endvar = cand[0][-1] + endvar_list.append(endvar) + endvar_shape_list.append(get_var_shape(endvar, lineno_dict, lines_list)) + return endvar_list, endvar_shape_list + + +def generate_paracode_for_candidate( + parent, + parent_lineno, + candidate, + lines_list, + outvar_lineno_dict, + processed_lineno_dict, + f, + dummy=False, +): + parent_line = lines_list[parent_lineno] + print_input_line(0, parent_line, f) + indent = len(parent_line) - len(parent_line.lstrip()) + processed_lineno_dict[parent_lineno] = True + th_id = 1 + endvar_list, endvar_shape_list = get_endvar_list_from_candidate( + candidate, outvar_lineno_dict, lines_list + ) + endvar_str = ", ".join(endvar_list) + endvar_shape_str = ", ".join(endvar_shape_list) + # Generate beginning part of parallel op + if not dummy: + instrument(indent, parent, -1, "onnx.Parallel", INST_START, f) + print_line(indent, LINE_ONNXPARALLEL.format(endvar_str), f) + for cand in candidate: + blocks = cand[0] + no_input_nodes = cand[1] + time = cand[2] + endvar = blocks[-1] + endvar_line = lines_list[outvar_lineno_dict[endvar]] + endvar_shape = endvar_line[endvar_line.rfind("->") + 2 :].strip() + forkvar = endvar + # Generate no_input ops before fork op + for no_input in no_input_nodes: + no_input_lineno = outvar_lineno_dict[no_input] + if no_input_lineno not in processed_lineno_dict: + print_input_line(0, lines_list[no_input_lineno], f) + processed_lineno_dict[no_input_lineno] = True + # Generate begining part of fork op + if not dummy: + instrument(indent, parent, th_id, "onnx.Fork", INST_START, f) + print_line(indent, LINE_ONNXFORK.format(forkvar), f) + # Generate block part + instrument(indent + 2, parent, th_id, "onnx.Block", INST_START, f) + for block in blocks: + block_lineno = outvar_lineno_dict[block] + print_input_line(2, lines_list[block_lineno], f, parent, th_id) + processed_lineno_dict[block_lineno] = True + instrument(indent + 2, parent, th_id, "onnx.Block", INST_FINISH, f) + # Generate ending part of fork op + if not dummy: + print_line(indent + 2, LINE_ONNXYIELD.format(endvar, endvar_shape), f) + print_line(indent, LINE_ONNXFORKEND.format(th_id, endvar_shape), f) + instrument(indent, parent, th_id, "onnx.Fork", INST_FINISH, f) + th_id += 1 + # Generate ending part of parallel op + if not dummy: + print_line( + indent, + LINE_ONNXPARALLELYIELD.format(endvar_str, endvar_shape_str), + f, + ) + print_line(indent, LINE_ONNXPARALLELEND.format(endvar_shape_str), f) + instrument(indent, parent, -1, "onnx.Parallel", INST_FINISH, f) + + +def generate_paracode(candidates, model, output, dummy=False): + # create outvar_list and line_outvar_dict + lines_list = [] + outvar_lineno_dict = {} + lineno_outvar_dict = {} + last_return_lineno = 0 + with open(model) as f: + lineno = 0 + for line in f: + lines_list.append(line) + outvar, _, _, _, _ = parse_line_in_model(line) + outvar_lineno_dict[outvar] = lineno + lineno_outvar_dict[lineno] = outvar + first_column = line.strip().split(" ")[0] + if first_column == "onnx.Return" or first_column == "return": + last_return_lineno = lineno + lineno += 1 + lines_list_len = len(lines_list) + with open(output, "w") as f: + processed_lineno_dict = {} + total_instrumentation_started = False + for lineno, line in enumerate(lines_list): + outvar = lineno_outvar_dict[lineno] + # generate instrument op for total start just before the first execution op + if args.instrument and outvar and not total_instrumentation_started: + tag = INSTRUMENT_INIT | INSTRUMENT_BEFORE_OP | INSTRUMENT_REPORT_TIME + print_line(4, LINE_INSTRUMENT.format("Total", "Total", tag), f) + total_instrumentation_started = True + # generate instrument op for total end just before the last line + if args.instrument and lineno == last_return_lineno: + tag = INSTRUMENT_AFTER_OP | INSTRUMENT_REPORT_TIME + print_line(4, LINE_INSTRUMENT.format("Total", "Total", tag), f) + if outvar in candidates: + blocks = candidates[outvar] + generate_paracode_for_candidate( + outvar, + lineno, + blocks, + lines_list, + outvar_lineno_dict, + processed_lineno_dict, + f, + dummy, + ) + elif lineno not in processed_lineno_dict: + print_input_line(0, line, f) + processed_lineno_dict[lineno] = True + + +# Command arguments. +parser = argparse.ArgumentParser() +parser.add_argument( + "-p", + "--profile", + type=str, + default="", + help="Path to a profile file generated by --InstrumentReportTime option", +) +parser.add_argument( + "-m", + "--model", + type=lambda s: valid_onnx_input(s), + default="", + help="Path to an ONNX model (.onnx or .mlir)", +) +parser.add_argument( + "--max-node-num-in-block", + type=int, + default=MAX_NODE_NUM_IN_BLOCK, + help="Maximum node number in block (default={})".format(MAX_NODE_NUM_IN_BLOCK), +) +parser.add_argument( + "--min-parallel-num", + type=int, + default=MIN_PARALLEL_NUM, + help="Minimum parallel number (default={})".format(MIN_PARALLEL_NUM), +) +parser.add_argument( + "--min-execution-time-in-block", + type=float, + default=MIN_EXECUTION_TIME_IN_BLOCK, + help="Minimum execution time (sec) in block (default={}). (work with --profile)".format( + MIN_EXECUTION_TIME_IN_BLOCK + ), +) +parser.add_argument( + "--print-model-graph", + action="store_true", + help="Flag to print model graph (default={})".format(False), +) +parser.add_argument( + "--print-key-operations", + action="store_true", + help="Flag to print key operations(default={})".format(False), +) +parser.add_argument( + "--print-candidates", + action="store_true", + help="Flag to print candidates to be parallelized(default={})".format(False), +) +parser.add_argument( + "--generate-originalcode", + type=str, + default="", + help="Path to generate original code (not generated if not specified)", +) +parser.add_argument( + "--generate-dummyparacode", + type=str, + default="", + help="Path to generate dummypara code (not generated if not specified)", +) +parser.add_argument( + "--generate-paracode", + type=str, + default="", + help="Path to generate paralized code (not generated if not specified)", +) +parser.add_argument( + "--instrument", + action="store_true", + help="Flag to set instrumentation(default={})".format(False), +) + + +args = parser.parse_args() +if not args.model: + print("error: no model file, use argument --model") + print(parser.format_usage()) + exit(1) +if args.min_execution_time_in_block > 0.0 and not args.profile: + print("error: --min-execution-time-in-block optionworks with --profile option") + print(parser.format_usage()) + exit(1) + + +# +# Main program +# +def main(): + profile_dict = {} + if args.profile: + profile_dict, _ = read_profile(args.profile) + model_graph, key_operations = generate_model_graph(args.model, profile_dict) + if args.print_model_graph: + print_graph(model_graph) + candidates = get_candidates(model_graph) + if args.profile: + print("PROFILE: {}".format(args.profile)) + if args.print_key_operations: + print_key_operations(key_operations, model_graph) + if args.print_candidates: + print_candidates(candidates, model_graph) + if args.generate_dummyparacode: + generate_paracode( + candidates, args.model, args.generate_dummyparacode, dummy=True + ) + if args.generate_originalcode: + generate_paracode({}, args.model, args.generate_originalcode) + if args.generate_paracode: + generate_paracode(candidates, args.model, args.generate_paracode) + + +if __name__ == "__main__": + main() diff --git a/utils/RunONNXModel.py b/utils/RunONNXModel.py index 96ec38c1fc..c90a99490d 100755 --- a/utils/RunONNXModel.py +++ b/utils/RunONNXModel.py @@ -201,6 +201,24 @@ def check_non_negative(argname, value): default="42", help="seed to initialize the random num generator for inputs", ) +parser.add_argument("--output-message", action="store_true", help="Output message") +parser.add_argument( + "-P", + "--oplevel-parallel", + action="store_true", + help="Enable operation level parallelization", +) +parser.add_argument( + "--oplevel-parallel-report", + action="store_true", + help="Report operation level parallelization status", +) +parser.add_argument( + "--keep-oplevel-parallel-code", + type=str, + default="", + help="Keep generated oplevel-parallel code at the specified directory", +) args = parser.parse_args() if args.verify and (args.verify_with_softmax is None) and (not args.verify_every_value): @@ -616,7 +634,6 @@ def main(): if args.load_so: shared_lib_path = args.load_so else: - print("Compiling the model ...") # Prepare input and output paths. output_path = os.path.join(temp_dir, "model") shared_lib_path = os.path.join(temp_dir, "model.so") @@ -648,10 +665,58 @@ def main(): "the shapes of the model's inputs will be " "changed to the shapes of the inputs in the data folder" ) + + if args.oplevel_parallel: + print("Parallelizing the model at operation-level...") + # Generate ONNXIR file in the temp_dir from the input file + base_output_path = ( + args.keep_oplevel_parallel_code + if args.keep_oplevel_parallel_code + else os.path.join(temp_dir, "model") + ) + onnxir_path = base_output_path + "-onnxir.mlir" + onnxir_short_path = base_output_path + "-onnxir-short.mlir" + prepare_command_str = command_str + [ + "--EmitONNXIR", + args.model, + "-o", + base_output_path, + ] + ok, msg = execute_commands(prepare_command_str) + # Rename the generated ONNXIR files + ok, msg = execute_commands( + ["/bin/mv", "-f", base_output_path + ".onnx.mlir", onnxir_path] + ) + ok, msg = execute_commands( + ["/bin/mv", "-f", base_output_path + ".tmp", onnxir_short_path] + ) + # Generate paralelized IR file from the ONNXIR file + parataskir_path = base_output_path + "-onnxir-oppara.mlir" + paratask_command_str = [ + os.path.dirname(__file__) + "/OpLevelParallel.py", + "-m", + onnxir_path, + "--generate-paracode", + parataskir_path, + ] + if args.oplevel_parallel_report: + paratask_command_str.append("--print-candidates") + ok, msg = execute_commands(paratask_command_str) + if args.output_message or args.oplevel_parallel_report: + print(msg) + if not ok: + print(msg) + exit(1) + # Set input_model_path as the parallelized IR file + input_model_path = parataskir_path + # Add "--parallel" option to onnx-mlir + command_str += ["--parallel"] + command_str += [input_model_path] command_str += ["-o", output_path] # Compile the model. + print("Compiling the model ...") start = time.perf_counter() ok, msg = execute_commands(command_str) # Dump the compilation log into a file. diff --git a/utils/RunONNXModelZoo.py b/utils/RunONNXModelZoo.py index 91fab6a861..1d44aa9ba5 100755 --- a/utils/RunONNXModelZoo.py +++ b/utils/RunONNXModelZoo.py @@ -163,6 +163,31 @@ def get_args(): default=os.getcwd(), help="Work dir for cloning and downloading, default cwd.", ) + parser.add_argument( + "-n", + "--n-iteration", + type=int, + default=1, + help="The number of inference runs excluding warmup", + ) + parser.add_argument("--output-message", action="store_true", help="Output message") + parser.add_argument( + "-P", + "--oplevel-parallel", + action="store_true", + help="Enable operation level parallelization", + ) + parser.add_argument( + "--oplevel-parallel-report", + action="store_true", + help="Report operation level parallelization status", + ) + parser.add_argument( + "--keep-oplevel-parallel-code", + type=str, + default="", + help="Keep generated oplevel-parallel code at the specified directory", + ) return parser.parse_args() @@ -388,10 +413,27 @@ def check_model(model_path, model_name, compile_args, report_dir): if args.compile_only: options += ["--compile-only"] options += ["--model={}".format(onnx_file)] + if args.n_iteration > 1: + options += ["--n-iteration={}".format(args.n_iteration)] + if args.output_message: + options += ["--output-message"] if args.log_to_file: options += ["--log-to-file={}".format(args.log_to_file)] + if args.oplevel_parallel: + options += ["--oplevel-parallel"] + if args.oplevel_parallel_report: + options += ["--oplevel-parallel-report"] + if args.keep_oplevel_parallel_code: + options += [ + "--keep-oplevel-parallel-code={}".format( + args.keep_oplevel_parallel_code + ) + ] + # Wait up to 30 minutes for compilation and inference to finish ok, msg = execute_commands(RUN_ONNX_MODEL_CMD + options, tmout=1800) + if args.output_message or args.oplevel_parallel_report: + print(msg) state = TEST_PASSED if ok else TEST_FAILED logger.info("[{}] check {}".format(model_name, "passed" if ok else "failed")) logger.debug("[{}] {}".format(model_name, msg))