Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick b8f373b0aee086a36ea357bc2ffb2944246be15a #17894

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x64/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="python" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.12.0" targetFramework="native" />
<package id="Microsoft.AI.DirectML.Preview" version="1.13.0-deveb7a0e89e82dcf90ae58761b35fb3aebc2275ef5" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion .pipelines/nuget_config/x86/packages.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<?xml version="1.0" encoding="utf-8"?>
<packages>
<package id="pythonx86" version="3.9.7" targetFramework="native" />
<package id="Microsoft.AI.DirectML" version="1.12.0" targetFramework="native" />
<package id="Microsoft.AI.DirectML.Preview" version="1.13.0-deveb7a0e89e82dcf90ae58761b35fb3aebc2275ef5" targetFramework="native" />
<package id="Microsoft.Windows.CppWinRT" version="2.0.201201.7" targetFramework="native" />
</packages>
2 changes: 1 addition & 1 deletion cmake/external/dml.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ if (NOT onnxruntime_USE_CUSTOM_DIRECTML)
set(NUGET_CONFIG ${PROJECT_SOURCE_DIR}/../NuGet.config)
set(PACKAGES_CONFIG ${PROJECT_SOURCE_DIR}/../packages.config)
get_filename_component(PACKAGES_DIR ${CMAKE_CURRENT_BINARY_DIR}/../packages ABSOLUTE)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.1.12.0)
set(DML_PACKAGE_DIR ${PACKAGES_DIR}/Microsoft.AI.DirectML.Preview.1.13.0-deveb7a0e89e82dcf90ae58761b35fb3aebc2275ef5)

# Restore nuget packages, which will pull down the DirectML redist package.
add_custom_command(
Expand Down
4 changes: 2 additions & 2 deletions cmake/onnxruntime_providers.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1314,13 +1314,13 @@ if (onnxruntime_USE_DML)
if (GDK_PLATFORM STREQUAL Scarlett)
target_link_libraries(onnxruntime_providers_dml PRIVATE ${gdk_dx_libs})
else()
target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib)
target_link_libraries(onnxruntime_providers_dml PRIVATE dxguid.lib d3d12.lib dxgi.lib dxcore.lib)
endif()

target_link_libraries(onnxruntime_providers_dml PRIVATE delayimp.lib)

if (NOT GDK_PLATFORM)
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /ignore:4199")
set(onnxruntime_DELAYLOAD_FLAGS "${onnxruntime_DELAYLOAD_FLAGS} /DELAYLOAD:DirectML.dll /DELAYLOAD:d3d12.dll /DELAYLOAD:dxgi.dll /DELAYLOAD:api-ms-win-core-com-l1-1-0.dll /DELAYLOAD:shlwapi.dll /DELAYLOAD:oleaut32.dll /DELAYLOAD:ext-ms-win-dxcore-l1-*.dll /ignore:4199")
endif()

target_compile_definitions(onnxruntime_providers_dml
Expand Down
32 changes: 32 additions & 0 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,31 @@ typedef struct IDMLDevice IDMLDevice;
extern "C" {
#endif

enum OrtDmlPerformancePreference {
Default = 0,
HighPerformance = 1,
MinimumPower = 2
};

enum OrtDmlDeviceFilter : uint32_t {
Any = 0xffffffff,
Gpu = 1 << 0,
Npu = 1 << 1,
};

inline OrtDmlDeviceFilter operator~(OrtDmlDeviceFilter a) { return (OrtDmlDeviceFilter) ~(int)a; }
inline OrtDmlDeviceFilter operator|(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a | (int)b); }
inline OrtDmlDeviceFilter operator&(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a & (int)b); }
inline OrtDmlDeviceFilter operator^(OrtDmlDeviceFilter a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter)((int)a ^ (int)b); }
inline OrtDmlDeviceFilter& operator|=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a |= (int)b); }
inline OrtDmlDeviceFilter& operator&=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a &= (int)b); }
inline OrtDmlDeviceFilter& operator^=(OrtDmlDeviceFilter& a, OrtDmlDeviceFilter b) { return (OrtDmlDeviceFilter&)((int&)a ^= (int)b); }

struct OrtDmlDeviceOptions {
OrtDmlPerformancePreference Preference;
OrtDmlDeviceFilter Filter;
};

/**
* [[deprecated]]
* This export is deprecated.
Expand Down Expand Up @@ -99,6 +124,13 @@ struct OrtDmlApi {
* This API gets the D3D12 resource when an OrtValue has been allocated by the DML EP.
*/
ORT_API2_STATUS(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* provider, _In_ void* dml_resource, _Out_ ID3D12Resource** d3d_resource);

/**
* SessionOptionsAppendExecutionProvider_DML2
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
};

#ifdef __cplusplus
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ namespace Windows::AI::MachineLearning::Adapter
{
uint32_t nodeCount;
std::vector<std::unique_ptr<AbstractOperatorDesc>> nodesAsOperatorDesc;

// TODO: Remove this
std::vector<Microsoft::WRL::ComPtr<IDMLOperator>> nodesAsIDMLOperator;

std::vector<DML_INPUT_GRAPH_EDGE_DESC> inputEdges;
std::vector<DML_OUTPUT_GRAPH_EDGE_DESC> outputEdges;
std::vector<DML_INTERMEDIATE_GRAPH_EDGE_DESC> intermediateEdges;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,38 +3,6 @@
#pragma once


struct DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC
{
const DML_TENSOR_DESC* ATensor;
const DML_TENSOR_DESC* AScaleTensor;
_Maybenull_ const DML_TENSOR_DESC* AZeroPointTensor;
const DML_TENSOR_DESC* BTensor;
const DML_TENSOR_DESC* BScaleTensor;
_Maybenull_ const DML_TENSOR_DESC* BZeroPointTensor;
_Maybenull_ const DML_TENSOR_DESC* BiasTensor;
const DML_TENSOR_DESC* OutputTensor;
};
const int DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT = 0x80000011;

struct DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC
{
const DML_TENSOR_DESC* InputTensor;
const DML_TENSOR_DESC* InputScaleTensor;
_Maybenull_ const DML_TENSOR_DESC* InputZeroPointTensor;
const DML_TENSOR_DESC* OutputScaleTensor;
_Maybenull_ const DML_TENSOR_DESC* OutputZeroPointTensor;
const DML_TENSOR_DESC* OutputTensor;
UINT DimensionCount;
_Field_size_(DimensionCount) const UINT* Strides;
_Field_size_(DimensionCount) const UINT* WindowSize;
_Field_size_(DimensionCount) const UINT* StartPadding;
_Field_size_(DimensionCount) const UINT* EndPadding;
_Field_size_(DimensionCount) const UINT* Dilations;
BOOL IncludePadding;
};
const int DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING = 0x8000000B;


namespace ApiTraits
{
template <typename T>
Expand Down Expand Up @@ -2711,11 +2679,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value)
case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2";
case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1";
case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1";
#pragma warning(push)
#pragma warning(disable: 4063)
case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT";
case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION";
#pragma warning(pop)
default:
assert(false);
return "<unknown>";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,11 @@ namespace Dml::GraphDescBuilder
// This is a highly inefficient approach to generating constant nodes. It duplicates constant data
// across the graph input as well as every consumer's unique constant node. However it is currently
// only used for small inputs.
uint32_t c_maxConstNodeDataSize = 64;

// TODO: Rework this to create DML constant nodes with the minimum data size actually used by consuming
// nodes. This would allow this size to be reduced while handling the case that 1D scale and zero point
// values that have been de-duplicated with conversion to scalars in kernels.
uint32_t c_maxConstNodeDataSize = 1024 * 1024;

ComPtr<OnnxTensorWrapper> constantInput = constantCpuGraphInputGetter(arg->Name());

Expand Down Expand Up @@ -405,6 +409,15 @@ namespace Dml::GraphDescBuilder
auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i];

DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator);

// TODO: Change as new header is ingested
if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
dmlDesc.Type = (DML_OPERATOR_TYPE) 169;

// TODO: Change as new header is ingested
if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
dmlDesc.Type = (DML_OPERATOR_TYPE) 170;

ComPtr<IDMLOperator> op;
ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op)));
allocator.Reset();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

namespace Dml
{

/*static*/ const uint32_t DmlOperator::zeroArray[8] = {};

DmlOperator::DmlOperator(const MLOperatorKernelCreationContext& kernelInfo)
{
ML_CHECK_HRESULT(kernelInfo.GetExecutionInterface().As(&m_executionProvider));
Expand Down Expand Up @@ -38,10 +41,6 @@ namespace Dml
}
}

// Create and compile the operator.
ComPtr<IDMLOperator> dmlOperator;
ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDesc, IID_PPV_ARGS(&dmlOperator)));

ComPtr<IMLOperatorKernelCreationContextPrivate> contextPrivate;
ORT_THROW_IF_FAILED(kernelInfo.GetInterface()->QueryInterface(contextPrivate.GetAddressOf()));

Expand Down Expand Up @@ -87,6 +86,20 @@ namespace Dml
}
else
{
auto operatorDescCopy = operatorDesc;

// TODO: Change as new header is ingested
if (operatorDescCopy.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
operatorDescCopy.Type = (DML_OPERATOR_TYPE) 169;

// TODO: Change as new header is ingested
if (operatorDescCopy.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
operatorDescCopy.Type = (DML_OPERATOR_TYPE) 170;

// Create and compile the operator.
ComPtr<IDMLOperator> dmlOperator;
ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&operatorDescCopy, IID_PPV_ARGS(&dmlOperator)));

DML_EXECUTION_FLAGS executionFlags = GetExecutionFlags();
ORT_THROW_IF_FAILED(m_dmlDevice->CompileOperator(dmlOperator.Get(), executionFlags, IID_PPV_ARGS(&m_compiledOperator)));

Expand Down Expand Up @@ -792,8 +805,18 @@ namespace Dml
graphDesc.NodeCount = operatorGraphDesc.nodeCount;
for (size_t i = 0; i < graphDesc.NodeCount; ++i)
{
DML_OPERATOR_DESC opDesc = *operatorGraphDesc.nodesAsOpDesc[i];

// TODO: Change as new header is ingested
if (opDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING)
opDesc.Type = (DML_OPERATOR_TYPE) 169;

// TODO: Change as new header is ingested
if (opDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT)
opDesc.Type = (DML_OPERATOR_TYPE) 170;

// Create the operator.
ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodesAsOpDesc[i], IID_PPV_ARGS(&dmlOperators[i])));
ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(&opDesc, IID_PPV_ARGS(&dmlOperators[i])));
dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{dmlOperators[i].Get()};
dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]};
}
Expand Down Expand Up @@ -824,4 +847,80 @@ namespace Dml
graphDesc.IntermediateEdges = dmlIntermediateEdges.data();
}

/*static*/ void DmlOperator::TryConvertTensorToBroadcastScalar(
const MLOperatorKernelCreationContext& kernelInfo,
const DML_TENSOR_DESC* tensor,
uint32_t kernelInputIndex)
{
if (!tensor)
{
return;
}

auto constExpTensor = kernelInfo.TryGetConstantInputTensor(kernelInputIndex);
if (!constExpTensor)
{
return;
}

uint32_t totalKernelInputElementCount = constExpTensor->GetTotalElementCount();
if (totalKernelInputElementCount <= 1)
{
return;
}

uint32_t elementSize = 0;

switch (constExpTensor->GetTensorDataType())
{
case MLOperatorTensorDataType::UInt8:
case MLOperatorTensorDataType::Int8:
elementSize = 1;
break;

case MLOperatorTensorDataType::Float16:
case MLOperatorTensorDataType::UInt16:
case MLOperatorTensorDataType::Int16:
elementSize = 2;
break;

case MLOperatorTensorDataType::/*Float32*/Float:
case MLOperatorTensorDataType::UInt32:
case MLOperatorTensorDataType::Int32:
elementSize = 4;
break;

case MLOperatorTensorDataType::/*Float64*/Double:
case MLOperatorTensorDataType::UInt64:
case MLOperatorTensorDataType::Int64:
elementSize = 8;
break;

default:
return;
}

const std::uint8_t* byteData = static_cast<const std::uint8_t*>(constExpTensor->GetByteData());

assert(tensor->Type == DML_TENSOR_TYPE_BUFFER);
auto *bufferTensorDesc = const_cast<DML_BUFFER_TENSOR_DESC*>(static_cast<const DML_BUFFER_TENSOR_DESC*>(tensor->Desc));

for (size_t i = 1; i < totalKernelInputElementCount; ++i)
{
if (memcmp(byteData, byteData + i * elementSize, elementSize))
{
return;
}
}

if (bufferTensorDesc->DimensionCount > sizeof(zeroArray) / sizeof(zeroArray[0]))
{
assert(false);
return;
}

bufferTensorDesc->Strides = zeroArray;
bufferTensorDesc->TotalTensorSizeInBytes = (elementSize + 3) & ~3;
}

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,11 @@ namespace Dml
uint32_t minDimensionCount = NchwDimensionCount
) const;

static void TryConvertTensorToBroadcastScalar(
const MLOperatorKernelCreationContext& kernelInfo,
const DML_TENSOR_DESC* tensor,
uint32_t kernelInputIndex);

private:
// For each input or output of the DML kernel, the corresponding input or output of the original
// kernel. Entries for unused DML inputs are nullopt.
Expand All @@ -164,6 +169,7 @@ namespace Dml
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlOutputEdges,
_Inout_ std::vector<DML_GRAPH_EDGE_DESC>& dmlIntermediateEdges);

static const uint32_t zeroArray[8];
};

} // namespace Dml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class DmlOperatorBatchNormalization15 : public DmlOperator, BatchNormalizationHe
std::vector<DML_TENSOR_DESC> inputDescs = GetDmlInputDescs();
std::vector<DML_TENSOR_DESC> outputDescs = GetDmlOutputDescs();

// TODO: Port this to a graph description to enable DML graph optimization

dml::Graph graph(m_dmlDevice.Get());
dml::TensorDesc inputTensorDesc = inputDescs[OnnxInputIndex::X];
dml::TensorDesc scaleTensorDesc = inputDescs[OnnxInputIndex::Scale];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,9 @@ class DmlOperatorElementwiseQLinear : public DmlOperator
opDesc.ScaleTensor = &inputDescs[1];
opDesc.ZeroPointTensor = &inputDescs[2];
opDesc.OutputTensor = &outputDescs[0];

TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1);
TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2);

SetDmlOperatorDesc({ApiTraits::OperatorDescTraits<TOperatorDesc>::Type, &opDesc}, kernelInfo);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ class DmlOperatorQLinearAdd : public DmlOperator
AddDesc.OutputScaleTensor = &inputDescs[IN_C_SCALE];
AddDesc.OutputZeroPointTensor = inputDescs[IN_C_ZERO_POINT].Desc != nullptr ? &inputDescs[IN_C_ZERO_POINT] : nullptr;
AddDesc.OutputTensor = &outputDescs[0];

TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AScaleTensor, IN_A_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.AZeroPointTensor, IN_A_ZERO_POINT);

TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BScaleTensor, IN_B_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.BZeroPointTensor, IN_B_ZERO_POINT);

TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputScaleTensor, IN_C_SCALE);
TryConvertTensorToBroadcastScalar(kernelInfo, AddDesc.OutputZeroPointTensor, IN_C_ZERO_POINT);

DML_OPERATOR_DESC opDesc = { DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD, &AddDesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe
qLinearAvgPooldesc.InputTensor = &inputDescs[OrtInputTensors::ortInput];
qLinearAvgPooldesc.InputScaleTensor = &inputDescs[OrtInputTensors::ortInputScale];
qLinearAvgPooldesc.InputZeroPointTensor = &inputDescs[OrtInputTensors::ortInputZeroPoint];
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];;
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];;
qLinearAvgPooldesc.OutputScaleTensor = &inputDescs[OrtInputTensors::ortOutputScale];
qLinearAvgPooldesc.OutputZeroPointTensor = &inputDescs[OrtInputTensors::ortOutputZeroPoint];
qLinearAvgPooldesc.OutputTensor = &outputDescs[0];
qLinearAvgPooldesc.DimensionCount = m_kernel.spatialDimensionCount;
qLinearAvgPooldesc.WindowSize = m_kernel.windowSize;
Expand All @@ -129,6 +129,12 @@ class DmlOperatorQLinearAveragePooling : public DmlOperator, public PoolingHelpe
qLinearAvgPooldesc.Dilations = m_kernel.dilations;
qLinearAvgPooldesc.IncludePadding = kernelInfo.GetOptionalAttribute<bool>(AttrName::CountIncludePad, false);

TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputScaleTensor, OrtInputTensors::ortInputScale);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.InputZeroPointTensor, OrtInputTensors::ortInputZeroPoint);

TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputScaleTensor, OrtInputTensors::ortOutputScale);
TryConvertTensorToBroadcastScalar(kernelInfo, qLinearAvgPooldesc.OutputZeroPointTensor, OrtInputTensors::ortOutputZeroPoint);

DML_OPERATOR_DESC opDesc = { (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, &qLinearAvgPooldesc };
SetDmlOperatorDesc(opDesc, kernelInfo);
}
Expand Down
Loading
Loading