diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0317ffcfb0e31..d2e07ad9543ef 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -1059,7 +1059,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "(k-v buffer), it is of length max_sequence_length... otherwise of length past_sequence_length +" "kv_sequence_length.", "T") - .TypeConstraint("T", {"tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") + .TypeConstraint("T", {"tensor(float16)", "tensor(float)", "tensor(bfloat16)"}, "Constrain input and output to float tensors.") .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to int tensor.") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { GroupQueryAttentionTypeAndShapeInference(ctx, 3); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index c6a15e76f4736..eaf1fde0e32f8 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -105,6 +105,8 @@ namespace Dml::GraphDescBuilder // Mapping from the old indices to the new indices that have been shifted after removing earlier nodes std::vector shiftedIndicesMapping(graphNodes.size()); + std::unordered_set nodesRemoved; + uint32_t shift = 0; for (uint32_t nodeIndex = 0; nodeIndex < graphNodes.size(); ++nodeIndex) { @@ -112,6 +114,7 @@ namespace Dml::GraphDescBuilder { // The node is not connected, so we simply increase the shift value (the node will be overwritten by the following nodes) ++shift; + nodesRemoved.insert(nodeIndex); } else { @@ -123,6 +126,13 @@ namespace Dml::GraphDescBuilder graphNodes.resize(graphNodes.size() - shift); + // Remove the inputs that are not connected to anything anymore + auto inputEdgesEndIter = std::remove_if(graphInputEdges.begin(), graphInputEdges.end(), [&nodesRemoved](const auto& inputEdge) { + return nodesRemoved.count(inputEdge.ToNodeIndex); + }); + + graphInputEdges.erase(inputEdgesEndIter, graphInputEdges.end()); + // Adjust the node indices in the input edges for (auto& inputEdge : graphInputEdges) { @@ -344,8 +354,8 @@ namespace Dml::GraphDescBuilder dmlFusedNodeInputIndex < isConstGpuGraphInputCount && isConstGpuGraphInput[dmlFusedNodeInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently // only used for small inputs. uint32_t c_maxConstNodeDataSize = 8; @@ -357,7 +367,7 @@ namespace Dml::GraphDescBuilder if (constantInput && tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) { - // The tensor description's size should be no larger than the constant input unless it was rounded to + // The tensor description's size should be no larger than the constant input unless it was rounded to // the required alignment. assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp index 5cd8aa574155f..897fb53559034 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupQueryAttention.cpp @@ -55,7 +55,11 @@ class DmlOperatorGroupQueryAttention : public DmlOperator, public GroupQueryAtte inputIndices[dmlQueryIndex] = queryIndex; inputIndices[dmlKeyIndex] = keyIndex; inputIndices[dmlValueIndex] = valueIndex; - inputIndices[dmlPastSequenceLengthsIndex] = seqLensIndex; + + if (kernelCreationContext.GetInputTensorShape(queryIndex)[1] == 1) + { + inputIndices[dmlPastSequenceLengthsIndex] = seqLensIndex; + } std::vector> outputIndices = { outputIndex, @@ -101,20 +105,23 @@ class DmlOperatorGroupQueryAttention : public DmlOperator, public GroupQueryAtte ML_CHECK_VALID_ARGUMENT(valueSizes[1] == kvSequenceLength); ML_CHECK_VALID_ARGUMENT(valueSizes[2] == kvHiddenSize); - // Validate PastSequenceLengths dimensions - if (m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetDimensionCount() == 1) - { - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetSizes()[0] == batchSize); - } - else + if (sequenceLength == 1) { - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetDimensionCount() == 2); - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetSizes()[0] == batchSize); - ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetSizes()[1] == 1); + // Validate PastSequenceLengths dimensions + if (m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetDimensionCount() == 1) + { + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetSizes()[0] == batchSize); + } + else + { + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetDimensionCount() == 2); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetSizes()[0] == batchSize); + ML_CHECK_VALID_ARGUMENT(m_inputTensorDescs[dmlPastSequenceLengthsIndex].GetSizes()[1] == 1); + } } const std::array pastSequenceLengthsShape = {batchSize}; - auto pastSequenceLengthsDataType = kernelCreationContext.GetInputEdgeDescription(seqLensIndex).tensorDataType; + auto pastSequenceLengthsDataType = MLOperatorTensorDataType::Int32; TensorDesc pastSequenceLengthsTensorDesc = TensorDesc::ConstructDefaultTensorDesc(pastSequenceLengthsDataType, pastSequenceLengthsShape); const DML_TENSOR_DESC pastSequenceLengthsDmlTensorDesc = pastSequenceLengthsTensorDesc.GetDmlDesc(); @@ -132,9 +139,76 @@ class DmlOperatorGroupQueryAttention : public DmlOperator, public GroupQueryAtte mhaDesc.QueryHeadCount = queryNumHeads; mhaDesc.KeyValueHeadCount = kvNumHeads; mhaDesc.Scale = kernelCreationContext.GetOptionalAttribute(AttrName::Scale, gsl::narrow_cast(1.0f / std::sqrt(queryHeadSize))); + DML_OPERATOR_DESC mhaDmlDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION1, &mhaDesc }; - DML_OPERATOR_DESC opDesc = { DML_OPERATOR_MULTIHEAD_ATTENTION1, &mhaDesc }; - SetDmlOperatorDesc(opDesc, kernelCreationContext); + if (sequenceLength == 1) + { + SetDmlOperatorDesc(mhaDmlDesc, kernelCreationContext); + } + else + { + // The GQA offline fusion does this thing where it sums the number of 1's in the mask to figure out the value of the past sequence. + // This doesn't work well for the first iteration since, obviously, there are no past sequences and the mask in this case represents + // only the elements in the initial sequence. To work around this, the CUDA implementation of the operator ignores the value of + // pastSequenceLengths for the first iteration and acts as if it was 0. This feels like a pretty dirty hack and something that should + // be polished in the future, but for compatibility with the GQA fusion and the CUDA implementation we do the same thing here. We DO NOT + // want to do this within DirectML since DirectML should be agnostic w.r.t which iteration it's currently executing MHA for, and such a + // hack that is likely to be modified in the future shouldn't be enshrined within DirectML. Doing it here is OK because the nature of contrib + // ops is that they can change at any time. + DML_FILL_VALUE_CONSTANT_OPERATOR_DESC zeroScalarDesc = {}; + zeroScalarDesc.OutputTensor = &pastSequenceLengthsDmlTensorDesc; + zeroScalarDesc.ValueDataType = pastSequenceLengthsTensorDesc.GetDmlDataType(); + DML_OPERATOR_DESC zeroScalarDmlDesc = { DML_OPERATOR_FILL_VALUE_CONSTANT, &zeroScalarDesc }; + + std::vector opDescs = { + &zeroScalarDmlDesc, + &mhaDmlDesc, + }; + + // Construct the graph + std::vector inputEdges; + std::vector intermediateEdges; + std::vector outputEdges; + + // Link the query/key/value inputs to MHA + for (uint32_t i = 0; i < 3; ++i) + { + DML_INPUT_GRAPH_EDGE_DESC inputToMhaEdge = {}; + inputToMhaEdge.GraphInputIndex = i; + inputToMhaEdge.ToNodeIndex = 1; + inputToMhaEdge.ToNodeInputIndex = i; + inputEdges.push_back(inputToMhaEdge); + } + + // Link the zero scalar to MHA + DML_INTERMEDIATE_GRAPH_EDGE_DESC zeroScalarToMhaEdge = {}; + zeroScalarToMhaEdge.FromNodeIndex = 0; + zeroScalarToMhaEdge.FromNodeOutputIndex = 0; + zeroScalarToMhaEdge.ToNodeIndex = 1; + zeroScalarToMhaEdge.ToNodeInputIndex = dmlPastSequenceLengthsIndex; + intermediateEdges.push_back(zeroScalarToMhaEdge); + + // Link MHA's outputs to the graph's outputs + for (uint32_t i = 0; i < 3; ++i) + { + DML_OUTPUT_GRAPH_EDGE_DESC mhaToOutputEdge = {}; + mhaToOutputEdge.FromNodeIndex = 1; + mhaToOutputEdge.FromNodeOutputIndex = i; + mhaToOutputEdge.GraphOutputIndex = i; + outputEdges.push_back(mhaToOutputEdge); + } + + MLOperatorGraphDesc operatorGraphDesc = {}; + operatorGraphDesc.inputEdgeCount = gsl::narrow_cast(inputEdges.size()); + operatorGraphDesc.inputEdges = inputEdges.data(); + operatorGraphDesc.intermediateEdgeCount = gsl::narrow_cast(intermediateEdges.size()); + operatorGraphDesc.intermediateEdges = intermediateEdges.data(); + operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); + operatorGraphDesc.outputEdges = outputEdges.data(); + operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); + operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); + } } };