Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML] Register DML operators for opset 19 #16939

Merged
merged 14 commits into from
Jan 22, 2024
27 changes: 18 additions & 9 deletions docs/OperatorKernels.md

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DmlOperatorCast : public DmlOperator
castDesc.OutputTensor = outputDescs.data();

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc };

SetDmlOperatorDesc(opDesc, kernelInfo);
}

Expand All @@ -49,5 +49,6 @@ class DmlOperatorCast : public DmlOperator

DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast);
DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast);
DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@
Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
Expand All @@ -497,11 +497,11 @@
SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo);
}
else
{
{
Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0));

std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {};
opDesc.InputTensor = &inputDescs[0];
Expand All @@ -519,13 +519,16 @@
public:
DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo)
{
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3);

Check warning on line 522 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp#L522

Redundant blank line at the start of a code block should be deleted. [whitespace/blank_line] [2]
Raw output
onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp:522:  Redundant blank line at the start of a code block should be deleted.  [whitespace/blank_line] [2]
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2);
ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1);

Initialize(kernelInfo, std::nullopt, std::nullopt);

std::vector<uint32_t> outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0);
const uint32_t outputShapeDimCount = gsl::narrow_cast<uint32_t>(outputShape.size());

Initialize(kernelInfo, std::nullopt, std::nullopt);
const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType();
bool hasZeroPointTensor = kernelInfo.IsInputValid(2);

uint32_t axis = 0;

Expand All @@ -541,9 +544,14 @@
axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false);
}

// Explicitly reshape each of the inputs after the first input (scale and zero point tensors).
// Explicitly reshape each of the inputs after the first input (scale tensor and optional zero point tensor).
for (uint32_t index = 1, inputCount = gsl::narrow_cast<uint32_t>(m_inputTensorDescs.size()); index < inputCount; ++index)
linnealovespie marked this conversation as resolved.
Show resolved Hide resolved
{
if (!kernelInfo.IsInputValid(index))
{
continue;
}

auto edgeDesc = kernelInfo.GetInputEdgeDescription(index);
assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor);

Expand Down Expand Up @@ -583,12 +591,8 @@
TOperatorDesc opDesc = {};
opDesc.InputTensor = &inputDescs[0];
opDesc.ScaleTensor = &inputDescs[1];
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.ZeroPointTensor = hasZeroPointTensor ? &inputDescs[2] : nullptr;
opDesc.OutputTensor = &outputDescs[0];

TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1);
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2);

SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMul);
DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMulActivation);
DML_OP_EXTERN_CREATION_FUNCTION(Cast);
DML_OP_EXTERN_CREATION_FUNCTION(CastLike15);
DML_OP_EXTERN_CREATION_FUNCTION(CastLike19);
DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost);
DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost);
DML_OP_EXTERN_CREATION_FUNCTION(TopK7);
Expand Down Expand Up @@ -785,6 +786,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_COPY(13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(19, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
{REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
Expand All @@ -798,6 +800,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},

// Elementwise
{REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
Expand Down Expand Up @@ -857,8 +860,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)},
{REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)},
{REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)},
Expand Down Expand Up @@ -943,6 +948,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)},
{REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)},
{REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)},
{REG_INFO( 19, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)},
{REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
{REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)},
Expand Down Expand Up @@ -1004,7 +1010,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)},
{REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)},
{REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)},
Expand All @@ -1015,8 +1023,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 7, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)},
{REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)},
{REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)},
{REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)},

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,7 @@ using ShapeInferenceHelper_Expand = ExpandHelper;
using ShapeInferenceHelper_Reshape7 = ReshapeHelper;
using ShapeInferenceHelper_Reshape13 = ReshapeHelper;
using ShapeInferenceHelper_Reshape14 = ReshapeHelper;
using ShapeInferenceHelper_Reshape19 = ReshapeHelper;
using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper;
using ShapeInferenceHelper_Tile = TileHelper;
using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper<ResizeHelper, 10>;
Expand Down Expand Up @@ -1725,6 +1726,7 @@ using ShapeInferenceHelper_Identity7 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_MatMul = MatMulHelper;
using ShapeInferenceHelper_MatMulInteger = MatMulHelper;
using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper;
Expand All @@ -1750,6 +1752,7 @@ using ShapeInferenceHelper_CumSum14 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Range = RangeHelper;

using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper;

using ShapeInferenceHelper_DmlFusedConv = ConvHelper;
using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,16 @@ namespace OperatorHelper
namespace OnnxOperatorSet19
{
static const int sc_sinceVer_AveragePool = 19;
static const int sc_sinceVer_Cast = 19;
linnealovespie marked this conversation as resolved.
Show resolved Hide resolved
static const int sc_sinceVer_CastLike = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_CastLike' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_CastLike' can be computed at compile-time. Consider using constexpr (con.5).
static const int sc_sinceVer_Constant = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Constant' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Constant' can be computed at compile-time. Consider using constexpr (con.5).
static const int sc_sinceVer_Equal = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Equal' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Equal' can be computed at compile-time. Consider using constexpr (con.5).
static const int sc_sinceVer_Identity = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Identity' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Identity' can be computed at compile-time. Consider using constexpr (con.5).
static const int sc_sinceVer_QuantizeLinear = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_QuantizeLinear' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_QuantizeLinear' can be computed at compile-time. Consider using constexpr (con.5).
static const int sc_sinceVer_DequantizeLinear = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_DequantizeLinear' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_DequantizeLinear' can be computed at compile-time. Consider using constexpr (con.5).
static const int sc_sinceVer_Reshape = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Reshape' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Reshape' can be computed at compile-time. Consider using constexpr (con.5).
static const int sc_sinceVer_Shape = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Shape' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Shape' can be computed at compile-time. Consider using constexpr (con.5).
static const int sc_sinceVer_Size = 19;

Check warning

Code scanning / PREfast

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Size' can be computed at compile-time. Consider using constexpr (con.5).

The const variable 'OperatorHelper::OnnxOperatorSet19::sc_sinceVer_Size' can be computed at compile-time. Consider using constexpr (con.5).
}

namespace MsftOperatorSet1
Expand Down
10 changes: 0 additions & 10 deletions onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ TEST(DequantizeLinearOpTest, Int8) {

// scalar zero & scale with int8
TEST(DequantizeLinearOpTest, Int32) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect";
}
PatriceVignola marked this conversation as resolved.
Show resolved Hide resolved

OpTester test("DequantizeLinear", 10);
std::vector<int64_t> dims{4};
test.AddInput<int32_t>("x", dims, {-30, -3, 100, 127});
Expand Down Expand Up @@ -88,11 +83,6 @@ TEST(DequantizeLinearOpMLFloat16Test, Scalar) {

// dequantize without zero point
TEST(DequantizeLinearOpTest, Without_Zero_Point) {
// TODO: Unskip when fixed #41968513
if (DefaultDmlExecutionProvider().get() != nullptr) {
GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect";
}

OpTester test("DequantizeLinear", 10);
test.AddInput<int8_t>("x", {}, {100});
test.AddInput<float>("x_scale", {}, {2.0f});
Expand Down
Loading