Skip to content

Commit

Permalink
Restrict RESAMPLE3 changes to opset18+
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheil Kumar committed Jul 15, 2024
1 parent aeb4b95 commit a4324c3
Showing 1 changed file with 29 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -263,11 +263,6 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper
std::string mode = kernelCreationContext.GetOptionalAttribute<std::string>(AttrName::Mode, "NEAREST");
DML_INTERPOLATION_MODE interpolationMode = Dml::MapStringToInteropolationMode(mode);


#if DML_TARGET_VERSION >= 0x6400
const int antialiased = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Antialiased, 0);
#endif

// Map ONNX to DML's mode using offsets and rounding direction.
// These offsets are in addition to the coordinate transform offsets.
DML_AXIS_DIRECTION roundingDirection = DML_AXIS_DIRECTION_DECREASING;
Expand Down Expand Up @@ -307,25 +302,35 @@ class DmlOperatorResize : public DmlOperator, public ResizeHelper
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

#if DML_TARGET_VERSION >= 0x6400
DML_RESAMPLE3_OPERATOR_DESC operatorDesc = {};
operatorDesc.Antialiased = static_cast<BOOL>(antialiased);
#else
DML_RESAMPLE2_OPERATOR_DESC operatorDesc = {};
#endif
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.InterpolationMode = interpolationMode;
operatorDesc.RoundingDirection = roundingDirection;
operatorDesc.Scales = paddedScales.data();
operatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(paddedScales.size());
operatorDesc.InputPixelOffsets = inputPixelOffsets.data();
operatorDesc.OutputPixelOffsets = outputPixelOffsets.data();
#if DML_TARGET_VERSION >= 0x6400
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE3, &operatorDesc };
#else
DML_OPERATOR_DESC opDesc = { DML_OPERATOR_RESAMPLE2, &operatorDesc };
#endif
DML_OPERATOR_DESC opDesc = {};
if (opsetVersion >= 18) {
// Restrict this change to Resize18 and Resize19
const int antialiased = kernelCreationContext.GetOptionalAttribute<int>(AttrName::Antialiased, 0);
DML_RESAMPLE3_OPERATOR_DESC operatorDesc = {};
operatorDesc.Antialiased = static_cast<BOOL>(antialiased);
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.InterpolationMode = interpolationMode;
operatorDesc.RoundingDirection = roundingDirection;
operatorDesc.Scales = paddedScales.data();
operatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(paddedScales.size());
operatorDesc.InputPixelOffsets = inputPixelOffsets.data();
operatorDesc.OutputPixelOffsets = outputPixelOffsets.data();
opDesc = { DML_OPERATOR_RESAMPLE3, &operatorDesc };
}
else {

Check warning on line 321 in onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 If an else has a brace on one side, it should have it on both [readability/braces] [5] Raw Output: onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp:321: If an else has a brace on one side, it should have it on both [readability/braces] [5]
DML_RESAMPLE2_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.InterpolationMode = interpolationMode;
operatorDesc.RoundingDirection = roundingDirection;
operatorDesc.Scales = paddedScales.data();
operatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(paddedScales.size());
operatorDesc.InputPixelOffsets = inputPixelOffsets.data();
operatorDesc.OutputPixelOffsets = outputPixelOffsets.data();
opDesc = { DML_OPERATOR_RESAMPLE2, &operatorDesc };
}

SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};
Expand Down

0 comments on commit a4324c3

Please sign in to comment.