Skip to content

Commit

Permalink
Add separate op builder for Softmax/LogSoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianlizarraga committed Oct 10, 2023
1 parent 5f137ae commit bf83f04
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 28 deletions.
7 changes: 5 additions & 2 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,10 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateSimpleOpBuilder("Where", *this);
CreateSimpleOpBuilder("Sigmoid", *this);
CreateSimpleOpBuilder("Sin", *this);
CreateSimpleOpBuilder("Softmax", *this);
CreateSimpleOpBuilder("Sqrt", *this);
CreateSimpleOpBuilder("Sub", *this);
CreateSimpleOpBuilder("Tanh", *this);

CreateSimpleOpBuilder("LogSoftmax", *this);
CreateSimpleOpBuilder("MatMul", *this);
CreateSimpleOpBuilder("Concat", *this);

Expand All @@ -67,6 +65,11 @@ OpBuilderRegistrations::OpBuilderRegistrations() {
CreateSimpleOpBuilder("GridSample", *this);
}

{
CreateSoftmaxOpBuilder("Softmax", *this);
CreateSoftmaxOpBuilder("LogSoftmax", *this);
}

{
CreateCastOpBuilder("Cast", *this);
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/qnn/builder/op_builder_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ const IOpBuilder* GetOpBuilder(const std::string& onnx_op_type);

void CreateSimpleOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);

void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class SimpleOpBuilder : public BaseOpBuilder {
bool do_op_validation) const override ORT_MUST_USE_RESULT;

private:
Status ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const;
Status ExplicitOpCheck(const NodeUnit& node_unit) const;
Status ProcessSigmoidOrTanhOutput(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
Expand All @@ -41,30 +41,9 @@ class SimpleOpBuilder : public BaseOpBuilder {
static constexpr std::array<std::string_view, 3> gridsample_supported_padding_modes = {"zeros", "border", "reflection"};
};

static int32_t GetDefaultAxisAttribute(const std::string& op_type, int opset_version) {
if (op_type == "Softmax" || op_type == "LogSoftmax") {
// Default axis changed from 1 to -1 in opset 13.
return opset_version < 13 ? 1 : -1;
}

return 0;
}

Status SimpleOpBuilder::ExplicitOpCheck(const QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit) const {
Status SimpleOpBuilder::ExplicitOpCheck(const NodeUnit& node_unit) const {
const std::string& op_type = node_unit.OpType();

// QNN Softmax and LogSoftmax only support an axis value equal to input_rank - 1 (i.e., same as -1).
if (op_type == "Softmax" || op_type == "LogSoftmax") {
int32_t axis = GetDefaultAxisAttribute(op_type, node_unit.SinceVersion());
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(node_unit.Inputs()[0].node_arg, input_shape),
"QNN EP: Cannot get shape for Softmax input");
ORT_RETURN_IF(axis != static_cast<int32_t>(input_shape.size() - 1),
"QNN ", op_type.c_str(), " only supports an `axis` attribute equal to input_rank-1 (or -1)");
}

if (op_type == "GridSample") {
NodeAttrHelper node_helper(node_unit);
std::string mode = node_helper.Get("mode", "linear");
Expand Down Expand Up @@ -231,7 +210,7 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w
const std::string& op_type = node_unit.OpType();

if (do_op_validation) {
ORT_RETURN_IF_ERROR(ExplicitOpCheck(qnn_model_wrapper, node_unit));
ORT_RETURN_IF_ERROR(ExplicitOpCheck(node_unit));
// Skip the op validation for DepthToSpace & SpaceToDepth if it's not NHWC data layout
if (node_unit.Domain() != kMSInternalNHWCDomain && (op_type == "DepthToSpace" || op_type == "SpaceToDepth" || op_type == "GridSample")) {
return Status::OK();
Expand All @@ -251,8 +230,8 @@ Status SimpleOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_w

std::vector<std::string> param_tensor_names;
// Add attribute
if (op_type == "LogSoftmax" || op_type == "Softmax" || op_type == "Concat") {
int32_t default_axis = GetDefaultAxisAttribute(op_type, node_unit.SinceVersion());
if (op_type == "Concat") {
int32_t default_axis = 0;
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis));
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/providers/common.h"
#include "core/providers/shared/utils/utils.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/qnn/builder/qnn_model_wrapper.h"
#include "core/providers/qnn/builder/op_builder_factory.h"
#include "core/common/safeint.h"
#include "onnx/defs/data_type_utils.h"

#include "base_op_builder.h"

namespace onnxruntime {
namespace qnn {

class SoftmaxOpBuilder : public BaseOpBuilder {
public:
SoftmaxOpBuilder() : BaseOpBuilder("SoftmaxOpBuilder") {}
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(SoftmaxOpBuilder);

Status IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger) const override final ORT_MUST_USE_RESULT;

protected:
Status ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const override ORT_MUST_USE_RESULT;
};

static int32_t GetDefaultAxisAttribute(const std::string& op_type, int opset_version) {
if (op_type == "Softmax" || op_type == "LogSoftmax") {
// Default axis changed from 1 to -1 in opset 13.
return opset_version < 13 ? 1 : -1;
}

return 0;
}

Status SoftmaxOpBuilder::IsOpSupported(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
const logging::Logger& logger) const {
ORT_UNUSED_PARAMETER(logger);
const std::string& op_type = node_unit.OpType();

int32_t axis = GetDefaultAxisAttribute(op_type, node_unit.SinceVersion());
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, axis));
std::vector<uint32_t> input_shape;
ORT_RETURN_IF_NOT(qnn_model_wrapper.GetOnnxShape(node_unit.Inputs()[0].node_arg, input_shape),
"QNN EP: Cannot get shape for Softmax input");
ORT_RETURN_IF(axis != static_cast<int32_t>(input_shape.size() - 1),
"QNN ", op_type.c_str(), " only supports an `axis` attribute equal to input_rank-1 (or -1)");

return AddToModelBuilder(qnn_model_wrapper, node_unit, logger, true);
}

Status SoftmaxOpBuilder::ProcessAttributesAndOutputs(QnnModelWrapper& qnn_model_wrapper,
const NodeUnit& node_unit,
std::vector<std::string>&& input_names,
const logging::Logger& logger,
bool do_op_validation) const {
const std::string& op_type = node_unit.OpType();

int32_t default_axis = GetDefaultAxisAttribute(op_type, node_unit.SinceVersion());
Qnn_Scalar_t axis_qnn_scalar = QNN_SCALAR_INIT;
ORT_RETURN_IF_ERROR(ProcessAxisAttribute(qnn_model_wrapper, node_unit, axis_qnn_scalar, default_axis));
QnnParamWrapper axis_param(node_unit.Index(), node_unit.Name(), QNN_OP_SOFTMAX_PARAM_AXIS, axis_qnn_scalar);

std::vector<std::string> param_tensor_names;
param_tensor_names.push_back(axis_param.GetParamTensorName());
qnn_model_wrapper.AddParamWrapper(std::move(axis_param));

return ProcessOutputs(qnn_model_wrapper, node_unit,
std::move(input_names),
std::move(param_tensor_names),
logger, do_op_validation, GetQnnOpType(op_type));
}

void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
op_registrations.AddOpBuilder(op_type, std::make_unique<SoftmaxOpBuilder>());
}

} // namespace qnn
} // namespace onnxruntime

0 comments on commit bf83f04

Please sign in to comment.