diff --git a/src/Dialect/ONNX/AdditionalONNXOps.td b/src/Dialect/ONNX/AdditionalONNXOps.td index a72af4d7c2..7dd5afd583 100644 --- a/src/Dialect/ONNX/AdditionalONNXOps.td +++ b/src/Dialect/ONNX/AdditionalONNXOps.td @@ -620,3 +620,113 @@ def ONNXRMSLayerNormalizationOp:ONNX_Op<"RMSLayerNormalization", }]; let hasVerifier = 1; } + +//===----------------------------------------------------------------------===// +// ONNXForkOp +def ONNXForkOp:ONNX_Op<"Fork", + [Pure, HasParent<"ONNXParallelOp">, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + OpInterface<"mlir::HasOnnxSubgraphOpInterface">]> { + let summary = "ONNX operation for creating a thread"; + let description = [{ + "onnx.Fork" opeation Creates a thread to run the operations included in + the `body` region. The parent operation of "onnx.Fork" is the "onnx.Parallel" + operaiton. The thread is synchronized with main thread at the end of the + "onnx.Parallel" op. The results to be used in following operations need + to be set as the results of this operation. + In the following example, MatMul ops runs in parallel with two threads. + + Example: + ```mlir + %0:2 = "onnx.Parallel"() ({ + %00 = "onnx.Fork"() ({ + %01 = "onnx.MatMul"(%arg0, %c0) : (tensor<64x32xf32>, tensor<32x32xf32>) -> tensor<64x32xf32> + onnx.Yield %01 : tensor<64x32xf32> + }) {id = 0 : si64} : () -> tensor<64x32xf32> + %01 = "onnx.Fork"() ({ + %01 = "onnx.MatMul"(%arg0, %c2) : (tensor<64x32xf32>, tensor<32x32xf32>) -> tensor<64x32xf32> + onnx.Yield %01 : tensor<64x32xf32> + }) {id = 1 : si64} : () -> tensor<64x32xf32> + "onnx.Yield"(%00, %01) : (tensor<*xf32>, tensor<64x32xf32>) -> () + }) : () -> (tensor<64x32xf32>, tensor<64x32xf32>) + ``` + }]; + let arguments = (ins DefaultValuedAttr:$id); + let results = (outs Variadic, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[UI32]>, + TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>]>>:$results); + let regions = (region SizedRegion<1>:$body); + let skipDefaultBuilders = 1; + let extraClassDeclaration = [{ + int64_t getSubgraphRegionIdx(const std::string& name) { + if (name == "body") return 0; + llvm_unreachable("region with the specified name does not exist"); + } + using BodyBuilderFn = + llvm::function_ref; + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ONNXForkOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ONNXForkOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; + let builders = [ + OpBuilder<(ins "mlir::TypeRange":$resultTypes, + "mlir::ValueRange":$operands, + CArg<"llvm::function_ref", + "nullptr">:$bodyBuilder)> + ]; +} + +//===----------------------------------------------------------------------===// +// ONNXParallelOp +def ONNXParallelOp:ONNX_Op<"Parallel", + [Pure, DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + DeclareOpInterfaceMethods, + OpInterface<"mlir::HasOnnxSubgraphOpInterface">]> { + let summary = "ONNX operation to specify paralell region."; + let description = [{ + The "onnx.Parallel" operation represents parallel region to run operations + in parallel using multiple threads. The threads are synchronized with main + thread at the end of this operation. The operation takes body region + where several "onnx.Fork" operataions are included to create threads + (see "onnx.Fork" for details). The i-th result is the result of the i-th + "onnx.Fork" operation. Consequently, the number of results of "onnx.Parallel" + op matches the number of "onnx.Fork" ops in the region. + operations need to be set as the results of this operation by using + The body region must contain "onnx.Yield" op whose operands are the results + of "onnx.Fork" ops. + }]; + let results = (outs Variadic, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[UI32]>, + TensorOf<[UI64]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>]>>:$results); + let regions = (region SizedRegion<1>:$body); + + let skipDefaultBuilders = 1; + let extraClassDeclaration = [{ + int64_t getSubgraphRegionIdx(const std::string& name) { + if (name == "body") return 0; + llvm_unreachable("region with the specified name does not exist"); + } + using BodyBuilderFn = + llvm::function_ref; + }]; + let extraClassDefinition = [{ + onnx_mlir::ONNXOpShapeHelper * ONNXParallelOp::getShapeHelper(mlir::Operation *op, mlir::ArrayRef oper, + onnx_mlir::IndexExprBuilder *ieb, onnx_mlir::IndexExprScope *scope) { + onnx_mlir::ONNXOpShapeHelper *sh = new ONNXParallelOpShapeHelper(op, oper, ieb, scope); + assert(sh && "failed to allocate shape helper"); + return sh; + } + }]; + let builders = [ + OpBuilder<(ins "mlir::TypeRange":$resultTypes, + "mlir::ValueRange":$operands, + CArg<"llvm::function_ref", + "nullptr">:$bodyBuilder)> + ]; +} diff --git a/src/Dialect/ONNX/CMakeLists.txt b/src/Dialect/ONNX/CMakeLists.txt index 3917c94aa4..f15506fa1d 100644 --- a/src/Dialect/ONNX/CMakeLists.txt +++ b/src/Dialect/ONNX/CMakeLists.txt @@ -31,9 +31,11 @@ add_onnx_mlir_library(OMONNXOps ONNXOps/Additional/Custom.cpp ONNXOps/Additional/Dim.cpp ONNXOps/Additional/EntryPoint.cpp + ONNXOps/Additional/Fork.cpp ONNXOps/Additional/Return.cpp ONNXOps/Additional/LayoutTransform.cpp ONNXOps/Additional/None.cpp + ONNXOps/Additional/Parallel.cpp ONNXOps/Additional/ShapeTransform.cpp ONNXOps/ControlFlow/If.cpp ONNXOps/ControlFlow/Loop.cpp diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Fork.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Fork.cpp new file mode 100644 index 0000000000..6e02786f1b --- /dev/null +++ b/src/Dialect/ONNX/ONNXOps/Additional/Fork.cpp @@ -0,0 +1,95 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Fork.cpp - ONNX Operations -------------------------===// +// +// Copyright 2019-2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file provides definition of ONNX dialect Fork operation. +// +//===----------------------------------------------------------------------===// + +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +//===----------------------------------------------------------------------===// +// ShapeHelper +//===----------------------------------------------------------------------===// + +template <> +LogicalResult ONNXForkOpShapeHelper::computeShape() { + ONNXForkOp forkOp = llvm::cast(op); + (void)forkOp.inferShapes([](Region ®ion) {}); + Operation *yieldOp = forkOp.getBody().front().getTerminator(); + for (unsigned i = 0; i < yieldOp->getNumOperands(); ++i) { + DimsExpr outputDims; + Value returnVal = yieldOp->getOperands()[i]; + int64_t outRank = returnVal.getType().cast().getRank(); + for (int64_t j = 0; j < outRank; ++j) + outputDims.emplace_back(createIE->getShapeAsDim(returnVal, j)); + setOutputDims(outputDims, i); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Type Inference +//===----------------------------------------------------------------------===// + +std::vector ONNXForkOp::resultTypeInference() { + Operation *terminator = getRegion().back().getTerminator(); + auto bodyOutputTys = terminator->getOperandTypes(); + std::vector resultTypes; + for (auto [i, ty] : llvm::enumerate(bodyOutputTys)) { + resultTypes.push_back(ty); + } + return resultTypes; +} + +//===----------------------------------------------------------------------===// +// Shape Inference +//===----------------------------------------------------------------------===// + +LogicalResult ONNXForkOp::inferShapes( + std::function doShapeInference) { + doShapeInference(getRegion()); + for (auto [i, ty] : llvm::enumerate(resultTypeInference())) + getResult(i).setType(ty); + return success(); +} + +//===----------------------------------------------------------------------===// +// Builder: Refer to Async ExecuteOp +//===----------------------------------------------------------------------===// +void ONNXForkOp::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, ValueRange operands, BodyBuilderFn bodyBuilder) { + + result.addOperands(operands); + result.addTypes(resultTypes); + + // Add a body region with block arguments + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + for (Value operand : operands) { + bodyBlock.addArgument(operand.getType(), operand.getLoc()); + } + + // Create the default terminator if the builder is not provided and if the + // expected result is empty. Otherwise, leave this to the caller + // because we don't know which values to return from the execute op. + if (resultTypes.empty() && !bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + builder.create(result.location, ValueRange()); + } else if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuilder(builder, result.location, bodyBlock.getArguments()); + } +} diff --git a/src/Dialect/ONNX/ONNXOps/Additional/Parallel.cpp b/src/Dialect/ONNX/ONNXOps/Additional/Parallel.cpp new file mode 100644 index 0000000000..570afd39ea --- /dev/null +++ b/src/Dialect/ONNX/ONNXOps/Additional/Parallel.cpp @@ -0,0 +1,96 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + */ + +//===---------------- Fork.cpp - ONNX Operations -------------------------===// +// +// Copyright 2019-2024 The IBM Research Authors. +// +// ============================================================================= +// +// This file provides definition of ONNX dialect Fork operation. +// +//===----------------------------------------------------------------------===// + +#include "src/Dialect/ONNX/ONNXOps/OpHelper.hpp" + +using namespace mlir; +using namespace onnx_mlir; + +//===----------------------------------------------------------------------===// +// ShapeHelper +//===----------------------------------------------------------------------===// + +template <> +LogicalResult ONNXParallelOpShapeHelper::computeShape() { + ONNXParallelOp parallelOp = llvm::cast(op); + (void)parallelOp.inferShapes([](Region ®ion) {}); + Operation *yieldOp = parallelOp.getBody().front().getTerminator(); + for (unsigned i = 0; i < yieldOp->getNumOperands(); ++i) { + DimsExpr outputDims; + Value returnVal = yieldOp->getOperands()[i]; + int64_t outRank = returnVal.getType().cast().getRank(); + for (int64_t j = 0; j < outRank; ++j) + outputDims.emplace_back(createIE->getShapeAsDim(returnVal, j)); + setOutputDims(outputDims, i); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// Type Inference +//===----------------------------------------------------------------------===// + +std::vector ONNXParallelOp::resultTypeInference() { + Operation *terminator = getRegion().back().getTerminator(); + auto bodyOutputTys = terminator->getOperandTypes(); + + std::vector resultTypes; + for (auto [i, ty] : llvm::enumerate(bodyOutputTys)) { + resultTypes.push_back(ty); + } + return resultTypes; +} + +//===----------------------------------------------------------------------===// +// Shape Inference +//===----------------------------------------------------------------------===// + +LogicalResult ONNXParallelOp::inferShapes( + std::function doShapeInference) { + doShapeInference(getRegion()); + for (auto [i, ty] : llvm::enumerate(resultTypeInference())) + getResult(i).setType(ty); + return success(); +} + +//===----------------------------------------------------------------------===// +// Builder: Refer to Async ExecuteOp +//===----------------------------------------------------------------------===// +void ONNXParallelOp::build(OpBuilder &builder, OperationState &result, + TypeRange resultTypes, ValueRange operands, BodyBuilderFn bodyBuilder) { + + result.addOperands(operands); + result.addTypes(resultTypes); + + // Add a body region with block arguments + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + for (Value operand : operands) { + bodyBlock.addArgument(operand.getType(), operand.getLoc()); + } + + // Create the default terminator if the builder is not provided and if the + // expected result is empty. Otherwise, leave this to the caller + // because we don't know which values to return from the execute op. + if (resultTypes.empty() && !bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + builder.create(result.location, ValueRange()); + } else if (bodyBuilder) { + OpBuilder::InsertionGuard guard(builder); + builder.setInsertionPointToStart(&bodyBlock); + bodyBuilder(builder, result.location, bodyBlock.getArguments()); + } +} diff --git a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp index a58f9572f7..aefd3b272f 100644 --- a/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp +++ b/src/Dialect/ONNX/ONNXOps/ShapeHelper.hpp @@ -885,6 +885,8 @@ using ONNXTileOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXTopKOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXTransposeOpShapeHelper = ONNXNonSpecificOpShapeHelper; using ONNXUpsampleOpShapeHelper = ONNXNonSpecificOpShapeHelper; +using ONNXForkOpShapeHelper = ONNXNonSpecificOpShapeHelper; +using ONNXParallelOpShapeHelper = ONNXNonSpecificOpShapeHelper; // clang-format on //===----------------------------------------------------------------------===// diff --git a/test/mlir/onnx/onnx_shape_inference.mlir b/test/mlir/onnx/onnx_shape_inference.mlir index 07889a0abb..94e0fef19a 100644 --- a/test/mlir/onnx/onnx_shape_inference.mlir +++ b/test/mlir/onnx/onnx_shape_inference.mlir @@ -3801,3 +3801,47 @@ func.func @test_RMSlayer_norm_2inputs(%arg0: tensor<12x3x5xf32>, %arg1: tensor<5 // CHECK: } } +// ----- + +//===----------------------------------------------------------------------===// +/// Test shape inference for Parallel and Fork. +//===----------------------------------------------------------------------===// + +func.func @test_parallel_fork_1(%arg0: tensor<8x64x32xf32>, %arg1: tensor<32x32xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %c0 = onnx.Constant dense<1.0> : tensor<32x32xf32> + %c1 = onnx.Constant dense<1.0> : tensor<32xf32> + %c2 = onnx.Constant dense<1.0> : tensor<32x32xf32> + + %0:2 = "onnx.Parallel"() ({ + %00 = "onnx.Fork"() ({ + %01 = "onnx.MatMul"(%arg0, %c0) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<*xf32> + onnx.Yield %01 : tensor<*xf32> + }) {id = 0 : si64} : () -> tensor<*xf32> + %01 = "onnx.Fork"() ({ + %01 = "onnx.MatMul"(%arg0, %c2) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<*xf32> + onnx.Yield %01 : tensor<*xf32> + }) {id = 1 : si64} : () -> tensor<*xf32> + "onnx.Yield"(%00, %01) : (tensor<*xf32>, tensor<*xf32>) -> () + }) : () -> (tensor<*xf32>, tensor<*xf32>) + "onnx.Return"(%0#0,%0#1): (tensor<*xf32>, tensor<*xf32>) -> () + +// CHECK-LABEL: func.func @test_parallel_fork_1 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<8x64x32xf32>, [[PARAM_1_:%.+]]: tensor<32x32xf32>) -> (tensor<8x64x32xf32>, tensor<8x64x32xf32>) { +// CHECK-DAG: [[VAR_0_:%.+]] = onnx.Constant dense<1.000000e+00> : tensor<32x32xf32> +// CHECK-DAG: [[VAR_1_:%.+]]:2 = "onnx.Parallel"() ({ +// CHECK-DAG: [[VAR_2_:%.+]] = "onnx.Fork"() ({ +// CHECK: [[VAR_4_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<8x64x32xf32> +// CHECK: onnx.Yield [[VAR_4_]] : tensor<8x64x32xf32> +// CHECK: }) {id = 0 : si64} : () -> tensor<8x64x32xf32> +// CHECK-DAG: [[VAR_3_:%.+]] = "onnx.Fork"() ({ +// CHECK-DAG: [[VAR_4_1_:%.+]] = "onnx.MatMul"([[PARAM_0_]], [[VAR_0_]]) : (tensor<8x64x32xf32>, tensor<32x32xf32>) -> tensor<8x64x32xf32> +// CHECK: onnx.Yield [[VAR_4_1_]] : tensor<8x64x32xf32> +// CHECK: }) {id = 1 : si64} : () -> tensor<8x64x32xf32> +// CHECK: onnx.Yield [[VAR_2_]], [[VAR_3_]] : tensor<8x64x32xf32>, tensor<8x64x32xf32> +// CHECK: }) : () -> (tensor<8x64x32xf32>, tensor<8x64x32xf32>) +// CHECK: onnx.Return [[VAR_1_]]#0, [[VAR_1_]]#1 : tensor<8x64x32xf32>, tensor<8x64x32xf32> +// CHECK: } +} + +// ----- +