diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 0ce4874a620bb..0668cb4a93a62 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -1294,6 +1294,7 @@ Do not modify directly.* |FusedMatMulActivation|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**M** = tensor(float), tensor(float16)
**T** = tensor(float), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)| |NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index eb068087de4ad..353f698bb6f2c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -330,7 +330,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( IMLOperatorKernelFactory* operatorKernelFactory, _In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept { - return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, nullptr, false, false, false); + return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, nullptr, false, false); } HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( @@ -339,11 +339,12 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( _In_opt_ IMLOperatorShapeInferrer* shapeInferrer, _In_opt_ IMLOperatorSupportQueryPrivate* supportQuery, bool isInternalOperator, - bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, - uint32_t constantCpuInputCount) const noexcept + uint32_t constantCpuInputCount, + _In_reads_(aliasCount) const std::pair* aliases, + uint32_t aliasCount) const noexcept { ORT_TRY { @@ -417,9 +418,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( builder.InputMemoryType(::OrtMemType::OrtMemTypeCPUInput, inputIndex); } - if (canAliasFirstInput) + for (uint32_t i = 0; i < aliasCount; ++i) { - builder.Alias(0, 0); + builder.Alias(aliases[i].first, aliases[i].second); } // Set type constraints @@ -553,7 +554,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( else { // Currently unsupported for external operators - if (canAliasFirstInput || + if (aliasCount > 0 || supportsGraph || requiredInputCountForGraph) { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h index d6b1448b559b1..eb84b4f822e92 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.h @@ -15,7 +15,7 @@ namespace WRL } namespace Windows::AI::MachineLearning::Adapter -{ +{ using namespace Microsoft::WRL; @@ -38,11 +38,12 @@ class AbiCustomRegistry : public WRL::Base* aliases = nullptr, + uint32_t aliasCount = 0) const noexcept override; HRESULT STDMETHODCALLTYPE RegisterOperatorKernel( const MLOperatorKernelDescription* opKernel, @@ -56,7 +57,7 @@ class AbiCustomRegistry : public WRL::Base m_kernelRegistry; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 1c82c40a88bb6..2bd9377e4c2fa 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -107,6 +107,8 @@ namespace Dml::GraphDescBuilder // Mapping from the old indices to the new indices that have been shifted after removing earlier nodes std::vector shiftedIndicesMapping(graphNodes.size()); + std::unordered_set nodesRemoved; + uint32_t shift = 0; for (uint32_t nodeIndex = 0; nodeIndex < graphNodes.size(); ++nodeIndex) { @@ -114,6 +116,7 @@ namespace Dml::GraphDescBuilder { // The node is not connected, so we simply increase the shift value (the node will be overwritten by the following nodes) ++shift; + nodesRemoved.insert(nodeIndex); } else { @@ -125,6 +128,13 @@ namespace Dml::GraphDescBuilder graphNodes.resize(graphNodes.size() - shift); + // Remove the inputs that are not connected to anything anymore + auto inputEdgesEndIter = std::remove_if(graphInputEdges.begin(), graphInputEdges.end(), [&nodesRemoved](const auto& inputEdge) { + return nodesRemoved.count(inputEdge.ToNodeIndex); + }); + + graphInputEdges.erase(inputEdgesEndIter, graphInputEdges.end()); + // Adjust the node indices in the input edges std::unordered_set usedInputEdgeIndex; for (auto& inputEdge : graphInputEdges) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h index ddd6d56128461..a99d8bf655fec 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlDFT.h @@ -1147,7 +1147,6 @@ class GpuDFTOperatorFactory : public WRL::Base shareInferrer.Get(), nullptr, false, // isInternalOperator - false, // alias false, // supportsGraph nullptr, requiredConstantCpuInputs.data(), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h index 4f5da9dd05491..5ba936ddf3976 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlGridSample.h @@ -844,7 +844,6 @@ class DmlGridSampleOperatorFactory : public WRL::Base shareInferrer.Get(), nullptr, false, // isInternalOperator - false, // alias false, // supportsGraph nullptr, nullptr, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp new file mode 100644 index 0000000000000..bf5a182f9662c --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp @@ -0,0 +1,318 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +namespace Dml +{ +class DmlOperatorGroupQueryAttention : public DmlOperator, public GroupQueryAttentionHelper +{ +public: + DmlOperatorGroupQueryAttention(const MLOperatorKernelCreationContext& kernelCreationContext) + : DmlOperator(kernelCreationContext), + GroupQueryAttentionHelper(kernelCreationContext, kernelCreationContext.GetTensorShapeDescription()) + { + enum InputIndex : uint32_t + { + queryIndex, + keyIndex, + valueIndex, + pastKeyIndex, + pastValueIndex, + seqLensIndex, + inputCount, + }; + + enum OutputIndex : uint32_t + { + outputIndex, + outputPresentKeyIndex, + outputPresentValueIndex, + outputCount, + }; + + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetInputCount() >= 1); + ML_CHECK_VALID_ARGUMENT(kernelCreationContext.GetOutputCount() >= 1); + + std::vector> inputIndices(inputCount); + inputIndices[queryIndex] = queryIndex; + inputIndices[keyIndex] = keyIndex; + inputIndices[valueIndex] = valueIndex; + + const uint32_t sequenceLength = kernelCreationContext.GetInputTensorShape(queryIndex)[1]; + + if (kernelCreationContext.GetInputTensorShape(queryIndex)[1] == 1) + { + inputIndices[seqLensIndex] = seqLensIndex; + } + + std::vector> outputIndices = { + outputIndex, + outputPresentKeyIndex, + outputPresentValueIndex, + }; + DmlOperator::Initialize(kernelCreationContext, inputIndices, outputIndices, std::nullopt, std::nullopt, 1); + + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[queryIndex].GetDimensionCount() == 3); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[keyIndex].GetDimensionCount() == 3); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[valueIndex].GetDimensionCount() == 3); + + const uint32_t queryNumHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::NumHeads)); + const uint32_t kvNumHeads = gsl::narrow_cast(kernelCreationContext.GetAttribute(AttrName::KvNumHeads)); + + auto querySizes = m_inputTensorDescs[queryIndex].GetSizes(); + auto keySizes = m_inputTensorDescs[keyIndex].GetSizes(); + auto valueSizes = m_inputTensorDescs[valueIndex].GetSizes(); + + const uint32_t batchSize = querySizes[0]; + const uint32_t queryHiddenSize = querySizes[2]; + + const uint32_t kvSequenceLength = keySizes[1]; + const uint32_t kvHiddenSize = keySizes[2]; + + const uint32_t queryHeadSize = queryHiddenSize / queryNumHeads; + const uint32_t kvHeadSize = kvHiddenSize / kvNumHeads; + const uint32_t totalSequenceLength = GetTotalSequenceLength(); + + // Validate Query dimensions + ML_CHECK_VALID_ARGUMENT(querySizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(querySizes[1] == sequenceLength); + ML_CHECK_VALID_ARGUMENT(querySizes[2] == queryHiddenSize); + + // Validate Key dimensions + ML_CHECK_VALID_ARGUMENT(keySizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(keySizes[1] == kvSequenceLength); + ML_CHECK_VALID_ARGUMENT(keySizes[2] == kvHiddenSize); + + // Validate Value dimensions + ML_CHECK_VALID_ARGUMENT(valueSizes[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(valueSizes[1] == kvSequenceLength); + ML_CHECK_VALID_ARGUMENT(valueSizes[2] == kvHiddenSize); + + if (sequenceLength == 1) + { + // Validate PastSequenceLengths dimensions + if (m_inputTensorDescs[seqLensIndex].GetDimensionCount() == 1) + { + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[seqLensIndex].GetSizes()[0] == batchSize); + } + else + { + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[seqLensIndex].GetDimensionCount() == 2); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[seqLensIndex].GetSizes()[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[seqLensIndex].GetSizes()[1] == 1); + } + } + + const std::array pastSequenceLengthsShape = {batchSize}; + auto pastSequenceLengthsDataType = MLOperatorTensorDataType::Int32; + TensorDesc pastSequenceLengthsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastSequenceLengthsDataType, pastSequenceLengthsShape); + const DML_TENSOR_DESC pastSequenceLengthsDmlTensorDesc = pastSequenceLengthsTensorDesc.GetDmlDesc(); + + std::vector inputDescs = GetDmlInputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); + + // GQA is very sensitive to overflows, so we cast all inputs to fp32 and cast the outputs back to fp16. At the DML level, + // those casts will be eliminated and replaced with half precision computation instead, which mimics the CUDA EP behavior + // of their flash attention kernel. + TensorDesc queryCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[queryIndex].GetSizes()); + DML_TENSOR_DESC queryCastDmlTensorDesc = queryCastTensorDesc.GetDmlDesc(); + DML_CAST_OPERATOR_DESC queryCastOpDesc{}; + queryCastOpDesc.InputTensor = &inputDescs[queryIndex]; + queryCastOpDesc.OutputTensor = &queryCastDmlTensorDesc; + DML_OPERATOR_DESC queryCastDmlDesc = { DML_OPERATOR_CAST, &queryCastOpDesc }; + + TensorDesc keyCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[keyIndex].GetSizes()); + DML_TENSOR_DESC keyCastDmlTensorDesc = keyCastTensorDesc.GetDmlDesc(); + DML_CAST_OPERATOR_DESC keyCastOpDesc{}; + keyCastOpDesc.InputTensor = &inputDescs[keyIndex]; + keyCastOpDesc.OutputTensor = &keyCastDmlTensorDesc; + DML_OPERATOR_DESC keyCastDmlDesc = { DML_OPERATOR_CAST, &keyCastOpDesc }; + + TensorDesc valueCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_inputTensorDescs[valueIndex].GetSizes()); + DML_TENSOR_DESC valueCastDmlTensorDesc = valueCastTensorDesc.GetDmlDesc(); + DML_CAST_OPERATOR_DESC valueCastOpDesc{}; + valueCastOpDesc.InputTensor = &inputDescs[valueIndex]; + valueCastOpDesc.OutputTensor = &valueCastDmlTensorDesc; + DML_OPERATOR_DESC valueCastDmlDesc = { DML_OPERATOR_CAST, &valueCastOpDesc }; + + TensorDesc outputCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_outputTensorDescs[outputIndex].GetSizes()); + DML_TENSOR_DESC outputCastDmlTensorDesc = outputCastTensorDesc.GetDmlDesc(); + DML_CAST_OPERATOR_DESC outputCastOpDesc{}; + outputCastOpDesc.InputTensor = &outputCastDmlTensorDesc; + outputCastOpDesc.OutputTensor = &outputDescs[outputIndex]; + DML_OPERATOR_DESC outputCastDmlDesc = { DML_OPERATOR_CAST, &outputCastOpDesc }; + + TensorDesc outputPresentKeyCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_outputTensorDescs[outputPresentKeyIndex].GetSizes()); + DML_TENSOR_DESC outputPresentKeyCastDmlTensorDesc = outputPresentKeyCastTensorDesc.GetDmlDesc(); + DML_CAST_OPERATOR_DESC outputPresentKeyCastOpDesc{}; + outputPresentKeyCastOpDesc.InputTensor = &outputPresentKeyCastDmlTensorDesc; + outputPresentKeyCastOpDesc.OutputTensor = &outputDescs[outputPresentKeyIndex]; + DML_OPERATOR_DESC outputPresentKeyCastDmlDesc = { DML_OPERATOR_CAST, &outputPresentKeyCastOpDesc }; + + TensorDesc outputPresentValueCastTensorDesc(DML_TENSOR_DATA_TYPE_FLOAT32, m_outputTensorDescs[outputPresentValueIndex].GetSizes()); + DML_TENSOR_DESC outputPresentValueCastDmlTensorDesc = outputPresentValueCastTensorDesc.GetDmlDesc(); + DML_CAST_OPERATOR_DESC outputPresentValueCastOpDesc{}; + outputPresentValueCastOpDesc.InputTensor = &outputPresentValueCastDmlTensorDesc; + outputPresentValueCastOpDesc.OutputTensor = &outputDescs[outputPresentValueIndex]; + DML_OPERATOR_DESC outputPresentValueCastDmlDesc = { DML_OPERATOR_CAST, &outputPresentValueCastOpDesc }; + + const bool isFp16 = m_inputTensorDescs[queryIndex].GetDmlDataType() == DML_TENSOR_DATA_TYPE_FLOAT16; + + DML_MULTIHEAD_ATTENTION1_OPERATOR_DESC mhaDesc = {}; + mhaDesc.QueryTensor = isFp16 ? &queryCastDmlTensorDesc : &inputDescs[queryIndex]; + mhaDesc.KeyTensor = isFp16 ? &keyCastDmlTensorDesc : &inputDescs[keyIndex]; + mhaDesc.ValueTensor = isFp16 ? &valueCastDmlTensorDesc : &inputDescs[valueIndex]; + mhaDesc.PastSequenceLengthsTensor = &pastSequenceLengthsDmlTensorDesc; + mhaDesc.OutputTensor = isFp16 ? &outputCastDmlTensorDesc : &outputDescs[outputIndex]; + mhaDesc.OutputPresentKeyTensor = isFp16 ? &outputPresentKeyCastDmlTensorDesc : &outputDescs[outputPresentKeyIndex]; + mhaDesc.OutputPresentValueTensor = isFp16 ? &outputPresentValueCastDmlTensorDesc : &outputDescs[outputPresentValueIndex]; + mhaDesc.QueryHeadCount = queryNumHeads; + mhaDesc.KeyValueHeadCount = kvNumHeads; + mhaDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(queryHeadSize))); + mhaDesc.MaskFilterValue = -10'000.0f; + DML_OPERATOR_DESC mhaDmlDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION1, &mhaDesc }; + + DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zeroScalarDesc = {}; + zeroScalarDesc.OutputTensor = &pastSequenceLengthsDmlTensorDesc; + zeroScalarDesc.ValueDataType = pastSequenceLengthsTensorDesc.GetDmlDataType(); + DML_OPERATOR_DESC zeroScalarDmlDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &zeroScalarDesc }; + + std::vector opDescs = { + &mhaDmlDesc, + }; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + if (isFp16) + { + opDescs.push_back(&queryCastDmlDesc); + opDescs.push_back(&keyCastDmlDesc); + opDescs.push_back(&valueCastDmlDesc); + opDescs.push_back(&outputCastDmlDesc); + opDescs.push_back(&outputPresentKeyCastDmlDesc); + opDescs.push_back(&outputPresentValueCastDmlDesc); + + // Link the query/key/value inputs to the cast nodes + for (uint32_t i = 0; i < 3; ++i) + { + DML_INPUT_GRAPH_EDGE_DESC inputToMhaEdge = {}; + inputToMhaEdge.GraphInputIndex = i; + inputToMhaEdge.ToNodeIndex = 1 + i; + inputToMhaEdge.ToNodeInputIndex = 0; + inputEdges.push_back(inputToMhaEdge); + } + + // Link the input cast nodes to MHA + for (uint32_t i = 0; i < 3; ++i) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC castToMhaEdge = {}; + castToMhaEdge.FromNodeIndex = 1 + i; + castToMhaEdge.FromNodeOutputIndex = 0; + castToMhaEdge.ToNodeIndex = 0; + castToMhaEdge.ToNodeInputIndex = i; + intermediateEdges.push_back(castToMhaEdge); + } + } + else + { + // Link the query/key/value inputs to MHA + for (uint32_t i = 0; i < 3; ++i) + { + DML_INPUT_GRAPH_EDGE_DESC inputToMhaEdge = {}; + inputToMhaEdge.GraphInputIndex = i; + inputToMhaEdge.ToNodeIndex = 0; + inputToMhaEdge.ToNodeInputIndex = i; + inputEdges.push_back(inputToMhaEdge); + } + } + + constexpr uint32_t dmlPastSequenceLengthsIndex = 11; + + // The GQA offline fusion does this thing where it sums the number of 1's in the mask to figure out the value of the past sequence. + // This doesn't work well for the first iteration since, obviously, there are no past sequences and the mask in this case represents + // only the elements in the initial sequence. To work around this, the CUDA implementation of the operator ignores the value of + // pastSequenceLengths for the first iteration and acts as if it was 0. This feels like a pretty dirty hack and something that should + // be polished in the future, but for compatibility with the GQA fusion and the CUDA implementation we do the same thing here. We DO NOT + // want to do this within DirectML since DirectML should be agnostic w.r.t which iteration it's currently executing MHA for, and such a + // hack that is likely to be modified in the future shouldn't be enshrined within DirectML. Doing it here is OK because the nature of contrib + // ops is that they can change at any time. + if (sequenceLength == 1) + { + // Link the PastSequenceLengths input to MHA + DML_INPUT_GRAPH_EDGE_DESC inputToMhaEdge = {}; + inputToMhaEdge.GraphInputIndex = seqLensIndex; + inputToMhaEdge.ToNodeIndex = 0; + inputToMhaEdge.ToNodeInputIndex = dmlPastSequenceLengthsIndex; + inputEdges.push_back(inputToMhaEdge); + } + else + { + opDescs.push_back(&zeroScalarDmlDesc); + + // Link the zero scalar to MHA + DML_INTERMEDIATE_GRAPH_EDGE_DESC zeroScalarToMhaEdge = {}; + zeroScalarToMhaEdge.FromNodeIndex = gsl::narrow_cast(opDescs.size() - 1); + zeroScalarToMhaEdge.FromNodeOutputIndex = 0; + zeroScalarToMhaEdge.ToNodeIndex = 0; + zeroScalarToMhaEdge.ToNodeInputIndex = dmlPastSequenceLengthsIndex; + intermediateEdges.push_back(zeroScalarToMhaEdge); + } + + if (isFp16) + { + // Output cast nodes start at the 4th index (previously we have the mha, query, key and value nodes) + const uint32_t outputCastNodeStart = 4; + + // Link MHA's output to the output cast nodes + for (uint32_t i = 0; i < 3; ++i) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC mhaToCastEdge = {}; + mhaToCastEdge.FromNodeIndex = 0; + mhaToCastEdge.FromNodeOutputIndex = i; + mhaToCastEdge.ToNodeIndex = outputCastNodeStart + i; + mhaToCastEdge.ToNodeInputIndex = 0; + intermediateEdges.push_back(mhaToCastEdge); + } + + // Link the output cast nodes to the graph's outputs + for (uint32_t i = 0; i < 3; ++i) + { + DML_OUTPUT_GRAPH_EDGE_DESC castToOutputEdge = {}; + castToOutputEdge.FromNodeIndex = outputCastNodeStart + i; + castToOutputEdge.FromNodeOutputIndex = 0; + castToOutputEdge.GraphOutputIndex = i; + outputEdges.push_back(castToOutputEdge); + } + } + else + { + // Link MHA's outputs to the graph's outputs + for (uint32_t i = 0; i < 3; ++i) + { + DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {}; + mhaToOutputEdge.FromNodeIndex = 0; + mhaToOutputEdge.FromNodeOutputIndex = i; + mhaToOutputEdge.GraphOutputIndex = i; + outputEdges.push_back(mhaToOutputEdge); + } + } + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodes = opDescs.data(); + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + } +}; + +DML_OP_DEFINE_CREATION_FUNCTION(GroupQueryAttention, DmlOperatorGroupQueryAttention); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp index 03500d0ee86a9..cde08864ca54e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorMultiHeadAttention.cpp @@ -56,8 +56,8 @@ class DmlOperatorMultiHeadAttention : public DmlOperator const bool hasBias = kernelCreationContext.IsInputValid(biasIndex); const bool hasMask = kernelCreationContext.IsInputValid(maskIndex); const bool hasRelativePositionBias = kernelCreationContext.IsInputValid(relativePositionBiasIndex); - const bool hasPastKey = keyValueIsPast || kernelCreationContext.IsInputValid(pastKeyIndex); - const bool hasPastValue = keyValueIsPast || kernelCreationContext.IsInputValid(pastValueIndex); + const bool hasPastKey = keyValueIsPast || (kernelCreationContext.IsInputValid(pastKeyIndex) && kernelCreationContext.GetInputTensorShape(pastKeyIndex)[2] != 0); + const bool hasPastValue = keyValueIsPast || (kernelCreationContext.IsInputValid(pastValueIndex) && kernelCreationContext.GetInputTensorShape(pastValueIndex)[2] != 0); const bool hasPresentKeyOutput = kernelCreationContext.IsOutputValid(outputPresentKeyIndex); const bool hasPresentValueOutput = kernelCreationContext.IsOutputValid(outputPresentValueIndex); const bool stackedQkv = kernelCreationContext.GetInputTensorDimensionCount(queryIndex) == 5; @@ -74,8 +74,8 @@ class DmlOperatorMultiHeadAttention : public DmlOperator biasIndex, hasMask ? std::optional(maskIndex) : std::nullopt, relativePositionBiasIndex, - keyValueIsPast ? keyIndex : pastKeyIndex, - keyValueIsPast ? valueIndex : pastValueIndex, + hasPastKey ? std::optional(keyValueIsPast ? keyIndex : pastKeyIndex) : std::nullopt, + hasPastValue ? std::optional(keyValueIsPast ? valueIndex : pastValueIndex) : std::nullopt, }; std::vector> outputIndices = { diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h index 15dcf4fb174fb..e2f38231f7295 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlSTFT.h @@ -588,7 +588,6 @@ class DmlSTFTOperatorFactory : public WRL::Base shareInferrer.Get(), nullptr, false, // isInternalOperator - false, // alias false, // supportsGraph nullptr, requiredConstantCpuInputs.data(), diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 2230ee74d9ff6..0091210f439a4 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -291,7 +291,7 @@ struct OperatorRegistrationInformation const char* domain; MLOperatorKernelCreateFn creationFunction; MLOperatorShapeInferenceFunction shapeInferenceFunction; - bool canAliasFirstInput; + std::pair, 4>, int> aliases; gsl::span tensorTypeNames; gsl::span supportedTensorDataTypes; @@ -522,6 +522,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Size); DML_OP_EXTERN_CREATION_FUNCTION(QAttention); DML_OP_EXTERN_CREATION_FUNCTION(Attention); DML_OP_EXTERN_CREATION_FUNCTION(MultiHeadAttention); +DML_OP_EXTERN_CREATION_FUNCTION(GroupQueryAttention); DML_OP_EXTERN_CREATION_FUNCTION(NonZero); DML_OP_EXTERN_CREATION_FUNCTION(QuickGelu); DML_OP_EXTERN_CREATION_FUNCTION(BitwiseAnd); @@ -679,29 +680,48 @@ constexpr auto requiredConstantCpuInputs(Args... args) return std::make_pair(inputs, static_cast(sizeof...(args))); } +template +constexpr auto Aliases(Args... args) +{ + if constexpr (sizeof...(args) == 0) + { + std::array, 4> aliases = {std::make_pair(0, 0)}; + return std::make_pair(aliases, 0); + } + else + { + std::array, 4> aliases = {static_cast>(args)...}; + return std::make_pair(aliases, static_cast(sizeof...(args))); + } +} + // Define a single row of OperatorRegistrationInformation. #define REG_INFO(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, ShapeInferenceFunction, Aliases(), ##__VA_ARGS__, #define REG_INFO_DYNAMIC_OUTPUTS(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, nullptr, false, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName, nullptr, Aliases(), ##__VA_ARGS__, // Versioned operator #define REG_INFO_VER(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, false, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, Create##operatorName##version, ShapeInferenceFunction, Aliases(), ##__VA_ARGS__, // Identity operators use Copy, alias their first input, and use elementwise identity operators // when needed for striding support, but issue actual copies outside the graph. #define REG_INFO_COPY(version, operatorName, ...) \ - #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, true, ##__VA_ARGS__, + #operatorName, OnnxOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kOnnxDomain, CreateCopy, ShapeInferenceFunction, Aliases(std::make_pair(0, 0)), ##__VA_ARGS__, // MS-domain operators #define REG_INFO_MS(version, operatorName, ...) \ - #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, + #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, Aliases(), ##__VA_ARGS__, + +// MS-domain operators +#define REG_INFO_MS_ALIAS(version, operatorName, aliases, ...) \ + #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDomain, Create##operatorName, ShapeInferenceFunction, aliases, ##__VA_ARGS__, // MS-domain operators #define REG_INFO_MSDML(version, operatorName, ...) \ - #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, false, ##__VA_ARGS__, + #operatorName, MsftOperatorSet##version::sc_sinceVer_##operatorName, onnxruntime::kMSDmlDomain, Create##operatorName, ShapeInferenceFunction, Aliases(), ##__VA_ARGS__, constexpr static OperatorRegistrationInformation operatorRegistrationInformationTable[] = { @@ -1120,6 +1140,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_MS( 1, BiasAdd, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QuickGelu, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, GroupNorm, typeNameListGroupNorm, supportedTypeListGroupNorm, DmlGraphSupport::Supported)}, + + // Operators that need to alias an input with an output + {REG_INFO_MS_ALIAS(1, GroupQueryAttention, Aliases(std::make_pair(3, 1), std::make_pair(4, 2)), typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(6))}, }; template @@ -1275,11 +1298,12 @@ void RegisterDmlOperators(IMLOperatorRegistry* registry) shapeInferrer.Get(), supportQuery.Get(), true, // isInternalOperator - information.canAliasFirstInput, // alias kernelSupportsGraph, // supportsGraph information.requiredInputCountForDmlGraphSupport ? &(*information.requiredInputCountForDmlGraphSupport) : nullptr, information.requiredConstantCpuInputs.first.data(), - static_cast(information.requiredConstantCpuInputs.second) + static_cast(information.requiredConstantCpuInputs.second), + information.aliases.first.data(), + static_cast(information.aliases.second) )); } diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index 287deaa513f64..5d5806865a8da 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -108,6 +108,7 @@ namespace AttrName static constexpr const char* QkvHiddenSizes = "qkv_hidden_sizes"; static constexpr const char* Unidirectional = "unidirectional"; static constexpr const char* NumHeads = "num_heads"; + static constexpr const char* KvNumHeads = "kv_num_heads"; static constexpr const char* PastPresentShareBuffer = "past_present_share_buffer"; static constexpr const char* FusedActivation = "fused_activation"; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index ac3a3eb1268b8..8a218470d30bb 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -14,7 +14,7 @@ struct MLOperatorGraphDesc { uint32_t nodeCount; _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodes; - + uint32_t inputEdgeCount; _Field_size_(inputEdgeCount) const DML_INPUT_GRAPH_EDGE_DESC* inputEdges; @@ -35,7 +35,7 @@ IMLOperatorShapeInferenceContextPrivate : public IMLOperatorShapeInferenceContex ) const noexcept PURE; STDMETHOD(TryGetConstantInputTensor)( - uint32_t inputIndex, + uint32_t inputIndex, _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; @@ -72,7 +72,7 @@ IMLOperatorKernelCreationContextPrivate : public IMLOperatorKernelCreationContex ) const noexcept PURE; STDMETHOD(TryGetConstantInputTensor)( - uint32_t inputIndex, + uint32_t inputIndex, _Outptr_ IMLOperatorTensor** tensor ) const noexcept PURE; @@ -174,11 +174,12 @@ IMLOperatorRegistryPrivate : public IUnknown _In_opt_ IMLOperatorShapeInferrer* shapeInferrer, _In_opt_ IMLOperatorSupportQueryPrivate* supportQuery, bool isInternalOperator, - bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph = nullptr, _In_reads_(constantCpuInputCount) const uint32_t* constantCpuInputs = nullptr, - uint32_t constantCpuInputCount = 0 + uint32_t constantCpuInputCount = 0, + _In_reads_(aliasCount) const std::pair* aliases = nullptr, + uint32_t aliasCount = 0 ) const noexcept PURE; }; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp index 53522110f6f1c..7637c120867f7 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp @@ -2771,6 +2771,52 @@ namespace OperatorHelper m_numHeads = gsl::narrow_cast(kernelInformation.GetAttributes().GetAttribute(AttrName::NumHeads)); } + std::vector GroupQueryAttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const + { + ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 2); + + const auto queryShape = shapeInfo.GetInputTensorShape(0); + ML_CHECK_VALID_ARGUMENT(queryShape.size() == 3); + const uint32_t batchSize = queryShape[0]; + const uint32_t sequenceLength = queryShape[1]; + const uint32_t hiddenSize = queryShape[2]; + + const auto keyShape = shapeInfo.GetInputTensorShape(1); + ML_CHECK_VALID_ARGUMENT(keyShape.size() == 3); + const uint32_t kvHiddenSize = keyShape[2]; + const uint32_t kvHeadSize = kvHiddenSize / m_kvNumHeads; + + uint32_t pastSequenceLength = 0; + + if (shapeInfo.IsInputValid(3)) + { + const auto pastKeyShape = shapeInfo.GetInputTensorShape(3); + ML_CHECK_VALID_ARGUMENT(pastKeyShape.size() == 4); + pastSequenceLength = pastKeyShape[2]; + } + + const uint32_t presentSequenceLength = std::max(pastSequenceLength, m_totalSequenceLength); + + std::vector outputShapes = + { + EdgeShapes({batchSize, sequenceLength, hiddenSize}), + EdgeShapes({batchSize, m_kvNumHeads, presentSequenceLength, kvHeadSize}), + EdgeShapes({batchSize, m_kvNumHeads, presentSequenceLength, kvHeadSize}), + }; + + return outputShapes; + } + + void GroupQueryAttentionHelper::Initialize(const IKernelInformationAdapter& kernelInformation) + { + m_kvNumHeads = gsl::narrow_cast(kernelInformation.GetAttributes().GetAttribute(AttrName::KvNumHeads)); + + std::vector totalSequenceLength; + ReadCpuLocalTensorIntoInt32(kernelInformation.GetConstantInputTensor(6), /*out*/ totalSequenceLength); + ML_CHECK_VALID_ARGUMENT(totalSequenceLength.size() == 1, "total_sequence_length must be a scalar."); + m_totalSequenceLength = totalSequenceLength[0]; + } + std::vector AttentionHelper::GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const { ML_CHECK_VALID_ARGUMENT(shapeInfo.GetInputCount() >= 2); diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 42468d7829c8f..0ad6ff6e3ee6d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1538,6 +1538,27 @@ class MultiHeadAttentionHelper uint32_t m_numHeads; }; +class GroupQueryAttentionHelper +{ +public: + template + GroupQueryAttentionHelper(const Info_t& info, const Shape_t& shapeInfo) + { + Initialize(KernelInformationAdapter(info)); + } + + std::vector GetOutputShapes(const MLShapeInferenceContext& shapeInfo) const; + +protected: + uint32_t GetTotalSequenceLength() { return m_totalSequenceLength; } + +private: + void Initialize(const IKernelInformationAdapter& kernelInformation); + + uint32_t m_kvNumHeads; + uint32_t m_totalSequenceLength; +}; + class AttentionHelper { public: @@ -1720,6 +1741,7 @@ using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_QAttention = QAttentionHelper; using ShapeInferenceHelper_Attention = AttentionHelper; using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper; +using ShapeInferenceHelper_GroupQueryAttention = GroupQueryAttentionHelper; using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper; using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 9d2f88008185b..1c84e3badae2c 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -454,6 +454,7 @@ namespace OperatorHelper static const int sc_sinceVer_Attention = 1; static const int sc_sinceVer_MatMulIntegerToFloat = 1; static const int sc_sinceVer_MultiHeadAttention = 1; + static const int sc_sinceVer_GroupQueryAttention = 1; static const int sc_sinceVer_SkipLayerNormalization = 1; static const int sc_sinceVer_SkipSimplifiedLayerNormalization = 1; static const int sc_sinceVer_EmbedLayerNormalization = 1; diff --git a/winml/adapter/abi_custom_registry_impl.cpp b/winml/adapter/abi_custom_registry_impl.cpp index a3921d8a92fb7..7aceafaf5fb2c 100644 --- a/winml/adapter/abi_custom_registry_impl.cpp +++ b/winml/adapter/abi_custom_registry_impl.cpp @@ -51,11 +51,12 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( _In_opt_ IMLOperatorShapeInferrer* shapeInferrer, _In_opt_ IMLOperatorSupportQueryPrivate* supportQuery, bool isInternalOperator, - bool canAliasFirstInput, bool supportsGraph, const uint32_t* requiredInputCountForGraph, _In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs, - uint32_t constantCpuInputCount + uint32_t constantCpuInputCount, + _In_reads_(aliasCount) const std::pair* aliases, + uint32_t aliasCount ) const noexcept try { #ifdef LAYERING_DONE // Log a custom op telemetry if the operator is not a built-in DML operator @@ -73,11 +74,12 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistryImpl::RegisterOperatorKernel( shapeInferrer, supportQuery, isInternalOperator, - canAliasFirstInput, supportsGraph, requiredInputCountForGraph, requiredConstantCpuInputs, - constantCpuInputCount + constantCpuInputCount, + aliases, + aliasCount ); } CATCH_RETURN(); diff --git a/winml/adapter/abi_custom_registry_impl.h b/winml/adapter/abi_custom_registry_impl.h index 1942a2e4f82f1..101997d7c94ab 100644 --- a/winml/adapter/abi_custom_registry_impl.h +++ b/winml/adapter/abi_custom_registry_impl.h @@ -26,11 +26,12 @@ class AbiCustomRegistryImpl : public AbiCustomRegistry { _In_opt_ IMLOperatorShapeInferrer* shape_inferrer, _In_opt_ IMLOperatorSupportQueryPrivate* supportQuery, bool is_internal_operator, - bool can_alias_first_input, bool supports_graph, const uint32_t* required_input_count_for_graph = nullptr, _In_reads_(constant_cpu_input_count) const uint32_t* required_constant_cpu_inputs = nullptr, - uint32_t constant_cpu_input_count = 0 + uint32_t constant_cpu_input_count = 0, + _In_reads_(aliasCount) const std::pair* aliases = nullptr, + uint32_t aliasCount = 0 ) const noexcept override; HRESULT STDMETHODCALLTYPE RegisterOperatorKernel(