From a4324c3ca510bcc541b2fcca9df4fc7c3f555565 Mon Sep 17 00:00:00 2001 From: Sheil Kumar Date: Mon, 15 Jul 2024 08:21:50 -0700 Subject: [PATCH] Restrict RESAMPLE3 changes to opset18+ --- .../src/Operators/DmlOperatorResize.cpp | 53 ++++++++++--------- 1 file changed, 29 insertions(+), 24 deletions(-) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp index 5256e01f86fb6..26ea940a0fdec 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp @@ -263,11 +263,6 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper std::string mode = kernelCreationContext.GetOptionalAttribute(AttrName::Mode, "NEAREST"); DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode); - -#if DML_TARGET_VERSION >= 0x6400 - const int antialiased = kernelCreationContext.GetOptionalAttribute(AttrName::Antialiased, 0); -#endif - // Map ONNX to DML's mode using offsets and rounding direction. // These offsets are in addition to the coordinate transform offsets. DML_AXIS_DIRECTION roundingDirection = DML_AXIS_DIRECTION_DECREASING; @@ -307,25 +302,35 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper std::vector inputDescs = GetDmlInputDescs(); std::vector outputDescs = GetDmlOutputDescs(); -#if DML_TARGET_VERSION >= 0x6400 - DML_RESAMPLE3_OPERATOR_DESC operatorDesc = {}; - operatorDesc.Antialiased = static_cast(antialiased); -#else - DML_RESAMPLE2_OPERATOR_DESC operatorDesc = {}; -#endif - operatorDesc.InputTensor = inputDescs.data(); - operatorDesc.OutputTensor = outputDescs.data(); - operatorDesc.InterpolationMode = interpolationMode; - operatorDesc.RoundingDirection = roundingDirection; - operatorDesc.Scales = paddedScales.data(); - operatorDesc.DimensionCount = gsl::narrow_cast(paddedScales.size()); - operatorDesc.InputPixelOffsets = inputPixelOffsets.data(); - operatorDesc.OutputPixelOffsets = outputPixelOffsets.data(); -#if DML_TARGET_VERSION >= 0x6400 - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE3, &operatorDesc }; -#else - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE2, &operatorDesc }; -#endif + DML_OPERATOR_DESC opDesc = {}; + if (opsetVersion >= 18) { + // Restrict this change to Resize18 and Resize19 + const int antialiased = kernelCreationContext.GetOptionalAttribute(AttrName::Antialiased, 0); + DML_RESAMPLE3_OPERATOR_DESC operatorDesc = {}; + operatorDesc.Antialiased = static_cast(antialiased); + operatorDesc.InputTensor = inputDescs.data(); + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.InterpolationMode = interpolationMode; + operatorDesc.RoundingDirection = roundingDirection; + operatorDesc.Scales = paddedScales.data(); + operatorDesc.DimensionCount = gsl::narrow_cast(paddedScales.size()); + operatorDesc.InputPixelOffsets = inputPixelOffsets.data(); + operatorDesc.OutputPixelOffsets = outputPixelOffsets.data(); + opDesc = { DML_OPERATOR_RESAMPLE3, &operatorDesc }; + } + else { + DML_RESAMPLE2_OPERATOR_DESC operatorDesc = {}; + operatorDesc.InputTensor = inputDescs.data(); + operatorDesc.OutputTensor = outputDescs.data(); + operatorDesc.InterpolationMode = interpolationMode; + operatorDesc.RoundingDirection = roundingDirection; + operatorDesc.Scales = paddedScales.data(); + operatorDesc.DimensionCount = gsl::narrow_cast(paddedScales.size()); + operatorDesc.InputPixelOffsets = inputPixelOffsets.data(); + operatorDesc.OutputPixelOffsets = outputPixelOffsets.data(); + opDesc = { DML_OPERATOR_RESAMPLE2, &operatorDesc }; + } + SetDmlOperatorDesc(opDesc, kernelCreationContext); } };