Skip to content

Commit

Permalink
port add ops, create new 3p buck targets, add op_add kernel modification
Browse files Browse the repository at this point in the history
Differential Revision: D65300260

Pull Request resolved: #6601
  • Loading branch information
zonglinpeng authored Nov 22, 2024
1 parent fc42a4e commit 82f2333
Show file tree
Hide file tree
Showing 16 changed files with 458 additions and 153 deletions.
4 changes: 2 additions & 2 deletions backends/cadence/aot/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ def quantize_and_export_to_cadence(
def export_to_executorch_gen_etrecord(
model: torch.nn.Module,
inputs: tuple[object, ...],
dump_graphs: bool = False,
output_dir: Optional[str] = None,
opt_level: int = 1,
dump_graphs: bool = False,
) -> ExecutorchProgramManager:
edge_prog_manager = export_to_edge(model, inputs)
cadence_passes = get_cadence_passes(opt_level)
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)

# Run a couple required passes for quant/dequant ops
cadence_prog_manager = edge_prog_manager.transform(
Expand Down
40 changes: 22 additions & 18 deletions backends/cadence/hifi/operators/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
Expand Down Expand Up @@ -121,7 +123,7 @@ Tensor& add_out(
float alpha_val;
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);

constexpr auto name = "add.out";
static constexpr const char op_name[] = "add.out";
constexpr int kNnlibMaxDim = 4; /*fallback if broadcast and dim > 4 */

int a_dim = a.dim(), b_dim = b.dim(), out_dim = out.dim();
Expand Down Expand Up @@ -178,23 +180,25 @@ Tensor& add_out(
return out;
}

ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
CTYPE_IN alpha_val;
torch::executor::native::utils::extract_scalar(alpha, &alpha_val);

ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
AddInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, alpha_val, out);
});
});
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
const CTYPE_COMPUTE val_alpha =
torch::executor::native::utils::scalar_to<CTYPE_COMPUTE>(alpha);
torch::executor::native::utils::
apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[val_alpha](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a + val_alpha * val_b;
},
ctx,
a,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
b,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
out,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16);
});

return out;
Expand Down
113 changes: 70 additions & 43 deletions backends/cadence/hifi/operators/op_div.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
#include <executorch/kernels/portable/cpu/util/math_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
Expand Down Expand Up @@ -134,25 +136,26 @@ div_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
InvalidArgument,
out);

ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.out", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out", CTYPE_B, [&]() {
ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div.out", CTYPE_IN, [&]() {
ET_SWITCH_FLOAT_TYPES(out_type, ctx, "div.out", CTYPE_OUT, [&]() {
torch::executor::
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted / b_casted;

return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
});
});
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "div.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::
apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a / val_b;
},
ctx,
a,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
b,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
out,
torch::executor::native::utils::SupportedTensorDtypes::FLOATHBF16);
});

return out;
Expand Down Expand Up @@ -254,35 +257,59 @@ Tensor& div_out_mode(
return out;
}

ET_SWITCH_REAL_TYPES_AND(Bool, a_type, ctx, "div.out_mode", CTYPE_A, [&]() {
ET_SWITCH_REAL_TYPES_AND(Bool, b_type, ctx, "div.out_mode", CTYPE_B, [&]() {
ET_SWITCH_FLOAT_TYPES(common_type, ctx, "div.out_mode", CTYPE_IN, [&]() {
ET_SWITCH_REAL_TYPES(out_type, ctx, "div.out_mode", CTYPE_OUT, [&]() {
torch::executor::
apply_binary_elementwise_fn<CTYPE_A, CTYPE_B, CTYPE_OUT>(
[mode](const CTYPE_A val_a, const CTYPE_B val_b) {
CTYPE_IN a_casted = static_cast<CTYPE_IN>(val_a);
CTYPE_IN b_casted = static_cast<CTYPE_IN>(val_b);
CTYPE_IN value = a_casted / b_casted;
if (mode.has_value() && mode.value() == "trunc") {
value = std::trunc(value);
} else if (mode.has_value() && mode.value() == "floor") {
value = std::floor(value);
}
return static_cast<CTYPE_OUT>(value);
},
a,
b,
out);
});
});
});
bool div_by_zero_error = false;
const bool mode_is_trunc = (mode.has_value() && mode.value() == "trunc");
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "div.out";

ET_SWITCH_REAL_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::
apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[mode_is_trunc, &div_by_zero_error](
const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
if (executorch::runtime::is_integral_type<
CTYPE_COMPUTE,
/*includeBool=*/true>::value) {
if (val_b == 0) {
div_by_zero_error = true;
return static_cast<CTYPE_COMPUTE>(0);
}
}
CTYPE_COMPUTE value = val_a / val_b;
if (mode_is_trunc) {
value = std::trunc(value);
} else {
// We established above that the mode is either trunc or floor,
// so it must be floor.
value =
torch::executor::native::utils::floor_divide(val_a, val_b);
}
return value;
},
ctx,
a,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
b,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
out,
torch::executor::native::utils::SupportedTensorDtypes::REALHBF16);
});

ET_KERNEL_CHECK_MSG(
ctx,
!div_by_zero_error,
InvalidArgument,
out,
"Div mode operation encountered integer division by zero");

return out;
}

} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
} // namespace cadence
9 changes: 5 additions & 4 deletions backends/cadence/hifi/operators/op_mean.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/kernels/portable/cpu/util/kernel_ops_util.h>
#include <executorch/kernels/portable/cpu/util/reduce_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
Expand Down Expand Up @@ -141,11 +142,11 @@ Tensor& mean_dim_out(
return out;
}

ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, name, CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, name, CTYPE_OUT, [&] {
ET_SWITCH_REALHB_TYPES(in.scalar_type(), ctx, "mean.out", CTYPE_IN, [&] {
ET_SWITCH_FLOATH_TYPES(out.scalar_type(), ctx, "mean.out", CTYPE_OUT, [&] {
CTYPE_OUT* out_data = out.mutable_data_ptr<CTYPE_OUT>();
const size_t num = torch::executor::get_reduced_dim_product(in, dim_list);

const size_t num =
torch::executor::exeget_reduced_dim_product(in, dim_list);
for (size_t out_ix = 0; out_ix < out.numel(); ++out_ix) {
CTYPE_OUT sum = 0;
if (in.numel() > 0) {
Expand Down
38 changes: 23 additions & 15 deletions backends/cadence/hifi/operators/op_mul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/kernels/portable/cpu/scalar_utils.h>
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>
Expand Down Expand Up @@ -144,20 +146,26 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
return out;
}

ET_SWITCH_REALHB_TYPES(a_type, ctx, "mul.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "mul.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "mul.out", CTYPE_OUT, [&]() {
MulInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
CTYPE_B,
CTYPE_IN,
CTYPE_OUT>::run(a, b, out);
});
});
// Compute Dtype
ScalarType compute_type =
torch::executor::native::utils::get_compute_type(common_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "mul.Scalar_out";

ET_SWITCH_REALB_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::
apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_a, const CTYPE_COMPUTE val_b) {
return val_a * val_b;
},
ctx,
a,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
b,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
out,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16);
});

return out;
Expand All @@ -166,4 +174,4 @@ mul_out(RuntimeContext& ctx, const Tensor& a, const Tensor& b, Tensor& out) {
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
} // namespace cadence
38 changes: 24 additions & 14 deletions backends/cadence/hifi/operators/op_sigmoid.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <cmath>

#include <executorch/backends/cadence/hifi/kernels/kernels.h>
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
#include <executorch/kernels/portable/cpu/util/functional_util.h>
#include <executorch/runtime/kernel/kernel_includes.h>

Expand Down Expand Up @@ -58,19 +60,27 @@ Tensor& sigmoid_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
return out;
}

ET_SWITCH_REALHB_TYPES(in_type, ctx, "sigmoid.out", CTYPE_IN, [&]() {
ET_SWITCH_FLOATH_TYPES(out_type, ctx, "sigmoid.out", CTYPE_OUT, [&]() {
torch::executor::apply_unary_map_fn(
[](const CTYPE_IN val_in) {
// perform math in double to preserve precision
double in_casted = static_cast<double>(val_in);
double out_val = 1.0 / (1.0 + exp(-in_casted));
return static_cast<CTYPE_OUT>(out_val);
},
in.const_data_ptr<CTYPE_IN>(),
out.mutable_data_ptr<CTYPE_OUT>(),
in.numel());
});
ScalarType compute_type =
executorch::runtime::isFloatingType(in.scalar_type()) ? in.scalar_type()
: ScalarType::Float;
compute_type = torch::executor::native::utils::get_compute_type(compute_type);

// @lint-ignore CLANGTIDY facebook-hte-CArray
static constexpr const char op_name[] = "sigmoid.out";

ET_SWITCH_FLOAT_TYPES(compute_type, ctx, op_name, CTYPE_COMPUTE, [&]() {
torch::executor::native::utils::
apply_unitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
[](const CTYPE_COMPUTE val_in) {
CTYPE_COMPUTE out_val = static_cast<CTYPE_COMPUTE>(1.0) /
(static_cast<CTYPE_COMPUTE>(1.0) + exp(-val_in));
return out_val;
},
ctx,
in,
torch::executor::native::utils::SupportedTensorDtypes::REALHBBF16,
out,
torch::executor::native::utils::SupportedTensorDtypes::FLOATHBF16);
});

return out;
Expand All @@ -79,4 +89,4 @@ Tensor& sigmoid_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) {
} // namespace native
} // namespace HiFi
} // namespace impl
} // namespace cadence
} // namespace cadence
Loading

0 comments on commit 82f2333

Please sign in to comment.