Skip to content

Commit

Permalink
Fix qdmlal instructions
Browse files Browse the repository at this point in the history
qdmlal instructions were implemented without saturation.
This has been fixed by utilising existing SIMDe saturating mult and add instructions.
Unit tests have been updated to test for all possible saturation cases.

- Fix qdmlal, qdmlal_n, qdmlal_lane,  qdmlal_high, qdmlal_high_n and qdmlal_high_lane
- Update unit tests for qdmlal, qdmlal_n, qdmlal_lane, qdmlal_high, qdmlal_high_n, qdmala_high_lane

Change-Id: I8d0d8cfba3f8d5203f2028efbe74b00c51485c61
  • Loading branch information
Ryo-not-rio committed Jul 15, 2024
1 parent 6f52a1d commit 5af8523
Show file tree
Hide file tree
Showing 10 changed files with 290 additions and 170 deletions.
19 changes: 6 additions & 13 deletions simde/arm/neon/qdmlal.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@
#if !defined(SIMDE_ARM_NEON_QDMLAL_H)
#define SIMDE_ARM_NEON_QDMLAL_H

#include "add.h"
#include "mul.h"
#include "mul_n.h"
#include "movl.h"
#include "qadd.h"
#include "types.h"
#include "qadd.h"
#include "qdmull.h"

HEDLEY_DIAGNOSTIC_PUSH
SIMDE_DISABLE_UNWANTED_DIAGNOSTICS
Expand All @@ -44,7 +41,7 @@ simde_vqdmlalh_s16(int32_t a, int16_t b, int16_t c) {
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
return vqdmlalh_s16(a, b, c);
#else
return HEDLEY_STATIC_CAST(int32_t, b) * HEDLEY_STATIC_CAST(int32_t, c) * 2 + a;
return simde_vqadds_s32(a, simde_vqdmullh_s16(b, c));
#endif
}
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
Expand All @@ -58,7 +55,7 @@ simde_vqdmlals_s32(int64_t a, int32_t b, int32_t c) {
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
return vqdmlals_s32(a, b, c);
#else
return HEDLEY_STATIC_CAST(int64_t, b) * HEDLEY_STATIC_CAST(int64_t, c) * 2 + a;
return simde_vqaddd_s64(a, simde_vqdmulls_s32(b, c));
#endif
}
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
Expand All @@ -72,8 +69,7 @@ simde_vqdmlal_s16(simde_int32x4_t a, simde_int16x4_t b, simde_int16x4_t c) {
#if defined(SIMDE_ARM_NEON_A32V7_NATIVE)
return vqdmlal_s16(a, b, c);
#else
simde_int32x4_t temp = simde_vmulq_s32(simde_vmovl_s16(b), simde_vmovl_s16(c));
return simde_vqaddq_s32(simde_vqaddq_s32(temp, temp), a);
return simde_vqaddq_s32(simde_vqdmull_s16(b, c), a);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V7_ENABLE_NATIVE_ALIASES)
Expand All @@ -87,10 +83,7 @@ simde_vqdmlal_s32(simde_int64x2_t a, simde_int32x2_t b, simde_int32x2_t c) {
#if defined(SIMDE_ARM_NEON_A32V7_NATIVE)
return vqdmlal_s32(a, b, c);
#else
simde_int64x2_t r = simde_x_vmulq_s64(
simde_vmovl_s32(b),
simde_vmovl_s32(c));
return simde_vqaddq_s64(a, simde_vqaddq_s64(r, r));
return simde_vqaddq_s64(simde_vqdmull_s32(b, c), a);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V7_ENABLE_NATIVE_ALIASES)
Expand Down
22 changes: 4 additions & 18 deletions simde/arm/neon/qdmlal_high.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,9 @@
#if !defined(SIMDE_ARM_NEON_QDMLAL_HIGH_H)
#define SIMDE_ARM_NEON_QDMLAL_HIGH_H

#include "movl_high.h"
#include "mla.h"
#include "mul_n.h"
#include "types.h"
#include "qadd.h"
#include "qdmull_high.h"

HEDLEY_DIAGNOSTIC_PUSH
SIMDE_DISABLE_UNWANTED_DIAGNOSTICS
Expand All @@ -42,10 +41,7 @@ simde_vqdmlal_high_s16(simde_int32x4_t a, simde_int16x8_t b, simde_int16x8_t c)
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
return vqdmlal_high_s16(a, b, c);
#else
return simde_vaddq_s32(
simde_vmulq_n_s32(
simde_vmulq_s32(
simde_vmovl_high_s16(b), simde_vmovl_high_s16(c)), 2), a);
return simde_vqaddq_s32(simde_vqdmull_high_s16(b, c), a);
#endif
}
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
Expand All @@ -59,17 +55,7 @@ simde_vqdmlal_high_s32(simde_int64x2_t a, simde_int32x4_t b, simde_int32x4_t c)
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
return vqdmlal_high_s32(a, b, c);
#else
simde_int64x2_private r_ = simde_int64x2_to_private(
simde_x_vmulq_s64(
simde_vmovl_high_s32(b),
simde_vmovl_high_s32(c)));

SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(r_.values) / sizeof(r_.values[0])) ; i++) {
r_.values[i] = r_.values[i] * HEDLEY_STATIC_CAST(int64_t, 2);
}

return simde_vaddq_s64(a, simde_int64x2_from_private(r_));
return simde_vqaddq_s64(simde_vqdmull_high_s32(b, c), a);
#endif
}
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
Expand Down
67 changes: 12 additions & 55 deletions simde/arm/neon/qdmlal_high_lane.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,92 +27,49 @@
#if !defined(SIMDE_ARM_NEON_QDMLAL_HIGH_LANE_H)
#define SIMDE_ARM_NEON_QDMLAL_HIGH_LANE_H

#include "movl_high.h"
#include "add.h"
#include "mul.h"
#include "mul_n.h"
#include "dup_n.h"
#include "mla.h"
#include "dup_lane.h"
#include "get_high.h"
#include "types.h"
#include "qdmlal.h"

HEDLEY_DIAGNOSTIC_PUSH
SIMDE_DISABLE_UNWANTED_DIAGNOSTICS
SIMDE_BEGIN_DECLS_

SIMDE_FUNCTION_ATTRIBUTES
simde_int32x4_t
simde_vqdmlal_high_lane_s16(simde_int32x4_t a, simde_int16x8_t b, simde_int16x4_t v, const int lane) SIMDE_REQUIRE_CONSTANT_RANGE(lane, 0, 3) {
return simde_vaddq_s32(
simde_vmulq_n_s32(
simde_vmulq_s32(
simde_vmovl_high_s16(b),
simde_vmovl_high_s16(simde_vdupq_n_s16(simde_int16x4_to_private(v).values[lane]))), 2), a);
}
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
#define simde_vqdmlal_high_lane_s16(a, b, v, lane) vqdmlal_high_lane_s16(a, b, v, lane)
#else
#define simde_vqdmlal_high_lane_s16(a, b, v, lane) simde_vqdmlal_s16((a), simde_vget_high_s16((b)), simde_vdup_lane_s16((v), (lane)))
#endif
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
#undef vqdmlal_high_lane_s16
#define vqdmlal_high_lane_s16(a, b, v, lane) simde_vqdmlal_high_lane_s16((a), (b), (v), (lane))
#define vqdmlal_high_lane_s16(a, b, c, lane) simde_vqdmlal_high_lane_s16((a), (b), (c), (lane))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_int32x4_t
simde_vqdmlal_high_laneq_s16(simde_int32x4_t a, simde_int16x8_t b, simde_int16x8_t v, const int lane) SIMDE_REQUIRE_CONSTANT_RANGE(lane, 0, 7) {
return simde_vaddq_s32(
simde_vmulq_n_s32(
simde_vmulq_s32(
simde_vmovl_high_s16(b),
simde_vmovl_high_s16(simde_vdupq_n_s16(simde_int16x8_to_private(v).values[lane]))), 2), a);
}
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
#define simde_vqdmlal_high_laneq_s16(a, b, v, lane) vqdmlal_high_laneq_s16(a, b, v, lane)
#else
#define simde_vqdmlal_high_laneq_s16(a, b, v, lane) simde_vqdmlal_s16((a), simde_vget_high_s16((b)), simde_vdup_laneq_s16((v), (lane)))
#endif
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
#undef vqdmlal_high_laneq_s16
#define vqdmlal_high_laneq_s16(a, b, v, lane) simde_vqdmlal_high_laneq_s16((a), (b), (v), (lane))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_int64x2_t
simde_vqdmlal_high_lane_s32(simde_int64x2_t a, simde_int32x4_t b, simde_int32x2_t v, const int lane) SIMDE_REQUIRE_CONSTANT_RANGE(lane, 0, 1) {
simde_int64x2_private r_ = simde_int64x2_to_private(
simde_x_vmulq_s64(
simde_vmovl_high_s32(b),
simde_vmovl_high_s32(simde_vdupq_n_s32(simde_int32x2_to_private(v).values[lane]))));

SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(r_.values) / sizeof(r_.values[0])) ; i++) {
r_.values[i] = r_.values[i] * HEDLEY_STATIC_CAST(int64_t, 2);
}

return simde_vaddq_s64(a, simde_int64x2_from_private(r_));
}
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
#define simde_vqdmlal_high_lane_s32(a, b, v, lane) vqdmlal_high_lane_s32(a, b, v, lane)
#else
#define simde_vqdmlal_high_lane_s32(a, b, v, lane) simde_vqdmlal_s32((a), simde_vget_high_s32((b)), simde_vdup_lane_s32((v), (lane)))
#endif
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
#undef vqdmlal_high_lane_s32
#define vqdmlal_high_lane_s32(a, b, v, lane) simde_vqdmlal_high_lane_s32((a), (b), (v), (lane))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_int64x2_t
simde_vqdmlal_high_laneq_s32(simde_int64x2_t a, simde_int32x4_t b, simde_int32x4_t v, const int lane) SIMDE_REQUIRE_CONSTANT_RANGE(lane, 0, 3) {
simde_int64x2_private r_ = simde_int64x2_to_private(
simde_x_vmulq_s64(
simde_vmovl_high_s32(b),
simde_vmovl_high_s32(simde_vdupq_n_s32(simde_int32x4_to_private(v).values[lane]))));

SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(r_.values) / sizeof(r_.values[0])) ; i++) {
r_.values[i] = r_.values[i] * HEDLEY_STATIC_CAST(int64_t, 2);
}

return simde_vaddq_s64(a, simde_int64x2_from_private(r_));
}
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
#define simde_vqdmlal_high_laneq_s32(a, b, v, lane) vqdmlal_high_laneq_s32(a, b, v, lane)
#else
#define simde_vqdmlal_high_laneq_s32(a, b, v, lane) simde_vqdmlal_s32((a), simde_vget_high_s32((b)), simde_vdup_laneq_s32((v), (lane)))
#endif
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
#undef vqdmlal_high_laneq_s32
Expand Down
23 changes: 3 additions & 20 deletions simde/arm/neon/qdmlal_high_n.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,9 @@
#if !defined(SIMDE_ARM_NEON_QDMLAL_HIGH_N_H)
#define SIMDE_ARM_NEON_QDMLAL_HIGH_N_H

#include "movl_high.h"
#include "dup_n.h"
#include "add.h"
#include "mul.h"
#include "mul_n.h"
#include "types.h"
#include "qdmlal_high.h"

HEDLEY_DIAGNOSTIC_PUSH
SIMDE_DISABLE_UNWANTED_DIAGNOSTICS
Expand All @@ -44,11 +41,7 @@ simde_vqdmlal_high_n_s16(simde_int32x4_t a, simde_int16x8_t b, int16_t c) {
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
return vqdmlal_high_n_s16(a, b, c);
#else
return simde_vaddq_s32(
simde_vmulq_n_s32(
simde_vmulq_s32(
simde_vmovl_high_s16(b),
simde_vmovl_high_s16(simde_vdupq_n_s16(c))), 2), a);
return simde_vqdmlal_high_s16(a, b, simde_vdupq_n_s16(c));
#endif
}
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
Expand All @@ -62,17 +55,7 @@ simde_vqdmlal_high_n_s32(simde_int64x2_t a, simde_int32x4_t b, int32_t c) {
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE)
return vqdmlal_high_n_s32(a, b, c);
#else
simde_int64x2_private r_ = simde_int64x2_to_private(
simde_x_vmulq_s64(
simde_vmovl_high_s32(b),
simde_vmovl_high_s32(simde_vdupq_n_s32(c))));

SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(r_.values) / sizeof(r_.values[0])) ; i++) {
r_.values[i] = r_.values[i] * HEDLEY_STATIC_CAST(int64_t, 2);
}

return simde_vaddq_s64(a, simde_int64x2_from_private(r_));
return simde_vqdmlal_high_s32(a, b, simde_vdupq_n_s32(c));
#endif
}
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
Expand Down
36 changes: 36 additions & 0 deletions test/arm/neon/qdmlal.c
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ test_simde_vqdmlalh_s16 (SIMDE_MUNIT_TEST_ARGS) {
{ INT16_C( 6764) },
{ -INT16_C( 707) },
{ -INT32_C( 6880798) } },
{ { INT32_C( INT32_MAX) },
{ INT16_C( 1) },
{ INT16_C( 1) },
{ INT32_C( INT32_MAX) } },
{ { INT32_C( INT32_MIN) },
{ INT16_C( 1) },
{ -INT16_C( 1) },
{ INT32_C( INT32_MIN) } },
{ { INT32_C( 0) },
{ INT16_C( INT16_MIN) },
{ INT16_C( INT16_MIN) },
{ INT32_C( INT32_MAX) } },
};

for (size_t i = 0 ; i < (sizeof(test_vec) / sizeof(test_vec[0])) ; i++) {
Expand Down Expand Up @@ -94,6 +106,18 @@ test_simde_vqdmlals_s32 (SIMDE_MUNIT_TEST_ARGS) {
{ INT32_C( 2995714) },
{ -INT32_C( 3814223) },
{ -INT64_C( 22853477950349) } },
{ { INT64_MAX },
{ INT32_C( 1) },
{ INT32_C( 1) },
{ INT64_MAX } },
{ { INT64_MIN },
{ INT32_C( 1) },
{ -INT32_C( 1) },
{ INT64_MIN } },
{ { INT64_C( 0) },
{ INT32_C( INT32_MIN) },
{ INT32_C( INT32_MIN) },
{ INT64_MAX } },
};

for (size_t i = 0 ; i < (sizeof(test_vec) / sizeof(test_vec[0])) ; i++) {
Expand Down Expand Up @@ -149,6 +173,10 @@ test_simde_vqdmlal_s16 (SIMDE_MUNIT_TEST_ARGS) {
{ INT16_MIN, INT16_MIN, INT16_MIN, INT16_MIN },
{ INT16_MIN, INT16_MIN, INT16_MIN, INT16_MIN },
{ INT32_C(2147483631), INT32_C(2147483632), INT32_C(2147483633), INT32_C(2147483634) } },
{ { INT32_C( INT32_MAX), INT32_C( INT32_MIN), INT32_C( 0), -INT32_C( 68184) },
{ INT16_C( 1), -INT16_C( 1), INT16_C( INT16_MIN), INT16_C( 9252) },
{ INT16_C( 1), INT16_C( 1), INT16_C( INT16_MIN), INT16_C( 5749) },
{ INT32_C( INT32_MAX), INT32_C( INT32_MIN), INT32_C( INT32_MAX), INT32_C( 106311312) } },
};

for (size_t i = 0 ; i < (sizeof(test_vec) / sizeof(test_vec[0])) ; i++) {
Expand Down Expand Up @@ -207,6 +235,14 @@ test_simde_vqdmlal_s32 (SIMDE_MUNIT_TEST_ARGS) {
{ INT32_MIN, INT32_MIN },
{ INT32_MIN, INT32_MIN },
{INT64_C(9223372036854775791), INT64_C(9223372036854775792) } },
{ { INT64_MAX, INT64_MIN },
{ INT32_C( 1), -INT32_C( 1) },
{ INT32_C( 1), INT32_C( 1) },
{ INT64_MAX, INT64_MIN } },
{ { INT64_C( 0), -INT64_C( 68184) },
{ INT32_C( INT32_MIN), INT32_C( 9252) },
{ INT32_C( INT32_MIN), INT32_C( 5749) },
{ INT64_MAX, INT64_C( 106311312) } }
};

for (size_t i = 0 ; i < (sizeof(test_vec) / sizeof(test_vec[0])) ; i++) {
Expand Down
14 changes: 14 additions & 0 deletions test/arm/neon/qdmlal_high.c
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ test_simde_vqdmlal_high_s16 (SIMDE_MUNIT_TEST_ARGS) {
{ -INT16_C( 9269), -INT16_C( 5310), INT16_C( 5746), INT16_C( 4013),
INT16_C( 5760), INT16_C( 4110), INT16_C( 8914), -INT16_C( 764) },
{ -INT32_C( 74794532), INT32_C( 36362128), INT32_C( 94724016), INT32_C( 13825770) } },
{ { INT32_C( INT32_MAX), INT32_C( INT32_MIN), INT32_C( 0), INT32_C( 5368290) },
{ -INT16_C( 9903), -INT16_C( 7336), INT16_C( 1785), INT16_C( 5751),
INT16_C( 1), INT16_C( 1), INT16_C( INT16_MIN), -INT16_C( 5535) },
{ -INT16_C( 9269), -INT16_C( 5310), INT16_C( 5746), INT16_C( 4013),
INT16_C( 1), -INT16_C( 1), INT16_C( INT16_MIN), -INT16_C( 764) },
{ INT32_C( INT32_MAX), INT32_C( INT32_MIN), INT32_C( INT32_MAX), INT32_C( 13825770) } },
};

for (size_t i = 0 ; i < (sizeof(test_vec) / sizeof(test_vec[0])) ; i++) {
Expand Down Expand Up @@ -113,6 +119,14 @@ test_simde_vqdmlal_high_s32 (SIMDE_MUNIT_TEST_ARGS) {
{ -INT32_C( 759050), -INT32_C( 437291), INT32_C( 207575), -INT32_C( 177006) },
{ -INT32_C( 262650), INT32_C( 912777), INT32_C( 556302), -INT32_C( 41245) },
{ INT64_C( 231133969127), INT64_C( 14599655586) } },
{ { INT64_MAX, INT64_MIN },
{ -INT32_C( 759050), -INT32_C( 437291), INT32_C( 1), INT32_C( 1) },
{ -INT32_C( 262650), INT32_C( 912777), INT32_C( 1), -INT32_C( 1) },
{ INT64_MAX, INT64_MIN } },
{ { INT64_C( 0), -INT64_C( 1569354) },
{ -INT32_C( 759050), -INT32_C( 437291), INT32_C( INT32_MIN), -INT32_C( 177006) },
{ -INT32_C( 262650), INT32_C( 912777), INT32_C( INT32_MIN), -INT32_C( 41245) },
{ INT64_MAX, INT64_C( 14599655586) } },
};

for (size_t i = 0 ; i < (sizeof(test_vec) / sizeof(test_vec[0])) ; i++) {
Expand Down
Loading

0 comments on commit 5af8523

Please sign in to comment.