diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp index 53f7ce1ba4657..b7cceb1d1d998 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorResize.cpp @@ -9,11 +9,12 @@ namespace Dml constexpr NameAndIndex coordinateTransformationModes[] = { {"half_pixel", 0}, - {"pytorch_half_pixel", 1}, - {"align_corners", 2}, - {"asymmetric", 3}, - {"tf_half_pixel_for_nn", 4}, - {"tf_crop_and_resize", 5}, + {"half_pixel_symmetric", 1}, + {"pytorch_half_pixel", 2}, + {"align_corners", 3}, + {"asymmetric", 4}, + {"tf_half_pixel_for_nn", 5}, + {"tf_crop_and_resize", 6}, }; constexpr NameAndIndex nearestNeighborRoundingModes[] = @@ -50,7 +51,7 @@ void ComputePixelOffsetsAndScales( uint32_t coordinateTransformationModeValue = *optionalCoordinateTransformationModeValue; ML_CHECK_VALID_ARGUMENT( - !regionOfInterest.empty() || coordinateTransformationModeValue != 5 /*tf_crop_and_resize*/, + !regionOfInterest.empty() || coordinateTransformationModeValue != 6 /*tf_crop_and_resize*/, "Resize expects 'roi' tensor for 'tf_crop_and_resize' mode." ); @@ -88,6 +89,18 @@ void ComputePixelOffsetsAndScales( break; case 1: + // coordinate_transformation_mode is "half_pixel_symmetric", + // adjustment = output_width_int / output_width + // center = input_width / 2 + // offset = center * (1 - adjustment) + // x_original = (x + 0.5) / scale - (0.5 - offset) + // x_original = (x + 0.5) / scale - (0.5 - [(input_width / 2) * (1 - (output_width_int / output_width))]) + // output_width can be fractional when calculated with scale factor + inputPixelOffset = 0.5f - float((inputDimensions[i] / 2.0f) * (1.0f - outputDimensions[i] / (scales[i] * inputDimensions[i]))); + outputPixelOffset = -0.5; + break; + + case 2: // if coordinate_transformation_mode is "pytorch_half_pixel", // x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0 if (inputDimensions[i] <= 1) @@ -104,7 +117,7 @@ void ComputePixelOffsetsAndScales( } break; - case 2: + case 3: // if coordinate_transformation_mode is "align_corners", // x_original = x_resized * (length_original - 1) / (length_resized - 1) inputPixelOffset = 0.0; @@ -121,7 +134,7 @@ void ComputePixelOffsetsAndScales( } break; - case 3: + case 4: // if coordinate_transformation_mode is "asymmetric", // x_original = x_resized / scale inputPixelOffset = 0.0; @@ -129,7 +142,7 @@ void ComputePixelOffsetsAndScales( // Keep existing scales. break; - case 4: + case 5: // if coordinate_transformation_mode is "tf_half_pixel_for_nn", // x_original = (x_resized + 0.5) / scale inputPixelOffset = 0.0; @@ -137,7 +150,7 @@ void ComputePixelOffsetsAndScales( // Keep existing scales. break; - case 5: + case 6: // if coordinate_transformation_mode is "tf_crop_and_resize", // x_original = length_resized > 1 ? start_x * (length_original - 1) + x_resized * (end_x - start_x) * (length_original - 1) / (length_resized - 1) // : 0.5 * (start_x + end_x) * (length_original - 1) @@ -357,6 +370,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel DML_OP_DEFINE_CREATION_FUNCTION(Resize13, VersionedKernel); #if DML_TARGET_VERSION >= 0x6300 DML_OP_DEFINE_CREATION_FUNCTION(Resize18, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Resize19, VersionedKernel); #endif DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 40b054549afbc..cc202fd7b8e30 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -507,6 +507,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Trilu); #if DML_TARGET_VERSION >= 0x6300 DML_OP_EXTERN_CREATION_FUNCTION(Col2Im); DML_OP_EXTERN_CREATION_FUNCTION(Resize18); +DML_OP_EXTERN_CREATION_FUNCTION(Resize19); #endif DML_OP_EXTERN_CREATION_FUNCTION(Shape); @@ -965,6 +966,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 13, Resize, typeNameListTwo, supportedTypeListResize13, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, #if DML_TARGET_VERSION >= 0x6300 {REG_INFO_VER( 18, Resize, typeNameListTwo, supportedTypeListResize18, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, + {REG_INFO_VER( 19, Resize, typeNameListTwo, supportedTypeListResize18, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)}, #endif // Activation Functions {REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 8820ad22b01d7..e25816811426b 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1655,6 +1655,7 @@ using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper; using ShapeInferenceHelper_Resize11 = VersionedOpsetHelper; using ShapeInferenceHelper_Resize13 = VersionedOpsetHelper; using ShapeInferenceHelper_Resize18 = VersionedOpsetHelper; +using ShapeInferenceHelper_Resize19 = VersionedOpsetHelper; using ShapeInferenceHelper_OneHot = OneHotHelper; using ShapeInferenceHelper_Sqrt = GetOutputShapeAsInputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 7cd5343d4db63..798adde905448 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -414,6 +414,7 @@ namespace OperatorHelper namespace OnnxOperatorSet19 { static const int sc_sinceVer_AveragePool = 19; + static const int sc_sinceVer_Resize = 19; } namespace MsftOperatorSet1