From b76d3a948379ee461b9ff4fb9270810d0ef3cba5 Mon Sep 17 00:00:00 2001 From: Shiyi Zou Date: Wed, 28 Aug 2024 09:47:44 +0800 Subject: [PATCH] update --- js/web/docs/webnn-operators.md | 1 + .../webnn/builders/impl/lstm_op_builder.cc | 59 +++---------------- 2 files changed, 9 insertions(+), 51 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 48b06b780dfc7..ca50b5e431823 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -51,6 +51,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | | | Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | | | LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 | +| LSTM | ai.onnx(7-13, 14+) | lstm | ✗ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | | | Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | | | MaxPool | ai.onnx(7, 8-9, 10, 11, 12+) | maxPool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'storage_order' != 1, one output | diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 7371d97bc4577..2f6d2f497a27b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -90,49 +90,14 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N if (helper.HasAttr("activations")) { const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh", "Tanh"}); - const auto activation_alpha = helper.Get("activation_alpha", std::vector{}); - const auto activation_beta = helper.Get("activation_beta", std::vector{}); - - auto get_value_or_default = [](std::vector::const_iterator& entry, - const std::vector::const_iterator& end, - float def_val) -> float { return entry == end ? def_val : *entry++; }; - - auto alpha_iter = activation_alpha.begin(); - auto beta_iter = activation_beta.begin(); - const auto alpha_iter_end = activation_alpha.end(); - const auto beta_iter_end = activation_beta.end(); emscripten::val opt_activations = emscripten::val::array(); for (size_t i = 0; i < 3; ++i) { const std::string& activation = activations[i]; - if (activation == "Affine") { - emscripten::val affine_options = emscripten::val::object(); - affine_options.set("alpha", get_value_or_default(alpha_iter, alpha_iter_end, 1.0)); - affine_options.set("beta", get_value_or_default(beta_iter, beta_iter_end, 0)); - opt_activations.call("push", model_builder.GetBuilder().call("linear", affine_options)); - } else if (activation == "Elu") { - emscripten::val elu_options = emscripten::val::object(); - elu_options.set("alpha", get_value_or_default(alpha_iter, alpha_iter_end, 1.0)); - opt_activations.call("push", model_builder.GetBuilder().call("elu", elu_options)); - } else if (activation == "HardSigmoid") { - emscripten::val hard_sigmoid_options = emscripten::val::object(); - hard_sigmoid_options.set("alpha", get_value_or_default(alpha_iter, alpha_iter_end, 0.2)); - hard_sigmoid_options.set("beta", get_value_or_default(beta_iter, beta_iter_end, 0.5)); - opt_activations.call( - "push", model_builder.GetBuilder().call("hardSigmoid", hard_sigmoid_options)); - } else if (activation == "LeakyRelu") { - emscripten::val leaky_relu_options = emscripten::val::object(); - leaky_relu_options.set("alpha", get_value_or_default(alpha_iter, alpha_iter_end, 0.01)); - opt_activations.call("push", - model_builder.GetBuilder().call("leakyRelu", leaky_relu_options)); - } else if (activation == "Relu") { + if (activation == "Relu") { opt_activations.call("push", model_builder.GetBuilder().call("relu")); } else if (activation == "Sigmoid") { opt_activations.call("push", model_builder.GetBuilder().call("sigmoid")); - } else if (activation == "Softplus") { - opt_activations.call("push", model_builder.GetBuilder().call("softplus")); - } else if (activation == "Softsign") { - opt_activations.call("push", model_builder.GetBuilder().call("softsign")); } else if (activation == "Tanh") { opt_activations.call("push", model_builder.GetBuilder().call("tanh")); } @@ -193,25 +158,17 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, if (activations.size() >= 6) { if (activations[0] != activations[3] || activations[1] != activations[4] || activations[2] != activations[5]) { - LOGS(logger, ERROR) << "LSTM: forward and reverse directions must have the same activations"; - return false; - } - // TODO(shiyi9801): support activation_alpha and activation_beta when provided 6 activations. - if (helper.HasAttr("activation_alpha") || helper.HasAttr("activation_beta")) { - LOGS(logger, ERROR) - << "LSTM: activation_alpha and activation_beta are not supported when provided 6 activations"; + LOGS(logger, ERROR) << "LSTM: forward and backward activations must be the same"; return false; } } - const InlinedHashSet supported_activations = {"Affine", "Relu", "LeakyRelu", "Tanh", "Sigmoid", - "HardSigmoid", "Elu", "Softsign", "Softplus"}; - if (!std::all_of(activations.begin(), activations.end(), - [&supported_activations](const std::string& activation) -> bool { - return supported_activations.contains(activation); - })) { - LOGS(logger, ERROR) << "LSTM: activations must be one of Affine, Relu, LeakyRelu, Tanh, Sigmoid, HardSigmoid, " - "Elu, Softsign, Softplus"; + const InlinedHashSet supported_activations = {"Relu", "Tanh", "Sigmoid"}; + if (std::any_of(activations.begin(), activations.end(), + [&supported_activations](const std::string& activation) -> bool { + return !supported_activations.contains(activation); + })) { + LOGS(logger, ERROR) << "LSTM: activations must be one of Relu, Tanh, Sigmoid"; return false; } }