Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Add Fused-Attention Layer for AVX2 Platforms #137

Merged
merged 10 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 90 additions & 1 deletion bestla/bestla/kernel_avx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ namespace avx2 {
#if CompileAVX2()
#ifdef __GNUC__
#pragma GCC push_options
#pragma GCC target("avx2", "fma")
#pragma GCC target("avx2", "fma", "f16c")
#else
#endif

Expand Down Expand Up @@ -1118,6 +1118,95 @@ static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, co
return BTLA_CODE::Success;
}

inline __m256 poly_scale_2nd_ps(const __m256i z, const __m256 f, const __m256 c0, const __m256 c1, const __m256 c2) {
const auto y = _mm256_fmadd_ps(_mm256_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
static const auto mask_exp = _mm256_set1_epi32(0x7f800000);
static const auto mask_not_exp = _mm256_set1_epi32(~0x7f800000);

const auto y_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_exp);
const auto y_not_exp = _mm256_and_si256(_mm256_castps_si256(y), mask_not_exp);

const auto y_exp_scaled = _mm256_add_epi32(y_exp, _mm256_slli_epi32(z, 23));
return _mm256_castsi256_ps(_mm256_or_si256(y_not_exp, _mm256_and_si256(y_exp_scaled, mask_exp)));
}

inline __m256 exp_ps_0_1(const __m256 x) {
static const auto c0 = _mm256_set1_ps(0.240226507f);
static const auto c1 = _mm256_set1_ps(0.452920674f);
static const auto c2 = _mm256_set1_ps(0.713483036f);
static const float v_log2e = std::log2(std::exp(1.f));
static const auto log2e = _mm256_set1_ps(v_log2e);
static const auto half = _mm256_set1_ps(.5f);

const auto x1 = _mm256_fmadd_ps(x, log2e, half); // auto x1 = x * log2e + _mm256_set1_ps(.5f);
const auto z = _mm256_floor_ps(x1);
const auto f = _mm256_sub_ps(x1, z); // auto f = x1 - z;

return poly_scale_2nd_ps(_mm256_cvtps_epi32(z), f, c0, c1, c2);
}

#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes" // https://stackoverflow.com/a/49216021
#endif
// Interleave 8 xmm vectors of words inplace
static inline std::array<__m128i, 8> tr_x8_word(std::array<__m128i, 8>& src) { // NOLINT [runtime/references]
std::array<__m128i, 8> dst;

for (int i = 0; i < 8; i += 2) {
dst[i + 0] = _mm_unpacklo_epi16(src[i + 0], src[i + 1]);
dst[i + 1] = _mm_unpackhi_epi16(src[i + 0], src[i + 1]);
}
for (int i = 0; i < 8; i += 4) {
src[i + 0] = _mm_unpacklo_epi32(dst[i + 0], dst[i + 2]);
src[i + 1] = _mm_unpackhi_epi32(dst[i + 0], dst[i + 2]);
src[i + 2] = _mm_unpacklo_epi32(dst[i + 1], dst[i + 3]);
src[i + 3] = _mm_unpackhi_epi32(dst[i + 1], dst[i + 3]);
}
dst[0] = _mm_unpacklo_epi64(src[0], src[4]);
dst[1] = _mm_unpackhi_epi64(src[0], src[4]);
dst[2] = _mm_unpacklo_epi64(src[1], src[5]);
dst[3] = _mm_unpackhi_epi64(src[1], src[5]);
dst[4] = _mm_unpacklo_epi64(src[2], src[6]);
dst[5] = _mm_unpackhi_epi64(src[2], src[6]);
dst[6] = _mm_unpacklo_epi64(src[3], src[7]);
dst[7] = _mm_unpackhi_epi64(src[3], src[7]);
return dst;
}

template <int tail>
inline std::array<__m128i, 8> load_fp32_fp16_tr_x8_word(const float* a, size_t lda) {
static_assert(tail > 0 && tail <= 8, "Unexpected tail value.");
std::array<__m128i, 8> dst;
for (int i = 0; i < tail; ++i) {
dst[i] = _mm256_cvtps_ph(_mm256_loadu_ps(a + i * lda), _MM_FROUND_TO_NEAREST_INT);
}
for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128();
return tr_x8_word(dst);
}
constexpr decltype(load_fp32_fp16_tr_x8_word<1>)* load_fp32_fp16_tr_x8_word_tbl[9]{
load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<1>, load_fp32_fp16_tr_x8_word<2>,
load_fp32_fp16_tr_x8_word<3>, load_fp32_fp16_tr_x8_word<4>, load_fp32_fp16_tr_x8_word<5>,
load_fp32_fp16_tr_x8_word<6>, load_fp32_fp16_tr_x8_word<7>, load_fp32_fp16_tr_x8_word<8>};

template <int tail>
inline std::array<__m128i, 8> load_maskz_fp32_fp16_tr_x8_word(const float* a, size_t lda, __m256i mask) {
static_assert(tail > 0 && tail <= 8, "Unexpected tail value.");
std::array<__m128i, 8> dst;
for (int i = 0; i < tail; ++i) {
dst[i] = _mm256_cvtps_ph(_mm256_maskload_ps(a + i * lda, mask), _MM_FROUND_TO_NEAREST_INT);
}
for (int i = tail; i < 8; ++i) dst[i] = _mm_setzero_si128();
return tr_x8_word(dst);
}
constexpr decltype(load_maskz_fp32_fp16_tr_x8_word<1>)* load_maskz_fp32_fp16_tr_x8_word_tbl[9]{
load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<1>, load_maskz_fp32_fp16_tr_x8_word<2>,
load_maskz_fp32_fp16_tr_x8_word<3>, load_maskz_fp32_fp16_tr_x8_word<4>, load_maskz_fp32_fp16_tr_x8_word<5>,
load_maskz_fp32_fp16_tr_x8_word<6>, load_maskz_fp32_fp16_tr_x8_word<7>, load_maskz_fp32_fp16_tr_x8_word<8>};
#ifdef __GNUC__
#pragma GCC diagnostic pop
#endif

#ifdef __GNUC__
#pragma GCC pop_options
#else
Expand Down
22 changes: 22 additions & 0 deletions bestla/bestla/kernel_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -2383,6 +2383,28 @@ static inline BTLA_CODE layernorm(const float* srcptr, const float* scaleptr, co
}
return BTLA_CODE::Success;
}

inline __m512 poly_scale_2nd_ps(const __m512 z, const __m512 f, const __m512 c0, const __m512 c1, const __m512 c2) {
const auto y = _mm512_fmadd_ps(_mm512_fmadd_ps(f, c0, c1), f, c2); // auto y = (f * c0 + c1) * f + c2;
const auto exp = _mm512_scalef_ps(y, z);
return exp;
}

inline __m512 exp_ps_0_1(const __m512 x) {
static const auto c0 = _mm512_set1_ps(0.240226507f);
static const auto c1 = _mm512_set1_ps(0.452920674f);
static const auto c2 = _mm512_set1_ps(0.713483036f);
static const float v_log2e = std::log2(std::exp(1.f));
static const auto log2e = _mm512_set1_ps(v_log2e);
static const auto half = _mm512_set1_ps(.5f);

const auto x1 = _mm512_fmadd_ps(x, log2e, half); // auto x1 = x * log2e + _mm512_set1_ps(.5f);
const auto z = _mm512_floor_ps(x1);
const auto f = _mm512_sub_ps(x1, z); // auto f = x1 - z;

return poly_scale_2nd_ps(z, f, c0, c1, c2);
}

#ifdef __GNUC__
#pragma GCC pop_options
#else
Expand Down
11 changes: 11 additions & 0 deletions bestla/bestla/kernel_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -1522,6 +1522,17 @@ static inline BTLA_CODE layernorm(const T* srcptr, const T* scaleptr, const T* b
}
return BTLA_CODE::Success;
}

inline float exp_ps_0_1(float x) {
static const float log2e = std::log2(std::exp(1.f));
static const float ln2 = std::log(2.f);
const float x1 = x * log2e + .5f;
const float z = std::floor(x1);
const float f = x1 - z;
constexpr std::array<float, 3> coeff{0.240226507f, 0.452920674f, 0.713483036f};
// same as a * std::pow(2, z) but more precise
return ldexpf(coeff[0] * f * f + coeff[1] * f + coeff[2], static_cast<int>(z));
}
} // namespace ref
} // namespace kernel
} // namespace bestla
10 changes: 6 additions & 4 deletions neural_speed/core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

find_package(Threads REQUIRED)
file(GLOB layers_srcs "layers/*.cpp")
file(GLOB test_srcs "layers/*test*.cpp")
list(REMOVE_ITEM layers_srcs ${test_srcs})
set(sources ne_layers.c ${layers_srcs})

add_shareable_library_w_warning(ne_layers "${sources}")
Expand All @@ -37,27 +39,27 @@ endif()

if (NS_BUILD_TESTS)

function(add_test_target src)
function(add_test_target src) # ARGN: additional source
get_filename_component(test_target ${src} NAME_WE)
get_filename_component(src_dir ${src} DIRECTORY)
string(REGEX REPLACE [/\\] "_" src_dir ${src_dir})
if(src_dir)
set (test_target "${src_dir}_${test_target}")
endif()
set (test_target "test_${test_target}")
add_executable_w_warning(${test_target} ${src})
add_executable_w_warning(${test_target} ${src} ${ARGN})
target_compile_definitions(${test_target} PRIVATE NS_TESTS)
target_compile_options(${test_target} PRIVATE -fsanitize=address)
target_link_options(${test_target} PRIVATE -fsanitize=address)
target_include_directories(${test_target} PUBLIC .)
target_link_libraries(${test_target} PUBLIC Threads::Threads bestla::bestla ne_vec)
target_link_libraries(${test_target} PUBLIC Threads::Threads bestla ne_vec)
if(NOT WIN32)
target_link_libraries(${test_target} PUBLIC rt)
endif()
add_test(NAME ${test_target} COMMAND ${test_target})
set_tests_properties(${test_target} PROPERTIES LABELS "${src_dir}_test")
endfunction()

add_test_target(layers/mha_dense.cpp)
add_test_target(layers/mha_dense.cpp layers/mha_dense_tests.cpp)

endif()
Loading
Loading