diff --git a/neon2rvv.h b/neon2rvv.h index 271f33cc..e6d30cce 100644 --- a/neon2rvv.h +++ b/neon2rvv.h @@ -10713,13 +10713,25 @@ FORCE_INLINE uint64x2_t vmlal_n_u32(uint64x2_t a, uint32x2_t b, uint32_t c) { return __riscv_vlmul_trunc_v_u64m2_u64m1(__riscv_vwmaccu_vx_u64m2(__riscv_vlmul_ext_v_u64m1_u64m2(a), c, b, 2)); } -// FORCE_INLINE int32x4_t vmlal_high_n_s16(int32x4_t a, int16x8_t b, int16_t c); +FORCE_INLINE int32x4_t vmlal_high_n_s16(int32x4_t a, int16x8_t b, int16_t c) { + vint16m1_t b_high = __riscv_vslidedown_vx_i16m1(b, 4, 8); + return __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmacc_vx_i32m2(__riscv_vlmul_ext_v_i32m1_i32m2(a), c, b_high, 4)); +} -// FORCE_INLINE int64x2_t vmlal_high_n_s32(int64x2_t a, int32x4_t b, int32_t c); +FORCE_INLINE int64x2_t vmlal_high_n_s32(int64x2_t a, int32x4_t b, int32_t c) { + vint32m1_t b_high = __riscv_vslidedown_vx_i32m1(b, 2, 4); + return __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmacc_vx_i64m2(__riscv_vlmul_ext_v_i64m1_i64m2(a), c, b_high, 2)); +} -// FORCE_INLINE uint32x4_t vmlal_high_n_u16(uint32x4_t a, uint16x8_t b, uint16_t c); +FORCE_INLINE uint32x4_t vmlal_high_n_u16(uint32x4_t a, uint16x8_t b, uint16_t c) { + vuint16m1_t b_high = __riscv_vslidedown_vx_u16m1(b, 4, 8); + return __riscv_vlmul_trunc_v_u32m2_u32m1(__riscv_vwmaccu_vx_u32m2(__riscv_vlmul_ext_v_u32m1_u32m2(a), c, b_high, 4)); +} -// FORCE_INLINE uint64x2_t vmlal_high_n_u32(uint64x2_t a, uint32x4_t b, uint32_t c); +FORCE_INLINE uint64x2_t vmlal_high_n_u32(uint64x2_t a, uint32x4_t b, uint32_t c) { + vuint32m1_t b_high = __riscv_vslidedown_vx_u32m1(b, 2, 4); + return __riscv_vlmul_trunc_v_u64m2_u64m1(__riscv_vwmaccu_vx_u64m2(__riscv_vlmul_ext_v_u64m1_u64m2(a), c, b_high, 2)); +} FORCE_INLINE int32x4_t vqdmlal_n_s16(int32x4_t a, int16x4_t b, int16_t c) { vint16m1_t c_dup = vdup_n_s16(c); diff --git a/tests/impl.cpp b/tests/impl.cpp index dc364415..8b1acbdb 100644 --- a/tests/impl.cpp +++ b/tests/impl.cpp @@ -34306,13 +34306,81 @@ result_t test_vmlal_n_u32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #endif // ENABLE_TEST_ALL } -result_t test_vmlal_high_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vmlal_high_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1; + const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2; + const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3; + int32_t _d[4]; + for (int i = 0; i < 4; i++) { + _d[i] = _a[i] + (int32_t)_b[i + 4] * (int32_t)_c[0]; + } + + int32x4_t a = vld1q_s32(_a); + int16x8_t b = vld1q_s16(_b); + int32x4_t d = vmlal_high_n_s16(a, b, _c[0]); + return validate_int32(d, _d[0], _d[1], _d[2], _d[3]); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} + +result_t test_vmlal_high_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1; + const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2; + const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3; + int64_t _d[2]; + for (int i = 0; i < 2; i++) { + _d[i] = _a[i] + (int64_t)_b[i + 2] * (int64_t)_c[0]; + } + + int64x2_t a = vld1q_s64(_a); + int32x4_t b = vld1q_s32(_b); + int64x2_t d = vmlal_high_n_s32(a, b, _c[0]); + return validate_int64(d, _d[0], _d[1]); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} -result_t test_vmlal_high_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vmlal_high_n_u16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const uint32_t *_a = (uint32_t *)impl.test_cases_int_pointer1; + const uint16_t *_b = (uint16_t *)impl.test_cases_int_pointer2; + const uint16_t *_c = (uint16_t *)impl.test_cases_int_pointer3; + uint32_t _d[4]; + for (int i = 0; i < 4; i++) { + _d[i] = _a[i] + (uint32_t)_b[i + 4] * (uint32_t)_c[0]; + } + + uint32x4_t a = vld1q_u32(_a); + uint16x8_t b = vld1q_u16(_b); + uint32x4_t d = vmlal_high_n_u16(a, b, _c[0]); + return validate_uint32(d, _d[0], _d[1], _d[2], _d[3]); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} -result_t test_vmlal_high_n_u16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } +result_t test_vmlal_high_n_u32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { +#ifdef ENABLE_TEST_ALL + const uint64_t *_a = (uint64_t *)impl.test_cases_int_pointer1; + const uint32_t *_b = (uint32_t *)impl.test_cases_int_pointer2; + const uint32_t *_c = (uint32_t *)impl.test_cases_int_pointer3; + uint64_t _d[2]; + for (int i = 0; i < 2; i++) { + _d[i] = _a[i] + (uint64_t)_b[i + 2] * (uint64_t)_c[0]; + } -result_t test_vmlal_high_n_u32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; } + uint64x2_t a = vld1q_u64(_a); + uint32x4_t b = vld1q_u32(_b); + uint64x2_t d = vmlal_high_n_u32(a, b, _c[0]); + return validate_uint64(d, _d[0], _d[1]); +#else + return TEST_UNIMPL; +#endif // ENABLE_TEST_ALL +} result_t test_vqdmlal_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { #ifdef ENABLE_TEST_ALL diff --git a/tests/impl.h b/tests/impl.h index cebd996f..ca10ed8a 100644 --- a/tests/impl.h +++ b/tests/impl.h @@ -2221,10 +2221,10 @@ _(vmlal_n_s32) \ _(vmlal_n_u16) \ _(vmlal_n_u32) \ - /*_(vmlal_high_n_s16) */ \ - /*_(vmlal_high_n_s32) */ \ - /*_(vmlal_high_n_u16) */ \ - /*_(vmlal_high_n_u32) */ \ + _(vmlal_high_n_s16) \ + _(vmlal_high_n_s32) \ + _(vmlal_high_n_u16) \ + _(vmlal_high_n_u32) \ _(vqdmlal_n_s16) \ _(vqdmlal_n_s32) \ /*_(vqdmlal_high_n_s16) */ \