Skip to content

Commit

Permalink
[DirectML EP] Add DML EP registration for Col2Im (#17786)
Browse files Browse the repository at this point in the history
### Description
[DirectML EP] Add DML EP registration for Col2Im operator

### Motivation and Context
Add Col2Im support for opset 18.
This operator is implemented as the DirectML Fold operator.

---------

Co-authored-by: Sheil Kumar <[email protected]>
Co-authored-by: Dwayne Robinson <[email protected]>
  • Loading branch information
3 people authored Dec 9, 2023
1 parent d02ca45 commit fc4549e
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 8 deletions.
4 changes: 2 additions & 2 deletions cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,11 @@ else()
if (dml_EXTERNAL_PROJECT)
set(dml_preset_config $<IF:$<CONFIG:Debug>,debug,release>)
set(dml_preset_name ${onnxruntime_target_platform}-win-redist-${dml_preset_config})
target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1)
include(ExternalProject)
ExternalProject_Add(
directml_repo
GIT_REPOSITORY https://dev.azure.com/microsoft/WindowsAI/_git/DirectML
GIT_TAG d460f0f46967bea878786f1bed69487692c779bf
GIT_TAG a5312f72c51864b4d705ac62d25d08bcd88c4fb1
GIT_SHALLOW OFF # not allowed when GIT_TAG is a commit SHA, which is preferred (it's stable, unlike branches)
GIT_PROGRESS ON
BUILD_IN_SOURCE ON
Expand All @@ -94,6 +93,7 @@ else()
target_link_libraries(DirectML INTERFACE ${directml_install_path}/lib/DirectML.lib)
add_dependencies(DirectML directml_repo-install)
include_directories(BEFORE ${directml_install_path}/include)
target_compile_definitions(DirectML INTERFACE DML_TARGET_VERSION_USE_LATEST=1)
else()
include_directories(BEFORE ${dml_INCLUDE_DIR})
set(DML_PACKAGE_DIR ${dml_INCLUDE_DIR}/..)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "./precomp.h"

namespace Dml
{

class DmlOperatorCol2Im : public DmlOperator, public Col2ImHelper
{
public:
explicit DmlOperatorCol2Im(const MLOperatorKernelCreationContext& kernelCreationContext)
: DmlOperator(kernelCreationContext),
Col2ImHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
{
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() == 3, "Col2Im expects 3 inputs.");
ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() == 1, "Col2Im expects 1 output.");

auto tensorShapeDescription = kernelCreationContext.GetTensorShapeDescription();
std::vector<uint32_t> inputTensorShape = tensorShapeDescription.GetInputTensorShape(0);
std::vector<uint32_t> outputTensorShape = tensorShapeDescription.GetOutputTensorShape(0);

ML_CHECK_VALID_ARGUMENT(outputTensorShape == m_outputShape);

std::vector<std::optional<uint32_t>> inputIndices = { 0 };
gsl::span<const uint32_t> inputShapes[1] = { m_inputShape };
gsl::span<const uint32_t> outputShapes[1] = { m_outputShape };
DmlOperator::InitializeWithShapes(
kernelCreationContext,
inputIndices,
std::nullopt,
inputShapes,
outputShapes,
3
);
// Prepare DML_FOLD_OPERATOR_DESC
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();
assert(inputDescs.size() == 1);
assert(outputDescs.size() == 1);

DML_FOLD_OPERATOR_DESC operatorDesc = {};
operatorDesc.InputTensor = inputDescs.data();
operatorDesc.OutputTensor = outputDescs.data();
operatorDesc.DimensionCount = gsl::narrow_cast<uint32_t>(m_blockShape.size());
operatorDesc.WindowSizes = m_blockShape.data();
operatorDesc.Dilations = m_dilations.data();
operatorDesc.StartPadding = m_pads.data();
operatorDesc.EndPadding = m_pads.data();
operatorDesc.Strides = m_strides.data();

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_FOLD, &operatorDesc };
SetDmlOperatorDesc(opDesc, kernelCreationContext);
}
};

DML_OP_DEFINE_CREATION_FUNCTION(Col2Im, DmlOperatorCol2Im);

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(MatMulInteger);
DML_OP_EXTERN_CREATION_FUNCTION(MatMulIntegerToFloat);
DML_OP_EXTERN_CREATION_FUNCTION(ConvInteger);
DML_OP_EXTERN_CREATION_FUNCTION(Trilu);
DML_OP_EXTERN_CREATION_FUNCTION(Col2Im);
DML_OP_EXTERN_CREATION_FUNCTION(Shape);
DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
Expand Down Expand Up @@ -789,6 +790,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO( 16, ScatterND, typeNameListScatterGatherND, supportedTypeListScatterGatherND, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryScatter)},
{REG_INFO( 9, EyeLike, typeNameListEyeLike, supportedTypeListEyeLike, DmlGraphSupport::Supported)},
{REG_INFO( 14, Trilu, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))},
{REG_INFO( 18, Col2Im, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2))},

// Data reorganization that merely changes the dimensions while keeping the data identical.
{REG_INFO_COPY( 7, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,15 @@ namespace OperatorHelper
}
}

void DowncastDimensions(gsl::span<const int64_t> inputDimensions, std::vector<DimensionType>& outputDimensions)
template <typename T>
void DowncastDimensions(gsl::span<T> inputDimensions, std::vector<DimensionType>& outputDimensions)
{
outputDimensions.reserve(inputDimensions.size());
outputDimensions.clear();

for (int64_t dim : inputDimensions)
for (T dim : inputDimensions)
{
outputDimensions.push_back(gsl::narrow_cast<uint32_t>(std::clamp<int64_t>(dim, INT32_MIN, INT32_MAX)));
outputDimensions.push_back(gsl::narrow_cast<DimensionType>(std::clamp<T>(dim, INT32_MIN, INT32_MAX)));
}
}

Expand Down Expand Up @@ -1870,6 +1871,64 @@ namespace OperatorHelper
return { std::move(outputShape) };
}

void Col2ImHelper::Initialize(
const IKernelInformationAdapter& kernelInformation,
const IShapeInformationAdapter& shapeInformation)
{
std::vector<int> shapeData;
ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(1), /*out*/ shapeData);
m_imageShape.resize(shapeData.size());
DowncastDimensions(gsl::span(shapeData), /*out*/ m_imageShape);
ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(2), /*out*/ shapeData);
m_blockShape.resize(shapeData.size());
DowncastDimensions(gsl::span(shapeData), /*out*/ m_blockShape);

const uint32_t dimCount = gsl::narrow_cast<uint32_t>(m_blockShape.size());
m_dilations = {dimCount, 1};
m_pads = {dimCount * 2, 0};
m_strides = {dimCount, 1};

if (kernelInformation.HasAttribute(AttrName::Dilations, MLOperatorAttributeType::IntArray))
{
shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Dilations);
m_dilations.resize(shapeData.size());
DowncastDimensions(gsl::span(shapeData), /*out*/ m_dilations);
ML_CHECK_VALID_ARGUMENT(m_dilations.size() == dimCount);
}

if (kernelInformation.HasAttribute(AttrName::Pads, MLOperatorAttributeType::IntArray))
{
shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Pads);
m_pads.resize(shapeData.size());
DowncastDimensions(gsl::span(shapeData), /*out*/ m_pads);
ML_CHECK_VALID_ARGUMENT(m_pads.size() == dimCount * 2);
}

if (kernelInformation.HasAttribute(AttrName::Strides, MLOperatorAttributeType::IntArray))
{
shapeData = kernelInformation.GetAttributes().GetOptionalAttributeVectorInt32(AttrName::Strides);
m_strides.resize(shapeData.size());
DowncastDimensions(gsl::span(shapeData), /*out*/ m_strides);
ML_CHECK_VALID_ARGUMENT(m_strides.size() == dimCount);
}

m_inputShape = shapeInformation.GetInputTensorShape(0);

auto blockShapeProduct = ComputeElementCountFromDimensions(m_blockShape);
m_outputShape.resize(2 + m_imageShape.size());
m_outputShape[0] = m_inputShape[0]; // N
m_outputShape[1] = m_inputShape[1] / blockShapeProduct; // C
for (int i = 2; i < m_outputShape.size(); i++)
{
m_outputShape[i] = m_imageShape[i - 2];
};
}

std::vector<EdgeShapes> Col2ImHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
return { EdgeShapes(m_outputShape) };
}

void ConcatHelperBase::Initialize(
const MLOperatorAttributes& operatorAttributes,
gsl::span<const DimensionType> inputDimensions
Expand Down Expand Up @@ -2020,7 +2079,7 @@ namespace OperatorHelper
}
return outputShapes;
}

std::vector<EdgeShapes> QLinearAveragePoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto inputShape = shapeInfo.GetInputTensorShape(0);
Expand Down Expand Up @@ -2050,7 +2109,7 @@ namespace OperatorHelper
}
return outputShapes;
}

std::vector<EdgeShapes> RoiPoolingHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
auto roiShape = shapeInfo.GetInputTensorShape(InputTensors::ROIS);
Expand Down Expand Up @@ -2113,7 +2172,7 @@ namespace OperatorHelper
{
std::vector<int64_t> outputDimensions64bit = shapeInfo.GetAttributeVector<int64_t>(AttrName::OutputShape);
ML_CHECK_VALID_ARGUMENT(outputDimensions64bit.size() == m_inputShape.size(), "Input dimensions and output_shape must have same rank.");
DowncastDimensions(outputDimensions64bit, /*out*/ outputDimensions);
DowncastDimensions(gsl::span(outputDimensions64bit), /*out*/ outputDimensions);
}
else
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,34 @@ class SqueezeHelper
std::vector<int> m_axes;
};

class Col2ImHelper
{
public:
void Initialize(
const IKernelInformationAdapter& kernelInformation,
const IShapeInformationAdapter& shapeInformation);

// Info_t is used to obtain attributes which will be used for calculating the output shape later.
// Shape_t is used to obtain input shape which will be used for adjusting attribute value.
template <typename Info_t, typename Shape_t>
Col2ImHelper(const Info_t& info, const Shape_t& shape)
{
Initialize(KernelInformationAdapter(info), ShapeInformationAdapter(shape));
}

std::vector<EdgeShapes> GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;

protected:
std::vector<uint32_t> m_dilations;
std::vector<uint32_t> m_pads;
std::vector<uint32_t> m_strides;
std::vector<uint32_t> m_imageShape;
std::vector<uint32_t> m_blockShape;
std::vector<uint32_t> m_inputShape;
std::vector<uint32_t> m_outputShape;
};


class UnsqueezeHelper
{
public:
Expand Down Expand Up @@ -1595,6 +1623,7 @@ using ShapeInferenceHelper_Unsqueeze11 = VersionedOpsetHelper<UnsqueezeHelper, 1
using ShapeInferenceHelper_Unsqueeze13 = VersionedOpsetHelper<UnsqueezeHelper, 13>;
using ShapeInferenceHelper_EyeLike = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Trilu = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Col2Im = Col2ImHelper;

using ShapeInferenceHelper_Expand = ExpandHelper;
using ShapeInferenceHelper_Reshape7 = ReshapeHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Pad = 18;
static const int sc_sinceVer_Split = 18;
static const int sc_sinceVer_LpPool = 18;
static const int sc_sinceVer_Col2Im = 18;
}

namespace OnnxOperatorSet19
Expand Down

0 comments on commit fc4549e

Please sign in to comment.