Skip to content

Commit

Permalink
[DML EP] Add GroupQueryAttention (#20327)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Apr 19, 2024
1 parent 7c80c39 commit 4d98f06
Show file tree
Hide file tree
Showing 17 changed files with 465 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,7 @@ Do not modify directly.*
|FusedMatMulActivation|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**M** = tensor(float), tensor(float16)<br/> **T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float), tensor(float16)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
IMLOperatorKernelFactory* operatorKernelFactory,
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer) const noexcept
{
return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, nullptr, false, false, false);
return RegisterOperatorKernel(opKernel, operatorKernelFactory, shapeInferrer, nullptr, false, false);
}

HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
Expand All @@ -339,11 +339,12 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool isInternalOperator,
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs,
uint32_t constantCpuInputCount) const noexcept
uint32_t constantCpuInputCount,
_In_reads_(aliasCount) const std::pair<uint32_t, uint32_t>* aliases,
uint32_t aliasCount) const noexcept
{
ORT_TRY
{
Expand Down Expand Up @@ -417,9 +418,9 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
builder.InputMemoryType(::OrtMemType::OrtMemTypeCPUInput, inputIndex);
}

if (canAliasFirstInput)
for (uint32_t i = 0; i < aliasCount; ++i)
{
builder.Alias(0, 0);
builder.Alias(aliases[i].first, aliases[i].second);
}

// Set type constraints
Expand Down Expand Up @@ -553,7 +554,7 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel(
else
{
// Currently unsupported for external operators
if (canAliasFirstInput ||
if (aliasCount > 0 ||
supportsGraph ||
requiredInputCountForGraph)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace WRL
}

namespace Windows::AI::MachineLearning::Adapter
{
{

using namespace Microsoft::WRL;

Expand All @@ -38,11 +38,12 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
_In_opt_ IMLOperatorShapeInferrer* shapeInferrer,
_In_opt_ IMLOperatorSupportQueryPrivate* supportQuery,
bool isInternalOperator,
bool canAliasFirstInput,
bool supportsGraph,
const uint32_t* requiredInputCountForGraph = nullptr,
_In_reads_(constantCpuInputCount) const uint32_t* requiredConstantCpuInputs = nullptr,
uint32_t constantCpuInputCount = 0) const noexcept override;
uint32_t constantCpuInputCount = 0,
_In_reads_(aliasCount) const std::pair<uint32_t, uint32_t>* aliases = nullptr,
uint32_t aliasCount = 0) const noexcept override;

HRESULT STDMETHODCALLTYPE RegisterOperatorKernel(
const MLOperatorKernelDescription* opKernel,
Expand All @@ -56,7 +57,7 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis
{
registries.push_back(registry.second);
}

registries.push_back(m_kernelRegistry);

return registries;
Expand Down Expand Up @@ -86,15 +87,15 @@ class AbiCustomRegistry : public WRL::Base<IMLOperatorRegistry, IMLOperatorRegis

private:
static onnx::OpSchema ConvertOpSchema(
_In_z_ const char* domain,
_In_z_ const char* domain,
const MLOperatorSchemaDescription& abiSchema,
IMLOperatorTypeInferrer* typeInferrer,
IMLOperatorShapeInferrer* shapeInferrer);

static std::string ConvertFormalParameterType(const MLOperatorSchemaEdgeDescription& formalParameter);
static onnx::OpSchema::FormalParameterOption ConvertFormalParameterOption(MLOperatorParameterOptions options);
static void SetAttributesAndDefaults(onnx::OpSchema& schema, const MLOperatorSchemaDescription& abiSchema);

static AttributeMap GetDefaultAttributes(const MLOperatorKernelDescription* opKernel);

std::shared_ptr<onnxruntime::CustomRegistry> m_kernelRegistry;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,16 @@ namespace Dml::GraphDescBuilder
// Mapping from the old indices to the new indices that have been shifted after removing earlier nodes
std::vector<uint32_t> shiftedIndicesMapping(graphNodes.size());

std::unordered_set<uint32_t> nodesRemoved;

uint32_t shift = 0;
for (uint32_t nodeIndex = 0; nodeIndex < graphNodes.size(); ++nodeIndex)
{
if (nodesData[nodeIndex].state == NodeState::NotVisited)
{
// The node is not connected, so we simply increase the shift value (the node will be overwritten by the following nodes)
++shift;
nodesRemoved.insert(nodeIndex);
}
else
{
Expand All @@ -125,6 +128,13 @@ namespace Dml::GraphDescBuilder

graphNodes.resize(graphNodes.size() - shift);

// Remove the inputs that are not connected to anything anymore
auto inputEdgesEndIter = std::remove_if(graphInputEdges.begin(), graphInputEdges.end(), [&nodesRemoved](const auto& inputEdge) {
return nodesRemoved.count(inputEdge.ToNodeIndex);
});

graphInputEdges.erase(inputEdgesEndIter, graphInputEdges.end());

// Adjust the node indices in the input edges
std::unordered_set<uint32_t> usedInputEdgeIndex;
for (auto& inputEdge : graphInputEdges)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1147,7 +1147,6 @@ class GpuDFTOperatorFactory : public WRL::Base<IMLOperatorKernelFactory>
shareInferrer.Get(),
nullptr,
false, // isInternalOperator
false, // alias
false, // supportsGraph
nullptr,
requiredConstantCpuInputs.data(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,6 @@ class DmlGridSampleOperatorFactory : public WRL::Base<IMLOperatorKernelFactory>
shareInferrer.Get(),
nullptr,
false, // isInternalOperator
false, // alias
false, // supportsGraph
nullptr,
nullptr,
Expand Down
Loading

0 comments on commit 4d98f06

Please sign in to comment.