Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyi9801 committed Aug 28, 2024
1 parent 1e49fe4 commit b76d3a9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 51 deletions.
1 change: 1 addition & 0 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
| LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual ||| |
| Log | ai.onnx(7-12, 13+) | log ||| |
| LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d ||| Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 |
| LSTM | ai.onnx(7-13, 14+) | lstm ||| Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' |
| MatMul | ai.onnx(7-8, 9-12, 13+) | matmul ||| |
| Max | ai.onnx(7, 8-11, 12, 13+) | max ||| |
| MaxPool | ai.onnx(7, 8-9, 10, 11, 12+) | maxPool2d ||| Only supports 4-D input, 2-D 'kernel_shape', 'storage_order' != 1, one output |
Expand Down
59 changes: 8 additions & 51 deletions onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,49 +90,14 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N

if (helper.HasAttr("activations")) {
const auto activations = helper.Get("activations", std::vector<std::string>{"Sigmoid", "Tanh", "Tanh"});
const auto activation_alpha = helper.Get("activation_alpha", std::vector<float>{});
const auto activation_beta = helper.Get("activation_beta", std::vector<float>{});

auto get_value_or_default = [](std::vector<float>::const_iterator& entry,
const std::vector<float>::const_iterator& end,
float def_val) -> float { return entry == end ? def_val : *entry++; };

auto alpha_iter = activation_alpha.begin();
auto beta_iter = activation_beta.begin();
const auto alpha_iter_end = activation_alpha.end();
const auto beta_iter_end = activation_beta.end();

emscripten::val opt_activations = emscripten::val::array();
for (size_t i = 0; i < 3; ++i) {
const std::string& activation = activations[i];
if (activation == "Affine") {
emscripten::val affine_options = emscripten::val::object();
affine_options.set("alpha", get_value_or_default(alpha_iter, alpha_iter_end, 1.0));
affine_options.set("beta", get_value_or_default(beta_iter, beta_iter_end, 0));
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("linear", affine_options));
} else if (activation == "Elu") {
emscripten::val elu_options = emscripten::val::object();
elu_options.set("alpha", get_value_or_default(alpha_iter, alpha_iter_end, 1.0));
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("elu", elu_options));
} else if (activation == "HardSigmoid") {
emscripten::val hard_sigmoid_options = emscripten::val::object();
hard_sigmoid_options.set("alpha", get_value_or_default(alpha_iter, alpha_iter_end, 0.2));
hard_sigmoid_options.set("beta", get_value_or_default(beta_iter, beta_iter_end, 0.5));
opt_activations.call<void>(
"push", model_builder.GetBuilder().call<emscripten::val>("hardSigmoid", hard_sigmoid_options));
} else if (activation == "LeakyRelu") {
emscripten::val leaky_relu_options = emscripten::val::object();
leaky_relu_options.set("alpha", get_value_or_default(alpha_iter, alpha_iter_end, 0.01));
opt_activations.call<void>("push",
model_builder.GetBuilder().call<emscripten::val>("leakyRelu", leaky_relu_options));
} else if (activation == "Relu") {
if (activation == "Relu") {
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("relu"));
} else if (activation == "Sigmoid") {
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("sigmoid"));
} else if (activation == "Softplus") {
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("softplus"));
} else if (activation == "Softsign") {
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("softsign"));
} else if (activation == "Tanh") {
opt_activations.call<void>("push", model_builder.GetBuilder().call<emscripten::val>("tanh"));
}
Expand Down Expand Up @@ -193,25 +158,17 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,

if (activations.size() >= 6) {
if (activations[0] != activations[3] || activations[1] != activations[4] || activations[2] != activations[5]) {
LOGS(logger, ERROR) << "LSTM: forward and reverse directions must have the same activations";
return false;
}
// TODO(shiyi9801): support activation_alpha and activation_beta when provided 6 activations.
if (helper.HasAttr("activation_alpha") || helper.HasAttr("activation_beta")) {
LOGS(logger, ERROR)
<< "LSTM: activation_alpha and activation_beta are not supported when provided 6 activations";
LOGS(logger, ERROR) << "LSTM: forward and backward activations must be the same";
return false;
}
}

const InlinedHashSet<std::string> supported_activations = {"Affine", "Relu", "LeakyRelu", "Tanh", "Sigmoid",
"HardSigmoid", "Elu", "Softsign", "Softplus"};
if (!std::all_of(activations.begin(), activations.end(),
[&supported_activations](const std::string& activation) -> bool {
return supported_activations.contains(activation);
})) {
LOGS(logger, ERROR) << "LSTM: activations must be one of Affine, Relu, LeakyRelu, Tanh, Sigmoid, HardSigmoid, "
"Elu, Softsign, Softplus";
const InlinedHashSet<std::string> supported_activations = {"Relu", "Tanh", "Sigmoid"};
if (std::any_of(activations.begin(), activations.end(),
[&supported_activations](const std::string& activation) -> bool {
return !supported_activations.contains(activation);
})) {
LOGS(logger, ERROR) << "LSTM: activations must be one of Relu, Tanh, Sigmoid";
return false;
}
}
Expand Down

0 comments on commit b76d3a9

Please sign in to comment.