Skip to content

Commit

Permalink
GNNE-1904 clamp f16&&bf16 dt input support (#1035)
Browse files Browse the repository at this point in the history
* fix clamp.cpp

* fix

* fix

* Support half datatype in clamp and matmul test

* fix float

* Apply code-format changes

* add new kernel test

* fix BucketPad's test

* Apply code-format changes

* Refactor activate unary with mul template macro

* Apply code-format changes

* Fix bool type instance error

* Apply code-format changes

* support bf16

* fix

* merge conflict

* bf16 support

* Apply code-format changes

---------

Co-authored-by: hejunchao <[email protected]>
Co-authored-by: HeJunchao <[email protected]>
Co-authored-by: lerenhua <[email protected]>
Co-authored-by: Hejunchao6 <[email protected]>
Co-authored-by: lerenhua <[email protected]>
Co-authored-by: HeJunchao100813 <[email protected]>
  • Loading branch information
7 people authored Aug 25, 2023
1 parent 8991ced commit f5b79ef
Show file tree
Hide file tree
Showing 63 changed files with 1,345 additions and 1,075 deletions.
6 changes: 6 additions & 0 deletions src/Native/include/nncase/runtime/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ struct bfloat16 {
explicit bfloat16(const T &val) noexcept
: bfloat16(static_cast<float>(val)) {}

bfloat16(int &&val) noexcept : bfloat16(static_cast<float>(val)) {}

constexpr bfloat16(from_raw_t, uint16_t value) noexcept : value_(value) {}

operator float() const noexcept {
Expand Down Expand Up @@ -153,6 +155,10 @@ struct bfloat16 {
return (value_ & 0x7FFF) == ZERO_VALUE;
}

void operator=(const float &v) noexcept {
value_ = (round_to_bfloat16(v).value_);
}

private:
uint16_t value_;
};
Expand Down
6 changes: 6 additions & 0 deletions src/Native/include/nncase/runtime/half.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct half {
std::is_floating_point<T>::value>>
explicit half(const T &val) noexcept : half(static_cast<float>(val)) {}

half(int &&val) noexcept : half(static_cast<float>(val)) {}

constexpr half(fp16_from_raw_t, uint16_t value) noexcept : value_(value) {}

operator float() const noexcept {
Expand Down Expand Up @@ -156,6 +158,10 @@ struct half {
return (value_ & 0x7FFF) == ZERO_VALUE;
}

void operator=(const float &v) noexcept {
value_ = (round_to_half(v).value_);
}

private:
uint16_t value_;
};
Expand Down
4 changes: 4 additions & 0 deletions src/Native/include/nncase/runtime/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,10 @@ inline bool is_contiguous(tensor tensor) {
switch (_typecode) { \
case dt_float32: \
_impl(float); \
case dt_float16: \
_impl(half); \
case dt_bfloat16: \
_impl(bfloat16); \
case dt_int8: \
_impl(int8_t); \
case dt_int16: \
Expand Down
5 changes: 4 additions & 1 deletion src/Native/src/kernels/stackvm/optimized/resize_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,10 @@ inline result<void> resize_bilinear_impl(
auto a3 = (in_y - in_y0) * (in_x - in_x0);

*output_ptr = bfloat16::round_to_bfloat16(
v0 * a0 + v1 * a1 + v2 * a2 + v3 * a3);
static_cast<float>(v0) * a0 +
static_cast<float>(v1) * a1 +
static_cast<float>(v2) * a2 +
static_cast<float>(v3) * a3);
++output_ptr;
}
}
Expand Down
52 changes: 26 additions & 26 deletions src/Native/src/kernels/stackvm/reference/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,30 +29,30 @@ using namespace nncase::runtime::stackvm;
using namespace nncase::kernels;
using namespace nncase::kernels::stackvm;

FLOAT_UNARY_TEMPLATE(relu, std::max((float)0, x))
FLOAT_UNARY_TEMPLATE(softsign, x / (1 + std::abs(x)))
FLOAT_UNARY_TEMPLATE(softplus, std::log(1 + std::exp(x)))
FLOAT_UNARY_TEMPLATE(sigmoid, 1 / (1 + exp(-x)))
FLOAT_UNARY_TEMPLATE(swish, x / (1 + exp(-x)))
FLOAT_UNARY_TEMPLATE(hard_swish,
x *std::max(0.f, std::min((float)1.f,
(float)(1.f / 6 * x + 0.5))))
FLOAT_UNARY_TEMPLATE(erf, erff(x)) // for k510 toolchain
FLOAT_UNARY_WITH_MUL_TEMPLATE(elu, alpha, x < 0 ? alpha * (exp(x) - 1) : x)
UNARY_TEMPLATE(relu, std::max((double)0, x))
UNARY_TEMPLATE(softsign, x / (1 + std::abs(x)))
UNARY_TEMPLATE(softplus, std::log(1 + std::exp(x)))
UNARY_TEMPLATE(sigmoid, 1 / (1 + exp(-x)))
UNARY_TEMPLATE(swish, x / (1 + exp(-x)))
UNARY_TEMPLATE(hard_swish,
x *std::max((double)0.f,
std::min((double)1.f, (double)(1.f / 6 * x + 0.5))))
UNARY_TEMPLATE(erf, erff(x)) // for k510 toolchain
UNARY_WITH_MUL_TEMPLATE_V2(elu, alpha, x < 0 ? alpha * (exp(x) - 1) : x)
// FLOAT_UNARY_WITH_MUL_TEMPLATE(prelu, slope, x < 0 ? slope * x : x)
FLOAT_UNARY_WITH_MUL_TEMPLATE(
celu, alpha,
std::max((float)0, x) +
std::min((float)0, (float)(alpha *(exp(x / alpha) - 1))))
FLOAT_UNARY_WITH_MUL_TEMPLATE(leaky_relu, alpha, x < 0 ? alpha * x : x)
FLOAT_UNARY_WITH_MUL_TEMPLATE(gelu, alpha,
0.5f * (alpha * x) *
(1.f + erff(alpha * x / sqrtf(2.f))))
FLOAT_ACTIVATION_TEMPLATE(selu,
x <= 0 ? gamma * (alpha * std::exp(x) - alpha)
: x * gamma,
alpha, gamma)
FLOAT_ACTIVATION_TEMPLATE(hard_sigmoid,
std::max((float)0,
std::min((float)1, x *alpha + beta)),
alpha, beta)
UNARY_WITH_MUL_TEMPLATE_V2(celu, alpha,
std::max((double)0, x) +
std::min((double)0,
(double)(alpha *(exp(x / alpha) - 1))))
UNARY_WITH_MUL_TEMPLATE_V2(leaky_relu, alpha, x < 0 ? alpha * x : x)
UNARY_WITH_MUL_TEMPLATE_V2(gelu, alpha,
0.5f * (alpha * x) *
(1.f + erff(alpha * x / sqrtf(2.f))))
ACTIVATION_TEMPLATE_V2(selu,
x <= 0 ? gamma * (alpha * std::exp(x) - alpha)
: x * gamma,
alpha, gamma)
ACTIVATION_TEMPLATE_V2(hard_sigmoid,
std::max((double)0,
std::min((double)1, x *alpha + gamma)),
alpha, gamma)
4 changes: 3 additions & 1 deletion src/Native/src/kernels/stackvm/reference/clamp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ result<void> clamp_impl(const T *input, T min, T max, T *output,
NNCASE_UNUSED kernel_context &context) {
return apply(in_shape, [&](gsl::span<const size_t> index) -> result<void> {
const auto v = input[offset(index, in_strides)];
output[offset(index, out_strides)] = std::min(std::max(v, min), max);
output[offset(index, out_strides)] = static_cast<T>(
std::min(std::max(static_cast<float>(v), static_cast<float>(min)),
static_cast<float>(max)));
return ok();
});
}
Expand Down
Loading

0 comments on commit f5b79ef

Please sign in to comment.