Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shiyi9801 committed Aug 29, 2024
1 parent 9971294 commit 24ffef5
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,20 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
LOGS(logger, ERROR) << "Cannot get input's shape";
return false;
}
int32_t steps = static_cast<int32_t>(input_shape[0]);

if (input_defs.size() > 4 && input_defs[4]->Exists()) {
if (!Contains(initializers, input_defs[4]->Name())) {
LOGS(logger, ERROR) << "LSTM: sequence_lens must be constant and the values must be equal to input shape[0]";
LOGS(logger, ERROR) << "LSTM: sequence_lens must be constant";
return false;
}

std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
LOGS(logger, ERROR) << "Cannot get input's shape";
return false;
}
int32_t steps = static_cast<int32_t>(input_shape[0]);

const auto& sequence_lens_tensor = *initializers.at(input_defs[4]->Name());
std::vector<int32_t> sequence_lens;
if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) {
Expand All @@ -189,22 +190,26 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
NodeAttrHelper helper(node);
if (helper.HasAttr("activations")) {
const auto activations = helper.Get("activations", std::vector<std::string>{"Sigmoid", "Tanh", "Tanh"});
if (activations.size() >= 6 &&
(activations[0] != activations[3] || activations[1] != activations[4] || activations[2] != activations[5])) {
LOGS(logger, ERROR) << "LSTM: forward and reverse directions must have the same activations";
return false;
}

if (activations.size() >= 6 && (helper.HasAttr("activation_alpha") || helper.HasAttr("activation_beta"))) {
LOGS(logger, ERROR) << "LSTM: activation_alpha and activation_beta are not supported when provided 6 activations";
return false;
if (activations.size() >= 6) {
if (activations[0] != activations[3] || activations[1] != activations[4] || activations[2] != activations[5]) {
LOGS(logger, ERROR) << "LSTM: forward and reverse directions must have the same activations";
return false;
}
// TODO(shiyi9801): support activation_alpha and activation_beta when provided 6 activations.
if (helper.HasAttr("activation_alpha") || helper.HasAttr("activation_beta")) {
LOGS(logger, ERROR)
<< "LSTM: activation_alpha and activation_beta are not supported when provided 6 activations";
return false;
}
}

if (!std::all_of(activations.begin(), activations.end(), [](const std::string& activation) -> bool {
return activation == "Affine" || activation == "Relu" || activation == "LeakyRelu" || activation == "Tanh" ||
activation == "Sigmoid" || activation == "HardSigmoid" || activation == "Elu" ||
activation == "Softsign" || activation == "Softplus";
})) {
const InlinedHashSet<std::string> supported_activations = {"Affine", "Relu", "LeakyRelu", "Tanh", "Sigmoid",
"HardSigmoid", "Elu", "Softsign", "Softplus"};
if (!std::all_of(activations.begin(), activations.end(),
[&supported_activations](const std::string& activation) -> bool {
return supported_activations.contains(activation);
})) {
LOGS(logger, ERROR) << "LSTM: activations must be one of Affine, Relu, LeakyRelu, Tanh, Sigmoid, HardSigmoid, "
"Elu, Softsign, Softplus";
return false;
Expand Down

0 comments on commit 24ffef5

Please sign in to comment.