diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp index 88b827f61f0c9..65568725b0e60 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorRecurrentNeuralNetwork.cpp @@ -123,58 +123,62 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper for (size_t i = 0; i < activations.size(); ++i) { - const std::string& activationName = activations[i]; + std::string& activationName = activations[i]; DML_OPERATOR_DESC& desc = descs[i]; ActivationOperatorDescUnion& activationDesc = m_activationDescs[i]; desc.Desc = &activationDesc; - - if (activationName == AttrValue::ActivationRelu) + + if (ActivationNameCompare(activationName, AttrValue::ActivationRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_RELU; - } - else if (activationName == AttrValue::ActivationLeakyRelu) + } + else if (ActivationNameCompare(activationName, AttrValue::ActivationLeakyRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU; activationDesc.leakyRelu.Alpha = NextAlpha(desc.Type); } - else if (activationName == AttrValue::ActivationThresholdedRelu) + else if (ActivationNameCompare(activationName, AttrValue::ActivationThresholdedRelu)) { desc.Type = DML_OPERATOR_ACTIVATION_THRESHOLDED_RELU; activationDesc.thresholdedRelu.Alpha = NextAlpha(desc.Type); - } - else if (activationName == AttrValue::ActivationTanh) + } + else if (ActivationNameCompare(activationName, AttrValue::ActivationTanh)) { desc.Type = DML_OPERATOR_ACTIVATION_TANH; - } - else if (activationName == AttrValue::ActivationScaledTanh) + } + else if (ActivationNameCompare(activationName, AttrValue::ActivationScaledTanh)) { desc.Type = DML_OPERATOR_ACTIVATION_SCALED_TANH; activationDesc.scaledTanh.Alpha = NextAlpha(desc.Type); activationDesc.scaledTanh.Beta = NextBeta(desc.Type); - } - else if (activationName == AttrValue::ActivationSigmoid) + } + else if (ActivationNameCompare(activationName, AttrValue::ActivationSigmoid)) { desc.Type = DML_OPERATOR_ACTIVATION_SIGMOID; - } - else if (activationName == AttrValue::ActivationSigmoidHard) + } + else if (ActivationNameCompare(activationName, AttrValue::ActivationSigmoidHard)) { desc.Type = DML_OPERATOR_ACTIVATION_HARD_SIGMOID; activationDesc.hardSigmoid.Alpha = NextAlpha(desc.Type); activationDesc.hardSigmoid.Beta = NextBeta(desc.Type); - } - else if (activationName == AttrValue::ActivationElu) + } + else if (ActivationNameCompare(activationName, AttrValue::ActivationElu)) { desc.Type = DML_OPERATOR_ACTIVATION_ELU; activationDesc.elu.Alpha = NextAlpha(desc.Type); - } - else if (activationName == AttrValue::ActivationSoftsign) + } + else if (ActivationNameCompare(activationName, AttrValue::ActivationSoftsign)) { desc.Type = DML_OPERATOR_ACTIVATION_SOFTSIGN; - } - else if (activationName == AttrValue::ActivationSoftplus) + } + else if (ActivationNameCompare(activationName, AttrValue::ActivationSoftplus)) { desc.Type = DML_OPERATOR_ACTIVATION_SOFTPLUS; } + else if (ActivationNameCompare(activationName, AttrValue::ActivationLeakyRelu)) + { + desc.Type = DML_OPERATOR_ACTIVATION_LEAKY_RELU; + } else { ML_INVALID_ARGUMENT("Unsupported activation function"); @@ -182,6 +186,23 @@ class DmlOperatorRecurrentBase: public DmlOperator, public RecurrentHelper } } + bool ActivationNameCompare(const std::string& activationName, const char* attrValue) + { + if (activationName.size() != std::char_traits::length(attrValue)) + { + return false; + } + + for (size_t i = 0; i < activationName.size(); ++i) + { + if (std::tolower(activationName[i]) != std::tolower(attrValue[i])) + { + return false; + } + } + return true; + } + void Compute(const MLOperatorKernelContext& kernelContext) override { // Assume that enough GPU work has been queued up after the RNN operator that it is worth