Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AVX2/Zen4 horizontal sums #57

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 140 additions & 24 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,41 @@ inline float hmax_float_8(__m256 x) {
max4 = _mm_max_ss( max4, _mm_movehdup_ps( max4));
return _mm_cvtss_f32(max4);
}
IQK_ALWAYS_INLINE __m256 hsum_float_8x8(__m256 * accm) {
for (int i = 0; i < 4; ++i) {
accm[i] = _mm256_set_m128(_mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1)),
_mm_add_ps(_mm256_castps256_ps128(accm[i+0]), _mm256_extractf128_ps(accm[i+0], 1)));
}
for (int i = 0; i < 2; ++i) accm[i] = _mm256_add_ps(_mm256_unpacklo_ps(accm[i], accm[i+2]), _mm256_unpackhi_ps(accm[i], accm[i+2]));
return _mm256_add_ps(_mm256_unpacklo_ps(accm[0], accm[1]), _mm256_unpackhi_ps(accm[0], accm[1]));
}
#ifdef HAVE_FANCY_SIMD
IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) {
union { __m256 vec; float val[8]; } h;
h.vec = hsum_float_8x8(accm);
for (int iy = 0; iy < 8; ++iy) info.store(ix, iy, h.val[iy]);
}
#else
// Somehow on the AVX2 system that I have available (Ryzen-5975WX), the store_8 version above
// and the commented out store_8 version below are slower than this.
IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) {
for (int iy = 0; iy < 8; ++iy) info.store(ix, iy, hsum_float_8(accm[iy]));
}
//IQK_ALWAYS_INLINE __m128 hsum_float_4x4(__m128 * a) {
// for (int i = 0; i < 2; ++i) a[i] = _mm_add_ps(_mm_unpacklo_ps(a[i], a[i+2]), _mm_unpackhi_ps(a[i], a[i+2]));
// return _mm_add_ps(_mm_unpacklo_ps(a[0], a[1]), _mm_unpackhi_ps(a[0], a[1]));
//}
//IQK_ALWAYS_INLINE void store_8(int ix, __m256 * accm, const DataInfo& info) {
// union { __m128 vec; float val[4]; } h;
// __m128 a[4];
// for (int i = 0; i < 4; ++i) a[i] = _mm_add_ps(_mm256_castps256_ps128(accm[i]), _mm256_extractf128_ps(accm[i], 1));
// h.vec = hsum_float_4x4(a);
// for (int iy = 0; iy < 4; ++iy) info.store(ix, iy, h.val[iy]);
// for (int i = 0; i < 4; ++i) a[i] = _mm_add_ps(_mm256_castps256_ps128(accm[i+4]), _mm256_extractf128_ps(accm[i+4], 1));
// h.vec = hsum_float_4x4(a);
// for (int iy = 0; iy < 4; ++iy) info.store(ix, iy+4, h.val[iy]);
#endif


#define MM256_SET_M128I(a, b) _mm256_insertf128_si256(_mm256_castsi128_si256(b), (a), 1)

Expand Down Expand Up @@ -1128,9 +1163,17 @@ static void mul_mat_qX_K_q8_K_AVX512(int n, const void * vx, size_t bx, const Da

}

for (int iy = 0; iy < nrc_y; ++iy) {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
if constexpr (nrc_y == 8) {
for (int iy = 0; iy < nrc_y; ++iy) {
accm[iy] = _mm256_add_ps(accm[iy], _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1)));
}
store_8(ix, accm, info);
}
else {
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
}
}

}
Expand Down Expand Up @@ -1177,18 +1220,30 @@ static void mul_mat_iq2tn_q8_K_AVX512(int n, const void * vx, size_t bx, const D
sumi_2 = _mm512_dpbusd_epi32(sumi_2, deq2.bits.values[3], q8q);
// The scale is supposed to be per per tensor, so we can use the same scale
auto vd = _mm512_set1_ps(d*q8.scale(iy, i));
accd[2*iy+0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
accd[2*iy+1] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
accd[iy+ 0] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_1), accd[iy+ 0]);
accd[iy+nrc_y] = _mm512_fmadd_ps(vd, _mm512_cvtepi32_ps(sumi_2), accd[iy+nrc_y]);
// Leaving this here just in case ternary models start using per row scales
//accd[2*iy+0] = _mm512_fmadd_ps(_mm512_set1_ps(deq1.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_1), accd[2*iy+0]);
//accd[2*iy+1] = _mm512_fmadd_ps(_mm512_set1_ps(deq2.d*q8.scale(iy, i)), _mm512_cvtepi32_ps(sumi_2), accd[2*iy+1]);
}

}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[2*iy+0]));
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[2*iy+1]));
if constexpr (nrc_y == 8) {
__m256 sums[8];
for (int iy = 0; iy < nrc_y; ++iy) {
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
}
store_8(ix+0, sums, info);
for (int iy = 0; iy < nrc_y; ++iy) {
sums[iy] = _mm256_add_ps(_mm512_castps512_ps256(accd[iy+nrc_y]), _mm512_extractf32x8_ps(accd[iy+nrc_y], 1));
}
store_8(ix+1, sums, info);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix+0, iy, _mm512_reduce_add_ps(accd[iy+ 0]));
info.store(ix+1, iy, _mm512_reduce_add_ps(accd[iy+nrc_y]));
}
}

}
Expand Down Expand Up @@ -1230,9 +1285,18 @@ static void mul_mat_iqX_k_q8_K_AVX512(int n, const void * vx, size_t bx, const D

}

for (int iy = 0; iy < nrc_y; ++iy) {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
if constexpr (nrc_y == 8) {
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
accm[iy] = _mm256_add_ps(accm[iy], sum256);
}
store_8(ix, accm, info);
}
else {
for (int iy = 0; iy < nrc_y; ++iy) {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd[iy]), _mm512_extractf32x8_ps(accd[iy], 1));
info.store(ix, iy, hsum_float_8(_mm256_add_ps(accm[iy], sum256)));
}
}

}
Expand All @@ -1256,6 +1320,9 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const

__m512i scales[2*k_nx];

__m256 sums[8];

int ks = 0;
for (int ix = 0; ix < nrc_x; ++ix) {

auto accd = _mm512_setzero_ps();
Expand Down Expand Up @@ -1289,12 +1356,21 @@ static void mul_mat_qX_K_q8_K_AVX512_1(int n, const void * vx, size_t bx, const
}

if constexpr (std::is_same_v<Dequantizer, DequantizerIQ2TN>) {
info.store(ix, 0, _mm512_reduce_add_ps(accd));
sums[ks++] = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
//info.store(ix, 0, _mm512_reduce_add_ps(accd));
} else {
auto sum256 = _mm256_add_ps(_mm512_castps512_ps256(accd), _mm512_extractf32x8_ps(accd, 1));
info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
sums[ks++] = _mm256_add_ps(accm, sum256);
//info.store(ix, 0, hsum_float_8(_mm256_add_ps(accm, sum256)));
}
if (ks == 8) {
_mm256_storeu_ps(info.dst_row(0) + ix - 7, hsum_float_8x8(sums));
ks = 0;
}
}
if (ks > 0) {
for (int ix = 0; ix < ks; ++ix) info.store(ix, 0, hsum_float_8(sums[ix]));
}
}

#else
Expand Down Expand Up @@ -1833,8 +1909,12 @@ IQK_NOINLINE void mul_mat_iq2tn_q8_K(int n, const void * vx, size_t bx, const Da

}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
if constexpr (nrc_y == 8) {
store_8(ix, accd, info);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}

}
Expand Down Expand Up @@ -1877,10 +1957,13 @@ static void mul_mat_qY_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf

}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
if constexpr (nrc_y == 8) {
store_8(ix, accd, info);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}

}

}
Expand Down Expand Up @@ -1926,8 +2009,12 @@ static void mul_mat_qX_K_q8_K_T(int n, const void * vx, size_t bx, const DataInf

}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
if constexpr (nrc_y == 8) {
store_8(ix, accd, info);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}

}
Expand Down Expand Up @@ -2094,8 +2181,12 @@ static void mul_mat_qX_K_q8_K_IQ_N(int n, const void * vx, size_t bx, const Data
}
}

for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
if constexpr (nrc_y == 8) {
store_8(ix, accd, info);
} else {
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, hsum_float_8(accd[iy]));
}
}
}
}
Expand Down Expand Up @@ -2999,10 +3090,17 @@ struct ScaleHelperQ_1 {
}
};

struct MinusType0 {
template <int nrc_y> struct MinusType0 {
inline __m256 compute(__m128 d, int) const { return _mm256_set_m128(d, d); }
inline float compute(float d, int) const { return d; }
inline float result(__m256 acc, int) const { return hsum_float_8(acc); }
//inline void store(int ix, __m256 * acc, const DataInfo& info) {
// if constexpr (nrc_y == 8) {
// store_8(ix, acc, info);
// } else {
// for (int iy = 0; iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_8(acc[iy]));
// }
//}
};

template <int nrc_y> struct MinusType1 {
Expand All @@ -3022,6 +3120,23 @@ template <int nrc_y> struct MinusType1 {
const __m128 sum = _mm_add_ps(_mm256_castps256_ps128(acc), _mm256_extractf128_ps(acc, 1));
return hsum_float_4(_mm_add_ps(sum, accm[iy]));
}
//inline void store(int ix, const __m256 * acc, const DataInfo& info) {
// for (int iy = 0; iy < nrc_y; ++iy) {
// accm[iy] = _mm_add_ps(accm[iy], _mm_add_ps(_mm256_castps256_ps128(acc[iy]), _mm256_extractf128_ps(acc[iy], 1)));
// }
// if constexpr (nrc_y >= 4) {
// union { __m128 vec; float val[4]; } h;
// for (int i = 0; i < nrc_y/4; ++i) {
// accm[4*i+0] = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+0], accm[4*i+2]), _mm_unpackhi_ps(accm[4*i+0], accm[4*i+2]));
// accm[4*i+1] = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+1], accm[4*i+3]), _mm_unpackhi_ps(accm[4*i+1], accm[4*i+3]));
// h.vec = _mm_add_ps(_mm_unpacklo_ps(accm[4*i+0], accm[4*i+1]), _mm_unpackhi_ps(accm[4*i+0], accm[4*i+1]));
// for (int j = 0; j < 4; ++j) info.store(ix, 4*i+j, h.val[j]);
// }
// for (int iy = 4*(nrc_y/4); iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_4(accm[iy]));
// } else {
// for (int iy = 0; iy < nrc_y; ++iy) info.store(ix, iy, hsum_float_4(accm[iy]));
// }
//}
};

template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
Expand Down Expand Up @@ -3054,14 +3169,15 @@ template <typename Minus, int nrc_y, bool is_multiple_of_4> struct AccumT {
}
}
}
//accm.store(ix, acc, info);
for (int iy = 0; iy < nrc_y; ++iy) {
info.store(ix, iy, accm.result(acc[iy], iy));
}
}
};

template <int nrc_y, bool is_multiple_of_4>
using AccumType0 = AccumT<MinusType0, nrc_y, is_multiple_of_4>;
using AccumType0 = AccumT<MinusType0<nrc_y>, nrc_y, is_multiple_of_4>;

template <int nrc_y, bool is_multiple_of_4>
using AccumType1 = AccumT<MinusType1<nrc_y>, nrc_y, is_multiple_of_4>;
Expand Down