diff --git a/src/Native/src/kernels/stackvm/reference/kernel_template.h b/src/Native/src/kernels/stackvm/reference/kernel_template.h index 0b750682ff..a77d1946b3 100644 --- a/src/Native/src/kernels/stackvm/reference/kernel_template.h +++ b/src/Native/src/kernels/stackvm/reference/kernel_template.h @@ -106,28 +106,52 @@ value_t input, value_t output, kernel_context &context) { \ try_input(input_mem, input); \ auto dtype = input_tensor->dtype(); \ - try_output_like_input(out_mem, output, input_tensor); \ + try_output_like_input(output_mem, output, input_tensor); \ + try_var(typecode, to_typecode(input_tensor->dtype())); \ if (is_contiguous(input_tensor)) { \ - try_(_name##_opt_impl(input_mem, out_mem, input_tensor->shape(), \ - input_tensor->strides(), \ - output_tensor->shape(), \ - output_tensor->strides(), context)); \ + try_(UNARY_WITH_DISPTCH(_name##_opt_impl)); \ } else { \ - try_(_name##_impl(input_mem, out_mem, input_tensor->shape(), \ - input_tensor->strides(), output_tensor->shape(), \ - output_tensor->strides(), context)); \ + try_(UNARY_WITH_DISPTCH(_name##_impl)); \ } \ return ok(output); \ } -#define FLOAT_UNARY_TEMPLATE(_name, _compute) \ - FLOAT_UNARY_IMPL_TEMPLATE(_name, _compute) \ - FLOAT_UNARY_OP_TEMPLATE(_name) - #define UNARY_TEMPLATE(_name, _compute) \ UNARY_IMPL_TEMPLATE(_name, _compute) \ + UNARY_WITH_DISPTCH_OP_TEMPLATE_V2(_name##_opt_impl) \ + UNARY_WITH_DISPTCH_OP_TEMPLATE_V2(_name##_impl) \ UNARY_OP_TEMPLATE(_name) +#define UNARY_WITH_DISPTCH(_impl_func) \ + _impl_func##_disptch(typecode, input_mem, output_mem, \ + input_tensor->shape(), input_tensor->strides(), \ + output_tensor->shape(), output_tensor->strides(), \ + context) + +#define UNARY_WITH_DISPTCH_OP_TEMPLATE_V2(_impl_func) \ + result _impl_func##_disptch( \ + typecode_t type, const gsl::byte *input, gsl::byte *output, \ + gsl::span in_shape, gsl::span in_strides, \ + gsl::span out_shape, \ + gsl::span out_strides, \ + NNCASE_UNUSED kernel_context &context) noexcept { \ + TYPE_SELECT_WITH_IMPL(type, UNARY_IMPL_FUNC_WRAPPER_V2, _impl_func); \ + } + +#define UNARY_WITH_DISPTCH_OP_TEMPLATE_V2(_impl_func) \ + result _impl_func##_disptch( \ + typecode_t type, const gsl::byte *input, gsl::byte *output, \ + gsl::span in_shape, gsl::span in_strides, \ + gsl::span out_shape, \ + gsl::span out_strides, \ + NNCASE_UNUSED kernel_context &context) noexcept { \ + TYPE_SELECT_WITH_IMPL(type, UNARY_IMPL_FUNC_WRAPPER_V2, _impl_func); \ + } + +#define UNARY_IMPL_FUNC_WRAPPER_V2(_impl_func, type) \ + return _impl_func(IN_CAST(type, input), OUT_CAST(type, output), in_shape, \ + in_strides, out_shape, out_strides, context) + #define FLOAT_UNARY_WITH_MUL_IMPL_TEMPLATE(_name, _alpha_name, _compute) \ template \ result _name##_impl( \ @@ -397,10 +421,10 @@ result nncase::kernels::stackvm::_name( \ value_t input, VALUE_ARGS_EXPAND(__VA_ARGS__), value_t output, \ kernel_context &context) { \ - try_f32_input(input_mem, input); \ + try_input(input_mem, input); \ auto dtype = input_tensor->dtype(); \ READ_FLOAT_SCALAR_EXPAND(__VA_ARGS__); \ - try_f32_output(out_mem, output, input_tensor->shape()); \ + try_output_like_input(out_mem, output, input_tensor); \ try_(_name##_impl(input_mem, out_mem, input_tensor->shape(), \ input_tensor->strides(), output_tensor->shape(), \ output_tensor->strides(), \ diff --git a/tests/kernels/test_erf.cpp b/tests/kernels/test_erf.cpp index e400d41b91..2c45358083 100644 --- a/tests/kernels/test_erf.cpp +++ b/tests/kernels/test_erf.cpp @@ -47,7 +47,7 @@ class ErfTest INSTANTIATE_TEST_SUITE_P( erf, ErfTest, - testing::Combine(testing::Values(dt_float32), + testing::Combine(testing::Values(dt_float32, dt_float16), testing::Values(dims_t{1, 3, 16, 16}, dims_t{1}, dims_t{8, 8}, dims_t{1, 4, 16}, dims_t{1, 3, 24, 24}, dims_t{}))); diff --git a/tests/kernels/test_hard_sigmoid.cpp b/tests/kernels/test_hard_sigmoid.cpp index d22c597880..f7eb1ec168 100644 --- a/tests/kernels/test_hard_sigmoid.cpp +++ b/tests/kernels/test_hard_sigmoid.cpp @@ -53,7 +53,7 @@ class HardSigmoidTest INSTANTIATE_TEST_SUITE_P( hard_sigmoid, HardSigmoidTest, - testing::Combine(testing::Values(dt_float32), + testing::Combine(testing::Values(dt_float16), testing::Values(dims_t{1, 3, 16, 16}, dims_t{1}, dims_t{1, 3}, dims_t{1, 3, 16}, dims_t{}), testing::Values(1.2f, 0.8f, 0.5f, 0.6f), @@ -63,15 +63,15 @@ TEST_P(HardSigmoidTest, hard_sigmoid) { auto l_ort = runtime_tensor_2_ort_tensor(input); // expected - float alpha_ptr[] = {alpha_value}; - auto alpha = hrt::create(nncase::dt_float32, {1}, + half alpha_ptr[] = {(half)alpha_value}; + auto alpha = hrt::create(nncase::dt_float16, {1}, {reinterpret_cast(alpha_ptr), sizeof(alpha_ptr)}, true, host_runtime_tensor::pool_cpu_only) .expect("create tensor failed"); - float gamma_ptr[] = {gamma_value}; - auto gamma = hrt::create(nncase::dt_float32, {1}, + half gamma_ptr[] = {(half)gamma_value}; + auto gamma = hrt::create(nncase::dt_float16, {1}, {reinterpret_cast(gamma_ptr), sizeof(gamma_ptr)}, true, host_runtime_tensor::pool_cpu_only) diff --git a/tests/kernels/test_hard_swish.cpp b/tests/kernels/test_hard_swish.cpp index 740304c2d8..5fc73c9b6f 100644 --- a/tests/kernels/test_hard_swish.cpp +++ b/tests/kernels/test_hard_swish.cpp @@ -46,32 +46,15 @@ class HardSwishTest INSTANTIATE_TEST_SUITE_P( hard_swish, HardSwishTest, - testing::Combine(testing::Values(dt_float32), + testing::Combine(testing::Values(dt_float16), testing::Values(dims_t{1, 3, 16, 16}, dims_t{1, 2}, dims_t{1}, dims_t{16, 16}, dims_t{}))); TEST_P(HardSwishTest, hard_swish) { auto l_ort = runtime_tensor_2_ort_tensor(input); - auto alpha_value = 1.0f / 6.0f; - auto beta_value = 0.5f; // expected - float alpha_ptr[] = {alpha_value}; - auto alpha = hrt::create(nncase::dt_float32, {1}, - {reinterpret_cast(alpha_ptr), - sizeof(alpha_ptr)}, - true, host_runtime_tensor::pool_cpu_only) - .expect("create tensor failed"); - - float beta_ptr[] = {beta_value}; - auto beta = - hrt::create(nncase::dt_float32, {1}, - {reinterpret_cast(beta_ptr), sizeof(beta_ptr)}, - true, host_runtime_tensor::pool_cpu_only) - .expect("create tensor failed"); - - auto output_ort = - ortki_Mul(l_ort, ortki_HardSigmoid(l_ort, alpha_value, beta_value)); + auto output_ort = ortki_HardSwish(l_ort); size_t size = 0; void *ptr_ort = tensor_buffer(output_ort, &size); dims_t shape(tensor_rank(output_ort)); @@ -90,6 +73,8 @@ TEST_P(HardSwishTest, hard_swish) { cosine_similarity_tensor(expected, actual); if (!result) { + std::cout << "input "; + print_runtime_tensor(input); std::cout << "actual "; print_runtime_tensor(actual); std::cout << "expected "; diff --git a/tests/kernels/test_relu.cpp b/tests/kernels/test_relu.cpp index 0a8a8dcdeb..f16ce4e32f 100644 --- a/tests/kernels/test_relu.cpp +++ b/tests/kernels/test_relu.cpp @@ -47,10 +47,10 @@ class ReluTest INSTANTIATE_TEST_SUITE_P( Relu, ReluTest, - testing::Combine(testing::Values(dt_float32, dt_int32), - testing::Values(dims_t{1, 3, 16, 16}, dims_t{1}, - dims_t{8, 8}, dims_t{1, 4, 16}, - dims_t{1, 3, 24, 24}, dims_t{}))); + testing::Combine( + testing::Values(dt_float32, dt_int32, dt_float16, dt_float64), + testing::Values(dims_t{1, 3, 16, 16}, dims_t{1}, dims_t{8, 8}, + dims_t{1, 4, 16}, dims_t{1, 3, 24, 24}, dims_t{}))); TEST_P(ReluTest, Relu) { auto l_ort = runtime_tensor_2_ort_tensor(input); diff --git a/tests/kernels/test_sigmoid.cpp b/tests/kernels/test_sigmoid.cpp index a2ec67aaf0..1f5d4052f7 100644 --- a/tests/kernels/test_sigmoid.cpp +++ b/tests/kernels/test_sigmoid.cpp @@ -47,7 +47,7 @@ class SigmoidTest INSTANTIATE_TEST_SUITE_P( Sigmoid, SigmoidTest, - testing::Combine(testing::Values(dt_float32), + testing::Combine(testing::Values(dt_float16, dt_float64, dt_float32), testing::Values(dims_t{1}, dims_t{1, 3}, dims_t{1, 3, 16, 16}, dims_t{1, 3, 16}, dims_t{16, 16}, dims_t{}))); diff --git a/tests/kernels/test_softplus.cpp b/tests/kernels/test_softplus.cpp index 6f43ebf672..77ecaa36b7 100644 --- a/tests/kernels/test_softplus.cpp +++ b/tests/kernels/test_softplus.cpp @@ -47,7 +47,7 @@ class SoftplusTest INSTANTIATE_TEST_SUITE_P( Softplus, SoftplusTest, - testing::Combine(testing::Values(dt_float32), + testing::Combine(testing::Values(dt_float32, dt_float16, dt_float64), testing::Values(dims_t{1}, dims_t{1, 3}, dims_t{1, 3, 16, 16}, dims_t{1, 3, 16}, dims_t{}))); diff --git a/tests/kernels/test_softsign.cpp b/tests/kernels/test_softsign.cpp index 9c41331a7c..5160358e79 100644 --- a/tests/kernels/test_softsign.cpp +++ b/tests/kernels/test_softsign.cpp @@ -47,7 +47,7 @@ class SoftsignTest INSTANTIATE_TEST_SUITE_P( Softsign, SoftsignTest, - testing::Combine(testing::Values(dt_float32), + testing::Combine(testing::Values(dt_float32, dt_float16, dt_float64), testing::Values(dims_t{1}, dims_t{1, 3}, dims_t{1, 3, 16, 16}, dims_t{1, 3, 16}, dims_t{}))); diff --git a/tests/kernels/test_swish.cpp b/tests/kernels/test_swish.cpp index 10f9a4bee4..ee237934e1 100644 --- a/tests/kernels/test_swish.cpp +++ b/tests/kernels/test_swish.cpp @@ -47,7 +47,7 @@ class SwishTest INSTANTIATE_TEST_SUITE_P( Swish, SwishTest, - testing::Combine(testing::Values(dt_float32), + testing::Combine(testing::Values(dt_float32, dt_float16, dt_float64), testing::Values(dims_t{1}, dims_t{1, 3}, dims_t{1, 3, 16, 16}, dims_t{1, 3, 16}, dims_t{})));