diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 0ce4874a620bb..0668cb4a93a62 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -1294,6 +1294,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp
index eb068087de4ad..353f698bb6f2c 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp
@@ -330,7 +330,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept
{
- return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, nullptr, false, false, false);
+ return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, nullptr, false, false);
}
HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
@@ -339,11 +339,12 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool isInternalOperator,
- bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
- uint32_t constantCpuInputCount) const noexcept
+ uint32_t constantCpuInputCount,
+ _In_reads_(aliasCount) const std::pair* aliases,
+ uint32_t aliasCount) const noexcept
{
ORT_TRY
{
@@ -417,9 +418,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
builder.InputMemoryType(::OrtMemType::OrtMemTypeCPUInput, inputIndex);
}
- if (canAliasFirstInput)
+ for (uint32_t i = 0; i < aliasCount; ++i)
{
- builder.Alias(0, 0);
+ builder.Alias(aliases[i].first, aliases[i].second);
}
// Set type constraints
@@ -553,7 +554,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
else
{
// Currently unsupported for external operators
- if (canAliasFirstInput ||
+ if (aliasCount > 0 ||
supportsGraph ||
requiredInputCountForGraph)
{
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h
index d6b1448b559b1..eb84b4f822e92 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h
@@ -15,7 +15,7 @@ namespace WRL
}
namespace Windows::AI::MachineLearning::Adapter
-{
+{
using namespace Microsoft::WRL;
@@ -38,11 +38,12 @@ class AbiCustomRegistry : public WRL::Base* aliases = nullptr,
+ uint32_t aliasCount = 0) const noexcept override;
HRESULT STDMETHODCALLTYPE RegisterOperatorKernel(
const MLOperatorKernelDescription* opKernel,
@@ -56,7 +57,7 @@ class AbiCustomRegistry : public WRL::Base m_kernelRegistry;
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
index 1c82c40a88bb6..2bd9377e4c2fa 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp
@@ -107,6 +107,8 @@ namespace Dml::GraphDescBuilder
// Mapping from the old indices to the new indices that have been shifted after removing earlier nodes
std::vector shiftedIndicesMapping(graphNodes.size());
+ std::unordered_set nodesRemoved;
+
uint32_t shift = 0;
for (uint32_t nodeIndex = 0; nodeIndex < graphNodes.size(); ++nodeIndex)
{
@@ -114,6 +116,7 @@ namespace Dml::GraphDescBuilder
{
// The node is not connected, so we simply increase the shift value (the node will be overwritten by the following nodes)
++shift;
+ nodesRemoved.insert(nodeIndex);
}
else
{
@@ -125,6 +128,13 @@ namespace Dml::GraphDescBuilder
graphNodes.resize(graphNodes.size() - shift);
+ // Remove the inputs that are not connected to anything anymore
+ auto inputEdgesEndIter = std::remove_if(graphInputEdges.begin(), graphInputEdges.end(), [&nodesRemoved](const auto& inputEdge) {
+ return nodesRemoved.count(inputEdge.ToNodeIndex);
+ });
+
+ graphInputEdges.erase(inputEdgesEndIter, graphInputEdges.end());
+
// Adjust the node indices in the input edges
std::unordered_set usedInputEdgeIndex;
for (auto& inputEdge : graphInputEdges)
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h
index ddd6d56128461..a99d8bf655fec 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h
@@ -1147,7 +1147,6 @@ class GpuDFTOperatorFactory : public WRL::Base
shareInferrer.Get(),
nullptr,
false, // isInternalOperator
- false, // alias
false, // supportsGraph
nullptr,
requiredConstantCpuInputs.data(),
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h
index 4f5da9dd05491..5ba936ddf3976 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h
@@ -844,7 +844,6 @@ class DmlGridSampleOperatorFactory : public WRL::Base
shareInferrer.Get(),
nullptr,
false, // isInternalOperator
- false, // alias
false, // supportsGraph
nullptr,
nullptr,
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp
new file mode 100644
index 0000000000000..bf5a182f9662c
--- /dev/null
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp
@@ -0,0 +1,318 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "precomp.h"
+
+namespace Dml
+{
+class DmlOperatorGroupQueryAttention : public DmlOperator, public GroupQueryAttentionHelper
+{
+public:
+ DmlOperatorGroupQueryAttention(const MLOperatorKernelCreationContext& kernelCreationContext)
+ : DmlOperator(kernelCreationContext),
+ GroupQueryAttentionHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription())
+ {
+ enum InputIndex : uint32_t
+ {
+ queryIndex,
+ keyIndex,
+ valueIndex,
+ pastKeyIndex,
+ pastValueIndex,
+ seqLensIndex,
+ inputCount,
+ };
+
+ enum OutputIndex : uint32_t
+ {
+ outputIndex,
+ outputPresentKeyIndex,
+ outputPresentValueIndex,
+ outputCount,
+ };
+
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 1);
+ ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1);
+
+ std::vector> inputIndices(inputCount);
+ inputIndices[queryIndex] = queryIndex;
+ inputIndices[keyIndex] = keyIndex;
+ inputIndices[valueIndex] = valueIndex;
+
+ const uint32_t sequenceLength = kernelCreationContext.GetInputTensorShape(queryIndex)[1];
+
+ if (kernelCreationContext.GetInputTensorShape(queryIndex)[1] == 1)
+ {
+ inputIndices[seqLensIndex] = seqLensIndex;
+ }
+
+ std::vector> outputIndices = {
+ outputIndex,
+ outputPresentKeyIndex,
+ outputPresentValueIndex,
+ };
+ DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices, std::nullopt, std::nullopt, 1);
+
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[queryIndex].GetDimensionCount() == 3);
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[keyIndex].GetDimensionCount() == 3);
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[valueIndex].GetDimensionCount() == 3);
+
+ const uint32_t queryNumHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads));
+ const uint32_t kvNumHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::KvNumHeads));
+
+ auto querySizes = m_inputTensorDescs[queryIndex].GetSizes();
+ auto keySizes = m_inputTensorDescs[keyIndex].GetSizes();
+ auto valueSizes = m_inputTensorDescs[valueIndex].GetSizes();
+
+ const uint32_t batchSize = querySizes[0];
+ const uint32_t queryHiddenSize = querySizes[2];
+
+ const uint32_t kvSequenceLength = keySizes[1];
+ const uint32_t kvHiddenSize = keySizes[2];
+
+ const uint32_t queryHeadSize = queryHiddenSize / queryNumHeads;
+ const uint32_t kvHeadSize = kvHiddenSize / kvNumHeads;
+ const uint32_t totalSequenceLength = GetTotalSequenceLength();
+
+ // Validate Query dimensions
+ ML_CHECK_VALID_ARGUMENT(querySizes[0] == batchSize);
+ ML_CHECK_VALID_ARGUMENT(querySizes[1] == sequenceLength);
+ ML_CHECK_VALID_ARGUMENT(querySizes[2] == queryHiddenSize);
+
+ // Validate Key dimensions
+ ML_CHECK_VALID_ARGUMENT(keySizes[0] == batchSize);
+ ML_CHECK_VALID_ARGUMENT(keySizes[1] == kvSequenceLength);
+ ML_CHECK_VALID_ARGUMENT(keySizes[2] == kvHiddenSize);
+
+ // Validate Value dimensions
+ ML_CHECK_VALID_ARGUMENT(valueSizes[0] == batchSize);
+ ML_CHECK_VALID_ARGUMENT(valueSizes[1] == kvSequenceLength);
+ ML_CHECK_VALID_ARGUMENT(valueSizes[2] == kvHiddenSize);
+
+ if (sequenceLength == 1)
+ {
+ // Validate PastSequenceLengths dimensions
+ if (m_inputTensorDescs[seqLensIndex].GetDimensionCount() == 1)
+ {
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[seqLensIndex].GetSizes()[0] == batchSize);
+ }
+ else
+ {
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[seqLensIndex].GetDimensionCount() == 2);
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[seqLensIndex].GetSizes()[0] == batchSize);
+ ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[seqLensIndex].GetSizes()[1] == 1);
+ }
+ }
+
+ const std::array pastSequenceLengthsShape = {batchSize};
+ auto pastSequenceLengthsDataType = MLOperatorTensorDataType::Int32;
+ TensorDesc pastSequenceLengthsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastSequenceLengthsDataType, pastSequenceLengthsShape);
+ const DML_TENSOR_DESC pastSequenceLengthsDmlTensorDesc = pastSequenceLengthsTensorDesc.GetDmlDesc();
+
+ std::vector inputDescs = GetDmlInputDescs();
+ std::vector outputDescs = GetDmlOutputDescs();
+
+ // GQA is very sensitive to overflows, so we cast all inputs to fp32 and cast the outputs back to fp16. At the DML level,
+ // those casts will be eliminated and replaced with half precision computation instead, which mimics the CUDA EP behavior
+ // of their flash attention kernel.
+ TensorDesc queryCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[queryIndex].GetSizes());
+ DML_TENSOR_DESC queryCastDmlTensorDesc = queryCastTensorDesc.GetDmlDesc();
+ DML_CAST_OPERATOR_DESC queryCastOpDesc{};
+ queryCastOpDesc.InputTensor = &inputDescs[queryIndex];
+ queryCastOpDesc.OutputTensor = &queryCastDmlTensorDesc;
+ DML_OPERATOR_DESC queryCastDmlDesc = { DML_OPERATOR_CAST, &queryCastOpDesc };
+
+ TensorDesc keyCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[keyIndex].GetSizes());
+ DML_TENSOR_DESC keyCastDmlTensorDesc = keyCastTensorDesc.GetDmlDesc();
+ DML_CAST_OPERATOR_DESC keyCastOpDesc{};
+ keyCastOpDesc.InputTensor = &inputDescs[keyIndex];
+ keyCastOpDesc.OutputTensor = &keyCastDmlTensorDesc;
+ DML_OPERATOR_DESC keyCastDmlDesc = { DML_OPERATOR_CAST, &keyCastOpDesc };
+
+ TensorDesc valueCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[valueIndex].GetSizes());
+ DML_TENSOR_DESC valueCastDmlTensorDesc = valueCastTensorDesc.GetDmlDesc();
+ DML_CAST_OPERATOR_DESC valueCastOpDesc{};
+ valueCastOpDesc.InputTensor = &inputDescs[valueIndex];
+ valueCastOpDesc.OutputTensor = &valueCastDmlTensorDesc;
+ DML_OPERATOR_DESC valueCastDmlDesc = { DML_OPERATOR_CAST, &valueCastOpDesc };
+
+ TensorDesc outputCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_outputTensorDescs[outputIndex].GetSizes());
+ DML_TENSOR_DESC outputCastDmlTensorDesc = outputCastTensorDesc.GetDmlDesc();
+ DML_CAST_OPERATOR_DESC outputCastOpDesc{};
+ outputCastOpDesc.InputTensor = &outputCastDmlTensorDesc;
+ outputCastOpDesc.OutputTensor = &outputDescs[outputIndex];
+ DML_OPERATOR_DESC outputCastDmlDesc = { DML_OPERATOR_CAST, &outputCastOpDesc };
+
+ TensorDesc outputPresentKeyCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_outputTensorDescs[outputPresentKeyIndex].GetSizes());
+ DML_TENSOR_DESC outputPresentKeyCastDmlTensorDesc = outputPresentKeyCastTensorDesc.GetDmlDesc();
+ DML_CAST_OPERATOR_DESC outputPresentKeyCastOpDesc{};
+ outputPresentKeyCastOpDesc.InputTensor = &outputPresentKeyCastDmlTensorDesc;
+ outputPresentKeyCastOpDesc.OutputTensor = &outputDescs[outputPresentKeyIndex];
+ DML_OPERATOR_DESC outputPresentKeyCastDmlDesc = { DML_OPERATOR_CAST, &outputPresentKeyCastOpDesc };
+
+ TensorDesc outputPresentValueCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_outputTensorDescs[outputPresentValueIndex].GetSizes());
+ DML_TENSOR_DESC outputPresentValueCastDmlTensorDesc = outputPresentValueCastTensorDesc.GetDmlDesc();
+ DML_CAST_OPERATOR_DESC outputPresentValueCastOpDesc{};
+ outputPresentValueCastOpDesc.InputTensor = &outputPresentValueCastDmlTensorDesc;
+ outputPresentValueCastOpDesc.OutputTensor = &outputDescs[outputPresentValueIndex];
+ DML_OPERATOR_DESC outputPresentValueCastDmlDesc = { DML_OPERATOR_CAST, &outputPresentValueCastOpDesc };
+
+ const bool isFp16 = m_inputTensorDescs[queryIndex].GetDmlDataType() == DML_TENSOR_DATA_TYPE_FLOAT16;
+
+ DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC mhaDesc = {};
+ mhaDesc.QueryTensor = isFp16 ? &queryCastDmlTensorDesc : &inputDescs[queryIndex];
+ mhaDesc.KeyTensor = isFp16 ? &keyCastDmlTensorDesc : &inputDescs[keyIndex];
+ mhaDesc.ValueTensor = isFp16 ? &valueCastDmlTensorDesc : &inputDescs[valueIndex];
+ mhaDesc.PastSequenceLengthsTensor = &pastSequenceLengthsDmlTensorDesc;
+ mhaDesc.OutputTensor = isFp16 ? &outputCastDmlTensorDesc : &outputDescs[outputIndex];
+ mhaDesc.OutputPresentKeyTensor = isFp16 ? &outputPresentKeyCastDmlTensorDesc : &outputDescs[outputPresentKeyIndex];
+ mhaDesc.OutputPresentValueTensor = isFp16 ? &outputPresentValueCastDmlTensorDesc : &outputDescs[outputPresentValueIndex];
+ mhaDesc.QueryHeadCount = queryNumHeads;
+ mhaDesc.KeyValueHeadCount = kvNumHeads;
+ mhaDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(queryHeadSize)));
+ mhaDesc.MaskFilterValue = -10'000.0f;
+ DML_OPERATOR_DESC mhaDmlDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION1, &mhaDesc };
+
+ DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zeroScalarDesc = {};
+ zeroScalarDesc.OutputTensor = &pastSequenceLengthsDmlTensorDesc;
+ zeroScalarDesc.ValueDataType = pastSequenceLengthsTensorDesc.GetDmlDataType();
+ DML_OPERATOR_DESC zeroScalarDmlDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &zeroScalarDesc };
+
+ std::vector opDescs = {
+ &mhaDmlDesc,
+ };
+
+ // Construct the graph
+ std::vector inputEdges;
+ std::vector intermediateEdges;
+ std::vector outputEdges;
+
+ if (isFp16)
+ {
+ opDescs.push_back(&queryCastDmlDesc);
+ opDescs.push_back(&keyCastDmlDesc);
+ opDescs.push_back(&valueCastDmlDesc);
+ opDescs.push_back(&outputCastDmlDesc);
+ opDescs.push_back(&outputPresentKeyCastDmlDesc);
+ opDescs.push_back(&outputPresentValueCastDmlDesc);
+
+ // Link the query/key/value inputs to the cast nodes
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC inputToMhaEdge = {};
+ inputToMhaEdge.GraphInputIndex = i;
+ inputToMhaEdge.ToNodeIndex = 1 + i;
+ inputToMhaEdge.ToNodeInputIndex = 0;
+ inputEdges.push_back(inputToMhaEdge);
+ }
+
+ // Link the input cast nodes to MHA
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC castToMhaEdge = {};
+ castToMhaEdge.FromNodeIndex = 1 + i;
+ castToMhaEdge.FromNodeOutputIndex = 0;
+ castToMhaEdge.ToNodeIndex = 0;
+ castToMhaEdge.ToNodeInputIndex = i;
+ intermediateEdges.push_back(castToMhaEdge);
+ }
+ }
+ else
+ {
+ // Link the query/key/value inputs to MHA
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ DML_INPUT_GRAPH_EDGE_DESC inputToMhaEdge = {};
+ inputToMhaEdge.GraphInputIndex = i;
+ inputToMhaEdge.ToNodeIndex = 0;
+ inputToMhaEdge.ToNodeInputIndex = i;
+ inputEdges.push_back(inputToMhaEdge);
+ }
+ }
+
+ constexpr uint32_t dmlPastSequenceLengthsIndex = 11;
+
+ // The GQA offline fusion does this thing where it sums the number of 1's in the mask to figure out the value of the past sequence.
+ // This doesn't work well for the first iteration since, obviously, there are no past sequences and the mask in this case represents
+ // only the elements in the initial sequence. To work around this, the CUDA implementation of the operator ignores the value of
+ // pastSequenceLengths for the first iteration and acts as if it was 0. This feels like a pretty dirty hack and something that should
+ // be polished in the future, but for compatibility with the GQA fusion and the CUDA implementation we do the same thing here. We DO NOT
+ // want to do this within DirectML since DirectML should be agnostic w.r.t which iteration it's currently executing MHA for, and such a
+ // hack that is likely to be modified in the future shouldn't be enshrined within DirectML. Doing it here is OK because the nature of contrib
+ // ops is that they can change at any time.
+ if (sequenceLength == 1)
+ {
+ // Link the PastSequenceLengths input to MHA
+ DML_INPUT_GRAPH_EDGE_DESC inputToMhaEdge = {};
+ inputToMhaEdge.GraphInputIndex = seqLensIndex;
+ inputToMhaEdge.ToNodeIndex = 0;
+ inputToMhaEdge.ToNodeInputIndex = dmlPastSequenceLengthsIndex;
+ inputEdges.push_back(inputToMhaEdge);
+ }
+ else
+ {
+ opDescs.push_back(&zeroScalarDmlDesc);
+
+ // Link the zero scalar to MHA
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC zeroScalarToMhaEdge = {};
+ zeroScalarToMhaEdge.FromNodeIndex = gsl::narrow_cast(opDescs.size() - 1);
+ zeroScalarToMhaEdge.FromNodeOutputIndex = 0;
+ zeroScalarToMhaEdge.ToNodeIndex = 0;
+ zeroScalarToMhaEdge.ToNodeInputIndex = dmlPastSequenceLengthsIndex;
+ intermediateEdges.push_back(zeroScalarToMhaEdge);
+ }
+
+ if (isFp16)
+ {
+ // Output cast nodes start at the 4th index (previously we have the mha, query, key and value nodes)
+ const uint32_t outputCastNodeStart = 4;
+
+ // Link MHA's output to the output cast nodes
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ DML_INTERMEDIATE_GRAPH_EDGE_DESC mhaToCastEdge = {};
+ mhaToCastEdge.FromNodeIndex = 0;
+ mhaToCastEdge.FromNodeOutputIndex = i;
+ mhaToCastEdge.ToNodeIndex = outputCastNodeStart + i;
+ mhaToCastEdge.ToNodeInputIndex = 0;
+ intermediateEdges.push_back(mhaToCastEdge);
+ }
+
+ // Link the output cast nodes to the graph's outputs
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ DML_OUTPUT_GRAPH_EDGE_DESC castToOutputEdge = {};
+ castToOutputEdge.FromNodeIndex = outputCastNodeStart + i;
+ castToOutputEdge.FromNodeOutputIndex = 0;
+ castToOutputEdge.GraphOutputIndex = i;
+ outputEdges.push_back(castToOutputEdge);
+ }
+ }
+ else
+ {
+ // Link MHA's outputs to the graph's outputs
+ for (uint32_t i = 0; i < 3; ++i)
+ {
+ DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {};
+ mhaToOutputEdge.FromNodeIndex = 0;
+ mhaToOutputEdge.FromNodeOutputIndex = i;
+ mhaToOutputEdge.GraphOutputIndex = i;
+ outputEdges.push_back(mhaToOutputEdge);
+ }
+ }
+
+ MLOperatorGraphDesc operatorGraphDesc = {};
+ operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size());
+ operatorGraphDesc.inputEdges = inputEdges.data();
+ operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size());
+ operatorGraphDesc.intermediateEdges = intermediateEdges.data();
+ operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size());
+ operatorGraphDesc.outputEdges = outputEdges.data();
+ operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size());
+ operatorGraphDesc.nodes = opDescs.data();
+ SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext);
+ }
+};
+
+DML_OP_DEFINE_CREATION_FUNCTION(GroupQueryAttention, DmlOperatorGroupQueryAttention);
+} // namespace Dml
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp
index 03500d0ee86a9..cde08864ca54e 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp
@@ -56,8 +56,8 @@ class DmlOperatorMultiHeadAttention : public DmlOperator
const bool hasBias = kernelCreationContext.IsInputValid(biasIndex);
const bool hasMask = kernelCreationContext.IsInputValid(maskIndex);
const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex);
- const bool hasPastKey = keyValueIsPast || kernelCreationContext.IsInputValid(pastKeyIndex);
- const bool hasPastValue = keyValueIsPast || kernelCreationContext.IsInputValid(pastValueIndex);
+ const bool hasPastKey = keyValueIsPast || (kernelCreationContext.IsInputValid(pastKeyIndex) && kernelCreationContext.GetInputTensorShape(pastKeyIndex)[2] != 0);
+ const bool hasPastValue = keyValueIsPast || (kernelCreationContext.IsInputValid(pastValueIndex) && kernelCreationContext.GetInputTensorShape(pastValueIndex)[2] != 0);
const bool hasPresentKeyOutput = kernelCreationContext.IsOutputValid(outputPresentKeyIndex);
const bool hasPresentValueOutput = kernelCreationContext.IsOutputValid(outputPresentValueIndex);
const bool stackedQkv = kernelCreationContext.GetInputTensorDimensionCount(queryIndex) == 5;
@@ -74,8 +74,8 @@ class DmlOperatorMultiHeadAttention : public DmlOperator
biasIndex,
hasMask ? std::optional(maskIndex) : std::nullopt,
relativePositionBiasIndex,
- keyValueIsPast ? keyIndex : pastKeyIndex,
- keyValueIsPast ? valueIndex : pastValueIndex,
+ hasPastKey ? std::optional(keyValueIsPast ? keyIndex : pastKeyIndex) : std::nullopt,
+ hasPastValue ? std::optional(keyValueIsPast ? valueIndex : pastValueIndex) : std::nullopt,
};
std::vector> outputIndices = {
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h
index 15dcf4fb174fb..e2f38231f7295 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h
@@ -588,7 +588,6 @@ class DmlSTFTOperatorFactory : public WRL::Base
shareInferrer.Get(),
nullptr,
false, // isInternalOperator
- false, // alias
false, // supportsGraph
nullptr,
requiredConstantCpuInputs.data(),
diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
index 2230ee74d9ff6..0091210f439a4 100644
--- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
+++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp
@@ -291,7 +291,7 @@ struct OperatorRegistrationInformation
const char* domain;
MLOperatorKernelCreateFn creationFunction;
MLOperatorShapeInferenceFunction shapeInferenceFunction;
- bool canAliasFirstInput;
+ std::pair, 4>, int> aliases;
gsl::span tensorTypeNames;
gsl::span supportedTensorDataTypes;
@@ -522,6 +522,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Size);
DML_OP_EXTERN_CREATION_FUNCTION(QAttention);
DML_OP_EXTERN_CREATION_FUNCTION(Attention);
DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention);
+DML_OP_EXTERN_CREATION_FUNCTION(GroupQueryAttention);
DML_OP_EXTERN_CREATION_FUNCTION(NonZero);
DML_OP_EXTERN_CREATION_FUNCTION(QuickGelu);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseAnd);
@@ -679,29 +680,48 @@ constexpr auto requiredConstantCpuInputs(Args... args)
return std::make_pair(inputs, static_cast(sizeof...(args)));
}
+template
+constexpr auto Aliases(Args... args)
+{
+ if constexpr (sizeof...(args) == 0)
+ {
+ std::array, 4> aliases = {std::make_pair(0, 0)};
+ return std::make_pair(aliases, 0);
+ }
+ else
+ {
+ std::array, 4> aliases = {static_cast>(args)...};
+ return std::make_pair(aliases, static_cast(sizeof...(args)));
+ }
+}
+
// Define a single row of OperatorRegistrationInformation.
#define REG_INFO(version, operatorName, ...) \
- #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__,
+ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, Aliases(), ##__VA_ARGS__,
#define REG_INFO_DYNAMIC_OUTPUTS(version, operatorName, ...) \
- #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, nullptr, false, ##__VA_ARGS__,
+ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, nullptr, Aliases(), ##__VA_ARGS__,
// Versioned operator
#define REG_INFO_VER(version, operatorName, ...) \
- #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, ##__VA_ARGS__,
+ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, Aliases(), ##__VA_ARGS__,
// Identity operators use Copy, alias their first input, and use elementwise identity operators
// when needed for striding support, but issue actual copies outside the graph.
#define REG_INFO_COPY(version, operatorName, ...) \
- #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, true, ##__VA_ARGS__,
+ #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, Aliases(std::make_pair(0, 0)), ##__VA_ARGS__,
// MS-domain operators
#define REG_INFO_MS(version, operatorName, ...) \
- #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__,
+ #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, Aliases(), ##__VA_ARGS__,
+
+// MS-domain operators
+#define REG_INFO_MS_ALIAS(version, operatorName, aliases, ...) \
+ #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, aliases, ##__VA_ARGS__,
// MS-domain operators
#define REG_INFO_MSDML(version, operatorName, ...) \
- #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__,
+ #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, Aliases(), ##__VA_ARGS__,
constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] =
{
@@ -1120,6 +1140,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)},
+
+ // Operators that need to alias an input with an output
+ {REG_INFO_MS_ALIAS(1, GroupQueryAttention, Aliases(std::make_pair(3, 1), std::make_pair(4, 2)), typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(6))},
};
template
@@ -1275,11 +1298,12 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry)
shapeInferrer.Get(),
supportQuery.Get(),
true, // isInternalOperator
- information.canAliasFirstInput, // alias
kernelSupportsGraph, // supportsGraph
information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr,
information.requiredConstantCpuInputs.first.data(),
- static_cast(information.requiredConstantCpuInputs.second)
+ static_cast(information.requiredConstantCpuInputs.second),
+ information.aliases.first.data(),
+ static_cast(information.aliases.second)
));
}
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
index 287deaa513f64..5d5806865a8da 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h
@@ -108,6 +108,7 @@ namespace AttrName
static constexpr const char* QkvHiddenSizes = "qkv_hidden_sizes";
static constexpr const char* Unidirectional = "unidirectional";
static constexpr const char* NumHeads = "num_heads";
+ static constexpr const char* KvNumHeads = "kv_num_heads";
static constexpr const char* PastPresentShareBuffer = "past_present_share_buffer";
static constexpr const char* FusedActivation = "fused_activation";
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h
index ac3a3eb1268b8..8a218470d30bb 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h
@@ -14,7 +14,7 @@ struct MLOperatorGraphDesc
{
uint32_t nodeCount;
_Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodes;
-
+
uint32_t inputEdgeCount;
_Field_size_(inputEdgeCount) const DML_INPUT_GRAPH_EDGE_DESC* inputEdges;
@@ -35,7 +35,7 @@ IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContex
) const noexcept PURE;
STDMETHOD(TryGetConstantInputTensor)(
- uint32_t inputIndex,
+ uint32_t inputIndex,
_Outptr_ IMLOperatorTensor** tensor
) const noexcept PURE;
@@ -72,7 +72,7 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex
) const noexcept PURE;
STDMETHOD(TryGetConstantInputTensor)(
- uint32_t inputIndex,
+ uint32_t inputIndex,
_Outptr_ IMLOperatorTensor** tensor
) const noexcept PURE;
@@ -174,11 +174,12 @@ IMLOperatorRegistryPrivate : public IUnknown
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool isInternalOperator,
- bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
_In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr,
- uint32_t constantCpuInputCount = 0
+ uint32_t constantCpuInputCount = 0,
+ _In_reads_(aliasCount) const std::pair* aliases = nullptr,
+ uint32_t aliasCount = 0
) const noexcept PURE;
};
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
index 53522110f6f1c..7637c120867f7 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp
@@ -2771,6 +2771,52 @@ namespace OperatorHelper
m_numHeads = gsl::narrow_cast(kernelInformation.GetAttributes().GetAttribute(AttrName::NumHeads));
}
+ std::vector GroupQueryAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
+ {
+ ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 2);
+
+ const auto queryShape = shapeInfo.GetInputTensorShape(0);
+ ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3);
+ const uint32_t batchSize = queryShape[0];
+ const uint32_t sequenceLength = queryShape[1];
+ const uint32_t hiddenSize = queryShape[2];
+
+ const auto keyShape = shapeInfo.GetInputTensorShape(1);
+ ML_CHECK_VALID_ARGUMENT(keyShape.size() == 3);
+ const uint32_t kvHiddenSize = keyShape[2];
+ const uint32_t kvHeadSize = kvHiddenSize / m_kvNumHeads;
+
+ uint32_t pastSequenceLength = 0;
+
+ if (shapeInfo.IsInputValid(3))
+ {
+ const auto pastKeyShape = shapeInfo.GetInputTensorShape(3);
+ ML_CHECK_VALID_ARGUMENT(pastKeyShape.size() == 4);
+ pastSequenceLength = pastKeyShape[2];
+ }
+
+ const uint32_t presentSequenceLength = std::max(pastSequenceLength, m_totalSequenceLength);
+
+ std::vector outputShapes =
+ {
+ EdgeShapes({batchSize, sequenceLength, hiddenSize}),
+ EdgeShapes({batchSize, m_kvNumHeads, presentSequenceLength, kvHeadSize}),
+ EdgeShapes({batchSize, m_kvNumHeads, presentSequenceLength, kvHeadSize}),
+ };
+
+ return outputShapes;
+ }
+
+ void GroupQueryAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation)
+ {
+ m_kvNumHeads = gsl::narrow_cast(kernelInformation.GetAttributes().GetAttribute(AttrName::KvNumHeads));
+
+ std::vector totalSequenceLength;
+ ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(6), /*out*/ totalSequenceLength);
+ ML_CHECK_VALID_ARGUMENT(totalSequenceLength.size() == 1, "total_sequence_length must be a scalar.");
+ m_totalSequenceLength = totalSequenceLength[0];
+ }
+
std::vector AttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const
{
ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 2);
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
index 42468d7829c8f..0ad6ff6e3ee6d 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h
@@ -1538,6 +1538,27 @@ class MultiHeadAttentionHelper
uint32_t m_numHeads;
};
+class GroupQueryAttentionHelper
+{
+public:
+ template
+ GroupQueryAttentionHelper(const Info_t& info, const Shape_t& shapeInfo)
+ {
+ Initialize(KernelInformationAdapter(info));
+ }
+
+ std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const;
+
+protected:
+ uint32_t GetTotalSequenceLength() { return m_totalSequenceLength; }
+
+private:
+ void Initialize(const IKernelInformationAdapter& kernelInformation);
+
+ uint32_t m_kvNumHeads;
+ uint32_t m_totalSequenceLength;
+};
+
class AttentionHelper
{
public:
@@ -1720,6 +1741,7 @@ using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QAttention = QAttentionHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
+using ShapeInferenceHelper_GroupQueryAttention = GroupQueryAttentionHelper;
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper;
diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
index 9d2f88008185b..1c84e3badae2c 100644
--- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
+++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h
@@ -454,6 +454,7 @@ namespace OperatorHelper
static const int sc_sinceVer_Attention = 1;
static const int sc_sinceVer_MatMulIntegerToFloat = 1;
static const int sc_sinceVer_MultiHeadAttention = 1;
+ static const int sc_sinceVer_GroupQueryAttention = 1;
static const int sc_sinceVer_SkipLayerNormalization = 1;
static const int sc_sinceVer_SkipSimplifiedLayerNormalization = 1;
static const int sc_sinceVer_EmbedLayerNormalization = 1;
diff --git a/winml/adapter/abi_custom_registry_impl.cpp b/winml/adapter/abi_custom_registry_impl.cpp
index a3921d8a92fb7..7aceafaf5fb2c 100644
--- a/winml/adapter/abi_custom_registry_impl.cpp
+++ b/winml/adapter/abi_custom_registry_impl.cpp
@@ -51,11 +51,12 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool isInternalOperator,
- bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
- uint32_t constantCpuInputCount
+ uint32_t constantCpuInputCount,
+ _In_reads_(aliasCount) const std::pair* aliases,
+ uint32_t aliasCount
) const noexcept try {
#ifdef LAYERING_DONE
// Log a custom op telemetry if the operator is not a built-in DML operator
@@ -73,11 +74,12 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel(
shapeInferrer,
supportQuery,
isInternalOperator,
- canAliasFirstInput,
supportsGraph,
requiredInputCountForGraph,
requiredConstantCpuInputs,
- constantCpuInputCount
+ constantCpuInputCount,
+ aliases,
+ aliasCount
);
}
CATCH_RETURN();
diff --git a/winml/adapter/abi_custom_registry_impl.h b/winml/adapter/abi_custom_registry_impl.h
index 1942a2e4f82f1..101997d7c94ab 100644
--- a/winml/adapter/abi_custom_registry_impl.h
+++ b/winml/adapter/abi_custom_registry_impl.h
@@ -26,11 +26,12 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry {
_In_opt_ IMLOperatorShapeInferrer* shape_inferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool is_internal_operator,
- bool can_alias_first_input,
bool supports_graph,
const uint32_t* required_input_count_for_graph = nullptr,
_In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr,
- uint32_t constant_cpu_input_count = 0
+ uint32_t constant_cpu_input_count = 0,
+ _In_reads_(aliasCount) const std::pair* aliases = nullptr,
+ uint32_t aliasCount = 0
) const noexcept override;
HRESULT STDMETHODCALLTYPE RegisterOperatorKernel(