From a9568935a52b3d51ec802a4ab89ab3852129fc1e Mon Sep 17 00:00:00 2001 From: Sumit Agarwal Date: Mon, 26 Feb 2024 11:35:13 -0800 Subject: [PATCH] [DML EP] Enable DML Graph Serialization (#19505) ### Description This PR adds a feature to serialize all DML EP partitions into DML currency individually for a given a model. This feature can be dynamically turned on by using DML EP option `ep.dml.enable_graph_serialization`. ### Motivation and Context - Why is this change required? What problem does it solve? Useful when user want to capture the DML EP specific partition into DML currency to mitigate the dependency on the framework. --- .../inc/IWinmlExecutionProvider.h | 7 +- .../DmlExecutionProvider/src/ApiTraits.cpp | 570 +++++++ .../src/DmlGraphDeserialization.cpp | 554 +++++++ .../src/DmlGraphFusionHelper.cpp | 247 ++- .../src/DmlGraphFusionHelper.h | 19 +- .../src/DmlGraphFusionTransformer.cpp | 41 +- .../src/DmlGraphFusionTransformer.h | 4 +- .../src/DmlGraphSerialization.cpp | 580 ++++++++ .../src/DmlRuntimeFusedGraphKernel.cpp | 30 +- .../src/External/DirectMLHelpers/ApiTraits.h | 453 +++++- .../External/DirectMLHelpers/DirectMLSchema.h | 112 +- .../DirectMLHelpers/DmlGraphDesc_generated.h | 788 ++++++++++ .../DirectMLHelpers/DmlGraphDeserialization.h | 14 + .../DirectMLHelpers/DmlGraphSerialization.h | 8 + .../DirectMLHelpers/DmlSerializedGraphDesc.h | 73 + .../DirectMLHelpers/GeneratedSchemaHelpers.h | 92 +- .../DirectMLHelpers/GeneratedSchemaTypes.h | 32 +- .../OperatorFieldTypes_generated.h | 1318 +++++++++++++++++ .../External/DirectMLHelpers/SchemaHelpers.h | 54 +- .../src/GraphDescBuilder.cpp | 404 ++--- .../src/GraphDescBuilder.h | 21 +- .../src/MLOperatorAuthorImpl.cpp | 30 +- .../src/Operators/DmlOperator.cpp | 4 +- .../src/Operators/DmlOperatorAttention.cpp | 2 +- .../src/Operators/DmlOperatorBiasAdd.cpp | 2 +- .../Operators/DmlOperatorBiasSplitGelu.cpp | 2 +- .../DmlOperatorEmbedLayerNormalization.cpp | 2 +- .../src/Operators/DmlOperatorGroupNorm.cpp | 2 +- .../DmlOperatorLayerNormalization.cpp | 2 +- .../Operators/DmlOperatorQLinearConcat.cpp | 2 +- .../Operators/DmlOperatorQLinearSigmoid.cpp | 2 +- .../src/Operators/DmlOperatorQuickGelu.cpp | 2 +- .../Operators/DmlOperatorRotaryEmbedding.cpp | 2 +- .../DmlOperatorSkipLayerNormalization.cpp | 2 +- .../dml/DmlExecutionProvider/src/Utility.h | 141 ++ .../dml/DmlExecutionProvider/src/precomp.h | 7 + .../MLOperatorAuthorPrivate.h | 11 +- .../dml/dml_session_options_config_keys.h | 1 + onnxruntime/core/session/inference_session.cc | 9 +- .../test/perftest/command_args_parser.cc | 1 + onnxruntime/test/perftest/ort_test_session.cc | 10 + 41 files changed, 5203 insertions(+), 454 deletions(-) create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h create mode 100644 onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index f29cc3afc3cda..88e3dd487d427 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -80,15 +80,10 @@ namespace Windows::AI::MachineLearning::Adapter }; // This is the counterpart to the MLOperatorGraphDesc ABI struct which owns its memory and uses containers. - // Either nodesAsOperatorDesc or nodesAsIDMLOperator can have non-zero size. struct DmlGraphNodeCreateInfo { uint32_t nodeCount = 0; - std::vector> nodesAsOperatorDesc; - - // TODO (jeffbloo): Remove this - std::vector> nodesAsIDMLOperator; - + std::vector> nodes; std::vector inputEdges; std::vector outputEdges; std::vector intermediateEdges; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp new file mode 100644 index 0000000000000..bf9800458102b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ApiTraits.cpp @@ -0,0 +1,570 @@ +//--------------------------------------------------------------------------- +// Copyright (c) Microsoft Corporation. All rights reserved. +// +// This file is automatically generated. Please do not edit it directly. +// To modify this file, edit the schema: dml/Tools/DirectMLSchema.json +// And run this script to regenerate: dml/Tools/GenerateSchema.ps1 +// +// #dml-new-operator-location +//--------------------------------------------------------------------------- + +#pragma once + +#include "precomp.h" + +template +T ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ +#ifndef WAI_BUILD_LINUX + // Clang will instantiate this template even if it isn't used, + // so this static_assert will always fire and break the build. + static_assert(false, "Not implemented for this type"); +#endif +} + +template <> +DML_TENSOR_DATA_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_TENSOR_DATA_TYPE_UNKNOWN", DML_TENSOR_DATA_TYPE_UNKNOWN}, + {"DML_TENSOR_DATA_TYPE_FLOAT32", DML_TENSOR_DATA_TYPE_FLOAT32}, + {"DML_TENSOR_DATA_TYPE_FLOAT16", DML_TENSOR_DATA_TYPE_FLOAT16}, + {"DML_TENSOR_DATA_TYPE_UINT32", DML_TENSOR_DATA_TYPE_UINT32}, + {"DML_TENSOR_DATA_TYPE_UINT16", DML_TENSOR_DATA_TYPE_UINT16}, + {"DML_TENSOR_DATA_TYPE_UINT8", DML_TENSOR_DATA_TYPE_UINT8}, + {"DML_TENSOR_DATA_TYPE_INT32", DML_TENSOR_DATA_TYPE_INT32}, + {"DML_TENSOR_DATA_TYPE_INT16", DML_TENSOR_DATA_TYPE_INT16}, + {"DML_TENSOR_DATA_TYPE_INT8", DML_TENSOR_DATA_TYPE_INT8}, + {"DML_TENSOR_DATA_TYPE_FLOAT64", DML_TENSOR_DATA_TYPE_FLOAT64}, + {"DML_TENSOR_DATA_TYPE_UINT64", DML_TENSOR_DATA_TYPE_UINT64}, + {"DML_TENSOR_DATA_TYPE_INT64", DML_TENSOR_DATA_TYPE_INT64}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_TENSOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_TENSOR_TYPE_INVALID", DML_TENSOR_TYPE_INVALID}, + {"DML_TENSOR_TYPE_BUFFER", DML_TENSOR_TYPE_BUFFER}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_OPERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_OPERATOR_INVALID", DML_OPERATOR_INVALID}, + {"DML_OPERATOR_ELEMENT_WISE_IDENTITY", DML_OPERATOR_ELEMENT_WISE_IDENTITY}, + {"DML_OPERATOR_ELEMENT_WISE_ABS", DML_OPERATOR_ELEMENT_WISE_ABS}, + {"DML_OPERATOR_ELEMENT_WISE_ACOS", DML_OPERATOR_ELEMENT_WISE_ACOS}, + {"DML_OPERATOR_ELEMENT_WISE_ADD", DML_OPERATOR_ELEMENT_WISE_ADD}, + {"DML_OPERATOR_ELEMENT_WISE_ASIN", DML_OPERATOR_ELEMENT_WISE_ASIN}, + {"DML_OPERATOR_ELEMENT_WISE_ATAN", DML_OPERATOR_ELEMENT_WISE_ATAN}, + {"DML_OPERATOR_ELEMENT_WISE_CEIL", DML_OPERATOR_ELEMENT_WISE_CEIL}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP", DML_OPERATOR_ELEMENT_WISE_CLIP}, + {"DML_OPERATOR_ELEMENT_WISE_COS", DML_OPERATOR_ELEMENT_WISE_COS}, + {"DML_OPERATOR_ELEMENT_WISE_DIVIDE", DML_OPERATOR_ELEMENT_WISE_DIVIDE}, + {"DML_OPERATOR_ELEMENT_WISE_EXP", DML_OPERATOR_ELEMENT_WISE_EXP}, + {"DML_OPERATOR_ELEMENT_WISE_FLOOR", DML_OPERATOR_ELEMENT_WISE_FLOOR}, + {"DML_OPERATOR_ELEMENT_WISE_LOG", DML_OPERATOR_ELEMENT_WISE_LOG}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND", DML_OPERATOR_ELEMENT_WISE_LOGICAL_AND}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS", DML_OPERATOR_ELEMENT_WISE_LOGICAL_EQUALS}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_GREATER_THAN_OR_EQUAL}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL", DML_OPERATOR_ELEMENT_WISE_LOGICAL_LESS_THAN_OR_EQUAL}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT", DML_OPERATOR_ELEMENT_WISE_LOGICAL_NOT}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_OR}, + {"DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR", DML_OPERATOR_ELEMENT_WISE_LOGICAL_XOR}, + {"DML_OPERATOR_ELEMENT_WISE_MAX", DML_OPERATOR_ELEMENT_WISE_MAX}, + {"DML_OPERATOR_ELEMENT_WISE_MEAN", DML_OPERATOR_ELEMENT_WISE_MEAN}, + {"DML_OPERATOR_ELEMENT_WISE_MIN", DML_OPERATOR_ELEMENT_WISE_MIN}, + {"DML_OPERATOR_ELEMENT_WISE_MULTIPLY", DML_OPERATOR_ELEMENT_WISE_MULTIPLY}, + {"DML_OPERATOR_ELEMENT_WISE_POW", DML_OPERATOR_ELEMENT_WISE_POW}, + {"DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW", DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW}, + {"DML_OPERATOR_ELEMENT_WISE_RECIP", DML_OPERATOR_ELEMENT_WISE_RECIP}, + {"DML_OPERATOR_ELEMENT_WISE_SIN", DML_OPERATOR_ELEMENT_WISE_SIN}, + {"DML_OPERATOR_ELEMENT_WISE_SQRT", DML_OPERATOR_ELEMENT_WISE_SQRT}, + {"DML_OPERATOR_ELEMENT_WISE_SUBTRACT", DML_OPERATOR_ELEMENT_WISE_SUBTRACT}, + {"DML_OPERATOR_ELEMENT_WISE_TAN", DML_OPERATOR_ELEMENT_WISE_TAN}, + {"DML_OPERATOR_ELEMENT_WISE_THRESHOLD", DML_OPERATOR_ELEMENT_WISE_THRESHOLD}, + {"DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR}, + {"DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR", DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR}, + {"DML_OPERATOR_ACTIVATION_ELU", DML_OPERATOR_ACTIVATION_ELU}, + {"DML_OPERATOR_ACTIVATION_CELU", DML_OPERATOR_ACTIVATION_CELU}, + {"DML_OPERATOR_ACTIVATION_HARDMAX", DML_OPERATOR_ACTIVATION_HARDMAX}, + {"DML_OPERATOR_ACTIVATION_HARDMAX1", DML_OPERATOR_ACTIVATION_HARDMAX1}, + {"DML_OPERATOR_ACTIVATION_HARD_SIGMOID", DML_OPERATOR_ACTIVATION_HARD_SIGMOID}, + {"DML_OPERATOR_ACTIVATION_IDENTITY", DML_OPERATOR_ACTIVATION_IDENTITY}, + {"DML_OPERATOR_ACTIVATION_LEAKY_RELU", DML_OPERATOR_ACTIVATION_LEAKY_RELU}, + {"DML_OPERATOR_ACTIVATION_LINEAR", DML_OPERATOR_ACTIVATION_LINEAR}, + {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX}, + {"DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1", DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1}, + {"DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU", DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU}, + {"DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS", DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS}, + {"DML_OPERATOR_ACTIVATION_RELU", DML_OPERATOR_ACTIVATION_RELU}, + {"DML_OPERATOR_ACTIVATION_SCALED_ELU", DML_OPERATOR_ACTIVATION_SCALED_ELU}, + {"DML_OPERATOR_ACTIVATION_SCALED_TANH", DML_OPERATOR_ACTIVATION_SCALED_TANH}, + {"DML_OPERATOR_ACTIVATION_SIGMOID", DML_OPERATOR_ACTIVATION_SIGMOID}, + {"DML_OPERATOR_ACTIVATION_SOFTMAX", DML_OPERATOR_ACTIVATION_SOFTMAX}, + {"DML_OPERATOR_ACTIVATION_SOFTMAX1", DML_OPERATOR_ACTIVATION_SOFTMAX1}, + {"DML_OPERATOR_ACTIVATION_SOFTPLUS", DML_OPERATOR_ACTIVATION_SOFTPLUS}, + {"DML_OPERATOR_ACTIVATION_SOFTSIGN", DML_OPERATOR_ACTIVATION_SOFTSIGN}, + {"DML_OPERATOR_ACTIVATION_TANH", DML_OPERATOR_ACTIVATION_TANH}, + {"DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU", DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU}, + {"DML_OPERATOR_CONVOLUTION", DML_OPERATOR_CONVOLUTION}, + {"DML_OPERATOR_GEMM", DML_OPERATOR_GEMM}, + {"DML_OPERATOR_REDUCE", DML_OPERATOR_REDUCE}, + {"DML_OPERATOR_AVERAGE_POOLING", DML_OPERATOR_AVERAGE_POOLING}, + {"DML_OPERATOR_AVERAGE_POOLING1", DML_OPERATOR_AVERAGE_POOLING1}, + {"DML_OPERATOR_LP_POOLING", DML_OPERATOR_LP_POOLING}, + {"DML_OPERATOR_LP_POOLING1", DML_OPERATOR_LP_POOLING1}, + {"DML_OPERATOR_MAX_POOLING", DML_OPERATOR_MAX_POOLING}, + {"DML_OPERATOR_ROI_POOLING", DML_OPERATOR_ROI_POOLING}, + {"DML_OPERATOR_SLICE", DML_OPERATOR_SLICE}, + {"DML_OPERATOR_CAST", DML_OPERATOR_CAST}, + {"DML_OPERATOR_SPLIT", DML_OPERATOR_SPLIT}, + {"DML_OPERATOR_JOIN", DML_OPERATOR_JOIN}, + {"DML_OPERATOR_PADDING", DML_OPERATOR_PADDING}, + {"DML_OPERATOR_PADDING1", DML_OPERATOR_PADDING1}, + {"DML_OPERATOR_VALUE_SCALE_2D", DML_OPERATOR_VALUE_SCALE_2D}, + {"DML_OPERATOR_UPSAMPLE_2D", DML_OPERATOR_UPSAMPLE_2D}, + {"DML_OPERATOR_GATHER", DML_OPERATOR_GATHER}, + {"DML_OPERATOR_SPACE_TO_DEPTH", DML_OPERATOR_SPACE_TO_DEPTH}, + {"DML_OPERATOR_DEPTH_TO_SPACE", DML_OPERATOR_DEPTH_TO_SPACE}, + {"DML_OPERATOR_TILE", DML_OPERATOR_TILE}, + {"DML_OPERATOR_TOP_K", DML_OPERATOR_TOP_K}, + {"DML_OPERATOR_BATCH_NORMALIZATION", DML_OPERATOR_BATCH_NORMALIZATION}, + {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING}, + {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION}, + {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION}, + {"DML_OPERATOR_LP_NORMALIZATION", DML_OPERATOR_LP_NORMALIZATION}, + {"DML_OPERATOR_RNN", DML_OPERATOR_RNN}, + {"DML_OPERATOR_LSTM", DML_OPERATOR_LSTM}, + {"DML_OPERATOR_GRU", DML_OPERATOR_GRU}, + {"DML_OPERATOR_ELEMENT_WISE_SIGN", DML_OPERATOR_ELEMENT_WISE_SIGN}, + {"DML_OPERATOR_ELEMENT_WISE_IS_NAN", DML_OPERATOR_ELEMENT_WISE_IS_NAN}, + {"DML_OPERATOR_ELEMENT_WISE_ERF", DML_OPERATOR_ELEMENT_WISE_ERF}, + {"DML_OPERATOR_ELEMENT_WISE_SINH", DML_OPERATOR_ELEMENT_WISE_SINH}, + {"DML_OPERATOR_ELEMENT_WISE_COSH", DML_OPERATOR_ELEMENT_WISE_COSH}, + {"DML_OPERATOR_ELEMENT_WISE_TANH", DML_OPERATOR_ELEMENT_WISE_TANH}, + {"DML_OPERATOR_ELEMENT_WISE_ASINH", DML_OPERATOR_ELEMENT_WISE_ASINH}, + {"DML_OPERATOR_ELEMENT_WISE_ACOSH", DML_OPERATOR_ELEMENT_WISE_ACOSH}, + {"DML_OPERATOR_ELEMENT_WISE_ATANH", DML_OPERATOR_ELEMENT_WISE_ATANH}, + {"DML_OPERATOR_ELEMENT_WISE_IF", DML_OPERATOR_ELEMENT_WISE_IF}, + {"DML_OPERATOR_ELEMENT_WISE_ADD1", DML_OPERATOR_ELEMENT_WISE_ADD1}, + {"DML_OPERATOR_ACTIVATION_SHRINK", DML_OPERATOR_ACTIVATION_SHRINK}, + {"DML_OPERATOR_MAX_POOLING1", DML_OPERATOR_MAX_POOLING1}, + {"DML_OPERATOR_MAX_UNPOOLING", DML_OPERATOR_MAX_UNPOOLING}, + {"DML_OPERATOR_DIAGONAL_MATRIX", DML_OPERATOR_DIAGONAL_MATRIX}, + {"DML_OPERATOR_SCATTER", DML_OPERATOR_SCATTER}, + {"DML_OPERATOR_ONE_HOT", DML_OPERATOR_ONE_HOT}, + {"DML_OPERATOR_RESAMPLE", DML_OPERATOR_RESAMPLE}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_LEFT}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT", DML_OPERATOR_ELEMENT_WISE_BIT_SHIFT_RIGHT}, + {"DML_OPERATOR_ELEMENT_WISE_ROUND", DML_OPERATOR_ELEMENT_WISE_ROUND}, + {"DML_OPERATOR_ELEMENT_WISE_IS_INFINITY", DML_OPERATOR_ELEMENT_WISE_IS_INFINITY}, + {"DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE", DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE}, + {"DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR", DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR}, + {"DML_OPERATOR_FILL_VALUE_SEQUENCE", DML_OPERATOR_FILL_VALUE_SEQUENCE}, + {"DML_OPERATOR_FILL_VALUE_CONSTANT", DML_OPERATOR_FILL_VALUE_CONSTANT}, + {"DML_OPERATOR_CUMULATIVE_SUMMATION", DML_OPERATOR_CUMULATIVE_SUMMATION}, + {"DML_OPERATOR_REVERSE_SUBSEQUENCES", DML_OPERATOR_REVERSE_SUBSEQUENCES}, + {"DML_OPERATOR_GATHER_ELEMENTS", DML_OPERATOR_GATHER_ELEMENTS}, + {"DML_OPERATOR_GATHER_ND", DML_OPERATOR_GATHER_ND}, + {"DML_OPERATOR_SCATTER_ND", DML_OPERATOR_SCATTER_ND}, + {"DML_OPERATOR_MAX_POOLING2", DML_OPERATOR_MAX_POOLING2}, + {"DML_OPERATOR_SLICE1", DML_OPERATOR_SLICE1}, + {"DML_OPERATOR_TOP_K1", DML_OPERATOR_TOP_K1}, + {"DML_OPERATOR_DEPTH_TO_SPACE1", DML_OPERATOR_DEPTH_TO_SPACE1}, + {"DML_OPERATOR_SPACE_TO_DEPTH1", DML_OPERATOR_SPACE_TO_DEPTH1}, + {"DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1", DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION1}, + {"DML_OPERATOR_RESAMPLE1", DML_OPERATOR_RESAMPLE1}, + {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER}, + {"DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY", DML_OPERATOR_QUANTIZED_LINEAR_MATRIX_MULTIPLY}, + {"DML_OPERATOR_CONVOLUTION_INTEGER", DML_OPERATOR_CONVOLUTION_INTEGER}, + {"DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION", DML_OPERATOR_QUANTIZED_LINEAR_CONVOLUTION}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_AND", DML_OPERATOR_ELEMENT_WISE_BIT_AND}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_OR", DML_OPERATOR_ELEMENT_WISE_BIT_OR}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_XOR", DML_OPERATOR_ELEMENT_WISE_BIT_XOR}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_NOT", DML_OPERATOR_ELEMENT_WISE_BIT_NOT}, + {"DML_OPERATOR_ELEMENT_WISE_BIT_COUNT", DML_OPERATOR_ELEMENT_WISE_BIT_COUNT}, + {"DML_OPERATOR_ACTIVATION_RELU_GRAD", DML_OPERATOR_ACTIVATION_RELU_GRAD}, + {"DML_OPERATOR_AVERAGE_POOLING_GRAD", DML_OPERATOR_AVERAGE_POOLING_GRAD}, + {"DML_OPERATOR_MAX_POOLING_GRAD", DML_OPERATOR_MAX_POOLING_GRAD}, + {"DML_OPERATOR_RANDOM_GENERATOR", DML_OPERATOR_RANDOM_GENERATOR}, + {"DML_OPERATOR_NONZERO_COORDINATES", DML_OPERATOR_NONZERO_COORDINATES}, + {"DML_OPERATOR_RESAMPLE_GRAD", DML_OPERATOR_RESAMPLE_GRAD}, + {"DML_OPERATOR_SLICE_GRAD", DML_OPERATOR_SLICE_GRAD}, + {"DML_OPERATOR_ADAM_OPTIMIZER", DML_OPERATOR_ADAM_OPTIMIZER}, + {"DML_OPERATOR_ARGMIN", DML_OPERATOR_ARGMIN}, + {"DML_OPERATOR_ARGMAX", DML_OPERATOR_ARGMAX}, + {"DML_OPERATOR_ROI_ALIGN", DML_OPERATOR_ROI_ALIGN}, + {"DML_OPERATOR_GATHER_ND1", DML_OPERATOR_GATHER_ND1}, + {"DML_OPERATOR_ELEMENT_WISE_ATAN_YX", DML_OPERATOR_ELEMENT_WISE_ATAN_YX}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD}, + {"DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE", DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE}, + {"DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD", DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD}, + {"DML_OPERATOR_CUMULATIVE_PRODUCT", DML_OPERATOR_CUMULATIVE_PRODUCT}, + {"DML_OPERATOR_BATCH_NORMALIZATION_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_GRAD}, + {"DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD", DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD}, + {"DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD", DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD}, + {"DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR", DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR}, + {"DML_OPERATOR_ROI_ALIGN1", DML_OPERATOR_ROI_ALIGN1}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP1", DML_OPERATOR_ELEMENT_WISE_CLIP1}, + {"DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1", DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1}, + {"DML_OPERATOR_ELEMENT_WISE_NEGATE", DML_OPERATOR_ELEMENT_WISE_NEGATE}, + {"DML_OPERATOR_ACTIVATION_GELU", DML_OPERATOR_ACTIVATION_GELU}, + {"DML_OPERATOR_ACTIVATION_SWISH", DML_OPERATOR_ACTIVATION_SWISH}, + {"DML_OPERATOR_ACTIVATION_HARD_SWISH", DML_OPERATOR_ACTIVATION_HARD_SWISH}, + {"DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2}, + {"DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1}, + {"DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1}, + {"DML_OPERATOR_MULTIHEAD_ATTENTION", DML_OPERATOR_MULTIHEAD_ATTENTION}, + {"DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING}, + {"DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_BINDING_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_BINDING_TYPE_NONE", DML_BINDING_TYPE_NONE}, + {"DML_BINDING_TYPE_BUFFER", DML_BINDING_TYPE_BUFFER}, + {"DML_BINDING_TYPE_BUFFER_ARRAY", DML_BINDING_TYPE_BUFFER_ARRAY}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_REDUCE_FUNCTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_REDUCE_FUNCTION_ARGMAX", DML_REDUCE_FUNCTION_ARGMAX}, + {"DML_REDUCE_FUNCTION_ARGMIN", DML_REDUCE_FUNCTION_ARGMIN}, + {"DML_REDUCE_FUNCTION_AVERAGE", DML_REDUCE_FUNCTION_AVERAGE}, + {"DML_REDUCE_FUNCTION_L1", DML_REDUCE_FUNCTION_L1}, + {"DML_REDUCE_FUNCTION_L2", DML_REDUCE_FUNCTION_L2}, + {"DML_REDUCE_FUNCTION_LOG_SUM", DML_REDUCE_FUNCTION_LOG_SUM}, + {"DML_REDUCE_FUNCTION_LOG_SUM_EXP", DML_REDUCE_FUNCTION_LOG_SUM_EXP}, + {"DML_REDUCE_FUNCTION_MAX", DML_REDUCE_FUNCTION_MAX}, + {"DML_REDUCE_FUNCTION_MIN", DML_REDUCE_FUNCTION_MIN}, + {"DML_REDUCE_FUNCTION_MULTIPLY", DML_REDUCE_FUNCTION_MULTIPLY}, + {"DML_REDUCE_FUNCTION_SUM", DML_REDUCE_FUNCTION_SUM}, + {"DML_REDUCE_FUNCTION_SUM_SQUARE", DML_REDUCE_FUNCTION_SUM_SQUARE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + +template <> +DML_MATRIX_TRANSFORM ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_MATRIX_TRANSFORM_NONE", DML_MATRIX_TRANSFORM_NONE}, + {"DML_MATRIX_TRANSFORM_TRANSPOSE", DML_MATRIX_TRANSFORM_TRANSPOSE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_CONVOLUTION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_CONVOLUTION_MODE_CONVOLUTION", DML_CONVOLUTION_MODE_CONVOLUTION}, + {"DML_CONVOLUTION_MODE_CROSS_CORRELATION", DML_CONVOLUTION_MODE_CROSS_CORRELATION}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_CONVOLUTION_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_CONVOLUTION_DIRECTION_FORWARD", DML_CONVOLUTION_DIRECTION_FORWARD}, + {"DML_CONVOLUTION_DIRECTION_BACKWARD", DML_CONVOLUTION_DIRECTION_BACKWARD}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + +template <> +DML_PADDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_PADDING_MODE_CONSTANT", DML_PADDING_MODE_CONSTANT}, + {"DML_PADDING_MODE_EDGE", DML_PADDING_MODE_EDGE}, + {"DML_PADDING_MODE_REFLECTION", DML_PADDING_MODE_REFLECTION}, + {"DML_PADDING_MODE_SYMMETRIC", DML_PADDING_MODE_SYMMETRIC}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_INTERPOLATION_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR", DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR}, + {"DML_INTERPOLATION_MODE_LINEAR", DML_INTERPOLATION_MODE_LINEAR}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_RECURRENT_NETWORK_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_RECURRENT_NETWORK_DIRECTION_FORWARD", DML_RECURRENT_NETWORK_DIRECTION_FORWARD}, + {"DML_RECURRENT_NETWORK_DIRECTION_BACKWARD", DML_RECURRENT_NETWORK_DIRECTION_BACKWARD}, + {"DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL", DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_FEATURE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT", DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT}, + {"DML_FEATURE_FEATURE_LEVELS", DML_FEATURE_FEATURE_LEVELS}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_FEATURE_LEVEL ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_FEATURE_LEVEL_1_0", DML_FEATURE_LEVEL_1_0}, + {"DML_FEATURE_LEVEL_2_0", DML_FEATURE_LEVEL_2_0}, + {"DML_FEATURE_LEVEL_2_1", DML_FEATURE_LEVEL_2_1}, + {"DML_FEATURE_LEVEL_3_0", DML_FEATURE_LEVEL_3_0}, + {"DML_FEATURE_LEVEL_3_1", DML_FEATURE_LEVEL_3_1}, + {"DML_FEATURE_LEVEL_4_0", DML_FEATURE_LEVEL_4_0}, + {"DML_FEATURE_LEVEL_4_1", DML_FEATURE_LEVEL_4_1}, + {"DML_FEATURE_LEVEL_5_0", DML_FEATURE_LEVEL_5_0}, + {"DML_FEATURE_LEVEL_5_1", DML_FEATURE_LEVEL_5_1}, + {"DML_FEATURE_LEVEL_5_2", DML_FEATURE_LEVEL_5_2}, + {"DML_FEATURE_LEVEL_6_0", DML_FEATURE_LEVEL_6_0}, + {"DML_FEATURE_LEVEL_6_1", DML_FEATURE_LEVEL_6_1}, + {"DML_FEATURE_LEVEL_6_2", DML_FEATURE_LEVEL_6_2}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_IS_INFINITY_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_IS_INFINITY_MODE_EITHER", DML_IS_INFINITY_MODE_EITHER}, + {"DML_IS_INFINITY_MODE_POSITIVE", DML_IS_INFINITY_MODE_POSITIVE}, + {"DML_IS_INFINITY_MODE_NEGATIVE", DML_IS_INFINITY_MODE_NEGATIVE}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_DEPTH_SPACE_ORDER ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW", DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW}, + {"DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH", DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_AXIS_DIRECTION ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_AXIS_DIRECTION_INCREASING", DML_AXIS_DIRECTION_INCREASING}, + {"DML_AXIS_DIRECTION_DECREASING", DML_AXIS_DIRECTION_DECREASING}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_ROUNDING_MODE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN", DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN}, + {"DML_ROUNDING_MODE_TOWARD_ZERO", DML_ROUNDING_MODE_TOWARD_ZERO}, + {"DML_ROUNDING_MODE_TOWARD_INFINITY", DML_ROUNDING_MODE_TOWARD_INFINITY}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_RANDOM_GENERATOR_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10", DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + + +template <> +DML_MULTIHEAD_ATTENTION_MASK_TYPE ApiTraits::StringifyHelpers::FromString(std::string_view value) +{ + constexpr StringUtil::NameAndIndex mapping[] = + { + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE", DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END", DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END}, + {"DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN", DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN}, + }; + auto index = StringUtil::MapToIndex(value, mapping); + if (!index) + { + assert(false); + return static_cast(0); + } + return static_cast(*index); +} + diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp new file mode 100644 index 0000000000000..7d8ed17e7d925 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphDeserialization.cpp @@ -0,0 +1,554 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "precomp.h" + +OperatorFieldVariant CreateAttribute( + const DML_SCHEMA_FIELD* schemaField, + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc); + +OperatorFieldVariant CreateActivation( + const dml::ir::operatorFieldTypes::Activation* activationDesc) +{ + DML_OPERATOR_TYPE activationOperatorType = ApiTraits::StringifyHelpers::FromString(activationDesc->type()->c_str()); + const DML_OPERATOR_SCHEMA& activationSchema = SchemaHelpers::GetSchema(activationOperatorType); + std::vector activationOperatorFields(activationSchema.FieldCount); + uint32_t attributeIndex = 0; + + for (uint32_t fieldIndex = 0; fieldIndex < activationSchema.FieldCount; fieldIndex++) + { + const DML_SCHEMA_FIELD* schemaField = &activationSchema.Fields[fieldIndex]; + OperatorFieldVariant field; + switch (schemaField->Kind) + { + case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR: + case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR: + { + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + field = OperatorFieldTypes::TensorDesc(); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + field = OperatorFieldTypes::TensorDescArray(); + } + break; + } + case DML_SCHEMA_FIELD_KIND_ATTRIBUTE: + { + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = + attributeIndex >= activationDesc->attributes()->size() ? + nullptr : + activationDesc->attributes()->Get(attributeIndex++); + field = CreateAttribute(schemaField, attributeDesc); + break; + } + } + + activationOperatorFields[fieldIndex] = OperatorField(schemaField, std::move(field)); + } + + return AbstractOperatorDesc(&activationSchema, std::move(activationOperatorFields)); +} + +OperatorFieldVariant CreateActivations( + const dml::ir::operatorFieldTypes::ActivationArray* activationDescs) +{ + std::vector activations; + for (uint32_t index = 0; index < static_cast(activationDescs->data()->size()); index++) + { + OperatorFieldVariant activation = CreateActivation(activationDescs->data()->Get(index)); + activations.push_back(std::get(activation).value()); + } + return activations; +} + +OperatorFieldVariant CreateAttribute( + const DML_SCHEMA_FIELD* schemaField, + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc) +{ + switch (schemaField->Type) + { + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC: + { + return attributeDesc != nullptr && attributeDesc->val_as_Activation() != nullptr ? + CreateActivation(attributeDesc->val_as_Activation()) : + OperatorFieldTypes::FusedActivationOperatorDesc(); + } + case DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY: + { + return attributeDesc != nullptr && attributeDesc->val_as_ActivationArray() != nullptr ? + CreateActivations(attributeDesc->val_as_ActivationArray()) : + OperatorFieldTypes::FusedActivationOperatorDescArray(); + } + case DML_SCHEMA_FIELD_TYPE_UINT: + { + OperatorFieldTypes::UInt data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_UInt32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_UINT64: + { + OperatorFieldTypes::UInt64 data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_UInt64()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_INT: + { + OperatorFieldTypes::Int data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Int32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_FLOAT: + { + OperatorFieldTypes::Float data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Float32()->data(); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: + { + OperatorFieldTypes::UIntArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_UIntArray()->data()->begin(), attributeDesc->val_as_UIntArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_INT_ARRAY: + { + OperatorFieldTypes::IntArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_IntArray()->data()->begin(), attributeDesc->val_as_IntArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY: + { + OperatorFieldTypes::FloatArray data; + if (attributeDesc != nullptr) + { + data.assign(attributeDesc->val_as_FloatArray()->data()->begin(), attributeDesc->val_as_FloatArray()->data()->end()); + } + return data; + } + case DML_SCHEMA_FIELD_TYPE_SCALE_BIAS: + { + OperatorFieldTypes::ScaleBias scaleBias; + const dml::ir::operatorFieldTypes::ScaleBias* scaleBiasAttribute = attributeDesc->val_as_ScaleBias(); + if (scaleBiasAttribute != nullptr) + { + scaleBias = {scaleBiasAttribute->scale(), scaleBiasAttribute->bias()}; + } + return scaleBias; + } + case DML_SCHEMA_FIELD_TYPE_SIZE_2D: + { + OperatorFieldTypes::Size2D size2d = {}; + if (attributeDesc != nullptr) + { + size2d.Height = attributeDesc->val_as_Size2D()->height(); + size2d.Width = attributeDesc->val_as_Size2D()->width(); + } + return size2d; + } + case DML_SCHEMA_FIELD_TYPE_SCALAR_UNION: + { + DML_SCALAR_UNION scalarUnion; + if (attributeDesc != nullptr) + { + const dml::ir::operatorFieldTypes::ByteArray* byteArr = attributeDesc->val_as_ScalarUnionData()->data_as_ByteArray(); + std::copy(byteArr->data()->begin(), byteArr->data()->end(), scalarUnion.Bytes); + } + return scalarUnion; + } + case DML_SCHEMA_FIELD_TYPE_BOOL: + { + OperatorFieldTypes::Bool data; + if (attributeDesc != nullptr) + { + data = attributeDesc->val_as_Bool()->data(); + } + return data; + } + default: + { + throw std::invalid_argument("Invalid attribute type."); + } + } +} + +OperatorFieldTypes::TensorDesc CreateBufferTensorDesc( + const dml::ir::DmlBufferTensorDesc* tensorDesc, + const bool isConstantTensor = false) +{ + DmlBufferTensorDesc bufferTensorDesc = {}; + bufferTensorDesc.dataType = ApiTraits::StringifyHelpers::FromString(tensorDesc->dataType()->c_str()); + if (isConstantTensor) + { + bufferTensorDesc.flags = DML_TENSOR_FLAG_OWNED_BY_DML; + } + bufferTensorDesc.sizes.assign(tensorDesc->sizes()->begin(), tensorDesc->sizes()->end()); + if (flatbuffers::IsFieldPresent(tensorDesc, dml::ir::DmlBufferTensorDesc::VT_STRIDES)) + { + bufferTensorDesc.strides.emplace(tensorDesc->strides()->begin(), tensorDesc->strides()->end()); + } + bufferTensorDesc.totalTensorSizeInBytes = tensorDesc->totalTensorSizeInBytes(); + return bufferTensorDesc; +} + +AbstractOperatorDesc CreateAbstractOperatorDesc( + uint32_t nodeIndex, + const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeInputNames, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* nodeOutputNames, + const std::unordered_set& constantInputs) +{ + DML_OPERATOR_TYPE type = ApiTraits::StringifyHelpers::FromString(flatbufferOperatorNodeDesc->type()->c_str()); + if (type == DML_OPERATOR_INVALID) + { + throw std::invalid_argument("Graph operator node at index:" + std::to_string(nodeIndex) + + " either has empty or invalid operator type."); + } + const DML_OPERATOR_SCHEMA& schema = SchemaHelpers::GetSchema(type); + std::vector operatorFields(schema.FieldCount); + + auto inputNameItr = nodeInputNames->begin(); + uint32_t inputTensorDescIndex = 0; + + uint32_t outputTensorDescIndex = 0; + auto outputNameItr = nodeOutputNames->begin(); + + uint32_t attributeIndex = 0; + + + for (uint32_t fieldIndex = 0; fieldIndex < schema.FieldCount; fieldIndex++) + { + const DML_SCHEMA_FIELD* schemaField = &schema.Fields[fieldIndex]; + + OperatorFieldVariant field; + switch (schemaField->Kind) + { + case DML_SCHEMA_FIELD_KIND_INPUT_TENSOR: + { + if (inputNameItr == nodeInputNames->end()) + { + throw std::invalid_argument("Missing input names for node at index:" + std::to_string(nodeIndex)); + } + + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + const flatbuffers::String* inputName = *inputNameItr; + inputNameItr++; + if (inputName->size() == 0) + { + field = OperatorFieldTypes::TensorDesc(); + break; + } + bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end(); + + if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + + "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++); + field = CreateBufferTensorDesc(tensorDesc, isConstantTensor); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + std::vector tensors; + while (inputTensorDescIndex < static_cast(flatbufferOperatorNodeDesc->inputs()->size())) + { + const flatbuffers::String* inputName = *inputNameItr; + inputNameItr++; + bool isConstantTensor = !constantInputs.empty() && constantInputs.find(inputName->c_str()) != constantInputs.end(); + + if (flatbufferOperatorNodeDesc->inputs()->size() <= inputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(inputTensorDescIndex + 1) + + "input tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->inputs()->Get(inputTensorDescIndex++); + tensors.push_back(CreateBufferTensorDesc(tensorDesc, isConstantTensor).value()); + } + field = tensors; + } + break; + } + case DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR: + { + if (outputNameItr == nodeOutputNames->end()) + { + throw std::invalid_argument("Missing output names for node at index:" + std::to_string(nodeIndex)); + } + + if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC) + { + const flatbuffers::String* outputName = *outputNameItr; + outputNameItr++; + + if (outputName->size() == 0) + { + field = OperatorFieldTypes::TensorDesc(); + break; + } + + if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + + "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++); + field = CreateBufferTensorDesc(tensorDesc); + } + else if (schemaField->Type == DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY) + { + std::vector tensors; + while (outputTensorDescIndex < static_cast(flatbufferOperatorNodeDesc->outputs()->size())) + { + if (flatbufferOperatorNodeDesc->outputs()->size() <= outputTensorDescIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(outputTensorDescIndex + 1) + + "output tensor desc for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::DmlBufferTensorDesc* tensorDesc = flatbufferOperatorNodeDesc->outputs()->Get(outputTensorDescIndex++); + tensors.push_back(CreateBufferTensorDesc(tensorDesc).value()); + } + field = tensors; + } + break; + } + case DML_SCHEMA_FIELD_KIND_ATTRIBUTE: + { + if (flatbufferOperatorNodeDesc->attributes()->size() <= attributeIndex) + { + throw std::invalid_argument("Expecting at least " + std::to_string(attributeIndex + 1) + + "attributes for graph operator node at index:" + std::to_string(nodeIndex)); + } + const dml::ir::operatorFieldTypes::AttributeDesc* attributeDesc = + attributeIndex >= flatbufferOperatorNodeDesc->attributes()->size() ? + nullptr : + flatbufferOperatorNodeDesc->attributes()->Get(attributeIndex++); + field = CreateAttribute(schemaField, attributeDesc); + break; + } + } + + operatorFields[fieldIndex] = OperatorField(schemaField, std::move(field)); + } + + return AbstractOperatorDesc(&schema, std::move(operatorFields)); +} + +std::unordered_map ConvertToEdgeNameToIndexMap( + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* list) +{ + std::unordered_map nameToIndexMap; + for (uint32_t index = 0; index < list->size(); index++) + { + const flatbuffers::String* name = list->GetAsString(index); + if (name->size() == 0) + { + continue; + } + nameToIndexMap[name->string_view()] = index; + } + return nameToIndexMap; // NRVO will automatically move it. no need to use std::move +} + +template void PopulateEdges( + const uint32_t nodeIndex, + const ::flatbuffers::Vector<::flatbuffers::Offset<::flatbuffers::String>>* edgeNames, + const std::unordered_map& edgeNameToIndexMap, + /*out*/ std::vector& edges, + /*out*/ std::vector& intermediateEdges, + /*out*/ std::unordered_map& edgeToOutgoingNodeIndexMap) +{ + for (flatbuffers::uoffset_t edgeIndex = 0; edgeIndex < edgeNames->size(); edgeIndex++) + { + const flatbuffers::String* edgeName = edgeNames->Get(edgeIndex); + if (edgeName->size() == 0) + { + // This must be optional input/output + continue; + } + // edge can be graphInput or graphOutput + if (edgeNameToIndexMap.find(edgeName->string_view()) != edgeNameToIndexMap.end()) + { + EdgeType edge = {}; + edge.Name = edgeName->str(); + + if constexpr (std::is_same_v) + { + edge.GraphInputIndex = edgeNameToIndexMap.at(edgeName->string_view()); + edge.ToNodeIndex = nodeIndex; + edge.ToNodeInputIndex = edgeIndex; + } + else if constexpr (std::is_same_v) + { + edge.GraphOutputIndex = edgeNameToIndexMap.at(edgeName->string_view()); + edge.FromNodeIndex = nodeIndex; + edge.FromNodeOutputIndex = edgeIndex; + edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex}; + } + + edges.push_back(edge); + } + // edge is intermediate edge + else + { + if constexpr (std::is_same_v) + { + if (edgeToOutgoingNodeIndexMap.find(edgeName->string_view()) == edgeToOutgoingNodeIndexMap.end()) + { + throw std::range_error("Neither there is any graph input with name " + edgeName->str() + + "nor there is any node which has " + edgeName->str() + " as one of the output."); + } + auto& intermediateEdgeNodeIndex = edgeToOutgoingNodeIndexMap[edgeName->string_view()]; + DmlIntermediateSerializedGraphEdge intermediateEdge = {}; + intermediateEdge.Name = edgeName->str(); + intermediateEdge.FromNodeIndex = intermediateEdgeNodeIndex.nodeIndex; + intermediateEdge.FromNodeOutputIndex = intermediateEdgeNodeIndex.nodeOutputIndex; + intermediateEdge.ToNodeIndex = nodeIndex; + intermediateEdge.ToNodeInputIndex = edgeIndex; + intermediateEdges.push_back(std::move(intermediateEdge)); + } + else if constexpr (std::is_same_v) + { + edgeToOutgoingNodeIndexMap[edgeName->string_view()] = {nodeIndex, edgeIndex}; + } + } + } +} + +/* +* - Handling of empty optional input/output/attibute for non-constant node: +* input/output +* - and will have an null entry +* but the actual OperatorNodeDesc variant's +* and will not have any entry. +* attribute +* - will have null entry +*/ +DmlSerializedGraphDesc DeserializeDmlGraph( + const uint8_t* flatbufferGraphDescBlob, + /*out*/ std::vector>& rawData) +{ + if (flatbufferGraphDescBlob == nullptr) + { + throw std::invalid_argument("Given pointer to flatbuffer blob is null"); + } + const dml::ir::DmlGraphDesc* flatbufferGraphDesc = dml::ir::GetDmlGraphDesc(flatbufferGraphDescBlob); + + std::unordered_map graphInputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphInputNames()); + std::unordered_map graphOutputEdgeToIndexMap = ConvertToEdgeNameToIndexMap(flatbufferGraphDesc->graphOutputNames()); + + std::unordered_map edgeToOutgoingNodeIndexMap; + std::unordered_set constantInputs; + + std::vector nodes(flatbufferGraphDesc->nodes()->size()); + std::vector inputEdges; + std::vector outputEdges; + std::vector intermediateEdges; + + for (uint32_t nodeIndex = 0; nodeIndex < flatbufferGraphDesc->nodes()->size(); nodeIndex++) + { + const dml::ir::DmlGraphNode* flatbufferNode = flatbufferGraphDesc->nodes()->Get(nodeIndex); + + PopulateEdges( + nodeIndex, + flatbufferNode->inputNames(), + graphInputEdgeToIndexMap, + inputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + PopulateEdges( + nodeIndex, + flatbufferNode->outputNames(), + graphOutputEdgeToIndexMap, + outputEdges, + intermediateEdges, + edgeToOutgoingNodeIndexMap); + + DmlSerializedGraphNode node = {}; + if (flatbufferNode->name()->size() == 0) + { + throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + " doesn't have any name"); + } + node.Name = flatbufferNode->name()->c_str(); + + if (flatbufferNode->desc_type() == dml::ir::NodeDesc_ConstantNodeDesc) + { + const dml::ir::ConstantNodeDesc* flatbufferConstantNode = flatbufferNode->desc_as_ConstantNodeDesc(); + if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantName) + { + if (flatbufferConstantNode->data_as_ConstantName()->name()->size() == 0) + { + throw std::invalid_argument("Constant node at index:" + std::to_string(nodeIndex) + + " doesn't have constant data name."); + } + + ConstantName constantNode = {flatbufferConstantNode->data_as_ConstantName()->name()->c_str()}; + node.Desc = constantNode; + // output of this node will part of constantInputs list + for (uint32_t outputIndex = 0; outputIndex < flatbufferNode->outputNames()->size(); outputIndex++) + { + constantInputs.insert(flatbufferNode->outputNames()->Get(outputIndex)->c_str()); + } + } + else if (flatbufferConstantNode->data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData) + { + + uint32_t rawDataSize = flatbufferConstantNode->data_as_ConstantRawData()->data()->size(); + rawData.push_back(std::make_unique(rawDataSize)); + std::transform( + flatbufferConstantNode->data_as_ConstantRawData()->data()->begin(), + flatbufferConstantNode->data_as_ConstantRawData()->data()->end(), + rawData.back().get(), + [](uint8_t b) {return static_cast(b);}); + + ConstantData constantData = {}; + constantData.dataSize = rawDataSize; + constantData.data = rawData.back().get(); + node.Desc = constantData; + } + + + } + else if (flatbufferNode->desc_type() == dml::ir::NodeDesc::NodeDesc_OperatorNodeDesc) + { + // convert dml::ir::OperatorNodeDesc to AbstractOperatorDesc + const dml::ir::OperatorNodeDesc* flatbufferOperatorNodeDesc = flatbufferNode->desc_as_OperatorNodeDesc(); + node.Desc = CreateAbstractOperatorDesc( + nodeIndex, + flatbufferOperatorNodeDesc, + flatbufferNode->inputNames(), + flatbufferNode->outputNames(), + constantInputs); + } + + nodes[nodeIndex] = node; + } + + DmlSerializedGraphDesc graphDesc; + graphDesc.InputCount = flatbufferGraphDesc->graphInputNames()->size(); + graphDesc.OutputCount = flatbufferGraphDesc->graphOutputNames()->size(); + graphDesc.InputEdges = std::move(inputEdges); + graphDesc.IntermediateEdges = std::move(intermediateEdges); + graphDesc.OutputEdges = std::move(outputEdges); + graphDesc.Nodes = std::move(nodes); + return graphDesc; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 642d9aa03eeef..202b762d99e01 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -135,8 +135,10 @@ namespace DmlGraphFusionHelper void ProcessInputData( const ExecutionProviderImpl* providerImpl, + const bool graphSerializationEnabled, const std::vector& isInputsUploadedByDmlEP, - const std::vector& inputEdges, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, const gsl::span subGraphInputArgNames, const std::unordered_map>& initializerNameToInitializerMap, onnxruntime::Graph& graph, @@ -162,8 +164,17 @@ namespace DmlGraphFusionHelper // Walk through each graph edge and mark used inputs inputsUsed.assign(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : inputEdges) { - inputsUsed[edge.GraphInputIndex] = true; + for (auto it = serializedGraphInputIndexToSubgraphInputIndex->begin(); it != serializedGraphInputIndexToSubgraphInputIndex->end(); it++) { + inputsUsed[it->second] = true; + } + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex->begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex->end(); it++) { + inputsUsed[it->second] = true; + } + + std::wstring modelName; + if (graphSerializationEnabled) + { + modelName = GetModelName(graph.ModelPath()); } for (uint32_t i = 0; i < initInputBindings.size(); i++) @@ -209,6 +220,10 @@ namespace DmlGraphFusionHelper // Tensor sizes in DML must be a multiple of 4 bytes large. tensorByteSize = AlignToPow2(tensorByteSize, 4); + if(graphSerializationEnabled) + { + WriteToFile(modelName, ConvertToWString(iter->first) + L".bin", reinterpret_cast(tensorPtr), tensorByteSize); + } if (inputRawData) { @@ -287,55 +302,158 @@ namespace DmlGraphFusionHelper return initializerPartitionMap; } + inline uint32_t GetConstantNodeGraphInputIndex( + const std::string& constantName, + const std::unordered_map* serializedGraphConstantNameToMainGraphInputIndex, + uint32_t& graphMaxInputIndex, + std::unordered_map& localConstantNameToIndexMap) + { + if (serializedGraphConstantNameToMainGraphInputIndex == nullptr) + { + if (localConstantNameToIndexMap.find(constantName) == localConstantNameToIndexMap.end()) + { + localConstantNameToIndexMap[constantName] = ++graphMaxInputIndex; + } + return localConstantNameToIndexMap[constantName]; + } + else + { + graphMaxInputIndex = std::max(graphMaxInputIndex, serializedGraphConstantNameToMainGraphInputIndex->at(constantName)); + return serializedGraphConstantNameToMainGraphInputIndex->at(constantName); + } + } + + template void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, - _Inout_ std::vector& dmlConstantGraphNodes, + IDMLDevice* device, + StackAllocator& allocator, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + _Inout_ std::vector>& dmlOperators, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, _Inout_ std::vector& dmlIntermediateEdges) { - for (size_t i = 0; i < graphDesc.nodes.size(); ++i) + std::unordered_map oldNodeIndexToNewNodeIndexMap; + for (uint32_t index = 0; index < static_cast(graphDesc.Nodes.size()); index++) { - auto& nodeInfo = graphDesc.nodes[i]; - - if (std::holds_alternative>(nodeInfo.nodeDef)) + const DmlSerializedGraphNode& node = graphDesc.Nodes[index]; + if (std::holds_alternative(node.Desc)) { - dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{std::get>(nodeInfo.nodeDef).Get(), nodeInfo.name.data()}; - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; + oldNodeIndexToNewNodeIndexMap[index] = static_cast(dmlGraphNodes.size()); + DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(std::get(node.Desc), &allocator); + ComPtr op; + ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); + dmlOperators.push_back(op); + DML_OPERATOR_GRAPH_NODE_DESC* dmlOperatorGraphNode = allocator.template Allocate(); + dmlOperatorGraphNode->Name = node.Name.data(); + dmlOperatorGraphNode->Operator = op.Get(); + dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, dmlOperatorGraphNode}); } else { - auto& nodeDefinitionData = std::get>(nodeInfo.nodeDef); - dmlConstantGraphNodes[i] = DML_CONSTANT_DATA_GRAPH_NODE_DESC{ - nodeDefinitionData.data(), - nodeDefinitionData.size(), - nodeInfo.name.data() - }; - - // TODO: Change as new header is ingested - dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{static_cast(2), &dmlConstantGraphNodes[i]}; + auto& constantNodeVariant = std::get(node.Desc); + if (std::holds_alternative(constantNodeVariant)) + { + oldNodeIndexToNewNodeIndexMap[index] = static_cast(dmlGraphNodes.size()); + + auto& constantData = std::get(constantNodeVariant); + + DML_CONSTANT_DATA_GRAPH_NODE_DESC* constantNode = allocator.template Allocate(); + constantNode->Name = node.Name.data(); + constantNode->DataSize = constantData.dataSize; + constantNode->Data = constantData.data; + dmlGraphNodes.push_back(DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_CONSTANT, constantNode}); + } } } - for (size_t i = 0; i < graphDesc.inputEdges.size(); ++i) + uint32_t graphMaxInputIndex = 0; + + for (size_t i = 0; i < graphDesc.InputEdges.size(); ++i) { - dmlInputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, &graphDesc.inputEdges[i]}; + DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + // 1. If serializedGraphInputIndexToMainGraphInputIndex is not null: + // then use the corresponding main graph input index, because the caller will use corresponding + // main graph input index for extracting the actual input tensor from the main graph and + // the caller does not own the creation of dml bindings directly. + // Use Case: When the caller is ORT (DML EP) or DmlEngine. + // + // 2. If serializedGraphInputIndexToMainGraphInputIndex is null: + // then assign the sequential graph input index, because it owns the creation of dml bindings + // directly. + edge->GraphInputIndex = serializedGraphInputIndexToSubgraphInputIndex == nullptr ? + graphDesc.InputEdges[i].GraphInputIndex : + serializedGraphInputIndexToSubgraphInputIndex->at(graphDesc.InputEdges[i].GraphInputIndex); + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.InputEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.InputEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.InputEdges[i].Name.data(); + + graphMaxInputIndex = std::max(graphMaxInputIndex, edge->GraphInputIndex); + dmlInputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INPUT, edge}); } - for (size_t i = 0; i < graphDesc.outputEdges.size(); ++i) + for (size_t i = 0; i < graphDesc.OutputEdges.size(); ++i) { - dmlOutputEdges[i] = DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, &graphDesc.outputEdges[i]}; + DML_OUTPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->GraphOutputIndex = graphDesc.OutputEdges[i].GraphOutputIndex; + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.OutputEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.OutputEdges[i].FromNodeOutputIndex; + edge->Name = graphDesc.OutputEdges[i].Name.data(); + + dmlOutputEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_OUTPUT, edge}); } - for (size_t i = 0; i < graphDesc.intermediateEdges.size(); ++i) + std::unordered_map localConstantNameToIndexMap; + for (uint32_t i = 0; i < static_cast(graphDesc.IntermediateEdges.size()); ++i) { - dmlIntermediateEdges[i] = - DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, &graphDesc.intermediateEdges[i]}; + DmlSerializedGraphNodeDescVariant descVariant = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Desc; + bool isConstantEdge = std::holds_alternative(descVariant); + if (isConstantEdge) + { + auto& constantNodeVariant = std::get(descVariant); + if (std::holds_alternative(constantNodeVariant)) + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex; + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge}); + } + else + { + const std::string& constantName = graphDesc.Nodes[graphDesc.IntermediateEdges[i].FromNodeIndex].Name; + + DML_INPUT_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->GraphInputIndex = GetConstantNodeGraphInputIndex( + constantName, + serializedGraphLargeConstantNameToSubgraphInputIndex, + graphMaxInputIndex, + localConstantNameToIndexMap); + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + + dmlInputEdges.push_back({DML_GRAPH_EDGE_TYPE_INPUT, edge}); + } + } + else + { + DML_INTERMEDIATE_GRAPH_EDGE_DESC* edge = allocator.template Allocate(); + edge->FromNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].FromNodeIndex]; + edge->FromNodeOutputIndex = graphDesc.IntermediateEdges[i].FromNodeOutputIndex; + edge->ToNodeIndex = oldNodeIndexToNewNodeIndexMap[graphDesc.IntermediateEdges[i].ToNodeIndex]; + edge->ToNodeInputIndex = graphDesc.IntermediateEdges[i].ToNodeInputIndex; + edge->Name = graphDesc.IntermediateEdges[i].Name.data(); + dmlIntermediateEdges.push_back(DML_GRAPH_EDGE_DESC{DML_GRAPH_EDGE_TYPE_INTERMEDIATE, edge}); + } } dmlGraphDesc.InputCount = inputCount; @@ -400,27 +518,34 @@ namespace DmlGraphFusionHelper Microsoft::WRL::ComPtr TryCreateCompiledOperator( const GraphDescBuilder::GraphDesc& graphDesc, const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl) + const ExecutionProviderImpl* providerImpl, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { const uint32_t fusedNodeInputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->inputs.size()); const uint32_t fusedNodeOutputCount = gsl::narrow_cast(indexedSubGraph.GetMetaDef()->outputs.size()); // convert DML EP GraphDesc into DML_GRAPH_DESC and create IDMLCompiledOperator - DML_GRAPH_DESC dmlGraphDesc = {}; - std::vector dmlOperatorGraphNodes(graphDesc.nodes.size()); - std::vector dmlConstantGraphNodes(graphDesc.nodes.size()); + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); - std::vector dmlGraphNodes(graphDesc.nodes.size()); - std::vector dmlInputEdges(graphDesc.inputEdges.size()); - std::vector dmlOutputEdges(graphDesc.outputEdges.size()); - std::vector dmlIntermediateEdges(graphDesc.intermediateEdges.size()); + StackAllocator<1024> allocator; + DML_GRAPH_DESC dmlGraphDesc = {}; + std::vector> dmlOperators; + std::vector dmlGraphNodes; + std::vector dmlInputEdges; + std::vector dmlOutputEdges; + std::vector dmlIntermediateEdges; ConvertGraphDesc( graphDesc, - dmlGraphDesc, fusedNodeInputCount, fusedNodeOutputCount, - dmlOperatorGraphNodes, - dmlConstantGraphNodes, + device.Get(), + allocator, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + dmlGraphDesc, + dmlOperators, dmlGraphNodes, dmlInputEdges, dmlOutputEdges, @@ -438,8 +563,6 @@ namespace DmlGraphFusionHelper executionFlags |= DML_EXECUTION_FLAG_DISABLE_META_COMMANDS; } - ComPtr device; - ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); ComPtr device1; ORT_THROW_IF_FAILED(device.As(&device1)); @@ -460,6 +583,7 @@ namespace DmlGraphFusionHelper } void FusePartitionAndRegisterKernel( + const uint32_t partitionIndex, onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::unordered_map>& initializerNameToInitializerMap, @@ -467,8 +591,43 @@ namespace DmlGraphFusionHelper const onnxruntime::IndexedSubGraph& indexedSubGraph, std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, - Microsoft::WRL::ComPtr compiledExecutionPlanOperator) + Microsoft::WRL::ComPtr compiledExecutionPlanOperator, + const bool graphSerializationEnabled, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex) { + if (graphSerializationEnabled) + { + + const std::wstring modelName = GetModelName(graph.ModelPath()); + auto buffer = SerializeDmlGraph(graphDesc); + + const std::wstring partitionName = + L"Partition_" + + std::to_wstring(partitionIndex) + + L".bin"; + WriteToFile(modelName, partitionName, buffer.data(), buffer.size()); + + std::vector> rawData; + DmlSerializedGraphDesc deserializedGraphDesc = DeserializeDmlGraph(buffer.data(), rawData); + GraphDescBuilder::GraphDesc deserializedDmlGraphDesc = {}; + deserializedDmlGraphDesc.InputCount = deserializedGraphDesc.InputCount; + deserializedDmlGraphDesc.InputEdges = std::move(deserializedGraphDesc.InputEdges); + deserializedDmlGraphDesc.IntermediateEdges = std::move(deserializedGraphDesc.IntermediateEdges); + deserializedDmlGraphDesc.Nodes = std::move(deserializedGraphDesc.Nodes); + deserializedDmlGraphDesc.OutputCount = deserializedGraphDesc.OutputCount; + deserializedDmlGraphDesc.OutputEdges = std::move(deserializedGraphDesc.OutputEdges); + deserializedDmlGraphDesc.reuseCommandList = graphDesc.reuseCommandList; + deserializedDmlGraphDesc.outputShapes = graphDesc.outputShapes; + + compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + deserializedDmlGraphDesc, + indexedSubGraph, + providerImpl, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex); + } + auto& fusedNode = graph.BeginFuseSubGraph(indexedSubGraph, indexedSubGraph.GetMetaDef()->name); fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); @@ -482,8 +641,10 @@ namespace DmlGraphFusionHelper std::vector inputsUsed; ProcessInputData( providerImpl, + graphSerializationEnabled, isInputsUploadedByDmlEP, - graphDesc.inputEdges, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, indexedSubGraph.GetMetaDef()->inputs, initializerNameToInitializerMap, graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index f8f6162aaa1e0..f1e9654021196 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -45,12 +45,17 @@ namespace DmlGraphFusionHelper gsl::span> partitions ); + template void ConvertGraphDesc( const Dml::GraphDescBuilder::GraphDesc& graphDesc, - _Out_ DML_GRAPH_DESC& dmlGraphDesc, const uint32_t inputCount, const uint32_t outputCount, - _Inout_ std::vector& dmlOperatorGraphNodes, + IDMLDevice* device, + StackAllocator& allocator, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex, + _Out_ DML_GRAPH_DESC& dmlGraphDesc, + _Inout_ std::vector>& dmlOperators, _Inout_ std::vector& dmlGraphNodes, _Inout_ std::vector& dmlInputEdges, _Inout_ std::vector& dmlOutputEdges, @@ -69,9 +74,12 @@ namespace DmlGraphFusionHelper Microsoft::WRL::ComPtr TryCreateCompiledOperator( const GraphDescBuilder::GraphDesc& graphDesc, const onnxruntime::IndexedSubGraph& indexedSubGraph, - const ExecutionProviderImpl* providerImpl); + const ExecutionProviderImpl* providerImpl, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex); void FusePartitionAndRegisterKernel( + const uint32_t partitionIndex, onnxruntime::Graph& graph, onnxruntime::KernelRegistry* registryForPartitionKernels, const std::unordered_map>& initializerNameToInitializerMap, @@ -79,7 +87,10 @@ namespace DmlGraphFusionHelper const onnxruntime::IndexedSubGraph& indexedSubGraph, std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, - Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + Microsoft::WRL::ComPtr compiledExecutionPlanOperator, + const bool graphSerializationEnabled, + const std::unordered_map* serializedGraphInputIndexToSubgraphInputIndex = nullptr, + const std::unordered_map* serializedGraphLargeConstantNameToSubgraphInputIndex = nullptr); void RegisterDynamicKernel( onnxruntime::Graph& graph, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 679738b639ec9..35a2c451a49a5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -24,15 +24,20 @@ namespace Dml std::vector isInputsUploadedByDmlEP; GraphDescBuilder::GraphDesc graphDesc; std::unordered_map> isInitializerTransferable; + std::vector> smallConstantData; // Need to keep it alive for maintaining lifetime + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; }; } DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, - const onnxruntime::IExecutionProvider* provider + const onnxruntime::IExecutionProvider* provider, + const bool graphSerializationEnabled ) :onnxruntime::GraphTransformer(name), - m_providerImpl(static_cast(provider)->GetImpl()) + m_providerImpl(static_cast(provider)->GetImpl()), + graphSerializationEnabled(graphSerializationEnabled) { } @@ -227,23 +232,39 @@ namespace Dml ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); + // This map will be used to transfer the initializer to D3D12 system heap memory. + // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why + // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition) + // input arg index. + // For ex: Let's say intermediate edge index = idx, then + // indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx]; + // corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]] + // We are using intermediate edge index as a key because same constant tensor can be used by + // multiple nodes. + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; + std::vector> smallConstantData; GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, partitionNodePropsMap, - device.Get(), m_providerImpl, modelPath, subgraphNodes, subgraphInputs, - subgraphOutputs); + subgraphOutputs, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + smallConstantData); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, indexedSubGraph, - m_providerImpl); + m_providerImpl, + &serializedGraphInputIndexToSubgraphInputIndex, + &serializedGraphLargeConstantNameToSubgraphInputIndex); if (!compiledPartition) { @@ -264,6 +285,9 @@ namespace Dml compiledPartitionInfo->isInputsUploadedByDmlEP = std::move(isInputsUploadedByDmlEP); compiledPartitionInfo->graphDesc = std::move(graphDesc); compiledPartitionInfo->isInitializerTransferable = std::move(isInitializerTransferable); + compiledPartitionInfo->smallConstantData = std::move(smallConstantData); + compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex = std::move(serializedGraphInputIndexToSubgraphInputIndex); + compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex = std::move(serializedGraphLargeConstantNameToSubgraphInputIndex); compiledPartitionInfos[partitionIndex] = std::move(compiledPartitionInfo); } } @@ -271,12 +295,14 @@ namespace Dml } while (!additionalSplittingNodes.empty()); + uint32_t partitionIndex = 0; for (auto&& compiledPartitionInfo : compiledPartitionInfos) { // Null compiled operators were not DML partitions if (compiledPartitionInfo) { DmlGraphFusionHelper::FusePartitionAndRegisterKernel( + partitionIndex++, graph, m_providerImpl->GetKernelRegistry().get(), compiledPartitionInfo->isInitializerTransferable, @@ -284,7 +310,10 @@ namespace Dml compiledPartitionInfo->indexedSubGraph, std::move(compiledPartitionInfo->isInputsUploadedByDmlEP), compiledPartitionInfo->graphDesc, - compiledPartitionInfo->compiledOperator); + compiledPartitionInfo->compiledOperator, + graphSerializationEnabled, + &compiledPartitionInfo->serializedGraphInputIndexToSubgraphInputIndex, + &compiledPartitionInfo->serializedGraphLargeConstantNameToSubgraphInputIndex); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h index 19dab0c89943c..b370f3ef9043c 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h @@ -16,7 +16,8 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer public: DmlGraphFusionTransformer( const std::string& name, - const onnxruntime::IExecutionProvider* provider + const onnxruntime::IExecutionProvider* provider, + const bool graphSerializationEnabled ); public: @@ -38,5 +39,6 @@ class DmlGraphFusionTransformer : public onnxruntime::GraphTransformer private: const ExecutionProviderImpl* m_providerImpl = nullptr; + const bool graphSerializationEnabled = false; }; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp new file mode 100644 index 0000000000000..5355964e8db74 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphSerialization.cpp @@ -0,0 +1,580 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "precomp.h" + +template +T* ReadAs(uint8_t* base, size_t byteOffset) +{ + return reinterpret_cast(base + byteOffset); +} + +void SerializeAttributeDescs( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc, + /*out*/ std::vector>& attributeDescs); + +flatbuffers::Offset serializeActivation( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& activationOperatorDesc) +{ + std::vector> attributeDescs; + SerializeAttributeDescs(builder, activationOperatorDesc, attributeDescs); + + flatbuffers::Offset offset = dml::ir::operatorFieldTypes::CreateActivationDirect( + builder, + activationOperatorDesc.schema->OperatorName, + &attributeDescs); + return offset; +} + +void SerializeAttributeDescs( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc, + /*out*/ std::vector>& attributeDescs) +{ + for (const OperatorField& field : operatorDesc.fields) + { + if (field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_INPUT_TENSOR || + field.GetSchema()->Kind == DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR) + { + continue; + } + + flatbuffers::Offset offset; + + if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::FusedActivationOperatorDesc& fusedActivation = field.AsFusedActivationOperatorDesc(); + if (!fusedActivation.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation); + } + else + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation, + serializeActivation(builder, fusedActivation.value()).Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::FusedActivationOperatorDescArray& fusedActivations = + field.AsFusedActivationOperatorDescArray(); + if (!fusedActivations.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray); + } + else + { + std::vector> fbActivations; + + for (AbstractOperatorDesc activationOpDesc : fusedActivations.value()) + { + flatbuffers::Offset fbActivation = + serializeActivation(builder, activationOpDesc); + fbActivations.push_back(fbActivation); + } + + flatbuffers::Offset activationOffset = + dml::ir::operatorFieldTypes::CreateActivationArrayDirect(builder, &fbActivations); + + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray, + activationOffset.Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32, + builder.CreateStruct(dml::ir::operatorFieldTypes::UInt32(field.AsUInt())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64, + builder.CreateStruct(dml::ir::operatorFieldTypes::UInt64(field.AsUInt64())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32, + builder.CreateStruct(dml::ir::operatorFieldTypes::Int32(field.AsInt())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32, + builder.CreateStruct(dml::ir::operatorFieldTypes::Float32(field.AsFloat())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray, + dml::ir::operatorFieldTypes::CreateUIntArray(builder, builder.CreateVector(field.AsUIntArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray, + dml::ir::operatorFieldTypes::CreateIntArray(builder, builder.CreateVector(field.AsIntArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray, + dml::ir::operatorFieldTypes::CreateFloatArray(builder, builder.CreateVector(field.AsFloatArray())).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + const OperatorFieldTypes::ScaleBias& scaleBias = field.AsScaleBias(); + if (!scaleBias.has_value()) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias); + } + else + { + dml::ir::operatorFieldTypes::ScaleBias fbScaleBias(scaleBias.value().Scale, scaleBias.value().Bias); + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias, + builder.CreateStruct(fbScaleBias).Union()); + } + } + else if (std::holds_alternative(field.GetData())) + { + const DML_SIZE_2D size2d = field.AsSize2D(); + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D, + builder.CreateStruct(dml::ir::operatorFieldTypes::Size2D(size2d.Width, size2d.Height)).Union()); + } + else if (std::holds_alternative(field.GetData())) + { + OperatorFieldTypes::ScalarUnion scalarUnion = field.AsScalarUnion(); + dml::ir::operatorFieldTypes::ByteArray byteArr; + for (uint32_t index = 0; index < static_cast(sizeof(scalarUnion.Bytes)); index++) + { + byteArr.mutable_data()->Mutate(index, scalarUnion.Bytes[index]); + } + + flatbuffers::Offset scalarUnionOffset = + dml::ir::operatorFieldTypes::CreateScalarUnionData( + builder, + dml::ir::operatorFieldTypes::ScalarVariant_ByteArray, + builder.CreateStruct(byteArr).Union()); + + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData, + scalarUnionOffset.Union()); + } + else if (std::holds_alternative(field.GetData())) + { + offset = dml::ir::operatorFieldTypes::CreateAttributeDescDirect( + builder, + field.GetSchema()->Name, + dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool, + builder.CreateStruct(dml::ir::operatorFieldTypes::Bool(field.AsBool())).Union()); + } + else + { + continue; + } + + attributeDescs.push_back(offset); + } +} + +flatbuffers::Offset SerializeDmlTensorDesc( + flatbuffers::FlatBufferBuilder& builder, + const DmlBufferTensorDesc* tensorDesc) +{ + const std::vector *strides = nullptr; + if (tensorDesc->strides.has_value()) + { + strides = &tensorDesc->strides.value(); + } + + flatbuffers::Offset offset = dml::ir::CreateDmlBufferTensorDescDirect( + builder, + ApiTraits::StringifyHelpers::ToString(tensorDesc->dataType), + &tensorDesc->sizes, + strides, + tensorDesc->totalTensorSizeInBytes); + return offset; +} + +flatbuffers::Offset SerializeOperatorNodeDesc( + flatbuffers::FlatBufferBuilder& builder, + const AbstractOperatorDesc& operatorDesc) +{ + const DML_OPERATOR_SCHEMA* operatorSchema = operatorDesc.schema; + + std::vector> inputTensorDescs; + std::vector> outputTensorDescs; + + for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetInputTensors()) + { + if (tensorDesc == nullptr) + { + continue; + } + flatbuffers::Offset serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc); + inputTensorDescs.push_back(serializedDmlTensorDesc); + } + + for (const DmlBufferTensorDesc* tensorDesc : operatorDesc.GetOutputTensors()) + { + if (tensorDesc == nullptr) + { + continue; + } + flatbuffers::Offset serializedDmlTensorDesc = SerializeDmlTensorDesc(builder, tensorDesc); + outputTensorDescs.push_back(serializedDmlTensorDesc); + } + + std::vector> attributeDescs; + SerializeAttributeDescs(builder, operatorDesc, attributeDescs); + + flatbuffers::Offset offset = dml::ir::CreateOperatorNodeDesc( + builder, + builder.CreateString(operatorSchema->OperatorName), + builder.CreateVector(inputTensorDescs), + builder.CreateVector(outputTensorDescs), + builder.CreateVector(attributeDescs)); + return offset.Union(); +} + +flatbuffers::Offset SerializeConstantNodeDesc( + flatbuffers::FlatBufferBuilder& builder, + uint32_t nodeIndex, + const DmlSerializedGraphNodeConstantVariant& constantNodeDesc) +{ + flatbuffers::Offset offset; + + if (std::holds_alternative(constantNodeDesc)) + { + auto& constantName = std::get(constantNodeDesc); + if (constantName.name.empty()) + { + throw std::invalid_argument("Graph constant node at index:" + std::to_string(nodeIndex) + + " doesn't have the constant data name."); + } + + flatbuffers::Offset constantNameOffset = dml::ir::CreateConstantName( + builder, + builder.CreateString(constantName.name)); + + offset = dml::ir::CreateConstantNodeDesc( + builder, + dml::ir::ConstantNodeDescDetail_ConstantName, + constantNameOffset.Union()); + } + else + { + auto& constantData = std::get(constantNodeDesc); + std::vector rawBytes; + std::transform(constantData.data, constantData.data + constantData.dataSize, + std::back_inserter(rawBytes), [](std::byte b) {return static_cast(b); }); + flatbuffers::Offset constantDataOffset = dml::ir::CreateConstantRawDataDirect( + builder, + &rawBytes); + + offset = dml::ir::CreateConstantNodeDesc( + builder, + dml::ir::ConstantNodeDescDetail_ConstantRawData, + constantDataOffset.Union()); + } + + return offset.Union(); +} + +flatbuffers::Offset SerializeNode( + flatbuffers::FlatBufferBuilder& builder, + const uint32_t nodeIndex, + const DmlSerializedGraphNode& graphNode, + const std::vector>& nodeInputNames, + const std::vector>& nodeOutputNames) +{ + if (graphNode.Name.empty()) + { + throw std::invalid_argument("Graph node at index:" + std::to_string(nodeIndex) + + " does not have any name."); + } + + flatbuffers::Offset offset; + if (std::holds_alternative(graphNode.Desc)) + { + auto& operatorNode = std::get(graphNode.Desc); + offset = dml::ir::CreateDmlGraphNode( + builder, + dml::ir::NodeDesc_OperatorNodeDesc, + SerializeOperatorNodeDesc(builder, operatorNode), + builder.CreateString(graphNode.Name), + builder.CreateVector(nodeInputNames), + builder.CreateVector(nodeOutputNames)); + } + else + { + auto& constantNodeVariant = std::get(graphNode.Desc); + offset = dml::ir::CreateDmlGraphNode( + builder, + dml::ir::NodeDesc_ConstantNodeDesc, + SerializeConstantNodeDesc(builder, nodeIndex, constantNodeVariant), + builder.CreateString(graphNode.Name), + builder.CreateVector(nodeInputNames), + builder.CreateVector(nodeOutputNames)); + } + return offset; +} + +/* +* validates input/output edges and throws exception if an edge +* does not have a name or if an edge has more than 1 names. +*/ +template +std::unordered_map> ConvertToEdgeIndexToNameMap( + const std::vector& edges, + flatbuffers::FlatBufferBuilder& builder) +{ + std::unordered_map> edgeIndexToNameMap; + for (auto& edge : edges) + { + uint32_t index; + if constexpr (std::is_same_v) + { + index = edge.GraphInputIndex; + } + else if constexpr (std::is_same_v) + { + index = edge.GraphOutputIndex; + } + + if (edge.Name.empty()) + { + throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " does not have name."); + } + + if (edgeIndexToNameMap.find(index) != edgeIndexToNameMap.end()) + { + flatbuffers::String* edgeName = ReadAs( + builder.GetCurrentBufferPointer(), + builder.GetSize() - edgeIndexToNameMap[index].o); + if (edge.Name != edgeName->str()) + { + throw std::invalid_argument("Graph input or output edge at index " + std::to_string(index) + " has more than 1 names."); + } + } + + edgeIndexToNameMap[index] = builder.CreateString(edge.Name); + } + return edgeIndexToNameMap; // NRVO will automatically move it. no need to use std::move +} + +void PopulateNonConstantNodeInputOutputCount( + const std::vector& nodes, + /*out*/ std::vector& nodeInputCounts, + /*out*/ std::vector& nodeOutputCounts) +{ + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(nodes.size()); nodeIndex++) + { + auto& node = nodes[nodeIndex]; + if (std::holds_alternative(node.Desc)) + { + auto& operatorNode = std::get(node.Desc); + nodeInputCounts[nodeIndex] = std::max( + nodeInputCounts[nodeIndex], + static_cast(operatorNode.GetInputTensors().size())); + + nodeOutputCounts[nodeIndex] = std::max( + nodeOutputCounts[nodeIndex], + static_cast(operatorNode.GetOutputTensors().size())); + } + } +} + +void PopulateConstantNodeInputOutputCount( + const std::vector& edges, + /*out*/std::vector& maxInputIndexForNodes, + /*out*/std::vector& maxOutputIndexForNodes) +{ + for (auto& edge : edges) + { + maxInputIndexForNodes[edge.ToNodeIndex] = std::max(maxInputIndexForNodes[edge.ToNodeIndex], edge.ToNodeInputIndex + 1); + maxOutputIndexForNodes[edge.FromNodeIndex] = std::max(maxOutputIndexForNodes[edge.FromNodeIndex], edge.FromNodeOutputIndex + 1); + } +} + +/* +* validates intermediate edge and throws exception if an edge +* does not have a name or if an edge has more than 1 names. +*/ +void PopulateNodeInputOutputNames( + flatbuffers::FlatBufferBuilder& builder, + const DmlSerializedGraphDesc& graphDesc, + const std::unordered_map>& graphInputIndexToNameMap, + const std::unordered_map>& graphOutputIndexToNameMap, + /*out*/std::vector>>& nodeToInputNames, + /*out*/std::vector>>& nodeToOutputNames) +{ + for (auto& edge : graphDesc.InputEdges) + { + nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = graphInputIndexToNameMap.at(edge.GraphInputIndex); + } + + for (auto& edge : graphDesc.OutputEdges) + { + nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = graphOutputIndexToNameMap.at(edge.GraphOutputIndex); + } + + std::unordered_map>> intermediateEdgeNames; + for (uint32_t edgeIndex = 0; edgeIndex < static_cast(graphDesc.IntermediateEdges.size()); edgeIndex++) + { + auto& edge = graphDesc.IntermediateEdges[edgeIndex]; + if (edge.Name.empty()) + { + throw std::invalid_argument( + "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + + " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " doesn't have name."); + } + + if (intermediateEdgeNames.find(edge.FromNodeIndex) != intermediateEdgeNames.end() && + intermediateEdgeNames[edge.FromNodeIndex].find(edge.FromNodeOutputIndex) != intermediateEdgeNames[edge.FromNodeIndex].end()) + { + flatbuffers::Offset edgeNameOffset = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + flatbuffers::String* edgeName = ReadAs( + builder.GetCurrentBufferPointer(), + builder.GetSize() - edgeNameOffset.o); + + if (edgeName->str() != edge.Name) + { + throw std::invalid_argument( + "Graph intermediate edge from nodeIndex:" + std::to_string(edge.FromNodeIndex) + + " & nodeOutputIndex:" + std::to_string(edge.FromNodeOutputIndex) + " has more than 1 names."); + } + } + else + { + intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = builder.CreateString(edge.Name.c_str()); + } + nodeToInputNames[edge.ToNodeIndex][edge.ToNodeInputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + nodeToOutputNames[edge.FromNodeIndex][edge.FromNodeOutputIndex] = intermediateEdgeNames[edge.FromNodeIndex][edge.FromNodeOutputIndex]; + } +} + + +/* +* - If an edge is connected to multiple nodes, then there will be multiple instances +* of input or intermediate edges, all with the same name. +* - The input will be validated incrementally throughout the execution +* of the method. +* - Handling of empty optional input/output/attibute for non-constant node: +* input/output +* - and will have an null entry +* but the actual OperatorNodeDesc variant's +* and will not have any entry. +* attribute +* - will have null entry +*/ +flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc) +{ + + flatbuffers::FlatBufferBuilder builder(1024); + if (graphDesc.Nodes.empty()) + { + return builder.Release(); + } + + // create input/output edge index to name map + std::unordered_map> graphInputIndexToNameMap = + ConvertToEdgeIndexToNameMap(graphDesc.InputEdges, builder); + std::unordered_map> graphOutputIndexToNameMap = + ConvertToEdgeIndexToNameMap(graphDesc.OutputEdges, builder); + + /* + * - Calculate number of input/output for each operator to allocate + * appropriate amount of memory for each node to store input/output names. + * - Non-constant node's input/output count can be determined by the + * AbstractOperatorDesc. + * - Constant node will only have outgoing edges and those outgoing edges + * will be intermediate edges. + */ + std::vector nodeInputCounts(graphDesc.Nodes.size(), 0); + std::vector nodeOutputCounts(graphDesc.Nodes.size(), 0); + PopulateNonConstantNodeInputOutputCount(graphDesc.Nodes, nodeInputCounts, nodeOutputCounts); + PopulateConstantNodeInputOutputCount(graphDesc.IntermediateEdges, nodeInputCounts, nodeOutputCounts); + + // populate node input/output names. + std::vector>> nodeToInputNames(graphDesc.Nodes.size()); + std::vector>> nodeToOutputNames(graphDesc.Nodes.size()); + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + { + nodeToInputNames[nodeIndex].assign(nodeInputCounts[nodeIndex], builder.CreateString(nullptr, 0)); + nodeToOutputNames[nodeIndex].assign(nodeOutputCounts[nodeIndex], builder.CreateString(nullptr, 0)); + } + PopulateNodeInputOutputNames(builder, graphDesc, graphInputIndexToNameMap, graphOutputIndexToNameMap, nodeToInputNames, nodeToOutputNames); + + // Create flatbuffer node objects + std::vector> nodes(graphDesc.Nodes.size()); + for (uint32_t nodeIndex = 0; nodeIndex < static_cast(graphDesc.Nodes.size()); nodeIndex++) + { + nodes[nodeIndex] = SerializeNode( + builder, + nodeIndex, + graphDesc.Nodes[nodeIndex], + nodeToInputNames[nodeIndex], + nodeToOutputNames[nodeIndex]); + } + + // Convert to std::vector to create the object. + std::vector> graphInputNames(graphDesc.InputCount, builder.CreateString(nullptr, 0)); + std::vector> graphOutputNames(graphDesc.OutputCount, builder.CreateString(nullptr, 0)); + for (const auto& [key, value] : graphInputIndexToNameMap) + { + graphInputNames[key] = value; + } + for (const auto& [key, value] : graphOutputIndexToNameMap) + { + graphOutputNames[key] = value; + } + + flatbuffers::Offset dmlGraphDescOffset = dml::ir::CreateDmlGraphDescDirect( + builder, + &nodes, + &graphInputNames, + &graphOutputNames); + builder.Finish(dmlGraphDescOffset); + return builder.Release(); +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp index 5c7b7bff1e370..0f0d445a95bae 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -180,32 +180,50 @@ namespace Dml // Convert partitionONNXGraph into DML EP GraphDesc ComPtr device; ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + // This map will be used to transfer the initializer to D3D12 system heap memory. + // 'serializedDmlGraphDesc' will have constant input as intermediate edges, that's why + // we need a mapping between intermediateEdgeIndex and indexedSubGraph's (a given partition) + // input arg index. + // For ex: Let's say intermediate edge index = idx, then + // indexedSubGraphInputArgIdx = constantEdgeIdxToSubgraphInputArgIdxMap[idx]; + // corresponding constant tensor = initializerNameToInitializerMap[indexedSubGraph.GetMetaDef()->inputs[indexedSubGraphInputArgIdx]] + // We are using intermediate edge index as a key because same constant tensor can be used by + // multiple nodes. + std::unordered_map serializedGraphInputIndexToSubgraphInputIndex; + std::unordered_map serializedGraphLargeConstantNameToSubgraphInputIndex; + std::vector> smallConstantData; GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), m_isInitializerTransferable, m_partitionNodePropsMap, - device.Get(), providerImpl, m_modelPath, m_subgraphNodePointers, m_subgraphInputs, - m_subgraphOutputs); + m_subgraphOutputs, + serializedGraphInputIndexToSubgraphInputIndex, + serializedGraphLargeConstantNameToSubgraphInputIndex, + smallConstantData); m_outputShapes = graphDesc.outputShapes; // Walk through each graph edge and mark used inputs m_inputsUsed.resize(fusedNodeInputCount, false); - for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) - { - m_inputsUsed[edge.GraphInputIndex] = true; + for (auto it = serializedGraphInputIndexToSubgraphInputIndex.begin(); it != serializedGraphInputIndexToSubgraphInputIndex.end(); it++) { + m_inputsUsed[it->second] = true; + } + for (auto it = serializedGraphLargeConstantNameToSubgraphInputIndex.begin(); it != serializedGraphLargeConstantNameToSubgraphInputIndex.end(); it++) { + m_inputsUsed[it->second] = true; } // Compile the operator m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( graphDesc, *m_indexedSubGraph, - providerImpl); + providerImpl, + &serializedGraphInputIndexToSubgraphInputIndex, + &serializedGraphLargeConstantNameToSubgraphInputIndex); // Queue references to objects which must be kept alive until resulting GPU work completes m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h index a5415ba85f3d3..e1e7eacfbd85d 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/ApiTraits.h @@ -24,8 +24,8 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 161; - static constexpr size_t ActivationFunctionCount = 24; + static constexpr auto ValueCount = 168; + static constexpr size_t ActivationFunctionCount = 26; }; template <> @@ -62,7 +62,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 4; + static constexpr auto ValueCount = 5; }; template <> @@ -86,7 +86,7 @@ struct EnumTraits template <> struct EnumTraits { - static constexpr auto ValueCount = 8; + static constexpr auto ValueCount = 13; }; template <> @@ -119,6 +119,12 @@ struct EnumTraits static constexpr auto ValueCount = 1; }; +template <> +struct EnumTraits +{ + static constexpr auto ValueCount = 5; +}; + template constexpr auto EnumValueCount = EnumTraits::ValueCount; @@ -495,12 +501,6 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ROI_POOLING; }; -template <> -struct OperatorDescTraits -{ - static constexpr DML_OPERATOR_TYPE Type = (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; -}; - template <> struct OperatorDescTraits { @@ -1029,6 +1029,24 @@ struct OperatorDescTraits static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_DIAGONAL_MATRIX1; }; +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT; +}; + template <> struct OperatorDescTraits { @@ -1174,9 +1192,15 @@ struct OperatorDescTraits }; template <> -struct OperatorDescTraits +struct OperatorDescTraits { - static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_MULTIHEAD_ATTENTION; + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_SWISH; +}; + +template <> +struct OperatorDescTraits +{ + static constexpr DML_OPERATOR_TYPE Type = DML_OPERATOR_ACTIVATION_HARD_SWISH; }; template @@ -1502,12 +1526,6 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ROI_POOLING> using DescType = DML_ROI_POOLING_OPERATOR_DESC; }; -template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> -{ - using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; -}; - template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_SLICE> { @@ -2036,6 +2054,24 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_DIAGONAL_MATRIX1> using DescType = DML_DIAGONAL_MATRIX1_OPERATOR_DESC; }; +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +{ + using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING> +{ + using DescType = DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT> +{ + using DescType = DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC; +}; + template <> struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_ELU> { @@ -2181,14 +2217,20 @@ struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_GELU> }; template <> -struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_MULTIHEAD_ATTENTION> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_SWISH> { - using DescType = DML_MULTIHEAD_ATTENTION_OPERATOR_DESC; + using DescType = DML_ACTIVATION_SWISH_OPERATOR_DESC; +}; + +template <> +struct OperatorTypeTraits<(DML_OPERATOR_TYPE)DML_OPERATOR_ACTIVATION_HARD_SWISH> +{ + using DescType = DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC; }; // Calls a visitor functor, supplying an empty operator desc corresponding to the given DML_OPERATOR_TYPE as // the first argument. -// +// // For example: // Visit(DML_OPERATOR_ELEMENT_WISE_IDENTITY, [](auto tag) { // using T = decltype(tag); // T is one of the DML_*_OPERATOR_DESC structs @@ -2485,6 +2527,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_DIAGONAL_MATRIX1_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_MULTIHEAD_ATTENTION: return std::invoke(std::forward(visitor), DML_MULTIHEAD_ATTENTION_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return std::invoke(std::forward(visitor), DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_ELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_ELU_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_CELU: @@ -2533,13 +2579,10 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args return std::invoke(std::forward(visitor), DML_ACTIVATION_SHRINK_OPERATOR_DESC{}, std::forward(args)...); case DML_OPERATOR_ACTIVATION_GELU: return std::invoke(std::forward(visitor), DML_ACTIVATION_GELU_OPERATOR_DESC{}, std::forward(args)...); - -#pragma warning(push) -#pragma warning(disable: 4063) - case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: - return std::invoke(std::forward(visitor), DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC{}, std::forward(args)...); -#pragma warning(pop) - + case DML_OPERATOR_ACTIVATION_SWISH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_SWISH_OPERATOR_DESC{}, std::forward(args)...); + case DML_OPERATOR_ACTIVATION_HARD_SWISH: + return std::invoke(std::forward(visitor), DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC{}, std::forward(args)...); default: ORT_THROW_HR(E_INVALIDARG); return std::invoke(std::forward(visitor), DML_ACTIVATION_RELU_OPERATOR_DESC{}, std::forward(args)...); @@ -2547,7 +2590,55 @@ auto OperatorTypeVisitor(DML_OPERATOR_TYPE type, Visitor&& visitor, Ts&&... args } #pragma warning(pop) +namespace StringifyHelpers +{ +template +inline gsl::czstring ToString(T value) +{ +#ifndef WAI_BUILD_LINUX + // Clang will instantiate this template even if it isn't used, + // so this static_assert will always fire and break the build. + static_assert(false, "Not implemented for this type"); +#endif +} + +template <> +inline gsl::czstring ToString(DML_TENSOR_DATA_TYPE value) +{ + switch (value) + { + case DML_TENSOR_DATA_TYPE_UNKNOWN: return "DML_TENSOR_DATA_TYPE_UNKNOWN"; + case DML_TENSOR_DATA_TYPE_FLOAT32: return "DML_TENSOR_DATA_TYPE_FLOAT32"; + case DML_TENSOR_DATA_TYPE_FLOAT16: return "DML_TENSOR_DATA_TYPE_FLOAT16"; + case DML_TENSOR_DATA_TYPE_UINT32: return "DML_TENSOR_DATA_TYPE_UINT32"; + case DML_TENSOR_DATA_TYPE_UINT16: return "DML_TENSOR_DATA_TYPE_UINT16"; + case DML_TENSOR_DATA_TYPE_UINT8: return "DML_TENSOR_DATA_TYPE_UINT8"; + case DML_TENSOR_DATA_TYPE_INT32: return "DML_TENSOR_DATA_TYPE_INT32"; + case DML_TENSOR_DATA_TYPE_INT16: return "DML_TENSOR_DATA_TYPE_INT16"; + case DML_TENSOR_DATA_TYPE_INT8: return "DML_TENSOR_DATA_TYPE_INT8"; + case DML_TENSOR_DATA_TYPE_FLOAT64: return "DML_TENSOR_DATA_TYPE_FLOAT64"; + case DML_TENSOR_DATA_TYPE_UINT64: return "DML_TENSOR_DATA_TYPE_UINT64"; + case DML_TENSOR_DATA_TYPE_INT64: return "DML_TENSOR_DATA_TYPE_INT64"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_TENSOR_TYPE value) +{ + switch (value) + { + case DML_TENSOR_TYPE_INVALID: return "DML_TENSOR_TYPE_INVALID"; + case DML_TENSOR_TYPE_BUFFER: return "DML_TENSOR_TYPE_BUFFER"; + default: + assert(false); + return ""; + } +} +template <> inline gsl::czstring ToString(DML_OPERATOR_TYPE value) { switch (value) @@ -2561,9 +2652,6 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATAN: return "DML_OPERATOR_ELEMENT_WISE_ATAN"; case DML_OPERATOR_ELEMENT_WISE_CEIL: return "DML_OPERATOR_ELEMENT_WISE_CEIL"; case DML_OPERATOR_ELEMENT_WISE_CLIP: return "DML_OPERATOR_ELEMENT_WISE_CLIP"; - case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; - case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; - case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; case DML_OPERATOR_ELEMENT_WISE_COS: return "DML_OPERATOR_ELEMENT_WISE_COS"; case DML_OPERATOR_ELEMENT_WISE_DIVIDE: return "DML_OPERATOR_ELEMENT_WISE_DIVIDE"; case DML_OPERATOR_ELEMENT_WISE_EXP: return "DML_OPERATOR_ELEMENT_WISE_EXP"; @@ -2587,24 +2675,41 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_RECIP: return "DML_OPERATOR_ELEMENT_WISE_RECIP"; case DML_OPERATOR_ELEMENT_WISE_SIN: return "DML_OPERATOR_ELEMENT_WISE_SIN"; case DML_OPERATOR_ELEMENT_WISE_SQRT: return "DML_OPERATOR_ELEMENT_WISE_SQRT"; - case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; - case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; case DML_OPERATOR_ELEMENT_WISE_SUBTRACT: return "DML_OPERATOR_ELEMENT_WISE_SUBTRACT"; case DML_OPERATOR_ELEMENT_WISE_TAN: return "DML_OPERATOR_ELEMENT_WISE_TAN"; case DML_OPERATOR_ELEMENT_WISE_THRESHOLD: return "DML_OPERATOR_ELEMENT_WISE_THRESHOLD"; case DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZE_LINEAR"; case DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR: return "DML_OPERATOR_ELEMENT_WISE_DEQUANTIZE_LINEAR"; + case DML_OPERATOR_ACTIVATION_ELU: return "DML_OPERATOR_ACTIVATION_ELU"; + case DML_OPERATOR_ACTIVATION_CELU: return "DML_OPERATOR_ACTIVATION_CELU"; + case DML_OPERATOR_ACTIVATION_HARDMAX: return "DML_OPERATOR_ACTIVATION_HARDMAX"; + case DML_OPERATOR_ACTIVATION_HARDMAX1: return "DML_OPERATOR_ACTIVATION_HARDMAX1"; + case DML_OPERATOR_ACTIVATION_HARD_SIGMOID: return "DML_OPERATOR_ACTIVATION_HARD_SIGMOID"; + case DML_OPERATOR_ACTIVATION_IDENTITY: return "DML_OPERATOR_ACTIVATION_IDENTITY"; + case DML_OPERATOR_ACTIVATION_LEAKY_RELU: return "DML_OPERATOR_ACTIVATION_LEAKY_RELU"; + case DML_OPERATOR_ACTIVATION_LINEAR: return "DML_OPERATOR_ACTIVATION_LINEAR"; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX"; + case DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_LOG_SOFTMAX1"; + case DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU: return "DML_OPERATOR_ACTIVATION_PARAMETERIZED_RELU"; + case DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_PARAMETRIC_SOFTPLUS"; + case DML_OPERATOR_ACTIVATION_RELU: return "DML_OPERATOR_ACTIVATION_RELU"; + case DML_OPERATOR_ACTIVATION_SCALED_ELU: return "DML_OPERATOR_ACTIVATION_SCALED_ELU"; + case DML_OPERATOR_ACTIVATION_SCALED_TANH: return "DML_OPERATOR_ACTIVATION_SCALED_TANH"; + case DML_OPERATOR_ACTIVATION_SIGMOID: return "DML_OPERATOR_ACTIVATION_SIGMOID"; + case DML_OPERATOR_ACTIVATION_SOFTMAX: return "DML_OPERATOR_ACTIVATION_SOFTMAX"; + case DML_OPERATOR_ACTIVATION_SOFTMAX1: return "DML_OPERATOR_ACTIVATION_SOFTMAX1"; + case DML_OPERATOR_ACTIVATION_SOFTPLUS: return "DML_OPERATOR_ACTIVATION_SOFTPLUS"; + case DML_OPERATOR_ACTIVATION_SOFTSIGN: return "DML_OPERATOR_ACTIVATION_SOFTSIGN"; + case DML_OPERATOR_ACTIVATION_TANH: return "DML_OPERATOR_ACTIVATION_TANH"; + case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return "DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU"; case DML_OPERATOR_CONVOLUTION: return "DML_OPERATOR_CONVOLUTION"; case DML_OPERATOR_GEMM: return "DML_OPERATOR_GEMM"; case DML_OPERATOR_REDUCE: return "DML_OPERATOR_REDUCE"; - case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; - case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_AVERAGE_POOLING: return "DML_OPERATOR_AVERAGE_POOLING"; case DML_OPERATOR_AVERAGE_POOLING1: return "DML_OPERATOR_AVERAGE_POOLING1"; case DML_OPERATOR_LP_POOLING: return "DML_OPERATOR_LP_POOLING"; case DML_OPERATOR_LP_POOLING1: return "DML_OPERATOR_LP_POOLING1"; case DML_OPERATOR_MAX_POOLING: return "DML_OPERATOR_MAX_POOLING"; - case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_ROI_POOLING: return "DML_OPERATOR_ROI_POOLING"; case DML_OPERATOR_SLICE: return "DML_OPERATOR_SLICE"; case DML_OPERATOR_CAST: return "DML_OPERATOR_CAST"; @@ -2620,18 +2725,15 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_TILE: return "DML_OPERATOR_TILE"; case DML_OPERATOR_TOP_K: return "DML_OPERATOR_TOP_K"; case DML_OPERATOR_BATCH_NORMALIZATION: return "DML_OPERATOR_BATCH_NORMALIZATION"; - case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; - case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; case DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION: return "DML_OPERATOR_MEAN_VARIANCE_NORMALIZATION"; case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION"; - case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; case DML_OPERATOR_LP_NORMALIZATION: return "DML_OPERATOR_LP_NORMALIZATION"; case DML_OPERATOR_RNN: return "DML_OPERATOR_RNN"; case DML_OPERATOR_LSTM: return "DML_OPERATOR_LSTM"; case DML_OPERATOR_GRU: return "DML_OPERATOR_GRU"; case DML_OPERATOR_ELEMENT_WISE_SIGN: return "DML_OPERATOR_ELEMENT_WISE_SIGN"; case DML_OPERATOR_ELEMENT_WISE_IS_NAN: return "DML_OPERATOR_ELEMENT_WISE_IS_NAN"; - case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; case DML_OPERATOR_ELEMENT_WISE_ERF: return "DML_OPERATOR_ELEMENT_WISE_ERF"; case DML_OPERATOR_ELEMENT_WISE_SINH: return "DML_OPERATOR_ELEMENT_WISE_SINH"; case DML_OPERATOR_ELEMENT_WISE_COSH: return "DML_OPERATOR_ELEMENT_WISE_COSH"; @@ -2641,6 +2743,8 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_ATANH: return "DML_OPERATOR_ELEMENT_WISE_ATANH"; case DML_OPERATOR_ELEMENT_WISE_IF: return "DML_OPERATOR_ELEMENT_WISE_IF"; case DML_OPERATOR_ELEMENT_WISE_ADD1: return "DML_OPERATOR_ELEMENT_WISE_ADD1"; + case DML_OPERATOR_ACTIVATION_SHRINK: return "DML_OPERATOR_ACTIVATION_SHRINK"; + case DML_OPERATOR_MAX_POOLING1: return "DML_OPERATOR_MAX_POOLING1"; case DML_OPERATOR_MAX_UNPOOLING: return "DML_OPERATOR_MAX_UNPOOLING"; case DML_OPERATOR_DIAGONAL_MATRIX: return "DML_OPERATOR_DIAGONAL_MATRIX"; case DML_OPERATOR_SCATTER: return "DML_OPERATOR_SCATTER"; @@ -2652,10 +2756,9 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_ELEMENT_WISE_IS_INFINITY: return "DML_OPERATOR_ELEMENT_WISE_IS_INFINITY"; case DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_TRUNCATE"; case DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR: return "DML_OPERATOR_ELEMENT_WISE_MODULUS_FLOOR"; - case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_FILL_VALUE_SEQUENCE: return "DML_OPERATOR_FILL_VALUE_SEQUENCE"; + case DML_OPERATOR_FILL_VALUE_CONSTANT: return "DML_OPERATOR_FILL_VALUE_CONSTANT"; case DML_OPERATOR_CUMULATIVE_SUMMATION: return "DML_OPERATOR_CUMULATIVE_SUMMATION"; - case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; case DML_OPERATOR_REVERSE_SUBSEQUENCES: return "DML_OPERATOR_REVERSE_SUBSEQUENCES"; case DML_OPERATOR_GATHER_ELEMENTS: return "DML_OPERATOR_GATHER_ELEMENTS"; case DML_OPERATOR_GATHER_ND: return "DML_OPERATOR_GATHER_ND"; @@ -2684,20 +2787,278 @@ inline gsl::czstring ToString(DML_OPERATOR_TYPE value) case DML_OPERATOR_RESAMPLE_GRAD: return "DML_OPERATOR_RESAMPLE_GRAD"; case DML_OPERATOR_SLICE_GRAD: return "DML_OPERATOR_SLICE_GRAD"; case DML_OPERATOR_ADAM_OPTIMIZER: return "DML_OPERATOR_ADAM_OPTIMIZER"; + case DML_OPERATOR_ARGMIN: return "DML_OPERATOR_ARGMIN"; + case DML_OPERATOR_ARGMAX: return "DML_OPERATOR_ARGMAX"; case DML_OPERATOR_ROI_ALIGN: return "DML_OPERATOR_ROI_ALIGN"; - case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; case DML_OPERATOR_GATHER_ND1: return "DML_OPERATOR_GATHER_ND1"; - case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ELEMENT_WISE_ATAN_YX: return "DML_OPERATOR_ELEMENT_WISE_ATAN_YX"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD"; + case DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE: return "DML_OPERATOR_ELEMENT_WISE_DIFFERENCE_SQUARE"; + case DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD: return "DML_OPERATOR_LOCAL_RESPONSE_NORMALIZATION_GRAD"; + case DML_OPERATOR_CUMULATIVE_PRODUCT: return "DML_OPERATOR_CUMULATIVE_PRODUCT"; + case DML_OPERATOR_BATCH_NORMALIZATION_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_GRAD"; + case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING_GRAD"; case DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD: return "DML_OPERATOR_ELEMENT_WISE_QUANTIZED_LINEAR_ADD"; - case DML_OPERATOR_ROI_ALIGN_GRAD: return "DML_OPERATOR_ROI_ALIGN_GRAD"; - case DML_OPERATOR_BATCH_NORMALIZATION_TRAINING: return "DML_OPERATOR_BATCH_NORMALIZATION_TRAINING"; + case DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR: return "DML_OPERATOR_DYNAMIC_QUANTIZE_LINEAR"; + case DML_OPERATOR_ROI_ALIGN1: return "DML_OPERATOR_ROI_ALIGN1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP1: return "DML_OPERATOR_ELEMENT_WISE_CLIP1"; + case DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1: return "DML_OPERATOR_ELEMENT_WISE_CLIP_GRAD1"; + case DML_OPERATOR_ELEMENT_WISE_NEGATE: return "DML_OPERATOR_ELEMENT_WISE_NEGATE"; + case DML_OPERATOR_ACTIVATION_GELU: return "DML_OPERATOR_ACTIVATION_GELU"; + case DML_OPERATOR_ACTIVATION_SWISH: return "DML_OPERATOR_ACTIVATION_SWISH"; + case DML_OPERATOR_ACTIVATION_HARD_SWISH: return "DML_OPERATOR_ACTIVATION_HARD_SWISH"; case DML_OPERATOR_RESAMPLE2: return "DML_OPERATOR_RESAMPLE2"; case DML_OPERATOR_RESAMPLE_GRAD1: return "DML_OPERATOR_RESAMPLE_GRAD1"; case DML_OPERATOR_DIAGONAL_MATRIX1: return "DML_OPERATOR_DIAGONAL_MATRIX1"; case DML_OPERATOR_MULTIHEAD_ATTENTION: return "DML_OPERATOR_MULTIHEAD_ATTENTION"; + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING"; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_BINDING_TYPE value) +{ + switch (value) + { + case DML_BINDING_TYPE_NONE: return "DML_BINDING_TYPE_NONE"; + case DML_BINDING_TYPE_BUFFER: return "DML_BINDING_TYPE_BUFFER"; + case DML_BINDING_TYPE_BUFFER_ARRAY: return "DML_BINDING_TYPE_BUFFER_ARRAY"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_REDUCE_FUNCTION value) +{ + switch (value) + { + case DML_REDUCE_FUNCTION_ARGMAX: return "DML_REDUCE_FUNCTION_ARGMAX"; + case DML_REDUCE_FUNCTION_ARGMIN: return "DML_REDUCE_FUNCTION_ARGMIN"; + case DML_REDUCE_FUNCTION_AVERAGE: return "DML_REDUCE_FUNCTION_AVERAGE"; + case DML_REDUCE_FUNCTION_L1: return "DML_REDUCE_FUNCTION_L1"; + case DML_REDUCE_FUNCTION_L2: return "DML_REDUCE_FUNCTION_L2"; + case DML_REDUCE_FUNCTION_LOG_SUM: return "DML_REDUCE_FUNCTION_LOG_SUM"; + case DML_REDUCE_FUNCTION_LOG_SUM_EXP: return "DML_REDUCE_FUNCTION_LOG_SUM_EXP"; + case DML_REDUCE_FUNCTION_MAX: return "DML_REDUCE_FUNCTION_MAX"; + case DML_REDUCE_FUNCTION_MIN: return "DML_REDUCE_FUNCTION_MIN"; + case DML_REDUCE_FUNCTION_MULTIPLY: return "DML_REDUCE_FUNCTION_MULTIPLY"; + case DML_REDUCE_FUNCTION_SUM: return "DML_REDUCE_FUNCTION_SUM"; + case DML_REDUCE_FUNCTION_SUM_SQUARE: return "DML_REDUCE_FUNCTION_SUM_SQUARE"; default: assert(false); return ""; } } + +template <> +inline gsl::czstring ToString(DML_MATRIX_TRANSFORM value) +{ + switch (value) + { + case DML_MATRIX_TRANSFORM_NONE: return "DML_MATRIX_TRANSFORM_NONE"; + case DML_MATRIX_TRANSFORM_TRANSPOSE: return "DML_MATRIX_TRANSFORM_TRANSPOSE"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_CONVOLUTION_MODE value) +{ + switch (value) + { + case DML_CONVOLUTION_MODE_CONVOLUTION: return "DML_CONVOLUTION_MODE_CONVOLUTION"; + case DML_CONVOLUTION_MODE_CROSS_CORRELATION: return "DML_CONVOLUTION_MODE_CROSS_CORRELATION"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_CONVOLUTION_DIRECTION value) +{ + switch (value) + { + case DML_CONVOLUTION_DIRECTION_FORWARD: return "DML_CONVOLUTION_DIRECTION_FORWARD"; + case DML_CONVOLUTION_DIRECTION_BACKWARD: return "DML_CONVOLUTION_DIRECTION_BACKWARD"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_PADDING_MODE value) +{ + switch (value) + { + case DML_PADDING_MODE_CONSTANT: return "DML_PADDING_MODE_CONSTANT"; + case DML_PADDING_MODE_EDGE: return "DML_PADDING_MODE_EDGE"; + case DML_PADDING_MODE_REFLECTION: return "DML_PADDING_MODE_REFLECTION"; + case DML_PADDING_MODE_SYMMETRIC: return "DML_PADDING_MODE_SYMMETRIC"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_INTERPOLATION_MODE value) +{ + switch (value) + { + case DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR: return "DML_INTERPOLATION_MODE_NEAREST_NEIGHBOR"; + case DML_INTERPOLATION_MODE_LINEAR: return "DML_INTERPOLATION_MODE_LINEAR"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_RECURRENT_NETWORK_DIRECTION value) +{ + switch (value) + { + case DML_RECURRENT_NETWORK_DIRECTION_FORWARD: return "DML_RECURRENT_NETWORK_DIRECTION_FORWARD"; + case DML_RECURRENT_NETWORK_DIRECTION_BACKWARD: return "DML_RECURRENT_NETWORK_DIRECTION_BACKWARD"; + case DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL: return "DML_RECURRENT_NETWORK_DIRECTION_BIDIRECTIONAL"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_FEATURE value) +{ + switch (value) + { + case DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT: return "DML_FEATURE_TENSOR_DATA_TYPE_SUPPORT"; + case DML_FEATURE_FEATURE_LEVELS: return "DML_FEATURE_FEATURE_LEVELS"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_FEATURE_LEVEL value) +{ + switch (value) + { + case DML_FEATURE_LEVEL_1_0: return "DML_FEATURE_LEVEL_1_0"; + case DML_FEATURE_LEVEL_2_0: return "DML_FEATURE_LEVEL_2_0"; + case DML_FEATURE_LEVEL_2_1: return "DML_FEATURE_LEVEL_2_1"; + case DML_FEATURE_LEVEL_3_0: return "DML_FEATURE_LEVEL_3_0"; + case DML_FEATURE_LEVEL_3_1: return "DML_FEATURE_LEVEL_3_1"; + case DML_FEATURE_LEVEL_4_0: return "DML_FEATURE_LEVEL_4_0"; + case DML_FEATURE_LEVEL_4_1: return "DML_FEATURE_LEVEL_4_1"; + case DML_FEATURE_LEVEL_5_0: return "DML_FEATURE_LEVEL_5_0"; + case DML_FEATURE_LEVEL_5_1: return "DML_FEATURE_LEVEL_5_1"; + case DML_FEATURE_LEVEL_5_2: return "DML_FEATURE_LEVEL_5_2"; + case DML_FEATURE_LEVEL_6_0: return "DML_FEATURE_LEVEL_6_0"; + case DML_FEATURE_LEVEL_6_1: return "DML_FEATURE_LEVEL_6_1"; + case DML_FEATURE_LEVEL_6_2: return "DML_FEATURE_LEVEL_6_2"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_IS_INFINITY_MODE value) +{ + switch (value) + { + case DML_IS_INFINITY_MODE_EITHER: return "DML_IS_INFINITY_MODE_EITHER"; + case DML_IS_INFINITY_MODE_POSITIVE: return "DML_IS_INFINITY_MODE_POSITIVE"; + case DML_IS_INFINITY_MODE_NEGATIVE: return "DML_IS_INFINITY_MODE_NEGATIVE"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_DEPTH_SPACE_ORDER value) +{ + switch (value) + { + case DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW: return "DML_DEPTH_SPACE_ORDER_DEPTH_COLUMN_ROW"; + case DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH: return "DML_DEPTH_SPACE_ORDER_COLUMN_ROW_DEPTH"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_AXIS_DIRECTION value) +{ + switch (value) + { + case DML_AXIS_DIRECTION_INCREASING: return "DML_AXIS_DIRECTION_INCREASING"; + case DML_AXIS_DIRECTION_DECREASING: return "DML_AXIS_DIRECTION_DECREASING"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_ROUNDING_MODE value) +{ + switch (value) + { + case DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN: return "DML_ROUNDING_MODE_HALVES_TO_NEAREST_EVEN"; + case DML_ROUNDING_MODE_TOWARD_ZERO: return "DML_ROUNDING_MODE_TOWARD_ZERO"; + case DML_ROUNDING_MODE_TOWARD_INFINITY: return "DML_ROUNDING_MODE_TOWARD_INFINITY"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_RANDOM_GENERATOR_TYPE value) +{ + switch (value) + { + case DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10: return "DML_RANDOM_GENERATOR_TYPE_PHILOX_4X32_10"; + default: + assert(false); + return ""; + } +} + +template <> +inline gsl::czstring ToString(DML_MULTIHEAD_ATTENTION_MASK_TYPE value) +{ + switch (value) + { + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_NONE"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_LENGTH"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_SEQUENCE_END_START"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_KEY_QUERY_SEQUENCE_LENGTH_START_END"; + case DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN: return "DML_MULTIHEAD_ATTENTION_MASK_TYPE_BOOLEAN"; + default: + assert(false); + return ""; + } +} + + +template +T FromString(std::string_view value); + +} } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h index 2a82c12872a72..5fe6603c2a0bf 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DirectMLSchema.h @@ -618,7 +618,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_THRESHOLD_OPERATOR_SCHEMA { constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, }; @@ -633,7 +633,7 @@ constexpr DML_OPERATOR_SCHEMA DML_ELEMENT_WISE_QUANTIZE_LINEAR_OPERATOR_SCHEMA { constexpr DML_SCHEMA_FIELD DML_ELEMENT_WISE_DEQUANTIZE_LINEAR_OPERATOR_SCHEMA_FIELDS[4] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ZeroPointTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, }; @@ -869,31 +869,6 @@ constexpr DML_OPERATOR_SCHEMA DML_ROI_POOLING_OPERATOR_SCHEMA { DML_ROI_POOLING_OPERATOR_SCHEMA_FIELDS, }; - -constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, -}; - -constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { - "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", - static_cast(DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING), - DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, - 13, - DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, -}; - constexpr DML_SCHEMA_FIELD DML_SLICE_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -1146,7 +1121,7 @@ constexpr DML_SCHEMA_FIELD DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleGradientTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputBiasGradientTensor", false }, - DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Epsilon", false }, }; constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_GRAD_OPERATOR_SCHEMA { @@ -2312,7 +2287,7 @@ constexpr DML_OPERATOR_SCHEMA DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA { DML_BATCH_NORMALIZATION_TRAINING_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{ +constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "InterpolationMode", false }, @@ -2323,7 +2298,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE2_OPERATOR_SCHEMA_FIELDS[8]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE2_OPERATOR_SCHEMA { "DML_OPERATOR_RESAMPLE2", DML_OPERATOR_RESAMPLE2, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2342,7 +2317,7 @@ constexpr DML_SCHEMA_FIELD DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS[8]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY, "OutputPixelOffsets", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA { "DML_OPERATOR_RESAMPLE_GRAD1", DML_OPERATOR_RESAMPLE_GRAD1, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2350,7 +2325,7 @@ constexpr DML_OPERATOR_SCHEMA DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA{ DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA_FIELDS, }; -constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{ +constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", true }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "ValueDataType", false }, @@ -2359,7 +2334,7 @@ constexpr DML_SCHEMA_FIELD DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA_FIELDS[6]{ DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_INT, "DiagonalFillEnd", false }, }; -constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA{ +constexpr DML_OPERATOR_SCHEMA DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA { "DML_OPERATOR_DIAGONAL_MATRIX1", DML_OPERATOR_DIAGONAL_MATRIX1, DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, @@ -2396,6 +2371,48 @@ constexpr DML_OPERATOR_SCHEMA DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA { DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS[13] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "DimensionCount", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Strides", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "WindowSize", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "StartPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "EndPadding", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT_ARRAY, "Dilations", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_UINT, "IncludePadding", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA { + "DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING", + DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 13, + DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS[8] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "ATensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "AZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BScaleTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BZeroPointTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "BiasTensor", true }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA { + "DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT", + DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 8, + DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA_FIELDS, +}; constexpr DML_SCHEMA_FIELD DML_ACTIVATION_ELU_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, @@ -2732,6 +2749,35 @@ constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_GELU_OPERATOR_SCHEMA { DML_ACTIVATION_GELU_OPERATOR_SCHEMA_FIELDS, }; +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS[3] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "SigmoidInputScale", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_SWISH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_SWISH", + DML_OPERATOR_ACTIVATION_SWISH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 3, + DML_ACTIVATION_SWISH_OPERATOR_SCHEMA_FIELDS, +}; + +constexpr DML_SCHEMA_FIELD DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS[4] { + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_OUTPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "OutputTensor", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Alpha", false }, + DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_ATTRIBUTE, DML_SCHEMA_FIELD_TYPE_FLOAT, "Beta", false }, +}; + +constexpr DML_OPERATOR_SCHEMA DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA { + "DML_OPERATOR_ACTIVATION_HARD_SWISH", + DML_OPERATOR_ACTIVATION_HARD_SWISH, + DML_SCHEMA_OPERATOR_SUPPORT_FLAG_NONE, + 4, + DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA_FIELDS, +}; + constexpr DML_SCHEMA_FIELD DML_RNN_ZERO_OPERATOR_SCHEMA_FIELDS[3] { DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "InputTensor", false }, DML_SCHEMA_FIELD { DML_SCHEMA_FIELD_KIND_INPUT_TENSOR, DML_SCHEMA_FIELD_TYPE_TENSOR_DESC, "SequenceLengthsTensor", false }, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h new file mode 100644 index 0000000000000..72059b9a3f911 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDesc_generated.h @@ -0,0 +1,788 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ +#define FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ + +#include "flatbuffers/flatbuffers.h" + +#include "OperatorFieldTypes_generated.h" + +namespace dml { +namespace ir { + +struct ConstantRawData; +struct ConstantRawDataBuilder; + +struct ConstantName; +struct ConstantNameBuilder; + +struct ConstantNodeDesc; +struct ConstantNodeDescBuilder; + +struct DmlBufferTensorDesc; +struct DmlBufferTensorDescBuilder; + +struct OperatorNodeDesc; +struct OperatorNodeDescBuilder; + +struct DmlGraphNode; +struct DmlGraphNodeBuilder; + +struct DmlGraphDesc; +struct DmlGraphDescBuilder; + +enum ConstantNodeDescDetail { + ConstantNodeDescDetail_NONE = 0, + ConstantNodeDescDetail_ConstantName = 1, + ConstantNodeDescDetail_ConstantRawData = 2, + ConstantNodeDescDetail_MIN = ConstantNodeDescDetail_NONE, + ConstantNodeDescDetail_MAX = ConstantNodeDescDetail_ConstantRawData +}; + +inline const ConstantNodeDescDetail (&EnumValuesConstantNodeDescDetail())[3] { + static const ConstantNodeDescDetail values[] = { + ConstantNodeDescDetail_NONE, + ConstantNodeDescDetail_ConstantName, + ConstantNodeDescDetail_ConstantRawData + }; + return values; +} + +inline const char * const *EnumNamesConstantNodeDescDetail() { + static const char * const names[4] = { + "NONE", + "ConstantName", + "ConstantRawData", + nullptr + }; + return names; +} + +inline const char *EnumNameConstantNodeDescDetail(ConstantNodeDescDetail e) { + if (flatbuffers::IsOutRange(e, ConstantNodeDescDetail_NONE, ConstantNodeDescDetail_ConstantRawData)) return ""; + const size_t index = static_cast(e); + return EnumNamesConstantNodeDescDetail()[index]; +} + +template struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_NONE; +}; + +template<> struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantName; +}; + +template<> struct ConstantNodeDescDetailTraits { + static const ConstantNodeDescDetail enum_value = ConstantNodeDescDetail_ConstantRawData; +}; + +bool VerifyConstantNodeDescDetail(flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type); +bool VerifyConstantNodeDescDetailVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +enum NodeDesc { + NodeDesc_NONE = 0, + NodeDesc_OperatorNodeDesc = 1, + NodeDesc_ConstantNodeDesc = 2, + NodeDesc_MIN = NodeDesc_NONE, + NodeDesc_MAX = NodeDesc_ConstantNodeDesc +}; + +inline const NodeDesc (&EnumValuesNodeDesc())[3] { + static const NodeDesc values[] = { + NodeDesc_NONE, + NodeDesc_OperatorNodeDesc, + NodeDesc_ConstantNodeDesc + }; + return values; +} + +inline const char * const *EnumNamesNodeDesc() { + static const char * const names[4] = { + "NONE", + "OperatorNodeDesc", + "ConstantNodeDesc", + nullptr + }; + return names; +} + +inline const char *EnumNameNodeDesc(NodeDesc e) { + if (flatbuffers::IsOutRange(e, NodeDesc_NONE, NodeDesc_ConstantNodeDesc)) return ""; + const size_t index = static_cast(e); + return EnumNamesNodeDesc()[index]; +} + +template struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_NONE; +}; + +template<> struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_OperatorNodeDesc; +}; + +template<> struct NodeDescTraits { + static const NodeDesc enum_value = NodeDesc_ConstantNodeDesc; +}; + +bool VerifyNodeDesc(flatbuffers::Verifier &verifier, const void *obj, NodeDesc type); +bool VerifyNodeDescVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +struct ConstantRawData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantRawDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct ConstantRawDataBuilder { + typedef ConstantRawData Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(ConstantRawData::VT_DATA, data); + } + explicit ConstantRawDataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantRawDataBuilder &operator=(const ConstantRawDataBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantRawData( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + ConstantRawDataBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateConstantRawDataDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::CreateConstantRawData( + _fbb, + data__); +} + +struct ConstantName FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantNameBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + verifier.EndTable(); + } +}; + +struct ConstantNameBuilder { + typedef ConstantName Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(ConstantName::VT_NAME, name); + } + explicit ConstantNameBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantNameBuilder &operator=(const ConstantNameBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantName( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0) { + ConstantNameBuilder builder_(_fbb); + builder_.add_name(name); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateConstantNameDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return dml::ir::CreateConstantName( + _fbb, + name__); +} + +struct ConstantNodeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ConstantNodeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA_TYPE = 4, + VT_DATA = 6 + }; + dml::ir::ConstantNodeDescDetail data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const void *data() const { + return GetPointer(VT_DATA); + } + template const T *data_as() const; + const dml::ir::ConstantName *data_as_ConstantName() const { + return data_type() == dml::ir::ConstantNodeDescDetail_ConstantName ? static_cast(data()) : nullptr; + } + const dml::ir::ConstantRawData *data_as_ConstantRawData() const { + return data_type() == dml::ir::ConstantNodeDescDetail_ConstantRawData ? static_cast(data()) : nullptr; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DATA_TYPE) && + VerifyOffset(verifier, VT_DATA) && + VerifyConstantNodeDescDetail(verifier, data(), data_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::ConstantName *ConstantNodeDesc::data_as() const { + return data_as_ConstantName(); +} + +template<> inline const dml::ir::ConstantRawData *ConstantNodeDesc::data_as() const { + return data_as_ConstantRawData(); +} + +struct ConstantNodeDescBuilder { + typedef ConstantNodeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data_type(dml::ir::ConstantNodeDescDetail data_type) { + fbb_.AddElement(ConstantNodeDesc::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_data(flatbuffers::Offset data) { + fbb_.AddOffset(ConstantNodeDesc::VT_DATA, data); + } + explicit ConstantNodeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ConstantNodeDescBuilder &operator=(const ConstantNodeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateConstantNodeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::ConstantNodeDescDetail data_type = dml::ir::ConstantNodeDescDetail_NONE, + flatbuffers::Offset data = 0) { + ConstantNodeDescBuilder builder_(_fbb); + builder_.add_data(data); + builder_.add_data_type(data_type); + return builder_.Finish(); +} + +struct DmlBufferTensorDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlBufferTensorDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATATYPE = 4, + VT_SIZES = 6, + VT_STRIDES = 8, + VT_TOTALTENSORSIZEINBYTES = 10 + }; + const flatbuffers::String *dataType() const { + return GetPointer(VT_DATATYPE); + } + const flatbuffers::Vector *sizes() const { + return GetPointer *>(VT_SIZES); + } + const flatbuffers::Vector *strides() const { + return GetPointer *>(VT_STRIDES); + } + uint64_t totalTensorSizeInBytes() const { + return GetField(VT_TOTALTENSORSIZEINBYTES, 0); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATATYPE) && + verifier.VerifyString(dataType()) && + VerifyOffset(verifier, VT_SIZES) && + verifier.VerifyVector(sizes()) && + VerifyOffset(verifier, VT_STRIDES) && + verifier.VerifyVector(strides()) && + VerifyField(verifier, VT_TOTALTENSORSIZEINBYTES) && + verifier.EndTable(); + } +}; + +struct DmlBufferTensorDescBuilder { + typedef DmlBufferTensorDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_dataType(flatbuffers::Offset dataType) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_DATATYPE, dataType); + } + void add_sizes(flatbuffers::Offset> sizes) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_SIZES, sizes); + } + void add_strides(flatbuffers::Offset> strides) { + fbb_.AddOffset(DmlBufferTensorDesc::VT_STRIDES, strides); + } + void add_totalTensorSizeInBytes(uint64_t totalTensorSizeInBytes) { + fbb_.AddElement(DmlBufferTensorDesc::VT_TOTALTENSORSIZEINBYTES, totalTensorSizeInBytes, 0); + } + explicit DmlBufferTensorDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlBufferTensorDescBuilder &operator=(const DmlBufferTensorDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlBufferTensorDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset dataType = 0, + flatbuffers::Offset> sizes = 0, + flatbuffers::Offset> strides = 0, + uint64_t totalTensorSizeInBytes = 0) { + DmlBufferTensorDescBuilder builder_(_fbb); + builder_.add_totalTensorSizeInBytes(totalTensorSizeInBytes); + builder_.add_strides(strides); + builder_.add_sizes(sizes); + builder_.add_dataType(dataType); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlBufferTensorDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *dataType = nullptr, + const std::vector *sizes = nullptr, + const std::vector *strides = nullptr, + uint64_t totalTensorSizeInBytes = 0) { + auto dataType__ = dataType ? _fbb.CreateString(dataType) : 0; + auto sizes__ = sizes ? _fbb.CreateVector(*sizes) : 0; + auto strides__ = strides ? _fbb.CreateVector(*strides) : 0; + return dml::ir::CreateDmlBufferTensorDesc( + _fbb, + dataType__, + sizes__, + strides__, + totalTensorSizeInBytes); +} + +struct OperatorNodeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef OperatorNodeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4, + VT_INPUTS = 6, + VT_OUTPUTS = 8, + VT_ATTRIBUTES = 10 + }; + const flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + const flatbuffers::Vector> *inputs() const { + return GetPointer> *>(VT_INPUTS); + } + const flatbuffers::Vector> *outputs() const { + return GetPointer> *>(VT_OUTPUTS); + } + const flatbuffers::Vector> *attributes() const { + return GetPointer> *>(VT_ATTRIBUTES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyOffset(verifier, VT_INPUTS) && + verifier.VerifyVector(inputs()) && + verifier.VerifyVectorOfTables(inputs()) && + VerifyOffset(verifier, VT_OUTPUTS) && + verifier.VerifyVector(outputs()) && + verifier.VerifyVectorOfTables(outputs()) && + VerifyOffset(verifier, VT_ATTRIBUTES) && + verifier.VerifyVector(attributes()) && + verifier.VerifyVectorOfTables(attributes()) && + verifier.EndTable(); + } +}; + +struct OperatorNodeDescBuilder { + typedef OperatorNodeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type(flatbuffers::Offset type) { + fbb_.AddOffset(OperatorNodeDesc::VT_TYPE, type); + } + void add_inputs(flatbuffers::Offset>> inputs) { + fbb_.AddOffset(OperatorNodeDesc::VT_INPUTS, inputs); + } + void add_outputs(flatbuffers::Offset>> outputs) { + fbb_.AddOffset(OperatorNodeDesc::VT_OUTPUTS, outputs); + } + void add_attributes(flatbuffers::Offset>> attributes) { + fbb_.AddOffset(OperatorNodeDesc::VT_ATTRIBUTES, attributes); + } + explicit OperatorNodeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + OperatorNodeDescBuilder &operator=(const OperatorNodeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateOperatorNodeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset type = 0, + flatbuffers::Offset>> inputs = 0, + flatbuffers::Offset>> outputs = 0, + flatbuffers::Offset>> attributes = 0) { + OperatorNodeDescBuilder builder_(_fbb); + builder_.add_attributes(attributes); + builder_.add_outputs(outputs); + builder_.add_inputs(inputs); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateOperatorNodeDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *type = nullptr, + const std::vector> *inputs = nullptr, + const std::vector> *outputs = nullptr, + const std::vector> *attributes = nullptr) { + auto type__ = type ? _fbb.CreateString(type) : 0; + auto inputs__ = inputs ? _fbb.CreateVector>(*inputs) : 0; + auto outputs__ = outputs ? _fbb.CreateVector>(*outputs) : 0; + auto attributes__ = attributes ? _fbb.CreateVector>(*attributes) : 0; + return dml::ir::CreateOperatorNodeDesc( + _fbb, + type__, + inputs__, + outputs__, + attributes__); +} + +struct DmlGraphNode FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlGraphNodeBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DESC_TYPE = 4, + VT_DESC = 6, + VT_NAME = 8, + VT_INPUTNAMES = 10, + VT_OUTPUTNAMES = 12 + }; + dml::ir::NodeDesc desc_type() const { + return static_cast(GetField(VT_DESC_TYPE, 0)); + } + const void *desc() const { + return GetPointer(VT_DESC); + } + template const T *desc_as() const; + const dml::ir::OperatorNodeDesc *desc_as_OperatorNodeDesc() const { + return desc_type() == dml::ir::NodeDesc_OperatorNodeDesc ? static_cast(desc()) : nullptr; + } + const dml::ir::ConstantNodeDesc *desc_as_ConstantNodeDesc() const { + return desc_type() == dml::ir::NodeDesc_ConstantNodeDesc ? static_cast(desc()) : nullptr; + } + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + const flatbuffers::Vector> *inputNames() const { + return GetPointer> *>(VT_INPUTNAMES); + } + const flatbuffers::Vector> *outputNames() const { + return GetPointer> *>(VT_OUTPUTNAMES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DESC_TYPE) && + VerifyOffset(verifier, VT_DESC) && + VerifyNodeDesc(verifier, desc(), desc_type()) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyOffset(verifier, VT_INPUTNAMES) && + verifier.VerifyVector(inputNames()) && + verifier.VerifyVectorOfStrings(inputNames()) && + VerifyOffset(verifier, VT_OUTPUTNAMES) && + verifier.VerifyVector(outputNames()) && + verifier.VerifyVectorOfStrings(outputNames()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::OperatorNodeDesc *DmlGraphNode::desc_as() const { + return desc_as_OperatorNodeDesc(); +} + +template<> inline const dml::ir::ConstantNodeDesc *DmlGraphNode::desc_as() const { + return desc_as_ConstantNodeDesc(); +} + +struct DmlGraphNodeBuilder { + typedef DmlGraphNode Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_desc_type(dml::ir::NodeDesc desc_type) { + fbb_.AddElement(DmlGraphNode::VT_DESC_TYPE, static_cast(desc_type), 0); + } + void add_desc(flatbuffers::Offset desc) { + fbb_.AddOffset(DmlGraphNode::VT_DESC, desc); + } + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(DmlGraphNode::VT_NAME, name); + } + void add_inputNames(flatbuffers::Offset>> inputNames) { + fbb_.AddOffset(DmlGraphNode::VT_INPUTNAMES, inputNames); + } + void add_outputNames(flatbuffers::Offset>> outputNames) { + fbb_.AddOffset(DmlGraphNode::VT_OUTPUTNAMES, outputNames); + } + explicit DmlGraphNodeBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlGraphNodeBuilder &operator=(const DmlGraphNodeBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlGraphNode( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE, + flatbuffers::Offset desc = 0, + flatbuffers::Offset name = 0, + flatbuffers::Offset>> inputNames = 0, + flatbuffers::Offset>> outputNames = 0) { + DmlGraphNodeBuilder builder_(_fbb); + builder_.add_outputNames(outputNames); + builder_.add_inputNames(inputNames); + builder_.add_name(name); + builder_.add_desc(desc); + builder_.add_desc_type(desc_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlGraphNodeDirect( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::NodeDesc desc_type = dml::ir::NodeDesc_NONE, + flatbuffers::Offset desc = 0, + const char *name = nullptr, + const std::vector> *inputNames = nullptr, + const std::vector> *outputNames = nullptr) { + auto name__ = name ? _fbb.CreateString(name) : 0; + auto inputNames__ = inputNames ? _fbb.CreateVector>(*inputNames) : 0; + auto outputNames__ = outputNames ? _fbb.CreateVector>(*outputNames) : 0; + return dml::ir::CreateDmlGraphNode( + _fbb, + desc_type, + desc, + name__, + inputNames__, + outputNames__); +} + +struct DmlGraphDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef DmlGraphDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NODES = 4, + VT_GRAPHINPUTNAMES = 6, + VT_GRAPHOUTPUTNAMES = 8 + }; + const flatbuffers::Vector> *nodes() const { + return GetPointer> *>(VT_NODES); + } + const flatbuffers::Vector> *graphInputNames() const { + return GetPointer> *>(VT_GRAPHINPUTNAMES); + } + const flatbuffers::Vector> *graphOutputNames() const { + return GetPointer> *>(VT_GRAPHOUTPUTNAMES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NODES) && + verifier.VerifyVector(nodes()) && + verifier.VerifyVectorOfTables(nodes()) && + VerifyOffset(verifier, VT_GRAPHINPUTNAMES) && + verifier.VerifyVector(graphInputNames()) && + verifier.VerifyVectorOfStrings(graphInputNames()) && + VerifyOffset(verifier, VT_GRAPHOUTPUTNAMES) && + verifier.VerifyVector(graphOutputNames()) && + verifier.VerifyVectorOfStrings(graphOutputNames()) && + verifier.EndTable(); + } +}; + +struct DmlGraphDescBuilder { + typedef DmlGraphDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_nodes(flatbuffers::Offset>> nodes) { + fbb_.AddOffset(DmlGraphDesc::VT_NODES, nodes); + } + void add_graphInputNames(flatbuffers::Offset>> graphInputNames) { + fbb_.AddOffset(DmlGraphDesc::VT_GRAPHINPUTNAMES, graphInputNames); + } + void add_graphOutputNames(flatbuffers::Offset>> graphOutputNames) { + fbb_.AddOffset(DmlGraphDesc::VT_GRAPHOUTPUTNAMES, graphOutputNames); + } + explicit DmlGraphDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + DmlGraphDescBuilder &operator=(const DmlGraphDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateDmlGraphDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> nodes = 0, + flatbuffers::Offset>> graphInputNames = 0, + flatbuffers::Offset>> graphOutputNames = 0) { + DmlGraphDescBuilder builder_(_fbb); + builder_.add_graphOutputNames(graphOutputNames); + builder_.add_graphInputNames(graphInputNames); + builder_.add_nodes(nodes); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateDmlGraphDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *nodes = nullptr, + const std::vector> *graphInputNames = nullptr, + const std::vector> *graphOutputNames = nullptr) { + auto nodes__ = nodes ? _fbb.CreateVector>(*nodes) : 0; + auto graphInputNames__ = graphInputNames ? _fbb.CreateVector>(*graphInputNames) : 0; + auto graphOutputNames__ = graphOutputNames ? _fbb.CreateVector>(*graphOutputNames) : 0; + return dml::ir::CreateDmlGraphDesc( + _fbb, + nodes__, + graphInputNames__, + graphOutputNames__); +} + +inline bool VerifyConstantNodeDescDetail(flatbuffers::Verifier &verifier, const void *obj, ConstantNodeDescDetail type) { + switch (type) { + case ConstantNodeDescDetail_NONE: { + return true; + } + case ConstantNodeDescDetail_ConstantName: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case ConstantNodeDescDetail_ConstantRawData: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyConstantNodeDescDetailVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyConstantNodeDescDetail( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyNodeDesc(flatbuffers::Verifier &verifier, const void *obj, NodeDesc type) { + switch (type) { + case NodeDesc_NONE: { + return true; + } + case NodeDesc_OperatorNodeDesc: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case NodeDesc_ConstantNodeDesc: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + default: return true; + } +} + +inline bool VerifyNodeDescVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyNodeDesc( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline const dml::ir::DmlGraphDesc *GetDmlGraphDesc(const void *buf) { + return flatbuffers::GetRoot(buf); +} + +inline const dml::ir::DmlGraphDesc *GetSizePrefixedDmlGraphDesc(const void *buf) { + return flatbuffers::GetSizePrefixedRoot(buf); +} + +inline bool VerifyDmlGraphDescBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifyBuffer(nullptr); +} + +inline bool VerifySizePrefixedDmlGraphDescBuffer( + flatbuffers::Verifier &verifier) { + return verifier.VerifySizePrefixedBuffer(nullptr); +} + +inline void FinishDmlGraphDescBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.Finish(root); +} + +inline void FinishSizePrefixedDmlGraphDescBuffer( + flatbuffers::FlatBufferBuilder &fbb, + flatbuffers::Offset root) { + fbb.FinishSizePrefixed(root); +} + +} // namespace ir +} // namespace dml + +#endif // FLATBUFFERS_GENERATED_DMLGRAPHDESC_DML_IR_H_ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h new file mode 100644 index 0000000000000..9decf0dce1bb2 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphDeserialization.h @@ -0,0 +1,14 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "DmlSerializedGraphDesc.h" + +struct NodeIndex +{ + uint32_t nodeIndex; + uint32_t nodeOutputIndex; +}; + +DmlSerializedGraphDesc DeserializeDmlGraph( + const uint8_t* flatbufferGraphDescBlob, + /*out*/ std::vector>& rawData); \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h new file mode 100644 index 0000000000000..d8d069da906b7 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlGraphSerialization.h @@ -0,0 +1,8 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +#pragma once +#include "DmlGraphDesc_generated.h" + +struct DmlSerializedGraphDesc; + +flatbuffers::DetachedBuffer SerializeDmlGraph(const DmlSerializedGraphDesc& graphDesc); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h new file mode 100644 index 0000000000000..51c3d6c81244b --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/DmlSerializedGraphDesc.h @@ -0,0 +1,73 @@ +//----------------------------------------------------------------------------- +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// +//----------------------------------------------------------------------------- + +#pragma once + +struct ConstantName +{ + std::string name; +}; + +struct ConstantData +{ + std::byte* data; + uint64_t dataSize; +}; + +using DmlSerializedGraphNodeConstantVariant = std::variant< + ConstantName, + ConstantData +>; + +using DmlSerializedGraphNodeDescVariant = std::variant< + AbstractOperatorDesc, + DmlSerializedGraphNodeConstantVariant +>; + +struct DmlSerializedGraphNode +{ + DmlSerializedGraphNodeDescVariant Desc; + std::string Name; +}; + +struct DmlInputSerializedGraphEdge +{ + uint32_t GraphInputIndex; + uint32_t ToNodeIndex; + uint32_t ToNodeInputIndex; + std::string Name; +}; + +struct DmlOutputSerializedGraphEdge +{ + uint32_t FromNodeIndex; + uint32_t FromNodeOutputIndex; + uint32_t GraphOutputIndex; + std::string Name; +}; + +struct DmlIntermediateSerializedGraphEdge +{ + uint32_t FromNodeIndex; + uint32_t FromNodeOutputIndex; + uint32_t ToNodeIndex; + uint32_t ToNodeInputIndex; + std::string Name; +}; + +struct DmlSerializedGraphDesc +{ + uint32_t InputCount; + uint32_t OutputCount; + // nodes must be present in topological order for deserialization to work + // because while creating a intermediate edge during deserialization, node (from + // which given intermediate edge is outputting) must be visited before than the node + // (to which given intermediate edge is inputting) + std::vector Nodes; + std::vector InputEdges; + std::vector OutputEdges; + std::vector IntermediateEdges; +}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h index 99218c135f058..4be41ad3924a2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaHelpers.h @@ -425,7 +425,6 @@ inline std::vector GetFields(const DML_AVERAGE_POOLING_OPERATOR_D OperatorField(&DML_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.IncludePadding))), }; } - inline std::vector GetFields(const DML_AVERAGE_POOLING1_OPERATOR_DESC& desc) { return { @@ -502,24 +501,6 @@ inline std::vector GetFields(const DML_ROI_POOLING_OPERATOR_DESC& OperatorField(&DML_ROI_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.PooledSize))), }; } -inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) -{ - return { - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), - OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), - }; -} inline std::vector GetFields(const DML_SLICE_OPERATOR_DESC& desc) { return { @@ -1488,6 +1469,37 @@ inline std::vector GetFields(const DML_MULTIHEAD_ATTENTION_OPERAT OperatorField(&DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA.Fields[17], ToOperatorFieldType(static_cast(desc.MaskType))), }; } +inline std::vector GetFields(const DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.InputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.InputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.OutputScaleTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.OutputZeroPointTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.DimensionCount))), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.Strides), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[8], ToOperatorFieldType(static_cast(desc.WindowSize), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[9], ToOperatorFieldType(static_cast(desc.StartPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[10], ToOperatorFieldType(static_cast(desc.EndPadding), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[11], ToOperatorFieldType(static_cast(desc.Dilations), desc.DimensionCount)), + OperatorField(&DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA.Fields[12], ToOperatorFieldType(static_cast(desc.IncludePadding))), + }; +} +inline std::vector GetFields(const DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.ATensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.AScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.AZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.BTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[4], ToOperatorFieldType(static_cast(desc.BScaleTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[5], ToOperatorFieldType(static_cast(desc.BZeroPointTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[6], ToOperatorFieldType(static_cast(desc.BiasTensor))), + OperatorField(&DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA.Fields[7], ToOperatorFieldType(static_cast(desc.OutputTensor))), + }; +} inline std::vector GetFields(const DML_ACTIVATION_ELU_OPERATOR_DESC& desc) { return { @@ -1680,6 +1692,23 @@ inline std::vector GetFields(const DML_ACTIVATION_GELU_OPERATOR_D OperatorField(&DML_ACTIVATION_GELU_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), }; } +inline std::vector GetFields(const DML_ACTIVATION_SWISH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.SigmoidInputScale))), + }; +} +inline std::vector GetFields(const DML_ACTIVATION_HARD_SWISH_OPERATOR_DESC& desc) +{ + return { + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[0], ToOperatorFieldType(static_cast(desc.InputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[1], ToOperatorFieldType(static_cast(desc.OutputTensor))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[2], ToOperatorFieldType(static_cast(desc.Alpha))), + OperatorField(&DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA.Fields[3], ToOperatorFieldType(static_cast(desc.Beta))), + }; +} inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) { switch (operatorType) @@ -1826,6 +1855,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_RESAMPLE_GRAD1: return DML_RESAMPLE_GRAD1_OPERATOR_SCHEMA; case DML_OPERATOR_DIAGONAL_MATRIX1: return DML_DIAGONAL_MATRIX1_OPERATOR_SCHEMA; case DML_OPERATOR_MULTIHEAD_ATTENTION: return DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA; + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: return DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA; + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: return DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_ELU: return DML_ACTIVATION_ELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_CELU: return DML_ACTIVATION_CELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_HARDMAX: return DML_ACTIVATION_HARDMAX_OPERATOR_SCHEMA; @@ -1850,6 +1881,8 @@ inline const DML_OPERATOR_SCHEMA& GetSchema(DML_OPERATOR_TYPE operatorType) case DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU: return DML_ACTIVATION_THRESHOLDED_RELU_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_SHRINK: return DML_ACTIVATION_SHRINK_OPERATOR_SCHEMA; case DML_OPERATOR_ACTIVATION_GELU: return DML_ACTIVATION_GELU_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_SWISH: return DML_ACTIVATION_SWISH_OPERATOR_SCHEMA; + case DML_OPERATOR_ACTIVATION_HARD_SWISH: return DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA; default: ORT_THROW_HR(E_INVALIDARG); @@ -2431,6 +2464,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_MULTIHEAD_ATTENTION_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + return AbstractOperatorDesc( + &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT: + return AbstractOperatorDesc( + &DML_MATRIX_MULTIPLY_INTEGER_TO_FLOAT_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); case DML_OPERATOR_ACTIVATION_ELU: return AbstractOperatorDesc( &DML_ACTIVATION_ELU_OPERATOR_SCHEMA, @@ -2527,13 +2568,14 @@ inline AbstractOperatorDesc ConvertOperatorDesc(const DML_OPERATOR_DESC& opDesc) return AbstractOperatorDesc( &DML_ACTIVATION_GELU_OPERATOR_SCHEMA, GetFields(*static_cast(opDesc.Desc))); -#pragma warning(push) -#pragma warning(disable: 4063) - case DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING: + case DML_OPERATOR_ACTIVATION_SWISH: return AbstractOperatorDesc( - &DML_QUANTIZED_LINEAR_AVERAGE_POOLING_OPERATOR_SCHEMA, - GetFields(*static_cast(opDesc.Desc))); -#pragma warning(pop) + &DML_ACTIVATION_SWISH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); + case DML_OPERATOR_ACTIVATION_HARD_SWISH: + return AbstractOperatorDesc( + &DML_ACTIVATION_HARD_SWISH_OPERATOR_SCHEMA, + GetFields(*static_cast(opDesc.Desc))); default: ORT_THROW_HR(E_INVALIDARG); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h index 25f0dd26c6067..a94bb67b68d36 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/GeneratedSchemaTypes.h @@ -15,32 +15,34 @@ using ApiAttributeVariant = std::variant< const FLOAT*, const DML_SCALE_BIAS*, DML_SIZE_2D, - DML_SCALAR_UNION + DML_SCALAR_UNION, + BOOL >; namespace OperatorFieldTypes { using TensorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC using TensorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_TENSOR_DESC_ARRAY - using OperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC - using OperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY + using FusedActivationOperatorDesc = std::optional; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC + using FusedActivationOperatorDescArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_OPERATOR_DESC_ARRAY using UInt = uint32_t; // DML_SCHEMA_FIELD_TYPE_UINT using UInt64 = uint64_t; // DML_SCHEMA_FIELD_TYPE_UINT64 using Int = int32_t; // DML_SCHEMA_FIELD_TYPE_INT using Float = float; // DML_SCHEMA_FIELD_TYPE_FLOAT - using UIntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY - using IntArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY - using FloatArray = std::optional>; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY + using UIntArray = std::vector; // DML_SCHEMA_FIELD_TYPE_UINT_ARRAY + using IntArray = std::vector; // DML_SCHEMA_FIELD_TYPE_INT_ARRAY + using FloatArray = std::vector; // DML_SCHEMA_FIELD_TYPE_FLOAT_ARRAY using ScaleBias = std::optional; // DML_SCHEMA_FIELD_TYPE_SCALE_BIAS using Size2D = DML_SIZE_2D; // DML_SCHEMA_FIELD_TYPE_SIZE_2D using ScalarUnion = DML_SCALAR_UNION; // DML_SCHEMA_FIELD_TYPE_SCALAR_UNION + using Bool = bool; // DML_SCHEMA_FIELD_TYPE_BOOL } using OperatorFieldVariant = std::variant< OperatorFieldTypes::TensorDesc, OperatorFieldTypes::TensorDescArray, - OperatorFieldTypes::OperatorDesc, - OperatorFieldTypes::OperatorDescArray, + OperatorFieldTypes::FusedActivationOperatorDesc, + OperatorFieldTypes::FusedActivationOperatorDescArray, OperatorFieldTypes::UInt, OperatorFieldTypes::UInt64, OperatorFieldTypes::Int, @@ -50,7 +52,8 @@ using OperatorFieldVariant = std::variant< OperatorFieldTypes::FloatArray, OperatorFieldTypes::ScaleBias, OperatorFieldTypes::Size2D, - OperatorFieldTypes::ScalarUnion + OperatorFieldTypes::ScalarUnion, + OperatorFieldTypes::Bool >; class OperatorField @@ -80,11 +83,11 @@ class OperatorField const OperatorFieldTypes::TensorDescArray& AsTensorDescArray() const { return std::get(m_data); } OperatorFieldTypes::TensorDescArray& AsTensorDescArray() { return std::get(m_data); } - const OperatorFieldTypes::OperatorDesc& AsOperatorDesc() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDesc& AsOperatorDesc() { return std::get(m_data); } + const OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() const { return std::get(m_data); } + OperatorFieldTypes::FusedActivationOperatorDesc& AsFusedActivationOperatorDesc() { return std::get(m_data); } - const OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() const { return std::get(m_data); } - OperatorFieldTypes::OperatorDescArray& AsOperatorDescArray() { return std::get(m_data); } + const OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() const { return std::get(m_data); } + OperatorFieldTypes::FusedActivationOperatorDescArray& AsFusedActivationOperatorDescArray() { return std::get(m_data); } const OperatorFieldTypes::UInt& AsUInt() const { return std::get(m_data); } OperatorFieldTypes::UInt& AsUInt() { return std::get(m_data); } @@ -116,6 +119,9 @@ class OperatorField const OperatorFieldTypes::ScalarUnion& AsScalarUnion() const { return std::get(m_data); } OperatorFieldTypes::ScalarUnion& AsScalarUnion() { return std::get(m_data); } + const OperatorFieldTypes::Bool& AsBool() const { return std::get(m_data); } + OperatorFieldTypes::Bool& AsBool() { return std::get(m_data); } + private: const DML_SCHEMA_FIELD* m_schema; OperatorFieldVariant m_data; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h new file mode 100644 index 0000000000000..167a913bb0132 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/OperatorFieldTypes_generated.h @@ -0,0 +1,1318 @@ +// automatically generated by the FlatBuffers compiler, do not modify + + +#ifndef FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ +#define FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ + +#include "flatbuffers/flatbuffers.h" + +namespace dml { +namespace ir { +namespace operatorFieldTypes { + +struct AttributeDesc; +struct AttributeDescBuilder; + +struct Activation; +struct ActivationBuilder; + +struct ActivationArray; +struct ActivationArrayBuilder; + +struct UInt8; + +struct UInt16; + +struct UInt32; + +struct UInt64; + +struct Int8; + +struct Int16; + +struct Int32; + +struct Int64; + +struct Float32; + +struct Float64; + +struct UIntArray; +struct UIntArrayBuilder; + +struct IntArray; +struct IntArrayBuilder; + +struct FloatArray; +struct FloatArrayBuilder; + +struct ScaleBias; + +struct Size2D; + +struct ByteArray; + +struct ScalarUnionData; +struct ScalarUnionDataBuilder; + +struct Bool; + +enum AttributeFieldVariant { + AttributeFieldVariant_NONE = 0, + AttributeFieldVariant_Activation = 1, + AttributeFieldVariant_ActivationArray = 2, + AttributeFieldVariant_UInt32 = 3, + AttributeFieldVariant_UInt64 = 4, + AttributeFieldVariant_Int32 = 5, + AttributeFieldVariant_Float32 = 6, + AttributeFieldVariant_UIntArray = 7, + AttributeFieldVariant_IntArray = 8, + AttributeFieldVariant_FloatArray = 9, + AttributeFieldVariant_ScaleBias = 10, + AttributeFieldVariant_Size2D = 11, + AttributeFieldVariant_ScalarUnionData = 12, + AttributeFieldVariant_Bool = 13, + AttributeFieldVariant_MIN = AttributeFieldVariant_NONE, + AttributeFieldVariant_MAX = AttributeFieldVariant_Bool +}; + +inline const AttributeFieldVariant (&EnumValuesAttributeFieldVariant())[14] { + static const AttributeFieldVariant values[] = { + AttributeFieldVariant_NONE, + AttributeFieldVariant_Activation, + AttributeFieldVariant_ActivationArray, + AttributeFieldVariant_UInt32, + AttributeFieldVariant_UInt64, + AttributeFieldVariant_Int32, + AttributeFieldVariant_Float32, + AttributeFieldVariant_UIntArray, + AttributeFieldVariant_IntArray, + AttributeFieldVariant_FloatArray, + AttributeFieldVariant_ScaleBias, + AttributeFieldVariant_Size2D, + AttributeFieldVariant_ScalarUnionData, + AttributeFieldVariant_Bool + }; + return values; +} + +inline const char * const *EnumNamesAttributeFieldVariant() { + static const char * const names[15] = { + "NONE", + "Activation", + "ActivationArray", + "UInt32", + "UInt64", + "Int32", + "Float32", + "UIntArray", + "IntArray", + "FloatArray", + "ScaleBias", + "Size2D", + "ScalarUnionData", + "Bool", + nullptr + }; + return names; +} + +inline const char *EnumNameAttributeFieldVariant(AttributeFieldVariant e) { + if (flatbuffers::IsOutRange(e, AttributeFieldVariant_NONE, AttributeFieldVariant_Bool)) return ""; + const size_t index = static_cast(e); + return EnumNamesAttributeFieldVariant()[index]; +} + +template struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_NONE; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Activation; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ActivationArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UInt64; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Int32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Float32; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_UIntArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_IntArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_FloatArray; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScaleBias; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Size2D; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_ScalarUnionData; +}; + +template<> struct AttributeFieldVariantTraits { + static const AttributeFieldVariant enum_value = AttributeFieldVariant_Bool; +}; + +bool VerifyAttributeFieldVariant(flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type); +bool VerifyAttributeFieldVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +enum ScalarVariant { + ScalarVariant_NONE = 0, + ScalarVariant_ByteArray = 1, + ScalarVariant_Int8 = 2, + ScalarVariant_UInt8 = 3, + ScalarVariant_Int16 = 4, + ScalarVariant_UInt16 = 5, + ScalarVariant_Int32 = 6, + ScalarVariant_UInt32 = 7, + ScalarVariant_Int64 = 8, + ScalarVariant_UInt64 = 9, + ScalarVariant_Float32 = 10, + ScalarVariant_Float64 = 11, + ScalarVariant_MIN = ScalarVariant_NONE, + ScalarVariant_MAX = ScalarVariant_Float64 +}; + +inline const ScalarVariant (&EnumValuesScalarVariant())[12] { + static const ScalarVariant values[] = { + ScalarVariant_NONE, + ScalarVariant_ByteArray, + ScalarVariant_Int8, + ScalarVariant_UInt8, + ScalarVariant_Int16, + ScalarVariant_UInt16, + ScalarVariant_Int32, + ScalarVariant_UInt32, + ScalarVariant_Int64, + ScalarVariant_UInt64, + ScalarVariant_Float32, + ScalarVariant_Float64 + }; + return values; +} + +inline const char * const *EnumNamesScalarVariant() { + static const char * const names[13] = { + "NONE", + "ByteArray", + "Int8", + "UInt8", + "Int16", + "UInt16", + "Int32", + "UInt32", + "Int64", + "UInt64", + "Float32", + "Float64", + nullptr + }; + return names; +} + +inline const char *EnumNameScalarVariant(ScalarVariant e) { + if (flatbuffers::IsOutRange(e, ScalarVariant_NONE, ScalarVariant_Float64)) return ""; + const size_t index = static_cast(e); + return EnumNamesScalarVariant()[index]; +} + +template struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_NONE; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_ByteArray; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int8; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt8; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int16; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt16; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Int64; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_UInt64; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Float32; +}; + +template<> struct ScalarVariantTraits { + static const ScalarVariant enum_value = ScalarVariant_Float64; +}; + +bool VerifyScalarVariant(flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type); +bool VerifyScalarVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) UInt8 FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_; + + public: + UInt8() { + memset(static_cast(this), 0, sizeof(UInt8)); + } + UInt8(uint8_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint8_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint8_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt8, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) UInt16 FLATBUFFERS_FINAL_CLASS { + private: + uint16_t data_; + + public: + UInt16() { + memset(static_cast(this), 0, sizeof(UInt16)); + } + UInt16(uint16_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint16_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint16_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt16, 2); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) UInt32 FLATBUFFERS_FINAL_CLASS { + private: + uint32_t data_; + + public: + UInt32() { + memset(static_cast(this), 0, sizeof(UInt32)); + } + UInt32(uint32_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint32_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint32_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) UInt64 FLATBUFFERS_FINAL_CLASS { + private: + uint64_t data_; + + public: + UInt64() { + memset(static_cast(this), 0, sizeof(UInt64)); + } + UInt64(uint64_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + uint64_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(uint64_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(UInt64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Int8 FLATBUFFERS_FINAL_CLASS { + private: + int8_t data_; + + public: + Int8() { + memset(static_cast(this), 0, sizeof(Int8)); + } + Int8(int8_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int8_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int8_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int8, 1); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(2) Int16 FLATBUFFERS_FINAL_CLASS { + private: + int16_t data_; + + public: + Int16() { + memset(static_cast(this), 0, sizeof(Int16)); + } + Int16(int16_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int16_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int16_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int16, 2); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Int32 FLATBUFFERS_FINAL_CLASS { + private: + int32_t data_; + + public: + Int32() { + memset(static_cast(this), 0, sizeof(Int32)); + } + Int32(int32_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int32_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int32_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Int64 FLATBUFFERS_FINAL_CLASS { + private: + int64_t data_; + + public: + Int64() { + memset(static_cast(this), 0, sizeof(Int64)); + } + Int64(int64_t _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + int64_t data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(int64_t _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Int64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Float32 FLATBUFFERS_FINAL_CLASS { + private: + float data_; + + public: + Float32() { + memset(static_cast(this), 0, sizeof(Float32)); + } + Float32(float _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + float data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(float _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Float32, 4); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(8) Float64 FLATBUFFERS_FINAL_CLASS { + private: + double data_; + + public: + Float64() { + memset(static_cast(this), 0, sizeof(Float64)); + } + Float64(double _data) + : data_(flatbuffers::EndianScalar(_data)) { + } + double data() const { + return flatbuffers::EndianScalar(data_); + } + void mutate_data(double _data) { + flatbuffers::WriteScalar(&data_, _data); + } +}; +FLATBUFFERS_STRUCT_END(Float64, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) ScaleBias FLATBUFFERS_FINAL_CLASS { + private: + float scale_; + float bias_; + + public: + ScaleBias() { + memset(static_cast(this), 0, sizeof(ScaleBias)); + } + ScaleBias(float _scale, float _bias) + : scale_(flatbuffers::EndianScalar(_scale)), + bias_(flatbuffers::EndianScalar(_bias)) { + } + float scale() const { + return flatbuffers::EndianScalar(scale_); + } + void mutate_scale(float _scale) { + flatbuffers::WriteScalar(&scale_, _scale); + } + float bias() const { + return flatbuffers::EndianScalar(bias_); + } + void mutate_bias(float _bias) { + flatbuffers::WriteScalar(&bias_, _bias); + } +}; +FLATBUFFERS_STRUCT_END(ScaleBias, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(4) Size2D FLATBUFFERS_FINAL_CLASS { + private: + uint32_t width_; + uint32_t height_; + + public: + Size2D() { + memset(static_cast(this), 0, sizeof(Size2D)); + } + Size2D(uint32_t _width, uint32_t _height) + : width_(flatbuffers::EndianScalar(_width)), + height_(flatbuffers::EndianScalar(_height)) { + } + uint32_t width() const { + return flatbuffers::EndianScalar(width_); + } + void mutate_width(uint32_t _width) { + flatbuffers::WriteScalar(&width_, _width); + } + uint32_t height() const { + return flatbuffers::EndianScalar(height_); + } + void mutate_height(uint32_t _height) { + flatbuffers::WriteScalar(&height_, _height); + } +}; +FLATBUFFERS_STRUCT_END(Size2D, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) ByteArray FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_[8]; + + public: + ByteArray() { + memset(static_cast(this), 0, sizeof(ByteArray)); + } + const flatbuffers::Array *data() const { + return reinterpret_cast *>(data_); + } + flatbuffers::Array *mutable_data() { + return reinterpret_cast *>(data_); + } +}; +FLATBUFFERS_STRUCT_END(ByteArray, 8); + +FLATBUFFERS_MANUALLY_ALIGNED_STRUCT(1) Bool FLATBUFFERS_FINAL_CLASS { + private: + uint8_t data_; + + public: + Bool() { + memset(static_cast(this), 0, sizeof(Bool)); + } + Bool(bool _data) + : data_(flatbuffers::EndianScalar(static_cast(_data))) { + } + bool data() const { + return flatbuffers::EndianScalar(data_) != 0; + } + void mutate_data(bool _data) { + flatbuffers::WriteScalar(&data_, static_cast(_data)); + } +}; +FLATBUFFERS_STRUCT_END(Bool, 1); + +struct AttributeDesc FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef AttributeDescBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_NAME = 4, + VT_VAL_TYPE = 6, + VT_VAL = 8 + }; + const flatbuffers::String *name() const { + return GetPointer(VT_NAME); + } + flatbuffers::String *mutable_name() { + return GetPointer(VT_NAME); + } + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type() const { + return static_cast(GetField(VT_VAL_TYPE, 0)); + } + const void *val() const { + return GetPointer(VT_VAL); + } + template const T *val_as() const; + const dml::ir::operatorFieldTypes::Activation *val_as_Activation() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Activation ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ActivationArray *val_as_ActivationArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ActivationArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt32 *val_as_UInt32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt64 *val_as_UInt64() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UInt64 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int32 *val_as_Int32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Int32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float32 *val_as_Float32() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Float32 ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::UIntArray *val_as_UIntArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_UIntArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::IntArray *val_as_IntArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_IntArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::FloatArray *val_as_FloatArray() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_FloatArray ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ScaleBias *val_as_ScaleBias() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScaleBias ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Size2D *val_as_Size2D() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Size2D ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::ScalarUnionData *val_as_ScalarUnionData() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_ScalarUnionData ? static_cast(val()) : nullptr; + } + const dml::ir::operatorFieldTypes::Bool *val_as_Bool() const { + return val_type() == dml::ir::operatorFieldTypes::AttributeFieldVariant_Bool ? static_cast(val()) : nullptr; + } + void *mutable_val() { + return GetPointer(VT_VAL); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_NAME) && + verifier.VerifyString(name()) && + VerifyField(verifier, VT_VAL_TYPE) && + VerifyOffset(verifier, VT_VAL) && + VerifyAttributeFieldVariant(verifier, val(), val_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::operatorFieldTypes::Activation *AttributeDesc::val_as() const { + return val_as_Activation(); +} + +template<> inline const dml::ir::operatorFieldTypes::ActivationArray *AttributeDesc::val_as() const { + return val_as_ActivationArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt32 *AttributeDesc::val_as() const { + return val_as_UInt32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt64 *AttributeDesc::val_as() const { + return val_as_UInt64(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int32 *AttributeDesc::val_as() const { + return val_as_Int32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float32 *AttributeDesc::val_as() const { + return val_as_Float32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UIntArray *AttributeDesc::val_as() const { + return val_as_UIntArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::IntArray *AttributeDesc::val_as() const { + return val_as_IntArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::FloatArray *AttributeDesc::val_as() const { + return val_as_FloatArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::ScaleBias *AttributeDesc::val_as() const { + return val_as_ScaleBias(); +} + +template<> inline const dml::ir::operatorFieldTypes::Size2D *AttributeDesc::val_as() const { + return val_as_Size2D(); +} + +template<> inline const dml::ir::operatorFieldTypes::ScalarUnionData *AttributeDesc::val_as() const { + return val_as_ScalarUnionData(); +} + +template<> inline const dml::ir::operatorFieldTypes::Bool *AttributeDesc::val_as() const { + return val_as_Bool(); +} + +struct AttributeDescBuilder { + typedef AttributeDesc Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_name(flatbuffers::Offset name) { + fbb_.AddOffset(AttributeDesc::VT_NAME, name); + } + void add_val_type(dml::ir::operatorFieldTypes::AttributeFieldVariant val_type) { + fbb_.AddElement(AttributeDesc::VT_VAL_TYPE, static_cast(val_type), 0); + } + void add_val(flatbuffers::Offset val) { + fbb_.AddOffset(AttributeDesc::VT_VAL, val); + } + explicit AttributeDescBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + AttributeDescBuilder &operator=(const AttributeDescBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateAttributeDesc( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset name = 0, + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE, + flatbuffers::Offset val = 0) { + AttributeDescBuilder builder_(_fbb); + builder_.add_val(val); + builder_.add_name(name); + builder_.add_val_type(val_type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateAttributeDescDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *name = nullptr, + dml::ir::operatorFieldTypes::AttributeFieldVariant val_type = dml::ir::operatorFieldTypes::AttributeFieldVariant_NONE, + flatbuffers::Offset val = 0) { + auto name__ = name ? _fbb.CreateString(name) : 0; + return dml::ir::operatorFieldTypes::CreateAttributeDesc( + _fbb, + name__, + val_type, + val); +} + +struct Activation FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ActivationBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_TYPE = 4, + VT_ATTRIBUTES = 6 + }; + const flatbuffers::String *type() const { + return GetPointer(VT_TYPE); + } + flatbuffers::String *mutable_type() { + return GetPointer(VT_TYPE); + } + const flatbuffers::Vector> *attributes() const { + return GetPointer> *>(VT_ATTRIBUTES); + } + flatbuffers::Vector> *mutable_attributes() { + return GetPointer> *>(VT_ATTRIBUTES); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_TYPE) && + verifier.VerifyString(type()) && + VerifyOffset(verifier, VT_ATTRIBUTES) && + verifier.VerifyVector(attributes()) && + verifier.VerifyVectorOfTables(attributes()) && + verifier.EndTable(); + } +}; + +struct ActivationBuilder { + typedef Activation Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_type(flatbuffers::Offset type) { + fbb_.AddOffset(Activation::VT_TYPE, type); + } + void add_attributes(flatbuffers::Offset>> attributes) { + fbb_.AddOffset(Activation::VT_ATTRIBUTES, attributes); + } + explicit ActivationBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ActivationBuilder &operator=(const ActivationBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateActivation( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset type = 0, + flatbuffers::Offset>> attributes = 0) { + ActivationBuilder builder_(_fbb); + builder_.add_attributes(attributes); + builder_.add_type(type); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateActivationDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const char *type = nullptr, + const std::vector> *attributes = nullptr) { + auto type__ = type ? _fbb.CreateString(type) : 0; + auto attributes__ = attributes ? _fbb.CreateVector>(*attributes) : 0; + return dml::ir::operatorFieldTypes::CreateActivation( + _fbb, + type__, + attributes__); +} + +struct ActivationArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ActivationArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector> *data() const { + return GetPointer> *>(VT_DATA); + } + flatbuffers::Vector> *mutable_data() { + return GetPointer> *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.VerifyVectorOfTables(data()) && + verifier.EndTable(); + } +}; + +struct ActivationArrayBuilder { + typedef ActivationArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset>> data) { + fbb_.AddOffset(ActivationArray::VT_DATA, data); + } + explicit ActivationArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ActivationArrayBuilder &operator=(const ActivationArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateActivationArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset>> data = 0) { + ActivationArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateActivationArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector> *data = nullptr) { + auto data__ = data ? _fbb.CreateVector>(*data) : 0; + return dml::ir::operatorFieldTypes::CreateActivationArray( + _fbb, + data__); +} + +struct UIntArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef UIntArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct UIntArrayBuilder { + typedef UIntArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(UIntArray::VT_DATA, data); + } + explicit UIntArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + UIntArrayBuilder &operator=(const UIntArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateUIntArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + UIntArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateUIntArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateUIntArray( + _fbb, + data__); +} + +struct IntArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef IntArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct IntArrayBuilder { + typedef IntArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(IntArray::VT_DATA, data); + } + explicit IntArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + IntArrayBuilder &operator=(const IntArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateIntArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + IntArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateIntArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateIntArray( + _fbb, + data__); +} + +struct FloatArray FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef FloatArrayBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA = 4 + }; + const flatbuffers::Vector *data() const { + return GetPointer *>(VT_DATA); + } + flatbuffers::Vector *mutable_data() { + return GetPointer *>(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyOffset(verifier, VT_DATA) && + verifier.VerifyVector(data()) && + verifier.EndTable(); + } +}; + +struct FloatArrayBuilder { + typedef FloatArray Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data(flatbuffers::Offset> data) { + fbb_.AddOffset(FloatArray::VT_DATA, data); + } + explicit FloatArrayBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + FloatArrayBuilder &operator=(const FloatArrayBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateFloatArray( + flatbuffers::FlatBufferBuilder &_fbb, + flatbuffers::Offset> data = 0) { + FloatArrayBuilder builder_(_fbb); + builder_.add_data(data); + return builder_.Finish(); +} + +inline flatbuffers::Offset CreateFloatArrayDirect( + flatbuffers::FlatBufferBuilder &_fbb, + const std::vector *data = nullptr) { + auto data__ = data ? _fbb.CreateVector(*data) : 0; + return dml::ir::operatorFieldTypes::CreateFloatArray( + _fbb, + data__); +} + +struct ScalarUnionData FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef ScalarUnionDataBuilder Builder; + enum FlatBuffersVTableOffset FLATBUFFERS_VTABLE_UNDERLYING_TYPE { + VT_DATA_TYPE = 4, + VT_DATA = 6 + }; + dml::ir::operatorFieldTypes::ScalarVariant data_type() const { + return static_cast(GetField(VT_DATA_TYPE, 0)); + } + const void *data() const { + return GetPointer(VT_DATA); + } + template const T *data_as() const; + const dml::ir::operatorFieldTypes::ByteArray *data_as_ByteArray() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_ByteArray ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int8 *data_as_Int8() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int8 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt8 *data_as_UInt8() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt8 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int16 *data_as_Int16() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int16 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt16 *data_as_UInt16() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt16 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int32 *data_as_Int32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt32 *data_as_UInt32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Int64 *data_as_Int64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Int64 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::UInt64 *data_as_UInt64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_UInt64 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float32 *data_as_Float32() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float32 ? static_cast(data()) : nullptr; + } + const dml::ir::operatorFieldTypes::Float64 *data_as_Float64() const { + return data_type() == dml::ir::operatorFieldTypes::ScalarVariant_Float64 ? static_cast(data()) : nullptr; + } + void *mutable_data() { + return GetPointer(VT_DATA); + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField(verifier, VT_DATA_TYPE) && + VerifyOffset(verifier, VT_DATA) && + VerifyScalarVariant(verifier, data(), data_type()) && + verifier.EndTable(); + } +}; + +template<> inline const dml::ir::operatorFieldTypes::ByteArray *ScalarUnionData::data_as() const { + return data_as_ByteArray(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int8 *ScalarUnionData::data_as() const { + return data_as_Int8(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt8 *ScalarUnionData::data_as() const { + return data_as_UInt8(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int16 *ScalarUnionData::data_as() const { + return data_as_Int16(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt16 *ScalarUnionData::data_as() const { + return data_as_UInt16(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int32 *ScalarUnionData::data_as() const { + return data_as_Int32(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt32 *ScalarUnionData::data_as() const { + return data_as_UInt32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Int64 *ScalarUnionData::data_as() const { + return data_as_Int64(); +} + +template<> inline const dml::ir::operatorFieldTypes::UInt64 *ScalarUnionData::data_as() const { + return data_as_UInt64(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float32 *ScalarUnionData::data_as() const { + return data_as_Float32(); +} + +template<> inline const dml::ir::operatorFieldTypes::Float64 *ScalarUnionData::data_as() const { + return data_as_Float64(); +} + +struct ScalarUnionDataBuilder { + typedef ScalarUnionData Table; + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_data_type(dml::ir::operatorFieldTypes::ScalarVariant data_type) { + fbb_.AddElement(ScalarUnionData::VT_DATA_TYPE, static_cast(data_type), 0); + } + void add_data(flatbuffers::Offset data) { + fbb_.AddOffset(ScalarUnionData::VT_DATA, data); + } + explicit ScalarUnionDataBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + ScalarUnionDataBuilder &operator=(const ScalarUnionDataBuilder &); + flatbuffers::Offset Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset(end); + return o; + } +}; + +inline flatbuffers::Offset CreateScalarUnionData( + flatbuffers::FlatBufferBuilder &_fbb, + dml::ir::operatorFieldTypes::ScalarVariant data_type = dml::ir::operatorFieldTypes::ScalarVariant_NONE, + flatbuffers::Offset data = 0) { + ScalarUnionDataBuilder builder_(_fbb); + builder_.add_data(data); + builder_.add_data_type(data_type); + return builder_.Finish(); +} + +inline bool VerifyAttributeFieldVariant(flatbuffers::Verifier &verifier, const void *obj, AttributeFieldVariant type) { + switch (type) { + case AttributeFieldVariant_NONE: { + return true; + } + case AttributeFieldVariant_Activation: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_ActivationArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_UInt32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_UInt64: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Int32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Float32: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_UIntArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_IntArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_FloatArray: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_ScaleBias: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_Size2D: { + return verifier.Verify(static_cast(obj), 0); + } + case AttributeFieldVariant_ScalarUnionData: { + auto ptr = reinterpret_cast(obj); + return verifier.VerifyTable(ptr); + } + case AttributeFieldVariant_Bool: { + return verifier.Verify(static_cast(obj), 0); + } + default: return true; + } +} + +inline bool VerifyAttributeFieldVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyAttributeFieldVariant( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +inline bool VerifyScalarVariant(flatbuffers::Verifier &verifier, const void *obj, ScalarVariant type) { + switch (type) { + case ScalarVariant_NONE: { + return true; + } + case ScalarVariant_ByteArray: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int8: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt8: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int16: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt16: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Int64: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_UInt64: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Float32: { + return verifier.Verify(static_cast(obj), 0); + } + case ScalarVariant_Float64: { + return verifier.Verify(static_cast(obj), 0); + } + default: return true; + } +} + +inline bool VerifyScalarVariantVector(flatbuffers::Verifier &verifier, const flatbuffers::Vector> *values, const flatbuffers::Vector *types) { + if (!values || !types) return !values && !types; + if (values->size() != types->size()) return false; + for (flatbuffers::uoffset_t i = 0; i < values->size(); ++i) { + if (!VerifyScalarVariant( + verifier, values->Get(i), types->GetEnum(i))) { + return false; + } + } + return true; +} + +} // namespace operatorFieldTypes +} // namespace ir +} // namespace dml + +#endif // FLATBUFFERS_GENERATED_OPERATORFIELDTYPES_DML_IR_OPERATORFIELDTYPES_H_ diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h index 5285481485184..1bc694dfe90c2 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/External/DirectMLHelpers/SchemaHelpers.h @@ -26,14 +26,14 @@ namespace SchemaHelpers return field; } - inline OperatorFieldTypes::OperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) + inline OperatorFieldTypes::FusedActivationOperatorDesc ToOperatorFieldType(const DML_OPERATOR_DESC* value) { - return value ? OperatorFieldTypes::OperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; + return value ? OperatorFieldTypes::FusedActivationOperatorDesc(ConvertOperatorDesc(*value)) : std::nullopt; } - inline OperatorFieldTypes::OperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) + inline OperatorFieldTypes::FusedActivationOperatorDescArray ToOperatorFieldType(const DML_OPERATOR_DESC* values, uint32_t count) { - OperatorFieldTypes::OperatorDescArray field; + OperatorFieldTypes::FusedActivationOperatorDescArray field; if (values && count != 0) { field.emplace(count); @@ -65,13 +65,17 @@ namespace SchemaHelpers return value; } + inline OperatorFieldTypes::Bool ToOperatorFieldType(bool value) + { + return value; + } + inline OperatorFieldTypes::UIntArray ToOperatorFieldType(const uint32_t* values, uint32_t count) { OperatorFieldTypes::UIntArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -81,8 +85,7 @@ namespace SchemaHelpers OperatorFieldTypes::IntArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -92,8 +95,7 @@ namespace SchemaHelpers OperatorFieldTypes::FloatArray field; if (values && count != 0) { - field.emplace(count); - std::copy_n(values, count, field->begin()); + field.assign(values, values + count); } return field; } @@ -237,7 +239,7 @@ namespace SchemaHelpers { DML_OPERATOR_DESC* desc = nullptr; - const auto& value = field.AsOperatorDesc(); + const auto& value = field.AsFusedActivationOperatorDesc(); if (value) { desc = allocator->template Allocate(); @@ -251,7 +253,7 @@ namespace SchemaHelpers { DML_OPERATOR_DESC* descs = nullptr; - const auto& values = field.AsOperatorDescArray(); + const auto& values = field.AsFusedActivationOperatorDescArray(); if (values) { descs = allocator->template Allocate(values->size()); @@ -288,16 +290,20 @@ namespace SchemaHelpers dst->Write(value); } break; + case DML_SCHEMA_FIELD_TYPE_BOOL: + { + // OperatorFieldTypes::Bool is a 'bool' (1 byte) but written as 'BOOL' in op descs (4 bytes). + BOOL value = static_cast(field.AsBool()); + dst->Write(value); + } break; + case DML_SCHEMA_FIELD_TYPE_UINT_ARRAY: { uint32_t* arrayPtr = nullptr; const auto& values = field.AsUIntArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; @@ -307,11 +313,8 @@ namespace SchemaHelpers int32_t* arrayPtr = nullptr; const auto& values = field.AsIntArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; @@ -321,11 +324,8 @@ namespace SchemaHelpers float* arrayPtr = nullptr; const auto& values = field.AsFloatArray(); - if (values) - { - arrayPtr = allocator->template Allocate(values->size()); - std::copy(values->begin(), values->end(), arrayPtr); - } + arrayPtr = allocator->template Allocate(values.size()); + std::copy(values.begin(), values.end(), arrayPtr); dst->Write(arrayPtr); } break; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 2456b396de3f6..e6f008af5c23f 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -33,10 +33,10 @@ namespace Dml::GraphDescBuilder #pragma warning(pop) static void RemoveUnconnectedNodes( - std::vector& graphNodes, - std::vector& graphInputEdges, - std::vector& graphIntermediateEdges, - std::vector& graphOutputEdges) + std::vector& graphNodes, + std::vector& graphInputEdges, + std::vector& graphIntermediateEdges, + std::vector& graphOutputEdges) { enum class NodeState { @@ -52,7 +52,7 @@ namespace Dml::GraphDescBuilder }; std::vector nodesData(graphNodes.size()); - for (const DML_INTERMEDIATE_GRAPH_EDGE_DESC& intermediateEdge : graphIntermediateEdges) + for (const DmlIntermediateSerializedGraphEdge& intermediateEdge : graphIntermediateEdges) { nodesData[intermediateEdge.ToNodeIndex].predecessorIndices.push_back(intermediateEdge.FromNodeIndex); } @@ -60,7 +60,7 @@ namespace Dml::GraphDescBuilder std::stack nodeIndicesToVisit; // Start from the outputs of the graph and traverse upwards - for (const DML_OUTPUT_GRAPH_EDGE_DESC& outputEdge : graphOutputEdges) + for (const DmlOutputSerializedGraphEdge& outputEdge : graphOutputEdges) { nodeIndicesToVisit.push(outputEdge.FromNodeIndex); } @@ -143,17 +143,44 @@ namespace Dml::GraphDescBuilder } } + + uint32_t SetAndGetDmlGraphNodeIndex( + const uint32_t operatorDmlGraphNodeIndex, + const std::string& nodeNamePrefix, + AbstractOperatorDesc& operatorDesc, + /*in_out*/std::unordered_map& operatorDmlGraphToDmlGraphNodeIndexMap, + /*in_out*/std::vector& dmlGraphNodes) + { + auto iter = operatorDmlGraphToDmlGraphNodeIndexMap.find(operatorDmlGraphNodeIndex); + if (iter != operatorDmlGraphToDmlGraphNodeIndexMap.end()) + { + return iter->second; + } + operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex] = static_cast(dmlGraphNodes.size()); + dmlGraphNodes.push_back({operatorDesc, nodeNamePrefix + std::to_string(operatorDmlGraphNodeIndex)}); + return operatorDmlGraphToDmlGraphNodeIndexMap[operatorDmlGraphNodeIndex]; + } + + // Terminology: + // Subgraph: partitioned ONNX graph from the original (main) ONNX graph + // DmlGraph: a graph in DML currency converted from subgraph. + // operatorDmlGraph: a graph in DML currency for a given node or operator + // Main Points to note: + // - GraphDesc will always has sequential indices for input and intermediate edges. + // - 1 onnx node can be converted to one or more dml nodes. GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, - IDMLDevice* device, const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, - gsl::span subgraphOutputs) + gsl::span subgraphOutputs, + /*out*/ std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + /*out*/ std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex, + /*out*/ std::vector>& smallConstantData) { struct NodeAndIndex { @@ -161,19 +188,34 @@ namespace Dml::GraphDescBuilder uint32_t targetIndex; // The index of the input/output on the node (e.g. 1 for the second input on a node) }; - // Map from Lotus node argument names to the new node and index where it will be produced - std::unordered_map nameToNodeAndIndexMap; - std::unordered_map nodeOutputShapes; - // Map from Lotus node argument names to input indices of the fused kernel node. - std::unordered_map nameToDmlFusedNodeInputIndex; + // Map from ORT subgraph input names to indices + std::unordered_map subgraphInputNameToIndexMap; + + // - Map from ORT node's output names to DmlGraph . + // - Once a given ORT node (or operator) will be transformed into a operatorDmlGraph, + // then ORT node's output names will become output edges for the operatorDmlGraph. + // - This map will be populated for those output edges. + std::unordered_map dmlGraphNodeOutputNameToNodeAndIndexMap; + + // This map will be used to re-index an subGraphInputIndex to sequential input index + // for DmlGraph + std::unordered_map subGraphInputIndexToDmlGraphInputIndex; + + // Iterate through each node and create a corresponding node in the new graph + // We can iterate the nodes in any order because the edge connectivity will take care of the topological order + std::unordered_map> inferredOutputShapes; + + std::vector dmlGraphNodes; + std::vector dmlGraphInputEdges; + std::vector dmlGraphIntermediateEdges; + std::vector dmlGraphOutputEdges; for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; - - if (!graphInput) + const onnxruntime::NodeArg* subgraphInput = subgraphInputs[inputIndex]; + if (!subgraphInput) { // This is a workaround for when node inputs get manipulated by transformers outside of our control, // which then causes them to have a different name. If that happens we can't figure out how to @@ -181,45 +223,21 @@ namespace Dml::GraphDescBuilder // just bail early. ORT_THROW_HR(E_UNEXPECTED); } - - nameToDmlFusedNodeInputIndex.emplace(graphInput->Name(), gsl::narrow_cast(inputIndex)); - } - - StackAllocator<1024> allocator; // Used for converting abstract operator descs into DML_OPERATOR_DESC - - std::vector graphNodes; - std::vector graphInputEdges; - std::vector graphIntermediateEdges; - std::vector graphOutputEdges; - - // Avoid using separate command lists for small graphs. This value can be reduced by tuning the - // flushing behavior of DmlCommandRecorder. Its current behavior is to assume that graphs contain - // enough GPU work to be worth flushing immediately. - const uint32_t minNodeCountToReuseCommandList = 5; - bool reuseCommandList = false; - - if (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()) - { - reuseCommandList = true; + subgraphInputNameToIndexMap.emplace(subgraphInput->Name(), gsl::narrow_cast(inputIndex)); } auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; - auto iter = isInitializerTransferable.find(argName); if (iter != isInitializerTransferable.end()) { // Using const_cast here is simpler than making surrounding code const correct. tensorWrapper = wil::MakeOrThrow(const_cast(iter->second.first), modelPath); } - return tensorWrapper; }; - // Iterate through each node and create a corresponding node in the new graph - // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - std::unordered_map> inferredOutputShapes; for (const onnxruntime::Node* subgraphNode : subgraphNodes) { @@ -277,195 +295,206 @@ namespace Dml::GraphDescBuilder } EdgeShapes outputShapes; - DmlGraphNodeCreateInfo graphNodeCreateInfo; + DmlGraphNodeCreateInfo operatorDmlGraphCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, &inputShapesOverrides, /*out*/ &outputShapes, - /*out*/ &graphNodeCreateInfo + /*out*/ &operatorDmlGraphCreateInfo ); ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); for (int i = 0; i < node.OutputDefs().size(); ++i) { inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); - } - - // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. - std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; - uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); - const bool isNodeAsOpDesc = graphNodeCreateInfo.nodesAsOperatorDesc.size() > 0; - size_t firstOpDescGraphNodeIndex = graphNodes.size(); - - if (isNodeAsOpDesc) + } + + // Algorithm: + // 1. Create constant nodes by iterating through operatorDmlGraph's input edges and keep a map of it, + // because there would be an intermediate edge from the constantNode and source of the intermediate edge + // should come before the destination. + // 2. Again iterate through operatorDmlGraph's input edges to create mainGraph's input and intermediate edges. + // 3. Iterate through operatorDmlGraph's intermediate edges to create mainGraph's intermediate edges. + // 4. Iterate through operatorDmlGraph's output edges to populate outputEdgeNameToDmlGraphNodeAndIndex + // 5. While performing step 2, 3, and 4, insert operatorDmlGraphNode to the mainDmlGraphNode list. + + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) { - // Can't populate graphNodes vector at this point, because operatorDesc may get modified later. - for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++) + const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; + if (arg->Exists()) { - ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsOperatorDesc[nodeIndex]); - operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); - } + auto iter = subgraphInputNameToIndexMap.find(arg->Name()); + if (iter != subgraphInputNameToIndexMap.end() && + iter->second < isConstGpuGraphInputCount && + isConstGpuGraphInput[iter->second]) + { + DmlSerializedGraphNode constantNode = {}; + constantNode.Name = arg->Name(); + + // This is a highly inefficient approach to generating constant nodes. It duplicates constant data + // across the graph input as well as every consumer's unique constant node. However it is currently + // only used for small inputs. + auto& operatorDmlGraphInputNode = operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex]; + std::vector toNodeInputTensorDescs = operatorDmlGraphInputNode->GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex]; + ComPtr constantInput; + + if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) + { + constantInput = constantCpuGraphInputGetter(arg->Name()); + } - graphNodes.resize(graphNodes.size() + graphNodeCreateInfo.nodeCount); - } - else - { - for (uint32_t nodeIndex = 0; nodeIndex < graphNodeCreateInfo.nodeCount; nodeIndex++) - { - ORT_THROW_HR_IF(E_UNEXPECTED, !graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex].Get()); - operatorGraphNodeIndexToMainGraphNodeIndexMap.emplace(nodeIndex, graphNodeCount++); - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(graphNodeCreateInfo.nodesAsIDMLOperator[nodeIndex]); - graphNodes.push_back(std::move(nodeInfo)); + if (constantInput) + { + // The tensor description's size should be no larger than the constant input unless it was rounded to + // the required alignment. + assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); + size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); + auto data = static_cast(constantInput->GetData()); + std::vector tensorData(data, data + minimumConstantSize); + + smallConstantData.push_back(std::make_unique(tensorData.size())); + std::transform(tensorData.begin(), tensorData.end(), smallConstantData.back().get(), [](uint8_t b) {return static_cast(b);}); + + ConstantData constantData = {smallConstantData.back().get(), tensorData.size()}; + constantNode.Desc = constantData; + } + else + { + ConstantName constantFileName = {GetSanitizedFileName(arg->Name())}; + constantNode.Desc = constantFileName; + } + dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {static_cast(dmlGraphNodes.size()), 0}; + dmlGraphNodes.push_back(constantNode); + } } } - // map operatorGraphInputEdge as either mainGraphInputEdge or mainGraphIntermediateEdge - for (auto& operatorGraphInputEdge : graphNodeCreateInfo.inputEdges) - { - // operatorGraphInputEdge.GraphInputIndex will be the ONNX input index. - const onnxruntime::NodeArg* arg = node.InputDefs()[operatorGraphInputEdge.GraphInputIndex]; + // Create a map between operatorGraphNodeIndex to dmlGraphNodeIndex. + std::unordered_map operatorDmlGraphToDmlGraphNodeIndexMap; + // map operatorDmlGraphInputEdge as either mainDmlGraphInputEdge or mainDmlGraphIntermediateEdge + for (auto& operatorDmlGraphInputEdge : operatorDmlGraphCreateInfo.inputEdges) + { + // operatorDmlGraphInputEdge.GraphInputIndex will be the ONNX input index. + const onnxruntime::NodeArg* arg = node.InputDefs()[operatorDmlGraphInputEdge.GraphInputIndex]; if (arg->Exists()) { - auto iter = nameToDmlFusedNodeInputIndex.find(arg->Name()); - uint32_t mainGraphNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphInputEdge.ToNodeIndex]; - - if (iter != nameToDmlFusedNodeInputIndex.end()) + uint32_t dmlGraphNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorDmlGraphInputEdge.ToNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorDmlGraphInputEdge.ToNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + + auto iter = subgraphInputNameToIndexMap.find(arg->Name()); + if (iter != subgraphInputNameToIndexMap.end()) { - // This is a graph input - - const uint32_t dmlFusedNodeInputIndex = iter->second; - - // If this is a constant input, set the appropriate flags on the desc - if (isNodeAsOpDesc && - dmlFusedNodeInputIndex < isConstGpuGraphInputCount && - isConstGpuGraphInput[dmlFusedNodeInputIndex]) + const uint32_t subgraphInputIndex = iter->second; + + // Either this edge will be + // a constant input, then it will be an intermediate edge and + // set the OWNED_BY_DML flag if it is large constant + // or, + // a non-constant input, then it will be a mainDmlGraphInputEdge. + if (subgraphInputIndex < isConstGpuGraphInputCount && + isConstGpuGraphInput[subgraphInputIndex]) { - // This is a highly inefficient approach to generating constant nodes. It duplicates constant data - // across the graph input as well as every consumer's unique constant node. However it is currently - // only used for small inputs. - uint32_t c_maxConstNodeDataSize = 8; - - - auto& operatorGraphInputNode = graphNodeCreateInfo.nodesAsOperatorDesc[operatorGraphInputEdge.ToNodeIndex]; - std::vector toNodeInputTensorDescs = operatorGraphInputNode->GetInputTensors(); - DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorGraphInputEdge.ToNodeInputIndex]; - ComPtr constantInput; - - if (tensorDesc->totalTensorSizeInBytes < c_maxConstNodeDataSize) - { - constantInput = constantCpuGraphInputGetter(arg->Name()); - } - - if (constantInput) - { - // The tensor description's size should be no larger than the constant input unless it was rounded to - // the required alignment. - assert(((constantInput->GetTensorByteSize() + 3) & ~3) >= tensorDesc->totalTensorSizeInBytes); - size_t minimumConstantSize = std::min(constantInput->GetTensorByteSize(), gsl::narrow_cast(tensorDesc->totalTensorSizeInBytes)); - auto data = static_cast(constantInput->GetData()); - std::vector tensorData(data, data + minimumConstantSize); - - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(tensorData); - graphNodes.push_back(std::move(nodeInfo)); - - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; - edge.FromNodeIndex = static_cast(graphNodes.size() - 1); - edge.FromNodeOutputIndex = 0; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); - } - else + const auto& constantNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name()); + auto& constantNodeVariant = std::get(dmlGraphNodes[constantNodeAndIndex.nodeIndex].Desc); + if (std::holds_alternative(constantNodeVariant)) { - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphInputEdges.push_back(edge); - + auto& mainDmlGraphNode = dmlGraphNodes[dmlGraphNodeIndex]; + AbstractOperatorDesc& abstractOperatorDesc = std::get(mainDmlGraphNode.Desc); + std::vector toNodeInputTensorDescs = abstractOperatorDesc.GetInputTensors(); + DmlBufferTensorDesc* tensorDesc = toNodeInputTensorDescs[operatorDmlGraphInputEdge.ToNodeInputIndex]; tensorDesc->flags |= DML_TENSOR_FLAG_OWNED_BY_DML; + serializedGraphLargeConstantNameToSubgraphInputIndex[arg->Name()] = subgraphInputIndex; } + + DmlIntermediateSerializedGraphEdge edge = {}; + edge.FromNodeIndex = constantNodeAndIndex.nodeIndex; + edge.FromNodeOutputIndex = constantNodeAndIndex.targetIndex; + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; + edge.Name = arg->Name() + "-nodeIdx:" + std::to_string(edge.FromNodeIndex) + "-outputIdx:" + std::to_string(edge.FromNodeOutputIndex); + dmlGraphIntermediateEdges.push_back(edge); } else { - DML_INPUT_GRAPH_EDGE_DESC edge = {}; - edge.GraphInputIndex = dmlFusedNodeInputIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphInputEdges.push_back(edge); + DmlInputSerializedGraphEdge edge = {}; + if (subGraphInputIndexToDmlGraphInputIndex.find(subgraphInputIndex) == subGraphInputIndexToDmlGraphInputIndex.end()) + { + subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex] = static_cast(subGraphInputIndexToDmlGraphInputIndex.size()); + } + + edge.GraphInputIndex = subGraphInputIndexToDmlGraphInputIndex[subgraphInputIndex]; + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; // ?? might need to point inputIndex + edge.Name = arg->Name(); + + serializedGraphInputIndexToSubgraphInputIndex[edge.GraphInputIndex] = subgraphInputIndex; + dmlGraphInputEdges.push_back(edge); } } else { - const auto& inputNodeAndIndex = nameToNodeAndIndexMap.at(arg->Name()); + const auto& inputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(arg->Name()); - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; + DmlIntermediateSerializedGraphEdge edge = {}; edge.FromNodeIndex = inputNodeAndIndex.nodeIndex; edge.FromNodeOutputIndex = inputNodeAndIndex.targetIndex; - edge.ToNodeIndex = mainGraphNodeIndex; - edge.ToNodeInputIndex = operatorGraphInputEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); + edge.ToNodeIndex = dmlGraphNodeIndex; + edge.ToNodeInputIndex = operatorDmlGraphInputEdge.ToNodeInputIndex; + edge.Name = arg->Name(); + dmlGraphIntermediateEdges.push_back(edge); } } } // map operatorGraphIntermediateEdges as mainGraphIntermediateEdge - for (auto& operatorGraphIntermediateEdge : graphNodeCreateInfo.intermediateEdges) + for (auto& operatorGraphIntermediateEdge : operatorDmlGraphCreateInfo.intermediateEdges) { - DML_INTERMEDIATE_GRAPH_EDGE_DESC edge = {}; - edge.FromNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.FromNodeIndex]; + DmlIntermediateSerializedGraphEdge edge = {}; + uint32_t shiftedFromNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphIntermediateEdge.FromNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.FromNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + uint32_t shiftedToNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphIntermediateEdge.ToNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphIntermediateEdge.ToNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + + edge.FromNodeIndex = shiftedFromNodeIndex; edge.FromNodeOutputIndex = operatorGraphIntermediateEdge.FromNodeOutputIndex; - edge.ToNodeIndex = operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphIntermediateEdge.ToNodeIndex]; + edge.ToNodeIndex = shiftedToNodeIndex; edge.ToNodeInputIndex = operatorGraphIntermediateEdge.ToNodeInputIndex; - graphIntermediateEdges.push_back(edge); + edge.Name = "nodeIdx:" + std::to_string(shiftedFromNodeIndex) + "-outputIdx:" + std::to_string(operatorGraphIntermediateEdge.FromNodeOutputIndex); + dmlGraphIntermediateEdges.push_back(edge); } - + // populate nameToNodeAndIndexMap (which will be used by above loop) for operatorGraphOutputEdges - for (auto& operatorGraphOutputEdge : graphNodeCreateInfo.outputEdges) + for (auto& operatorGraphOutputEdge : operatorDmlGraphCreateInfo.outputEdges) { const onnxruntime::NodeArg* arg = node.OutputDefs()[operatorGraphOutputEdge.GraphOutputIndex]; if (arg->Exists()) { - nameToNodeAndIndexMap[arg->Name()] = NodeAndIndex { - operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], - operatorGraphOutputEdge.FromNodeOutputIndex - }; - + uint32_t shiftedNodeIndex = SetAndGetDmlGraphNodeIndex( + operatorGraphOutputEdge.FromNodeIndex, + node.Name(), + *operatorDmlGraphCreateInfo.nodes[operatorGraphOutputEdge.FromNodeIndex], + operatorDmlGraphToDmlGraphNodeIndexMap, + dmlGraphNodes); + dmlGraphNodeOutputNameToNodeAndIndexMap[arg->Name()] = {shiftedNodeIndex, operatorGraphOutputEdge.FromNodeOutputIndex}; nodeOutputShapes[arg->Name()] = outputShapes; } } - - if (isNodeAsOpDesc) - { - for (size_t i = 0; i < graphNodeCreateInfo.nodesAsOperatorDesc.size(); ++i) - { - auto& opDesc = graphNodeCreateInfo.nodesAsOperatorDesc[i]; - - DML_OPERATOR_DESC dmlDesc = SchemaHelpers::ConvertOperatorDesc(*opDesc, &allocator); - - // TODO: Change as new header is ingested - if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_QUANTIZED_LINEAR_AVERAGE_POOLING) - dmlDesc.Type = (DML_OPERATOR_TYPE) 169; - - // TODO: Change as new header is ingested - if (dmlDesc.Type == (DML_OPERATOR_TYPE) DML_OPERATOR_MATRIX_MULTIPLY_INTEGER_TO_FLOAT) - dmlDesc.Type = (DML_OPERATOR_TYPE) 170; - - ComPtr op; - ORT_THROW_IF_FAILED(device->CreateOperator(&dmlDesc, IID_PPV_ARGS(&op))); - allocator.Reset(); - - NodeInfo nodeInfo = {}; - nodeInfo.nodeDef = std::move(op); - nodeInfo.name = node.Name(); - graphNodes[firstOpDescGraphNodeIndex + i] = std::move(nodeInfo); - } - } } EdgeShapes graphOutputShapes(subgraphOutputs.size()); @@ -476,24 +505,27 @@ namespace Dml::GraphDescBuilder const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); - const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); + const auto& outputNodeAndIndex = dmlGraphNodeOutputNameToNodeAndIndexMap.at(graphOutput->Name()); - DML_OUTPUT_GRAPH_EDGE_DESC edge = {}; + DmlOutputSerializedGraphEdge edge = {}; edge.FromNodeIndex = outputNodeAndIndex.nodeIndex; edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); - graphOutputEdges.push_back(edge); + edge.Name = graphOutput->Name(); + dmlGraphOutputEdges.push_back(edge); graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } - RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); + RemoveUnconnectedNodes(dmlGraphNodes, dmlGraphInputEdges, dmlGraphIntermediateEdges, dmlGraphOutputEdges); GraphDesc graphDesc{}; - graphDesc.nodes = std::move(graphNodes); - graphDesc.inputEdges = std::move(graphInputEdges); - graphDesc.outputEdges = std::move(graphOutputEdges); - graphDesc.intermediateEdges = std::move(graphIntermediateEdges); - graphDesc.reuseCommandList = reuseCommandList; + graphDesc.InputCount = static_cast(dmlGraphInputEdges.size()); + graphDesc.OutputCount = static_cast(subgraphOutputs.size()); + graphDesc.Nodes = std::move(dmlGraphNodes); + graphDesc.InputEdges = std::move(dmlGraphInputEdges); + graphDesc.OutputEdges = std::move(dmlGraphOutputEdges); + graphDesc.IntermediateEdges = std::move(dmlGraphIntermediateEdges); + graphDesc.reuseCommandList = (subgraphNodes.size() >= minNodeCountToReuseCommandList || executionHandle->IsMcdmDevice()); graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index c95e89b45541b..4055984b40405 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -22,22 +22,15 @@ namespace Dml namespace GraphDescBuilder { + constexpr uint32_t minNodeCountToReuseCommandList = 5; + constexpr uint32_t c_maxConstNodeDataSize = 8; + // Gets a unique name for the node which survives recreation and graph manipulations between the point // that graph partitioning occurs and kernel creation happens const std::string& GetUniqueNodeName(const onnxruntime::Node& node); - struct NodeInfo - { - std::variant, std::vector> nodeDef; - std::string name; - }; - - struct GraphDesc + struct GraphDesc : DmlSerializedGraphDesc { - std::vector nodes; - std::vector inputEdges; - std::vector outputEdges; - std::vector intermediateEdges; bool reuseCommandList; Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; @@ -47,11 +40,13 @@ namespace Dml const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, const std::unordered_map& graphNodePropertyMap, - IDMLDevice* device, const ExecutionProviderImpl* executionHandle, const onnxruntime::Path& modelPath, gsl::span subgraphNodes, gsl::span subgraphInputs, - gsl::span subgraphOutputs); + gsl::span subgraphOutputs, + /*out*/ std::unordered_map& serializedGraphInputIndexToSubgraphInputIndex, + /*out*/ std::unordered_map& serializedGraphLargeConstantNameToSubgraphInputIndex, + /*out*/ std::vector>& smallConstantData); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index d524780de71b8..f29fbc7a1a65b 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1508,31 +1508,17 @@ namespace Windows::AI::MachineLearning::Adapter ORT_TRY { assert(operatorGraphDesc != nullptr); - // Either nodesAsOpDesc or nodesIDMLOperator can be present. - assert(operatorGraphDesc->nodeCount == 0 || (!operatorGraphDesc->nodesAsOpDesc ^ !operatorGraphDesc->nodesAsIDMLOperator)); + assert(operatorGraphDesc->nodeCount == 0 || operatorGraphDesc->nodes); - if (operatorGraphDesc->nodesAsOpDesc) + m_graphNodeCreateInfo->nodes = std::vector>(); + for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) { - m_graphNodeCreateInfo->nodesAsOperatorDesc = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) - { - auto* node = operatorGraphDesc->nodesAsOpDesc[nodeIndex]; - assert(node != nullptr); - AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); - m_graphNodeCreateInfo->nodesAsOperatorDesc.push_back(std::make_unique(std::move(abstractDesc))); - } - } - else - { - m_graphNodeCreateInfo->nodesAsIDMLOperator = std::vector>(); - for (uint32_t nodeIndex = 0; nodeIndex < operatorGraphDesc->nodeCount; nodeIndex++) - { - auto* node = operatorGraphDesc->nodesAsIDMLOperator[nodeIndex]; - assert(node != nullptr); - m_graphNodeCreateInfo->nodesAsIDMLOperator.push_back(node); - } + auto* node = operatorGraphDesc->nodes[nodeIndex]; + assert(node != nullptr); + AbstractOperatorDesc abstractDesc = SchemaHelpers::ConvertOperatorDesc(*node); + m_graphNodeCreateInfo->nodes.push_back(std::make_unique(std::move(abstractDesc))); } - + // There can be operators (or kernels) which don't require any input. assert(operatorGraphDesc->inputEdgeCount == 0 || operatorGraphDesc->inputEdges != nullptr); m_graphNodeCreateInfo->inputEdges.insert( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp index c3bb1a52210f5..287f1e5b6dfe7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperator.cpp @@ -53,7 +53,7 @@ namespace Dml MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 1; const DML_OPERATOR_DESC* opDescs{&operatorDesc}; - operatorGraphDesc.nodesAsOpDesc = &opDescs; + operatorGraphDesc.nodes = &opDescs; std::vector inputEdges; for (uint32_t inputIndex = 0; inputIndex < m_kernelInputIndices.size(); inputIndex++) @@ -796,7 +796,7 @@ namespace Dml for (size_t i = 0; i < graphDesc.NodeCount; ++i) { // Create the operator. - ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodesAsOpDesc[i], IID_PPV_ARGS(&dmlOperators[i]))); + ORT_THROW_IF_FAILED(m_dmlDevice->CreateOperator(operatorGraphDesc.nodes[i], IID_PPV_ARGS(&dmlOperators[i]))); dmlOperatorGraphNodes[i] = DML_OPERATOR_GRAPH_NODE_DESC{dmlOperators[i].Get()}; dmlGraphNodes[i] = DML_GRAPH_NODE_DESC{DML_GRAPH_NODE_TYPE_OPERATOR, &dmlOperatorGraphNodes[i]}; } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp index c8ca6806e75f7..73c2d57e984af 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorAttention.cpp @@ -531,7 +531,7 @@ class DmlOperatorAttention : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp index 1c851c94c4ddc..5aceebbdabfe3 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasAdd.cpp @@ -103,7 +103,7 @@ class DmlOperatorBiasAdd : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp index 501ce14f1fc08..1e10214ffd463 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorBiasSplitGelu.cpp @@ -137,7 +137,7 @@ class DmlOperatorBiasSplitGelu : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp index 6a8333cd72561..3c9458658c4d0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorEmbedLayerNormalization.cpp @@ -484,7 +484,7 @@ class DmlOperatorEmbedLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp index fed0e4645ffd8..8b275fc550f3e 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorGroupNorm.cpp @@ -287,7 +287,7 @@ class DmlOperatorGroupNorm : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp index 5c64059f7caa9..80e6fefc2fb80 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorLayerNormalization.cpp @@ -247,7 +247,7 @@ class DmlOperatorLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp index c97b03dc36b62..8727610ff3112 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearConcat.cpp @@ -166,7 +166,7 @@ class DmlOperatorQLinearConcat : public DmlOperator, public QLinearConcatHelper MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = static_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); uint32_t joinNodeIndex = operatorGraphDesc.nodeCount - 2; uint32_t quantizeNodeIndex = operatorGraphDesc.nodeCount - 1; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp index 35f926d62c92a..f658e7c7da323 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQLinearSigmoid.cpp @@ -113,7 +113,7 @@ class DmlOperatorQLinearSigmoid : public DmlOperator MLOperatorGraphDesc operatorGraphDesc = {}; operatorGraphDesc.nodeCount = 3; std::vector opDescs{&opDesc1, &opDesc2, &opDesc3}; - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); // set input edges std::pair nodeToNodeInputIndex[5] {{0, 0}, {0, 1}, {0, 2}, {2, 1}, {2, 2}}; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp index 3683ab7b0b0b3..e62b7d707ba78 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorQuickGelu.cpp @@ -123,7 +123,7 @@ class DmlOperatorQuickGelu : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp index 44004b5d77f70..0f15ebf342b3a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRotaryEmbedding.cpp @@ -441,7 +441,7 @@ class DmlOperatorRotaryEmbedding : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelInfo); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp index 4dafd78f21ea8..094c45a0e38e5 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorSkipLayerNormalization.cpp @@ -198,7 +198,7 @@ class DmlOperatorSkipLayerNormalization : public DmlOperator operatorGraphDesc.outputEdgeCount = gsl::narrow_cast(outputEdges.size()); operatorGraphDesc.outputEdges = outputEdges.data(); operatorGraphDesc.nodeCount = gsl::narrow_cast(opDescs.size()); - operatorGraphDesc.nodesAsOpDesc = opDescs.data(); + operatorGraphDesc.nodes = opDescs.data(); SetDmlOperatorGraphDesc(std::move(operatorGraphDesc), kernelCreationContext); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h new file mode 100644 index 0000000000000..02166f992449e --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Utility.h @@ -0,0 +1,141 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include +#include + + +namespace Dml +{ + static inline std::wstring ConvertToWString(std::string_view str) + { + std::wstring_convert,wchar_t> g_converterToUtf16; + return g_converterToUtf16.from_bytes(str.data()); + } + + static inline std::wstring GetModelName(const onnxruntime::Path& modelPath) + { + if (modelPath.GetComponents().empty()) + { + return L""; + } + + const onnxruntime::PathString& pathString = modelPath.GetComponents().back(); + size_t dotPosition = pathString.find_last_of('.'); + if (dotPosition == std::string::npos) + { + return L""; + } + + return pathString.substr(0, dotPosition); + } + + static inline std::wstring GetSanitizedFileName(std::wstring_view name) + { + std::wstring newName(name); + for (wchar_t& c : newName) + { + switch (c) + { + case '\\': + case '/': + case '\"': + case '|': + case '<': + case '>': + case ':': + case '?': + case '*': + c = '_'; + break; + } + } + return newName; + } + + static inline std::string GetSanitizedFileName(std::string_view name) + { + std::string newName(name); + for (char& c : newName) + { + switch (c) + { + case '\\': + case '/': + case '\"': + case '|': + case '<': + case '>': + case ':': + case '?': + case '*': + c = '_'; + break; + } + } + return newName; + } + + static inline void WriteToFile(std::wstring_view directoryName, std::wstring_view fileName, std::uint8_t* data, size_t dataSize) + { + std::wstring sanitizedFileName = GetSanitizedFileName(fileName); + std::filesystem::create_directory(directoryName); + std::wstring fullSanitizedFileName = std::wstring(directoryName) + + (directoryName.empty() ? L"" : L"/") + + sanitizedFileName; + std::ofstream file(fullSanitizedFileName, std::ios::binary); + if (!file.is_open()) + { + std::wstring_convert,wchar_t> g_converterToUtf16; + std::stringstream errorMessage; + errorMessage << "File named: " << g_converterToUtf16.to_bytes(fileName.data()) << " could not be opened\n"; + throw std::ios::failure(errorMessage.str()); + } + file.write(reinterpret_cast(data), dataSize); + } + +} + +namespace StringUtil +{ + struct NameAndIndex + { + const char* name; // Null terminated. + uint32_t index; + }; + + struct WideNameAndIndex + { + const wchar_t* name; // Null terminated. + uint32_t index; + }; + + inline std::optional MapToIndex(std::string_view mode, gsl::span nameAndIndexList) + { + for (auto& nameAndIndex : nameAndIndexList) + { + if (strncmp(nameAndIndex.name, mode.data(), mode.size()) == 0) + { + return nameAndIndex.index; + } + } + + return {}; + } + + inline std::optional MapToIndex(std::wstring_view mode, gsl::span nameAndIndexList) + { + for (auto& nameAndIndex : nameAndIndexList) + { + if (wcsncmp(nameAndIndex.name, mode.data(), mode.size()) == 0) + { + return nameAndIndex.index; + } + } + + return {}; + } +} \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h index 83737d2ba4848..332bf86685e8a 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/precomp.h @@ -17,6 +17,8 @@ #include #include #include +#include +#include #include #include @@ -37,6 +39,7 @@ #include #include "External/D3DX12/d3dx12.h" #endif +#include "flatbuffers/flatbuffers.h" #include "GraphicsUnknownHelper.h" @@ -53,6 +56,9 @@ #include "External/DirectMLHelpers/SchemaHelpers.h" #include "External/DirectMLHelpers/GeneratedSchemaHelpers.h" #include "External/DirectMLHelpers/DirectMLX.h" +#include "External/DirectMLHelpers/DmlSerializedGraphDesc.h" +#include "External/DirectMLHelpers/DmlGraphSerialization.h" +#include "External/DirectMLHelpers/DmlGraphDeserialization.h" using Microsoft::WRL::ComPtr; @@ -67,3 +73,4 @@ using Microsoft::WRL::ComPtr; #include "TensorDesc.h" #include "DescriptorPool.h" #include "IExecutionProvider.h" +#include "Utility.h" \ No newline at end of file diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h index 3bec8d3864cba..ac3a3eb1268b8 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorPrivate.h @@ -10,18 +10,11 @@ struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; -// Either nodesAsOpDesc or nodesAsIDMLOperator is present. -// 1) Operator kernels which implement operators using only a single DML operator will pass a DML_OPERATOR_DESC. -// These kernels pass DML_OPERATOR_DESC, because while building Dml graph (inside FusedGraphKernel.cpp) we can change the -// the flag of constant inputs to DML_TENSOR_FLAG_OWNED_BY_DML. -// 2) Operator kernels which implement operators using DMLX graph, they will pass IDMLOperator and won't be able -// to use DML_TENSOR_FLAG_OWNED_BY_DML. struct MLOperatorGraphDesc { uint32_t nodeCount; - _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodesAsOpDesc; - _Field_size_opt_(nodeCount) IDMLOperator** nodesAsIDMLOperator; - + _Field_size_opt_(nodeCount) const DML_OPERATOR_DESC** nodes; + uint32_t inputEdgeCount; _Field_size_(inputEdgeCount) const DML_INPUT_GRAPH_EDGE_DESC* inputEdges; diff --git a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h index d11fa7516e713..5b5f371f51616 100644 --- a/onnxruntime/core/providers/dml/dml_session_options_config_keys.h +++ b/onnxruntime/core/providers/dml/dml_session_options_config_keys.h @@ -21,3 +21,4 @@ // "1": disabled (disallowed). Graph fusion will never be used. // The default value is "0" static const char* const kOrtSessionOptionsConfigDisableDmlGraphFusion = "ep.dml.disable_graph_fusion"; +static const char* const kOrtSessionOptionsConfigEnableGraphSerialization = "ep.dml.enable_graph_serialization"; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index efd7db4ea7629..5fd66c459d382 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1725,10 +1725,17 @@ common::Status InferenceSession::Initialize() { // graph optimization level and is generally always applied. bool dml_graph_fusion_enabled = session_options_.optimized_model_filepath.empty() && session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigDisableDmlGraphFusion, "0") == "0"; + std::string dml_graph_serialization_enabled_config_val = session_options_.config_options.GetConfigOrDefault(kOrtSessionOptionsConfigEnableGraphSerialization, "0"); + std::transform(dml_graph_serialization_enabled_config_val.begin(), + dml_graph_serialization_enabled_config_val.end(), + dml_graph_serialization_enabled_config_val.begin(), + [](char ch) { return std::tolower(ch); }); + bool dml_graph_serialization_enabled = dml_graph_serialization_enabled_config_val == "true"; if (dml_graph_fusion_enabled) { std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", - dmlExecutionProvider); + dmlExecutionProvider, + dml_graph_serialization_enabled); if (dmlGraphFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); } diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index 3874901f86387..7d4111e3b9c39 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -68,6 +68,7 @@ namespace perftest { "\t [DML only] [device_filter]: DML device filter, options: 'any', 'gpu', 'npu', \n" "\t [DML only] [disable_metacommands]: Options: 'true', 'false', \n" "\t [DML only] [enable_dynamic_graph_fusion]: Options: 'true', 'false', \n" + "\t [DML only] [enable_graph_serialization]: Options: 'true', 'false', \n" "\t [OpenVINO only] [device_type]: Overrides the accelerator hardware type and precision with these values at runtime.\n" "\t [OpenVINO only] [device_id]: Selects a particular hardware device for inference.\n" "\t [OpenVINO only] [enable_npu_fast_compile]: Optionally enabled to speeds up the model's compilation on NPU device targets.\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index 87506c7240578..1934314b8ce43 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -18,6 +18,7 @@ #ifdef USE_DML #include "core/providers/dml/dml_provider_factory.h" +#include "core/providers/dml/dml_session_options_config_keys.h" #endif #ifdef _WIN32 @@ -542,6 +543,15 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); "[ERROR] [DML] You have selcted wrong value for the key 'enable_dynamic_graph_fusion'. " "Select from 'true' or 'false' \n"); } + } else if (key == "enable_graph_serialization") { + std::set ov_supported_values = {"true", "True", "false", "False"}; + if (ov_supported_values.find(value) != ov_supported_values.end()) { + session_options.AddConfigEntry(kOrtSessionOptionsConfigEnableGraphSerialization, value.data()); + } else { + ORT_THROW( + "[ERROR] [DML] You have selcted wrong value for the key 'enable_graph_serialization'. " + "Select from 'true' or 'false' \n"); + } } } session_options.AppendExecutionProvider("DML", dml_options);