Skip to content

Commit

Permalink
NEON: implement all bf16-related intrinsics (#1110)
Browse files Browse the repository at this point in the history
* [Feat] Add BF16 when the machine is supported.

Finished: vld1_bf16_x4 and vld1q_bf16_x2

* [NEON] Add a C implementation of the bf16 type

* [NEON] Add all ld_*_bf16 intrinsics.

* [NEON] Add all st*_bf16 intrinsics.

* [Test] Add vbfdot_f32 test case

* [NEON] Complete converting function from float32 to bfloat16.

- Also add bf-related functions in three series
- cvt, dot, dot_lane

* [Feat] Add option '+bf16' in cross-file

* [NEON] Completed initial implementation of bf-16 related intrinsics.

* [Fix] Remove redundant commment

* [Fix] Correct native aliases

* [Fix] The test generation code has been completed.
  • Loading branch information
yyctw authored Nov 20, 2023
1 parent c7d314b commit c59db7c
Show file tree
Hide file tree
Showing 101 changed files with 11,220 additions and 23 deletions.
4 changes: 2 additions & 2 deletions docker/cross-files/aarch64-clang-15-ccache.cross
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ ld = 'llvm-ld-15'
exe_wrapper = ['qemu-aarch64-static', '-L', '/usr/aarch64-linux-gnu']

[properties]
c_args = ['--target=aarch64-linux-gnu', '-march=armv8-a+simd+crypto+crc', '-isystem=/usr/aarch64-linux-gnu/include', '-Weverything', '-fno-lax-vector-conversions', '-Werror']
cpp_args = ['--target=aarch64-linux-gnu', '-march=armv8-a+simd+crypto+crc', '-isystem=/usr/aarch64-linux-gnu/include', '-Weverything', '-fno-lax-vector-conversions', '-Werror']
c_args = ['--target=aarch64-linux-gnu', '-march=armv8.2-a+simd+crypto+crc+bf16', '-isystem=/usr/aarch64-linux-gnu/include', '-Weverything', '-fno-lax-vector-conversions', '-Werror']
cpp_args = ['--target=aarch64-linux-gnu', '-march=armv8.2-a+simd+crypto+crc+bf16', '-isystem=/usr/aarch64-linux-gnu/include', '-Weverything', '-fno-lax-vector-conversions', '-Werror']
c_link_args = ['--target=aarch64-linux-gnu']
cpp_link_args = ['--target=aarch64-linux-gnu']

Expand Down
4 changes: 2 additions & 2 deletions docker/cross-files/aarch64-clang-16-ccache.cross
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ ld = 'llvm-ld-16'
exe_wrapper = ['qemu-aarch64-static', '-L', '/usr/aarch64-linux-gnu']

[properties]
c_args = ['--target=aarch64-linux-gnu', '-march=armv8-a+simd+crypto+crc', '-isystem=/usr/aarch64-linux-gnu/include', '-Weverything', '-fno-lax-vector-conversions', '-Werror', '-Wno-unsafe-buffer-usage']
cpp_args = ['--target=aarch64-linux-gnu', '-march=armv8-a+simd+crypto+crc', '-isystem=/usr/aarch64-linux-gnu/include', '-Weverything', '-fno-lax-vector-conversions', '-Werror', '-Wno-unsafe-buffer-usage']
c_args = ['--target=aarch64-linux-gnu', '-march=armv8-a+simd+crypto+crc+bf16', '-isystem=/usr/aarch64-linux-gnu/include', '-Weverything', '-fno-lax-vector-conversions', '-Werror', '-Wno-unsafe-buffer-usage']
cpp_args = ['--target=aarch64-linux-gnu', '-march=armv8-a+simd+crypto+crc+bf16', '-isystem=/usr/aarch64-linux-gnu/include', '-Weverything', '-fno-lax-vector-conversions', '-Werror', '-Wno-unsafe-buffer-usage']
c_link_args = ['--target=aarch64-linux-gnu']
cpp_link_args = ['--target=aarch64-linux-gnu']

Expand Down
4 changes: 2 additions & 2 deletions docker/cross-files/aarch64-gcc-12-ccache.cross
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ ld = 'aarch64-linux-gnu-ld'
exe_wrapper = ['qemu-aarch64-static', '-L', '/usr/aarch64-linux-gnu']

[properties]
c_args = ['-march=armv8-a+simd+crypto+crc', '-Wextra', '-Werror']
cpp_args = ['-march=armv8-a+simd+crypto+crc', '-Wextra', '-Werror']
c_args = ['-march=armv8.2-a+simd+crypto+crc+bf16', '-Wextra', '-Werror']
cpp_args = ['-march=armv8.2-a+simd+crypto+crc+bf16', '-Wextra', '-Werror']

[host_machine]
system = 'linux'
Expand Down
1 change: 1 addition & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ if not meson.is_subproject()
'simde/simde-aes.h',
'simde/simde-align.h',
'simde/simde-arch.h',
'simde/simde-bf16.h',
'simde/simde-common.h',
'simde/simde-constify.h',
'simde/simde-detect-clang.h',
Expand Down
25 changes: 25 additions & 0 deletions simde/arm/neon/combine.h
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,31 @@ simde_vcombine_p64(simde_poly64x1_t low, simde_poly64x1_t high) {
#define vcombine_p64(low, high) simde_vcombine_p64((low), (high))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x8_t
simde_vcombine_bf16(simde_bfloat16x4_t low, simde_bfloat16x4_t high) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcombine_bf16(low, high);
#else
simde_bfloat16x8_private r_;
simde_bfloat16x4_private
low_ = simde_bfloat16x4_to_private(low),
high_ = simde_bfloat16x4_to_private(high);

size_t halfway = (sizeof(r_.values) / sizeof(r_.values[0])) / 2;
SIMDE_VECTORIZE
for (size_t i = 0 ; i < halfway ; i++) {
r_.values[i] = low_.values[i];
r_.values[i + halfway] = high_.values[i];
}

return simde_bfloat16x8_from_private(r_);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcombine_bf16
#define vcombine_bf16(low, high) simde_vcombine_bf16((low), (high))
#endif

SIMDE_END_DECLS_
HEDLEY_DIAGNOSTIC_POP
Expand Down
77 changes: 77 additions & 0 deletions simde/arm/neon/copy_lane.h
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,83 @@ simde_vcopyq_laneq_p64(simde_poly64x2_t a, const int lane1, simde_poly64x2_t b,
#define vcopyq_laneq_p64(a, lane1, b, lane2) simde_vcopyq_laneq_p64((a), (lane1), (b), (lane2))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x4_t
simde_vcopy_lane_bf16(simde_bfloat16x4_t a, const int lane1, simde_bfloat16x4_t b, const int lane2)
SIMDE_REQUIRE_CONSTANT_RANGE(lane1, 0, 3)
SIMDE_REQUIRE_CONSTANT_RANGE(lane2, 0, 3) {
simde_bfloat16x4_private
b_ = simde_bfloat16x4_to_private(b),
r_ = simde_bfloat16x4_to_private(a);

r_.values[lane1] = b_.values[lane2];
return simde_bfloat16x4_from_private(r_);
}
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
#define simde_vcopy_lane_bf16(a, lane1, b, lane2) vcopy_lane_bf16((a), (lane1), (b), (lane2))
#endif
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
#undef vcopy_lane_bf16
#define vcopy_lane_bf16(a, lane1, b, lane2) simde_vcopy_lane_bf16((a), (lane1), (b), (lane2))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x4_t
simde_vcopy_laneq_bf16(simde_bfloat16x4_t a, const int lane1, simde_bfloat16x8_t b, const int lane2)
SIMDE_REQUIRE_CONSTANT_RANGE(lane1, 0, 3)
SIMDE_REQUIRE_CONSTANT_RANGE(lane2, 0, 7) {
simde_bfloat16x4_private r_ = simde_bfloat16x4_to_private(a);
simde_bfloat16x8_private b_ = simde_bfloat16x8_to_private(b);

r_.values[lane1] = b_.values[lane2];
return simde_bfloat16x4_from_private(r_);
}
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
#define simde_vcopy_laneq_bf16(a, lane1, b, lane2) vcopy_laneq_bf16((a), (lane1), (b), (lane2))
#endif
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
#undef vcopy_laneq_bf16
#define vcopy_laneq_bf16(a, lane1, b, lane2) simde_vcopy_laneq_bf16((a), (lane1), (b), (lane2))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x8_t
simde_vcopyq_lane_bf16(simde_bfloat16x8_t a, const int lane1, simde_bfloat16x4_t b, const int lane2)
SIMDE_REQUIRE_CONSTANT_RANGE(lane1, 0, 7)
SIMDE_REQUIRE_CONSTANT_RANGE(lane2, 0, 3) {
simde_bfloat16x4_private b_ = simde_bfloat16x4_to_private(b);
simde_bfloat16x8_private r_ = simde_bfloat16x8_to_private(a);

r_.values[lane1] = b_.values[lane2];
return simde_bfloat16x8_from_private(r_);
}
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
#define simde_vcopyq_lane_bf16(a, lane1, b, lane2) vcopyq_lane_bf16((a), (lane1), (b), (lane2))
#endif
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
#undef vcopyq_lane_bf16
#define vcopyq_lane_bf16(a, lane1, b, lane2) simde_vcopyq_lane_bf16((a), (lane1), (b), (lane2))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x8_t
simde_vcopyq_laneq_bf16(simde_bfloat16x8_t a, const int lane1, simde_bfloat16x8_t b, const int lane2)
SIMDE_REQUIRE_CONSTANT_RANGE(lane1, 0, 7)
SIMDE_REQUIRE_CONSTANT_RANGE(lane2, 0, 7) {
simde_bfloat16x8_private
b_ = simde_bfloat16x8_to_private(b),
r_ = simde_bfloat16x8_to_private(a);

r_.values[lane1] = b_.values[lane2];
return simde_bfloat16x8_from_private(r_);
}
#if defined(SIMDE_ARM_NEON_A64V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
#define simde_vcopyq_laneq_bf16(a, lane1, b, lane2) vcopyq_laneq_bf16((a), (lane1), (b), (lane2))
#endif
#if defined(SIMDE_ARM_NEON_A64V8_ENABLE_NATIVE_ALIASES)
#undef vcopyq_laneq_bf16
#define vcopyq_laneq_bf16(a, lane1, b, lane2) simde_vcopyq_laneq_bf16((a), (lane1), (b), (lane2))
#endif

SIMDE_END_DECLS_
HEDLEY_DIAGNOSTIC_POP
Expand Down
15 changes: 13 additions & 2 deletions simde/arm/neon/create.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
* 2023 Yi-Yen Chung <[email protected]> (Copyright owned by Andes Technology)
*/

/* Yi-Yen Chung: Added vcreate_f16 */

#if !defined(SIMDE_ARM_NEON_CREATE_H)
#define SIMDE_ARM_NEON_CREATE_H

Expand Down Expand Up @@ -235,6 +233,19 @@ simde_vcreate_p64(simde_poly64_t a) {
#define vcreate_p64(a) simde_vcreate_p64(a)
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x4_t
simde_vcreate_bf16(uint64_t a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcreate_bf16(a);
#else
return simde_vreinterpret_bf16_u64(simde_vdup_n_u64(a));
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcreate_bf16
#define vcreate_bf16(a) simde_vcreate_bf16(a)
#endif

SIMDE_END_DECLS_
HEDLEY_DIAGNOSTIC_POP
Expand Down
164 changes: 164 additions & 0 deletions simde/arm/neon/cvt.h
Original file line number Diff line number Diff line change
Expand Up @@ -2068,6 +2068,170 @@ simde_vcvtx_high_f32_f64(simde_float32x2_t r, simde_float64x2_t a) {
#define vcvtx_high_f32_f64(r, a) simde_vcvtx_high_f32_f64((r), (a))
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x4_t
simde_vcvt_bf16_f32(simde_float32x4_t a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcvt_bf16_f32(a);
#else
simde_float32x4_private a_ = simde_float32x4_to_private(a);
simde_bfloat16x4_private r_;

SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(r_.values) / sizeof(r_.values[0])) ; i++) {
r_.values[i] = simde_bfloat16_from_float32(a_.values[i]);
}

return simde_bfloat16x4_from_private(r_);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcvt_bf16_f32
#define vcvt_bf16_f32(a) simde_vcvt_bf16_f32(a)
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_float32x4_t
simde_vcvt_f32_bf16(simde_bfloat16x4_t a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcvt_f32_bf16(a);
#else
simde_bfloat16x4_private a_ = simde_bfloat16x4_to_private(a);
simde_float32x4_private r_;

SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(r_.values) / sizeof(r_.values[0])) ; i++) {
r_.values[i] = simde_bfloat16_to_float32(a_.values[i]);
}

return simde_float32x4_from_private(r_);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcvt_f32_bf16
#define vcvt_f32_bf16(a) simde_vcvt_f32_bf16(a)
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_float32_t
simde_vcvtah_f32_bf16(simde_bfloat16_t a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcvtah_f32_bf16(a);
#else
return simde_bfloat16_to_float32(a);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcvtah_f32_bf16
#define vcvtah_f32_bf16(a) simde_vcvtah_f32_bf16(a)
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16_t
simde_vcvth_bf16_f32(float a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcvth_bf16_f32(a);
#else
return simde_bfloat16_from_float32(a);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcvth_bf16_f32
#define vcvth_bf16_f32(a) simde_vcvth_bf16_f32(a)
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_float32x4_t
simde_vcvtq_low_f32_bf16(simde_bfloat16x8_t a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcvtq_low_f32_bf16(a);
#else
simde_bfloat16x8_private a_ = simde_bfloat16x8_to_private(a);
simde_float32x4_private r_;

SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(r_.values) / sizeof(r_.values[0])) ; i++) {
r_.values[i] = simde_bfloat16_to_float32(a_.values[i]);
}

return simde_float32x4_from_private(r_);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcvtq_low_f32_bf16
#define vcvtq_low_f32_bf16(a) simde_vcvtq_low_f32_bf16(a)
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_float32x4_t
simde_vcvtq_high_f32_bf16(simde_bfloat16x8_t a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcvtq_high_f32_bf16(a);
#else
simde_bfloat16x8_private a_ = simde_bfloat16x8_to_private(a);
simde_float32x4_private r_;

size_t rsize = (sizeof(r_.values) / sizeof(r_.values[0]));
SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(r_.values) / sizeof(r_.values[0])) ; i++) {
r_.values[i] = simde_bfloat16_to_float32(a_.values[i + rsize]);
}

return simde_float32x4_from_private(r_);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcvtq_high_f32_bf16
#define vcvtq_high_f32_bf16(a) simde_vcvtq_high_f32_bf16(a)
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x8_t
simde_vcvtq_low_bf16_f32(simde_float32x4_t a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcvtq_low_bf16_f32(a);
#else
simde_float32x4_private a_ = simde_float32x4_to_private(a);
simde_bfloat16x8_private r_;

size_t asize = (sizeof(a_.values) / sizeof(a_.values[0]));
SIMDE_VECTORIZE
for (size_t i = 0 ; i < asize; i++) {
r_.values[i] = simde_bfloat16_from_float32(a_.values[i]);
r_.values[i + asize] = SIMDE_BFLOAT16_VALUE(0.0);
}

return simde_bfloat16x8_from_private(r_);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcvtq_low_bf16_f32
#define vcvtq_low_bf16_f32(a) simde_vcvtq_low_bf16_f32(a)
#endif

SIMDE_FUNCTION_ATTRIBUTES
simde_bfloat16x8_t
simde_vcvtq_high_bf16_f32(simde_bfloat16x8_t inactive, simde_float32x4_t a) {
#if defined(SIMDE_ARM_NEON_A32V8_NATIVE) && defined(SIMDE_ARM_NEON_BF16)
return vcvtq_high_bf16_f32(inactive, a);
#else
simde_bfloat16x8_private inactive_ = simde_bfloat16x8_to_private(inactive);
simde_float32x4_private a_ = simde_float32x4_to_private(a);
simde_bfloat16x8_private r_;

size_t asize = (sizeof(a_.values) / sizeof(a_.values[0]));
SIMDE_VECTORIZE
for (size_t i = 0 ; i < (sizeof(a_.values) / sizeof(a_.values[0])) ; i++) {
r_.values[i] = inactive_.values[i];
r_.values[i + asize] = simde_bfloat16_from_float32(a_.values[i]);
}
return simde_bfloat16x8_from_private(r_);
#endif
}
#if defined(SIMDE_ARM_NEON_A32V8_ENABLE_NATIVE_ALIASES)
#undef vcvtq_high_bf16_f32
#define vcvtq_high_bf16_f32(inactive, a) simde_vcvtq_high_bf16_f32((inactive), (a))
#endif

SIMDE_END_DECLS_
HEDLEY_DIAGNOSTIC_POP
Expand Down
Loading

0 comments on commit c59db7c

Please sign in to comment.