Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DML EP] Add RotaryEmbedding #18158

Merged
merged 11 commits into from
Nov 7, 2023
1 change: 1 addition & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -2582,6 +2582,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].


#### Version

This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
Expand Down
1 change: 1 addition & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -1246,6 +1246,7 @@ Do not modify directly.*
|QLinearSigmoid|*in* X:**T**<br> *in* X_scale:**tensor(float)**<br> *in* X_zero_point:**T**<br> *in* Y_scale:**tensor(float)**<br> *in* Y_zero_point:**T**<br> *out* Y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float), tensor(float16), tensor(int32)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(BitwiseAnd);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseOr);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseXor);
DML_OP_EXTERN_CREATION_FUNCTION(BitwiseNot);
DML_OP_EXTERN_CREATION_FUNCTION(RotaryEmbedding);

DML_OP_EXTERN_QUERY_FUNCTION(MaxPool);
DML_OP_EXTERN_QUERY_FUNCTION(Slice);
Expand All @@ -527,6 +528,7 @@ DML_OP_EXTERN_QUERY_FUNCTION(Attention);
constexpr static std::array<const char*, 1> typeNameListDefault = {"T"};
constexpr static std::array<const char*, 1> typeNameListDefaultV = {"V"};
constexpr static std::array<const char*, 2> typeNameListAttention = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListRotaryEmbedding = {"T", "M"};
constexpr static std::array<const char*, 2> typeNameListTwo = { "T1", "T2" };
constexpr static std::array<const char*, 2> typeNameListLayerNorm = { "T", "U" };
constexpr static std::array<const char*, 2> typeNameListLayerNormContrib = { "T", "V" };
Expand Down Expand Up @@ -597,6 +599,7 @@ constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListShape
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListSize = {SupportedTensorDataTypes::All, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListQLinearSigmoid = {SupportedTensorDataTypes::UInt8 | SupportedTensorDataTypes::Int8};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListAttention = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int32};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListRotaryEmbedding = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Int64};
constexpr static std::array<SupportedTensorDataTypes, 2> supportedTypeListGroupNorm = {SupportedTensorDataTypes::Float16to32, SupportedTensorDataTypes::Float16to32};
constexpr static std::array<SupportedTensorDataTypes, 1> supportedTypeListNonZero = {SupportedTensorDataTypes::Float16to32 | SupportedTensorDataTypes::Ints8Bit | SupportedTensorDataTypes::Ints16Bit | SupportedTensorDataTypes::Ints32Bit | SupportedTensorDataTypes::Bool};

Expand Down Expand Up @@ -1006,6 +1009,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation
{REG_INFO_MS( 1, QLinearSigmoid, typeNameListDefault, supportedTypeListQLinearSigmoid, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryQLinearSigmoid)},
{REG_INFO_MS( 1, Attention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported, requiredConstantCpuInputs(), std::nullopt, QueryAttention)},
{REG_INFO_MS( 1, MultiHeadAttention, typeNameListAttention, supportedTypeListAttention, DmlGraphSupport::Supported)},
{REG_INFO_MS( 1, RotaryEmbedding, typeNameListRotaryEmbedding, supportedTypeListRotaryEmbedding, DmlGraphSupport::Supported)},

{REG_INFO( 10, IsInf, typeNameListTwo, supportedTypeListIsInf, DmlGraphSupport::Supported)},
{REG_INFO( 10, Mod, typeNameListDefault, supportedTypeListNumericDefault, DmlGraphSupport::Supported)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ namespace AttrName

static constexpr const char* GraphFusedActivation = "activation";
static constexpr const char* GraphFusedAxis = "activation_axis";
static constexpr const char* Interleaved = "interleaved";

} // namespace AttrName

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1584,6 +1584,7 @@ using ShapeInferenceHelper_DequantizeLinear = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_QLinearSigmoid = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Attention = AttentionHelper;
using ShapeInferenceHelper_MultiHeadAttention = MultiHeadAttentionHelper;
using ShapeInferenceHelper_RotaryEmbedding = GetOutputShapeAsInputShapeHelper;
using ShapeInferenceHelper_Sign = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_IsNaN = GetBroadcastedOutputShapeHelper;
using ShapeInferenceHelper_Erf = GetBroadcastedOutputShapeHelper;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,7 @@ namespace OperatorHelper
static const int sc_sinceVer_BiasAdd = 1;
static const int sc_sinceVer_QuickGelu = 1;
static const int sc_sinceVer_GroupNorm = 1;
static const int sc_sinceVer_RotaryEmbedding = 1;
} // namespace MsftOperatorSet1

} // namespace OperatorHelper
25 changes: 17 additions & 8 deletions onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ static void RunTest(
int64_t interleaved,
bool use_float16,
bool disable_cpu,
bool disable_cuda) {
bool disable_cuda,
bool disable_dml) {
// input : (batch_size, sequence_length, hidden_size)
// position ids : (1) or (batch_size, sequence_length)
// cos cache : (max_sequence_length, head_size / 2)
Expand All @@ -50,9 +51,14 @@ static void RunTest(

int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml;

if (enable_cuda && !disable_cuda) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (enable_dml && !disable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}
if (!use_float16 && !disable_cpu) {
execution_providers.push_back(DefaultCpuExecutionProvider());
}
Expand Down Expand Up @@ -107,9 +113,10 @@ static void RunTests(const std::vector<float>& input_data,
interleaved,
false, /* use_fp16 */
false, /* disable_cpu */
true /* disable_cuda */);
true, /* disable_cuda */
true /* disable_dml */);

// FP32 test for CUDA
// FP32 test for CUDA and DML
RunTest(input_data,
position_ids,
cos_cache,
Expand All @@ -123,9 +130,10 @@ static void RunTests(const std::vector<float>& input_data,
interleaved,
false, /* use_fp16 */
false, /* disable_cpu */
false /* disable_cuda */);
false, /* disable_cuda */
false /* disable_dml */);

// FP16 test for CUDA
// FP16 test for CUDA and DML
if (use_float16) {
RunTest(input_data,
position_ids,
Expand All @@ -138,9 +146,10 @@ static void RunTests(const std::vector<float>& input_data,
num_heads,
max_sequence_length,
interleaved,
true, /* use_fp16 */
true, /* disable_cpu */
false /* disable_cuda*/);
true, /* use_fp16 */
true, /* disable_cpu */
false, /* disable_cuda*/
false /* disable_dml */);
}
}

Expand Down
Loading