Skip to content

Commit

Permalink
bf16 support
Browse files Browse the repository at this point in the history
  • Loading branch information
hejunchao committed Aug 24, 2023
1 parent bc31b26 commit 0d6125f
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 5 deletions.
2 changes: 2 additions & 0 deletions src/Native/include/nncase/runtime/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ struct bfloat16 {
explicit bfloat16(const T &val) noexcept
: bfloat16(static_cast<float>(val)) {}

bfloat16(int &&val) noexcept : bfloat16(static_cast<float>(val)) {}

constexpr bfloat16(from_raw_t, uint16_t value) noexcept : value_(value) {}

operator float() const noexcept {
Expand Down
2 changes: 2 additions & 0 deletions src/Native/include/nncase/runtime/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,8 @@ inline bool is_contiguous(tensor tensor) {
_impl(float); \
case dt_float16: \
_impl(half); \
case dt_bfloat16: \
_impl(bfloat16); \
case dt_int8: \
_impl(int8_t); \
case dt_int16: \
Expand Down
13 changes: 8 additions & 5 deletions src/Native/src/kernels/stackvm/optimized/resize_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ result<void> resize_bilinear_impl(
auto *begin_output_ptr = output + batch * in_shape[1] * out_w * out_h;
#ifdef NNCASE_OPENMP
#pragma omp parallel for num_threads( \
kernels::default_kernel_context().num_threads)
kernels::default_kernel_context().num_threads)
#endif
for (int oc = 0; oc < in_shape[1]; oc++) {
auto in_c = in_batch + (size_t)oc * in_img_size;
Expand Down Expand Up @@ -103,7 +103,7 @@ result<void> resize_nearest_neighbor_impl(
auto *begin_output_ptr = output + batch * in_shape[1] * out_image_size;
#ifdef NNCASE_OPENMP
#pragma omp parallel for num_threads( \
kernels::default_kernel_context().num_threads)
kernels::default_kernel_context().num_threads)
#endif
for (int oc = 0; oc < in_shape[1]; oc++) {
auto *input_ptr = begin_input_ptr + oc * in_image_size;
Expand Down Expand Up @@ -148,7 +148,7 @@ inline result<void> gnne_resize_nearest_neighbor(
auto *begin_output_ptr = output + batch * in_shape[1] * out_image_size;
#ifdef NNCASE_OPENMP
#pragma omp parallel for num_threads( \
kernels::default_kernel_context().num_threads)
kernels::default_kernel_context().num_threads)
#endif
for (int oc = 0; oc < in_shape[1]; oc++) {
auto *input_ptr = begin_input_ptr + oc * in_image_size;
Expand Down Expand Up @@ -195,7 +195,7 @@ inline result<void> resize_bilinear_impl(
auto *begin_output_ptr = output + batch * in_shape[1] * out_w * out_h;
#ifdef NNCASE_OPENMP
#pragma omp parallel for num_threads( \
kernels::default_kernel_context().num_threads)
kernels::default_kernel_context().num_threads)
#endif
for (int oc = 0; oc < in_shape[1]; oc++) {
auto in_c = in_batch + (size_t)oc * in_img_size;
Expand All @@ -221,7 +221,10 @@ inline result<void> resize_bilinear_impl(
auto a3 = (in_y - in_y0) * (in_x - in_x0);

*output_ptr = bfloat16::round_to_bfloat16(
v0 * a0 + v1 * a1 + v2 * a2 + v3 * a3);
static_cast<float>(v0) * a0 +
static_cast<float>(v1) * a1 +
static_cast<float>(v2) * a2 +
static_cast<float>(v3) * a3);
++output_ptr;
}
}
Expand Down

0 comments on commit 0d6125f

Please sign in to comment.