Skip to content

Commit

Permalink
Merge pull request #433 from howjmay/vmlal_high_n
Browse files Browse the repository at this point in the history
feat: Add vmlal_high_n_[s16|s32|u16|u32]
  • Loading branch information
howjmay authored Jul 23, 2024
2 parents 303ce4e + 5c72bac commit 245014e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 12 deletions.
20 changes: 16 additions & 4 deletions neon2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -10713,13 +10713,25 @@ FORCE_INLINE uint64x2_t vmlal_n_u32(uint64x2_t a, uint32x2_t b, uint32_t c) {
return __riscv_vlmul_trunc_v_u64m2_u64m1(__riscv_vwmaccu_vx_u64m2(__riscv_vlmul_ext_v_u64m1_u64m2(a), c, b, 2));
}

// FORCE_INLINE int32x4_t vmlal_high_n_s16(int32x4_t a, int16x8_t b, int16_t c);
FORCE_INLINE int32x4_t vmlal_high_n_s16(int32x4_t a, int16x8_t b, int16_t c) {
vint16m1_t b_high = __riscv_vslidedown_vx_i16m1(b, 4, 8);
return __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmacc_vx_i32m2(__riscv_vlmul_ext_v_i32m1_i32m2(a), c, b_high, 4));
}

// FORCE_INLINE int64x2_t vmlal_high_n_s32(int64x2_t a, int32x4_t b, int32_t c);
FORCE_INLINE int64x2_t vmlal_high_n_s32(int64x2_t a, int32x4_t b, int32_t c) {
vint32m1_t b_high = __riscv_vslidedown_vx_i32m1(b, 2, 4);
return __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmacc_vx_i64m2(__riscv_vlmul_ext_v_i64m1_i64m2(a), c, b_high, 2));
}

// FORCE_INLINE uint32x4_t vmlal_high_n_u16(uint32x4_t a, uint16x8_t b, uint16_t c);
FORCE_INLINE uint32x4_t vmlal_high_n_u16(uint32x4_t a, uint16x8_t b, uint16_t c) {
vuint16m1_t b_high = __riscv_vslidedown_vx_u16m1(b, 4, 8);
return __riscv_vlmul_trunc_v_u32m2_u32m1(__riscv_vwmaccu_vx_u32m2(__riscv_vlmul_ext_v_u32m1_u32m2(a), c, b_high, 4));
}

// FORCE_INLINE uint64x2_t vmlal_high_n_u32(uint64x2_t a, uint32x4_t b, uint32_t c);
FORCE_INLINE uint64x2_t vmlal_high_n_u32(uint64x2_t a, uint32x4_t b, uint32_t c) {
vuint32m1_t b_high = __riscv_vslidedown_vx_u32m1(b, 2, 4);
return __riscv_vlmul_trunc_v_u64m2_u64m1(__riscv_vwmaccu_vx_u64m2(__riscv_vlmul_ext_v_u64m1_u64m2(a), c, b_high, 2));
}

FORCE_INLINE int32x4_t vqdmlal_n_s16(int32x4_t a, int16x4_t b, int16_t c) {
vint16m1_t c_dup = vdup_n_s16(c);
Expand Down
76 changes: 72 additions & 4 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34306,13 +34306,81 @@ result_t test_vmlal_n_u32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#endif // ENABLE_TEST_ALL
}

result_t test_vmlal_high_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
result_t test_vmlal_high_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1;
const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2;
const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3;
int32_t _d[4];
for (int i = 0; i < 4; i++) {
_d[i] = _a[i] + (int32_t)_b[i + 4] * (int32_t)_c[0];
}

int32x4_t a = vld1q_s32(_a);
int16x8_t b = vld1q_s16(_b);
int32x4_t d = vmlal_high_n_s16(a, b, _c[0]);
return validate_int32(d, _d[0], _d[1], _d[2], _d[3]);
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vmlal_high_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1;
const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2;
const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3;
int64_t _d[2];
for (int i = 0; i < 2; i++) {
_d[i] = _a[i] + (int64_t)_b[i + 2] * (int64_t)_c[0];
}

int64x2_t a = vld1q_s64(_a);
int32x4_t b = vld1q_s32(_b);
int64x2_t d = vmlal_high_n_s32(a, b, _c[0]);
return validate_int64(d, _d[0], _d[1]);
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vmlal_high_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
result_t test_vmlal_high_n_u16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const uint32_t *_a = (uint32_t *)impl.test_cases_int_pointer1;
const uint16_t *_b = (uint16_t *)impl.test_cases_int_pointer2;
const uint16_t *_c = (uint16_t *)impl.test_cases_int_pointer3;
uint32_t _d[4];
for (int i = 0; i < 4; i++) {
_d[i] = _a[i] + (uint32_t)_b[i + 4] * (uint32_t)_c[0];
}

uint32x4_t a = vld1q_u32(_a);
uint16x8_t b = vld1q_u16(_b);
uint32x4_t d = vmlal_high_n_u16(a, b, _c[0]);
return validate_uint32(d, _d[0], _d[1], _d[2], _d[3]);
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vmlal_high_n_u16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
result_t test_vmlal_high_n_u32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const uint64_t *_a = (uint64_t *)impl.test_cases_int_pointer1;
const uint32_t *_b = (uint32_t *)impl.test_cases_int_pointer2;
const uint32_t *_c = (uint32_t *)impl.test_cases_int_pointer3;
uint64_t _d[2];
for (int i = 0; i < 2; i++) {
_d[i] = _a[i] + (uint64_t)_b[i + 2] * (uint64_t)_c[0];
}

result_t test_vmlal_high_n_u32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
uint64x2_t a = vld1q_u64(_a);
uint32x4_t b = vld1q_u32(_b);
uint64x2_t d = vmlal_high_n_u32(a, b, _c[0]);
return validate_uint64(d, _d[0], _d[1]);
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vqdmlal_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
Expand Down
8 changes: 4 additions & 4 deletions tests/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2221,10 +2221,10 @@
_(vmlal_n_s32) \
_(vmlal_n_u16) \
_(vmlal_n_u32) \
/*_(vmlal_high_n_s16) */ \
/*_(vmlal_high_n_s32) */ \
/*_(vmlal_high_n_u16) */ \
/*_(vmlal_high_n_u32) */ \
_(vmlal_high_n_s16) \
_(vmlal_high_n_s32) \
_(vmlal_high_n_u16) \
_(vmlal_high_n_u32) \
_(vqdmlal_n_s16) \
_(vqdmlal_n_s32) \
/*_(vqdmlal_high_n_s16) */ \
Expand Down

0 comments on commit 245014e

Please sign in to comment.