Skip to content

Commit

Permalink
Resize-19
Browse files Browse the repository at this point in the history
  • Loading branch information
Linnea May committed Jan 25, 2024
2 parents 8133f88 + 77b7194 commit 59e75cf
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ namespace Dml
constexpr NameAndIndex coordinateTransformationModes[] =
{
{"half_pixel", 0},
{"pytorch_half_pixel", 1},
{"align_corners", 2},
{"asymmetric", 3},
{"tf_half_pixel_for_nn", 4},
{"tf_crop_and_resize", 5},
{"half_pixel_symmetric", 1},
{"pytorch_half_pixel", 2},
{"align_corners", 3},
{"asymmetric", 4},
{"tf_half_pixel_for_nn", 5},
{"tf_crop_and_resize", 6},
};

constexpr NameAndIndex nearestNeighborRoundingModes[] =
Expand Down Expand Up @@ -50,7 +51,7 @@ void ComputePixelOffsetsAndScales(
uint32_t coordinateTransformationModeValue = *optionalCoordinateTransformationModeValue;

ML_CHECK_VALID_ARGUMENT(
!regionOfInterest.empty() || coordinateTransformationModeValue != 5 /*tf_crop_and_resize*/,
!regionOfInterest.empty() || coordinateTransformationModeValue != 6 /*tf_crop_and_resize*/,
"Resize expects 'roi' tensor for 'tf_crop_and_resize' mode."
);

Expand Down Expand Up @@ -88,6 +89,18 @@ void ComputePixelOffsetsAndScales(
break;

case 1:
// coordinate_transformation_mode is "half_pixel_symmetric",
// adjustment = output_width_int / output_width
// center = input_width / 2
// offset = center * (1 - adjustment)
// x_original = (x + 0.5) / scale - (0.5 - offset)
// x_original = (x + 0.5) / scale - (0.5 - [(input_width / 2) * (1 - (output_width_int / output_width))])
// output_width can be fractional when calculated with scale factor
inputPixelOffset = 0.5f - float((inputDimensions[i] / 2.0f) * (1.0f - outputDimensions[i] / (scales[i] * inputDimensions[i])));
outputPixelOffset = -0.5;
break;

case 2:
// if coordinate_transformation_mode is "pytorch_half_pixel",
// x_original = length_resized > 1 ? (x_resized + 0.5) / scale - 0.5 : 0
if (inputDimensions[i] <= 1)
Expand All @@ -104,7 +117,7 @@ void ComputePixelOffsetsAndScales(
}
break;

case 2:
case 3:
// if coordinate_transformation_mode is "align_corners",
// x_original = x_resized * (length_original - 1) / (length_resized - 1)
inputPixelOffset = 0.0;
Expand All @@ -121,23 +134,23 @@ void ComputePixelOffsetsAndScales(
}
break;

case 3:
case 4:
// if coordinate_transformation_mode is "asymmetric",
// x_original = x_resized / scale
inputPixelOffset = 0.0;
outputPixelOffset = 0.0;
// Keep existing scales.
break;

case 4:
case 5:
// if coordinate_transformation_mode is "tf_half_pixel_for_nn",
// x_original = (x_resized + 0.5) / scale
inputPixelOffset = 0.0;
outputPixelOffset = -0.5;
// Keep existing scales.
break;

case 5:
case 6:
// if coordinate_transformation_mode is "tf_crop_and_resize",
// x_original = length_resized > 1 ? start_x * (length_original - 1) + x_resized * (end_x - start_x) * (length_original - 1) / (length_resized - 1)
// : 0.5 * (start_x + end_x) * (length_original - 1)
Expand Down Expand Up @@ -357,6 +370,7 @@ DML_OP_DEFINE_CREATION_FUNCTION(Resize11, VersionedKernel<DmlOperatorResize, 11>
DML_OP_DEFINE_CREATION_FUNCTION(Resize13, VersionedKernel<DmlOperatorResize, 13>);
#if DML_TARGET_VERSION >= 0x6300
DML_OP_DEFINE_CREATION_FUNCTION(Resize18, VersionedKernel<DmlOperatorResize, 18>);
DML_OP_DEFINE_CREATION_FUNCTION(Resize19, VersionedKernel<DmlOperatorResize, 19>);
#endif
DML_OP_DEFINE_CREATION_FUNCTION(Upsample7, VersionedKernel<DmlOperatorResize, 7>);
DML_OP_DEFINE_CREATION_FUNCTION(Upsample9, VersionedKernel<DmlOperatorResize, 9>);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
#if DML_TARGET_VERSION >= 0x6300
DML_OP_EXTERN_CREATION_FUNCTION(Col2Im);
DML_OP_EXTERN_CREATION_FUNCTION(Resize18);
DML_OP_EXTERN_CREATION_FUNCTION(Resize19);
#endif

DML_OP_EXTERN_CREATION_FUNCTION(Shape);
Expand Down Expand Up @@ -965,6 +966,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_VER( 13, Resize, typeNameListTwo, supportedTypeListResize13, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
#if DML_TARGET_VERSION >= 0x6300
{REG_INFO_VER( 18, Resize, typeNameListTwo, supportedTypeListResize18, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
{REG_INFO_VER( 19, Resize, typeNameListTwo, supportedTypeListResize18, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*roi, scales, sizes*/, std::nullopt, QueryResize)},
#endif
// Activation Functions
{REG_INFO( 7, Sigmoid, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,7 @@ using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper<ResizeHelper, 10>;
using ShapeInferenceHelper_Resize11 = VersionedOpsetHelper<ResizeHelper, 11>;
using ShapeInferenceHelper_Resize13 = VersionedOpsetHelper<ResizeHelper, 13>;
using ShapeInferenceHelper_Resize18 = VersionedOpsetHelper<ResizeHelper, 18>;
using ShapeInferenceHelper_Resize19 = VersionedOpsetHelper<ResizeHelper, 19>;
using ShapeInferenceHelper_OneHot = OneHotHelper;

using ShapeInferenceHelper_Sqrt = GetOutputShapeAsInputShapeHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ namespace OperatorHelper
namespace OnnxOperatorSet19
{
static const int sc_sinceVer_AveragePool = 19;
static const int sc_sinceVer_Resize = 19;
}

namespace MsftOperatorSet1
Expand Down

0 comments on commit 59e75cf

Please sign in to comment.