Skip to content

Commit

Permalink
feat: Add vqdmlsl_high_n_[s16|s32]
Browse files Browse the repository at this point in the history
  • Loading branch information
howjmay committed Jul 23, 2024
1 parent 269ae13 commit 07dc1c8
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 12 deletions.
16 changes: 14 additions & 2 deletions neon2rvv.h
Original file line number Diff line number Diff line change
Expand Up @@ -10845,9 +10845,21 @@ FORCE_INLINE int64x2_t vqdmlsl_n_s32(int64x2_t a, int32x2_t b, int32_t c) {
return __riscv_vsub_vv_i64m1(a, bc_mulx2, 2);
}

// FORCE_INLINE int32x4_t vqdmlsl_high_n_s16(int32x4_t a, int16x8_t b, int16_t c);
FORCE_INLINE int32x4_t vqdmlsl_high_n_s16(int32x4_t a, int16x8_t b, int16_t c) {
vint16m1_t c_dup = vdup_n_s16(c);
vint16m1_t b_high = __riscv_vslidedown_vx_i16m1(b, 4, 8);
vint32m1_t bc_mul = __riscv_vlmul_trunc_v_i32m2_i32m1(__riscv_vwmul_vv_i32m2(b_high, c_dup, 4));
vint32m1_t bc_mulx2 = __riscv_vsll_vx_i32m1(bc_mul, 1, 4);
return __riscv_vsub_vv_i32m1(a, bc_mulx2, 4);
}

// FORCE_INLINE int64x2_t vqdmlsl_high_n_s32(int64x2_t a, int32x4_t b, int32_t c);
FORCE_INLINE int64x2_t vqdmlsl_high_n_s32(int64x2_t a, int32x4_t b, int32_t c) {
vint32m1_t c_dup = vdup_n_s32(c);
vint32m1_t b_high = __riscv_vslidedown_vx_i32m1(b, 2, 4);
vint64m1_t bc_mul = __riscv_vlmul_trunc_v_i64m2_i64m1(__riscv_vwmul_vv_i64m2(b_high, c_dup, 2));
vint64m1_t bc_mulx2 = __riscv_vsll_vx_i64m1(bc_mul, 1, 2);
return __riscv_vsub_vv_i64m1(a, bc_mulx2, 2);
}

FORCE_INLINE int8x8_t vext_s8(int8x8_t a, int8x8_t b, const int c) {
vint8m1_t a_slidedown = __riscv_vslidedown_vx_i8m1(a, c, 8);
Expand Down
46 changes: 40 additions & 6 deletions tests/impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34421,13 +34421,13 @@ result_t test_vqdmlal_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
}

result_t test_vqdmlal_high_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
#ifdef ENABLE_TEST_ALL
const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1;
const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2;
const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3;
int32_t _d[4];
for (int i = 0; i < 4; i++) {
_d[i] = sat_add(_a[i], sat_dmull(_b[i+4], _c[0]));
_d[i] = sat_add(_a[i], sat_dmull(_b[i + 4], _c[0]));
}

int32x4_t a = vld1q_s32(_a);
Expand All @@ -34440,13 +34440,13 @@ result_t test_vqdmlal_high_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter)
}

result_t test_vqdmlal_high_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
#ifdef ENABLE_TEST_ALL
const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1;
const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2;
const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3;
int64_t _d[2];
for (int i = 0; i < 2; i++) {
_d[i] = sat_add(_a[i], sat_dmull(_b[i+2], _c[0]));
_d[i] = sat_add(_a[i], sat_dmull(_b[i + 2], _c[0]));
}

int64x2_t a = vld1q_s64(_a);
Expand Down Expand Up @@ -34838,9 +34838,43 @@ result_t test_vqdmlsl_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#endif // ENABLE_TEST_ALL
}

result_t test_vqdmlsl_high_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
result_t test_vqdmlsl_high_n_s16(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const int32_t *_a = (int32_t *)impl.test_cases_int_pointer1;
const int16_t *_b = (int16_t *)impl.test_cases_int_pointer2;
const int16_t *_c = (int16_t *)impl.test_cases_int_pointer3;
int32_t _d[4];
for (int i = 0; i < 4; i++) {
_d[i] = sat_sub(_a[i], sat_dmull(_b[i + 4], _c[0]));
}

int32x4_t a = vld1q_s32(_a);
int16x8_t b = vld1q_s16(_b);
int32x4_t d = vqdmlsl_high_n_s16(a, b, _c[0]);
return validate_int32(d, _d[0], _d[1], _d[2], _d[3]);
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vqdmlsl_high_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
const int64_t *_a = (int64_t *)impl.test_cases_int_pointer1;
const int32_t *_b = (int32_t *)impl.test_cases_int_pointer2;
const int32_t *_c = (int32_t *)impl.test_cases_int_pointer3;
int64_t _d[2];
for (int i = 0; i < 2; i++) {
_d[i] = sat_sub(_a[i], sat_dmull(_b[i + 2], _c[0]));
}

result_t test_vqdmlsl_high_n_s32(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) { return TEST_UNIMPL; }
int64x2_t a = vld1q_s64(_a);
int32x4_t b = vld1q_s32(_b);
int64x2_t d = vqdmlsl_high_n_s32(a, b, _c[0]);
return validate_int64(d, _d[0], _d[1]);
#else
return TEST_UNIMPL;
#endif // ENABLE_TEST_ALL
}

result_t test_vext_s8(const NEON2RVV_TEST_IMPL &impl, uint32_t iter) {
#ifdef ENABLE_TEST_ALL
Expand Down
8 changes: 4 additions & 4 deletions tests/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2227,8 +2227,8 @@
_(vmlal_high_n_u32) \
_(vqdmlal_n_s16) \
_(vqdmlal_n_s32) \
_(vqdmlal_high_n_s16) \
_(vqdmlal_high_n_s32) \
_(vqdmlal_high_n_s16) \
_(vqdmlal_high_n_s32) \
_(vmls_n_s16) \
_(vmls_n_s32) \
_(vmls_n_f32) \
Expand All @@ -2249,8 +2249,8 @@
_(vmlsl_high_n_u32) \
_(vqdmlsl_n_s16) \
_(vqdmlsl_n_s32) \
/*_(vqdmlsl_high_n_s16) */ \
/*_(vqdmlsl_high_n_s32) */ \
_(vqdmlsl_high_n_s16) \
_(vqdmlsl_high_n_s32) \
/*_(vext_p64) */ \
_(vext_s8) \
_(vext_s16) \
Expand Down

0 comments on commit 07dc1c8

Please sign in to comment.