From 473e280500dd38c646472eba90c592780e7e4ac7 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 27 Jul 2024 10:45:56 +0300 Subject: [PATCH] Fusing a mat mul op followed by scale op on the CPU This is useful for Bitnet here we have almost all matricx multiplications be followed by scale operations. As a result, we get a ~2% boost in Bitnet PP performance. Implementation is easy when the matrix multiplication is done by iqk_mul_mat. But if iqk_mul_mat is not implemented for the quant type/architecture, we need to add the scaling to llamafile sgemm and to ggml itself, which is way more messy, so I didn't do it yet. Given that Bitnet is just a niche thing for now, I'll just leave it on a branch for now. --- ggml/src/ggml-quants.c | 12 ++++++------ ggml/src/ggml.c | 32 ++++++++++++++++++++++++-------- ggml/src/iqk/iqk_mul_mat.cpp | 9 +++++---- ggml/src/iqk/iqk_mul_mat.h | 2 +- ggml/src/iqk/iqk_quantize.cpp | 4 ++-- 5 files changed, 38 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-quants.c b/ggml/src/ggml-quants.c index da4c9b9a..b76d5c69 100644 --- a/ggml/src/ggml-quants.c +++ b/ggml/src/ggml-quants.c @@ -3812,7 +3812,7 @@ static inline __m128i get_scale_shuffle(int i) { void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) { return; } #endif @@ -4296,7 +4296,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) { + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1, 1.f)) { return; } #endif @@ -4585,7 +4585,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) { return; } #endif @@ -4942,7 +4942,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) { + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1, 1.f)) { return; } #endif @@ -5318,7 +5318,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) { return; } #endif @@ -11692,7 +11692,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) { #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) { + if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) { return; } #endif diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c3cda4c4..8cfda7b8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -12295,7 +12295,8 @@ static void ggml_compute_forward_mul_mat_one_chunk( static void ggml_compute_forward_mul_mat( const struct ggml_compute_params * params, - struct ggml_tensor * dst) { + struct ggml_tensor * dst, + float scale) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -12350,7 +12351,7 @@ static void ggml_compute_forward_mul_mat( src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - 0, 1)) goto IQK_MulMat_Not_Available1; + 0, 1, scale)) goto IQK_MulMat_Not_Available1; } } } @@ -12363,7 +12364,7 @@ static void ggml_compute_forward_mul_mat( src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith, nth)) goto IQK_MulMat_Not_Available1; + ith, nth, scale)) goto IQK_MulMat_Not_Available1; return; } IQK_MulMat_Not_Available1:; @@ -12388,6 +12389,11 @@ IQK_MulMat_Not_Available1:; src1->type, dst->type)) goto UseGgmlGemm1; + //TODO: apply scale if different from 1 + //if (fabsf(scale-1.f) > 1e-4f) { + // ggml_barrier(params->shared); + // ggml_compute_forward_scale_f32(params, scale); + //} return; } UseGgmlGemm1:; @@ -12441,7 +12447,7 @@ UseGgmlGemm1:; src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type), vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type), (float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type), - ith, nth)) goto IQK_MulMat_Not_Available2; + ith, nth, scale)) goto IQK_MulMat_Not_Available2; return; } IQK_MulMat_Not_Available2:; @@ -12554,6 +12560,7 @@ UseGgmlGemm2:; current_chunk = atomic_fetch_add(¶ms->shared->current_chunk, 1); } + //TODO: apply scale if different from 1 } // ggml_compute_forward_mul_mat_id @@ -16811,11 +16818,11 @@ static void ggml_compute_forward_cross_entropy_loss_back( ///////////////////////////////// -static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { +static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) { GGML_ASSERT(params); if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) { - return; + return false; } switch (tensor->op) { @@ -16909,7 +16916,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm } break; case GGML_OP_MUL_MAT: { - ggml_compute_forward_mul_mat(params, tensor); + if (next && next->op == GGML_OP_SCALE) { + float scale; + memcpy(&scale, next->op_params, sizeof(float)); + ggml_compute_forward_mul_mat(params, tensor, scale); + return true; + } + ggml_compute_forward_mul_mat(params, tensor, 1.f); } break; case GGML_OP_MUL_MAT_ID: { @@ -17143,6 +17156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm GGML_ASSERT(false); } break; } + return false; } //////////////////////////////////////////////////////////////////////////////// @@ -18991,7 +19005,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; - ggml_compute_forward(¶ms, node); + if (ggml_compute_forward(¶ms, node, node_n < cgraph->n_nodes - 1 ? cgraph->nodes[node_n+1] : NULL)) { + ++node_n; + } if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) { state->shared->ec = GGML_STATUS_ABORTED; diff --git a/ggml/src/iqk/iqk_mul_mat.cpp b/ggml/src/iqk/iqk_mul_mat.cpp index bf517504..eadf2f4f 100644 --- a/ggml/src/iqk/iqk_mul_mat.cpp +++ b/ggml/src/iqk/iqk_mul_mat.cpp @@ -73,6 +73,7 @@ struct DataInfo { int ne11; const mmid_row_mapping * row_mapping = nullptr; size_t bs2 = 0; + float scale; inline const char * src1_row(int iy) const { if (!row_mapping) return cy + (cur_y + iy)*by; @@ -82,7 +83,7 @@ struct DataInfo { } inline void store(int ix, int iy, float result) const { - *(dst_row(iy) + ix) = result; + *(dst_row(iy) + ix) = result*scale; } inline float * dst_row(int iy) const { if (!row_mapping) return s + (cur_y + iy)*bs; @@ -133,7 +134,7 @@ struct MulMat { bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, - float * C, long stride_C, int ith, int nth) { + float * C, long stride_C, int ith, int nth, float scale) { MulMat mm; if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) { @@ -147,7 +148,7 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00, auto first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; - DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0}; + DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0, scale}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); @@ -171,7 +172,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int first_x = ith*nrc_x; if (first_x + nrc_x > Nx) nrc_x = Nx - first_x; DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), - row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)}; + row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float), 1.f}; mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny); return true; } diff --git a/ggml/src/iqk/iqk_mul_mat.h b/ggml/src/iqk/iqk_mul_mat.h index 6bed5f5a..6b105efc 100644 --- a/ggml/src/iqk/iqk_mul_mat.h +++ b/ggml/src/iqk/iqk_mul_mat.h @@ -14,7 +14,7 @@ extern "C" { bool iqk_mul_mat(long Nx, long Ny, long ne00, int typeA, const void * A, long strideA, int typeB, const void * B, long strideB, - float * C, long stride_C, int ith, int nth); + float * C, long stride_C, int ith, int nth, float scale); bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11, int typeA, const void * A, long strideA, diff --git a/ggml/src/iqk/iqk_quantize.cpp b/ggml/src/iqk/iqk_quantize.cpp index 8f541565..c1d8a3d0 100644 --- a/ggml/src/iqk/iqk_quantize.cpp +++ b/ggml/src/iqk/iqk_quantize.cpp @@ -236,7 +236,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64"); #if GGML_USE_IQK_MULMAT - if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) { + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1, 1.f)) { return; } #endif @@ -286,7 +286,7 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si static_assert(QK_IQ1BN == 64, "This dot product implementation for iq2_bn requires a block size of 64"); - if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) { + if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1, 1.f)) { return; }