Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fusing a mat mul op followed by a scale op on the CPU #5

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions ggml/src/ggml-quants.c
Original file line number Diff line number Diff line change
Expand Up @@ -3812,7 +3812,7 @@ static inline __m128i get_scale_shuffle(int i) {

void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
Expand Down Expand Up @@ -4296,7 +4296,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q4_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
Expand Down Expand Up @@ -4585,7 +4585,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
Expand Down Expand Up @@ -4942,7 +4942,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q5_1, vx, bx, GGML_TYPE_Q8_1, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
Expand Down Expand Up @@ -5318,7 +5318,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * restrict s, size_t bs, const void * r

void ggml_vec_dot_q8_0_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_Q8_0, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
Expand Down Expand Up @@ -11692,7 +11692,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void

void ggml_vec_dot_iq4_nl_q8_0(int n, float * restrict s, size_t bs, const void * restrict vx, size_t bx, const void * restrict vy, size_t by, int nrc) {
#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1)) {
if (iqk_mul_mat(nrc, nrc, n, GGML_TYPE_IQ4_NL, vx, bx, GGML_TYPE_Q8_0, vy, by, s, bs, 0, 1, 1.f)) {
return;
}
#endif
Expand Down
32 changes: 24 additions & 8 deletions ggml/src/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -12295,7 +12295,8 @@ static void ggml_compute_forward_mul_mat_one_chunk(

static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
struct ggml_tensor * dst,
float scale) {

const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
Expand Down Expand Up @@ -12350,7 +12351,7 @@ static void ggml_compute_forward_mul_mat(
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
0, 1)) goto IQK_MulMat_Not_Available1;
0, 1, scale)) goto IQK_MulMat_Not_Available1;
}
}
}
Expand All @@ -12363,7 +12364,7 @@ static void ggml_compute_forward_mul_mat(
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
src1->type, (const char *)src1->data + i12*nb12 + i13*nb13, nb11/ggml_type_size(src1->type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available1;
ith, nth, scale)) goto IQK_MulMat_Not_Available1;
return;
}
IQK_MulMat_Not_Available1:;
Expand All @@ -12388,6 +12389,11 @@ IQK_MulMat_Not_Available1:;
src1->type,
dst->type))
goto UseGgmlGemm1;
//TODO: apply scale if different from 1
//if (fabsf(scale-1.f) > 1e-4f) {
// ggml_barrier(params->shared);
// ggml_compute_forward_scale_f32(params, scale);
//}
return;
}
UseGgmlGemm1:;
Expand Down Expand Up @@ -12441,7 +12447,7 @@ UseGgmlGemm1:;
src0->type, (const char *)src0->data + i12/r2*nb02 + i13/r3*nb03, nb01/ggml_type_size(src0->type),
vec_dot_type, (const char *)wdata + (i12*ne11 + i13*ne12*ne11)*row_size, row_size/ggml_type_size(vec_dot_type),
(float *)((char *)dst->data + i12*nb2 + i13*nb3), nb1/ggml_type_size(dst->type),
ith, nth)) goto IQK_MulMat_Not_Available2;
ith, nth, scale)) goto IQK_MulMat_Not_Available2;
return;
}
IQK_MulMat_Not_Available2:;
Expand Down Expand Up @@ -12554,6 +12560,7 @@ UseGgmlGemm2:;

current_chunk = atomic_fetch_add(&params->shared->current_chunk, 1);
}
//TODO: apply scale if different from 1
}

// ggml_compute_forward_mul_mat_id
Expand Down Expand Up @@ -16811,11 +16818,11 @@ static void ggml_compute_forward_cross_entropy_loss_back(

/////////////////////////////////

static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor, struct ggml_tensor * next) {
GGML_ASSERT(params);

if (tensor->op == GGML_OP_NONE || ggml_is_empty(tensor)) {
return;
return false;
}

switch (tensor->op) {
Expand Down Expand Up @@ -16909,7 +16916,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_MUL_MAT:
{
ggml_compute_forward_mul_mat(params, tensor);
if (next && next->op == GGML_OP_SCALE) {
float scale;
memcpy(&scale, next->op_params, sizeof(float));
ggml_compute_forward_mul_mat(params, tensor, scale);
return true;
}
ggml_compute_forward_mul_mat(params, tensor, 1.f);
} break;
case GGML_OP_MUL_MAT_ID:
{
Expand Down Expand Up @@ -17143,6 +17156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
GGML_ASSERT(false);
} break;
}
return false;
}

////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -18991,7 +19005,9 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
for (int node_n = 0; node_n < cgraph->n_nodes; node_n++) {
struct ggml_tensor * node = cgraph->nodes[node_n];

ggml_compute_forward(&params, node);
if (ggml_compute_forward(&params, node, node_n < cgraph->n_nodes - 1 ? cgraph->nodes[node_n+1] : NULL)) {
++node_n;
}

if (state->ith == 0 && cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->ec = GGML_STATUS_ABORTED;
Expand Down
9 changes: 5 additions & 4 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ struct DataInfo {
int ne11;
const mmid_row_mapping * row_mapping = nullptr;
size_t bs2 = 0;
float scale;

inline const char * src1_row(int iy) const {
if (!row_mapping) return cy + (cur_y + iy)*by;
Expand All @@ -82,7 +83,7 @@ struct DataInfo {
}

inline void store(int ix, int iy, float result) const {
*(dst_row(iy) + ix) = result;
*(dst_row(iy) + ix) = result*scale;
}
inline float * dst_row(int iy) const {
if (!row_mapping) return s + (cur_y + iy)*bs;
Expand Down Expand Up @@ -133,7 +134,7 @@ struct MulMat {
bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth) {
float * C, long stride_C, int ith, int nth, float scale) {

MulMat mm;
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
Expand All @@ -147,7 +148,7 @@ bool iqk_mul_mat(long Nx, long Ny, long ne00,
auto first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;

DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0, scale};

mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);

Expand All @@ -171,7 +172,7 @@ bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int first_x = ith*nrc_x;
if (first_x + nrc_x > Nx) nrc_x = Nx - first_x;
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float), 1.f};
mm.mul_mat_NxM(ne00, (const char *)A + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny);
return true;
}
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/iqk/iqk_mul_mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ extern "C" {
bool iqk_mul_mat(long Nx, long Ny, long ne00,
int typeA, const void * A, long strideA,
int typeB, const void * B, long strideB,
float * C, long stride_C, int ith, int nth);
float * C, long stride_C, int ith, int nth, float scale);

bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
int typeA, const void * A, long strideA,
Expand Down
4 changes: 2 additions & 2 deletions ggml/src/iqk/iqk_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ void ggml_vec_dot_iq1_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si
static_assert(QK_IQ1BN == 64, "This dot product implementation for iq1_bn requires a block size of 64");

#if GGML_USE_IQK_MULMAT
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ1_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1, 1.f)) {
return;
}
#endif
Expand Down Expand Up @@ -286,7 +286,7 @@ void ggml_vec_dot_iq2_bn_q8_K64(int n, float * s, size_t bs, const void * vx, si

static_assert(QK_IQ1BN == 64, "This dot product implementation for iq2_bn requires a block size of 64");

if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1)) {
if (iqk_mul_mat(1, 1, n, GGML_TYPE_IQ2_BN, vx, 0, GGML_TYPE_Q8_K64, vy, 0, s, 0, 0, 1, 1.f)) {
return;
}

Expand Down