diff --git a/neural_speed/core/ne_layers.c b/neural_speed/core/ne_layers.c index 080aa8ba6..e5c7d6147 100644 --- a/neural_speed/core/ne_layers.c +++ b/neural_speed/core/ne_layers.c @@ -8965,56 +8965,24 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, const int64_t mode = ((int32_t*)src1->data)[ROPE_MODE_IDX]; const int64_t prompt_size = ((int32_t*)src1->data)[ROPE_PROMPTSIZE_IDX]; const int64_t n_keep = ((int32_t*)src1->data)[ROPE_NKEEP_IDX]; - const float longfactor[48] = {1.0299999713897705,1.0499999523162842,1.0499999523162842,1.0799999237060547,1.2299998998641968,1.2299998998641968,1.2999999523162842,1.4499999284744263, - 1.5999999046325684, - 1.6499998569488525, - 1.8999998569488525, - 2.859999895095825, - 3.68999981880188, - 5.419999599456787, - 5.489999771118164, - 5.489999771118164, - 9.09000015258789, - 11.579999923706055, - 15.65999984741211, - 15.769999504089355, - 15.789999961853027, - 18.360000610351562, - 21.989999771118164, - 23.079999923706055, - 30.009998321533203, - 32.35000228881836, - 32.590003967285156, - 35.56000518798828, - 39.95000457763672, - 53.840003967285156, - 56.20000457763672, - 57.95000457763672, - 59.29000473022461, - 59.77000427246094, - 59.920005798339844, - 61.190006256103516, - 61.96000671386719, - 62.50000762939453, - 63.3700065612793, - 63.48000717163086, - 63.48000717163086, - 63.66000747680664, - 63.850006103515625, - 64.08000946044922, - 64.760009765625, - 64.80001068115234, - 64.81001281738281, - 64.81001281738281 - }; - const float shortfactor[48] = {1.04999995, 1.04999995, 1.04999995, 1.10000002, 1.10000002, 1.14999998, - 1.20000005, 1.25000000, 1.29999995, 1.35000002, 1.50000000, 2.00000000, - 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, - 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, - 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, - 2.00000000, 2.00000000, 2.04999995, 2.04999995, 2.04999995, 2.09999990, - 2.09999990, 2.09999990, 2.15000010, 2.15000010, 2.34999990, 2.54999995, - 2.59999990, 2.59999990, 2.75000000, 2.84999990, 2.84999990, 2.95000005}; + const float longfactor[48] = { + 1.0299999713897705, 1.0499999523162842, 1.0499999523162842, 1.0799999237060547, 1.2299998998641968, + 1.2299998998641968, 1.2999999523162842, 1.4499999284744263, 1.5999999046325684, 1.6499998569488525, + 1.8999998569488525, 2.859999895095825, 3.68999981880188, 5.419999599456787, 5.489999771118164, + 5.489999771118164, 9.09000015258789, 11.579999923706055, 15.65999984741211, 15.769999504089355, + 15.789999961853027, 18.360000610351562, 21.989999771118164, 23.079999923706055, 30.009998321533203, + 32.35000228881836, 32.590003967285156, 35.56000518798828, 39.95000457763672, 53.840003967285156, + 56.20000457763672, 57.95000457763672, 59.29000473022461, 59.77000427246094, 59.920005798339844, + 61.190006256103516, 61.96000671386719, 62.50000762939453, 63.3700065612793, 63.48000717163086, + 63.48000717163086, 63.66000747680664, 63.850006103515625, 64.08000946044922, 64.760009765625, + 64.80001068115234, 64.81001281738281, 64.81001281738281}; + const float shortfactor[48] = {1.04999995, 1.04999995, 1.04999995, 1.10000002, 1.10000002, 1.14999998, 1.20000005, + 1.25000000, 1.29999995, 1.35000002, 1.50000000, 2.00000000, 2.00000000, 2.00000000, + 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, + 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, + 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.04999995, 2.04999995, 2.04999995, + 2.09999990, 2.09999990, 2.09999990, 2.15000010, 2.15000010, 2.34999990, 2.54999995, + 2.59999990, 2.59999990, 2.75000000, 2.84999990, 2.84999990, 2.95000005}; assert(n_past >= 0); NE_TENSOR_UNARY_OP_LOCALS; @@ -9047,8 +9015,8 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, const bool skip = mode & 1; const bool is_neox = mode & 2; const bool is_glm = mode & 4; - const bool is_phi_short = mode==16? true : false; - const bool is_phi_long = mode==17? true : false; + const bool is_phi_short = mode == 16 ? true : false; + const bool is_phi_long = mode == 17 ? true : false; const bool is_shift = n_keep >= 0; const bool use_yarn = ((mode & 0x8) != 0); NE_ASSERT(("RoPE shift not supported!", !is_shift)); @@ -9091,7 +9059,7 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, dst_data[n_dims] = x2 * cos_block_theta - x3 * sin_block_theta; dst_data[n_dims / 2 * 3] = x2 * sin_block_theta + x3 * cos_block_theta; } - } else if(is_phi_short){ + } else if (is_phi_short) { // TODO: this is probably wrong, but I can't figure it out .. // ref: // https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 @@ -9103,11 +9071,11 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, float cur_rot = inv_ndims * ic - ib; float cos_theta, sin_theta; - float tmp_theta_base=theta_base/shortfactor[ic/2]; + float tmp_theta_base = theta_base / shortfactor[ic / 2]; rope_yarn(tmp_theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - cos_theta *=scale_factor; - sin_theta *=scale_factor; + cos_theta *= scale_factor; + sin_theta *= scale_factor; theta_base *= theta_scale; const int64_t i0 = ib * n_dims + ic / 2; @@ -9120,10 +9088,9 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, dst_data[0] = x0 * cos_theta - x1 * sin_theta; dst_data[n_dims / 2] = x0 * sin_theta + x1 * cos_theta; - } - } - } - else if(is_phi_long){ + } + } + } else if (is_phi_long) { // TODO: this is probably wrong, but I can't figure it out .. // ref: // https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 @@ -9136,11 +9103,11 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, float cur_rot = inv_ndims * ic - ib; float cos_theta, sin_theta; - float tmp_theta_base=theta_base / longfactor[ic/2]; - rope_yarn(tmp_theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta, + float tmp_theta_base = theta_base / longfactor[ic / 2]; + rope_yarn(theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); - cos_theta *=scale_factor; - sin_theta *=scale_factor; + cos_theta *= scale_factor; + sin_theta *= scale_factor; theta_base *= theta_scale; const int64_t i0 = ib * n_dims + ic / 2; @@ -9153,9 +9120,9 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, dst_data[0] = x0 * cos_theta - x1 * sin_theta; dst_data[n_dims / 2] = x0 * sin_theta + x1 * cos_theta; - } - } - }else if (!is_neox) { + } + } + } else if (!is_neox) { // printf("theta_base = %ld, freq_scale %.4f, ne0 %d\n", p, freq_scale, ne0); for (int64_t i0 = 0; i0 < ne0; i0 += 2) { float cos_theta, sin_theta; @@ -9177,7 +9144,7 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params, // ref: // https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 theta_base = theta_base * freq_scale; - + for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) { for (int64_t ic = 0; ic < n_dims; ic += 2) { // simplified from `(ib * n_dims + ic) * inv_ndims` @@ -9237,6 +9204,24 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params, const size_t nb1 = dst->nb[1]; const size_t nb2 = dst->nb[2]; const size_t nb3 = dst->nb[3]; + const float longfactor[48] = { + 1.0299999713897705, 1.0499999523162842, 1.0499999523162842, 1.0799999237060547, 1.2299998998641968, + 1.2299998998641968, 1.2999999523162842, 1.4499999284744263, 1.5999999046325684, 1.6499998569488525, + 1.8999998569488525, 2.859999895095825, 3.68999981880188, 5.419999599456787, 5.489999771118164, + 5.489999771118164, 9.09000015258789, 11.579999923706055, 15.65999984741211, 15.769999504089355, + 15.789999961853027, 18.360000610351562, 21.989999771118164, 23.079999923706055, 30.009998321533203, + 32.35000228881836, 32.590003967285156, 35.56000518798828, 39.95000457763672, 53.840003967285156, + 56.20000457763672, 57.95000457763672, 59.29000473022461, 59.77000427246094, 59.920005798339844, + 61.190006256103516, 61.96000671386719, 62.50000762939453, 63.3700065612793, 63.48000717163086, + 63.48000717163086, 63.66000747680664, 63.850006103515625, 64.08000946044922, 64.760009765625, + 64.80001068115234, 64.81001281738281, 64.81001281738281}; + const float shortfactor[48] = {1.04999995, 1.04999995, 1.04999995, 1.10000002, 1.10000002, 1.14999998, 1.20000005, + 1.25000000, 1.29999995, 1.35000002, 1.50000000, 2.00000000, 2.00000000, 2.00000000, + 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, + 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, + 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.04999995, 2.04999995, 2.04999995, + 2.09999990, 2.09999990, 2.09999990, 2.15000010, 2.15000010, 2.34999990, 2.54999995, + 2.59999990, 2.59999990, 2.75000000, 2.84999990, 2.84999990, 2.95000005}; NE_ASSERT(nb0 == sizeof(ne_fp16_t)); @@ -9266,6 +9251,8 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params, const bool skip = mode & 1; const bool is_neox = mode & 2; const bool is_glm = mode & 4; + const bool is_phi_short = mode == 16 ? true : false; + const bool is_phi_long = mode == 17 ? true : false; NE_ASSERT(("glm mode RoPE is not implemented!", !is_glm)); const bool is_shift = n_keep >= 0; NE_ASSERT(("shift RoPE is only implemented for the vanilla mode", !is_shift || !(is_glm || is_neox || skip))); @@ -9317,8 +9304,58 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params, if (ir > ir1) break; float theta = freq_scale * (float)p; + float scale_factor = 1.1902380714238083; + if (is_phi_short) { + // TODO: this is probably wrong, but I can't figure it out .. + // ref: + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + float tmp_theta = theta / shortfactor[ic / 2]; + const float cos_theta = scale_factor * cosf(tmp_theta); + const float sin_theta = scale_factor * sinf(tmp_theta); - if (!is_neox) { + theta *= theta_scale; + + const int64_t i0 = ib * n_dims + ic / 2; + + const ne_fp16_t* const src = + (ne_fp16_t*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00); + ne_fp16_t* dst_data = (ne_fp16_t*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); + + const float x0 = NE_FP16_TO_FP32(src[0]); + const float x1 = NE_FP16_TO_FP32(src[n_dims / 2]); + + dst_data[0] = NE_FP32_TO_FP16(x0 * cos_theta - x1 * sin_theta); + dst_data[n_dims / 2] = NE_FP32_TO_FP16(x0 * sin_theta + x1 * cos_theta); + } + } + } else if (is_phi_long) { + // TODO: this is probably wrong, but I can't figure it out .. + // ref: + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28 + for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) { + for (int64_t ic = 0; ic < n_dims; ic += 2) { + float tmp_theta = theta / longfactor[ic / 2]; + const float cos_theta = scale_factor * cosf(tmp_theta); + const float sin_theta = scale_factor * sinf(tmp_theta); + + theta *= theta_scale; + + const int64_t i0 = ib * n_dims + ic / 2; + + const ne_fp16_t* const src = + (ne_fp16_t*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00); + ne_fp16_t* dst_data = (ne_fp16_t*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0); + + const float x0 = NE_FP16_TO_FP32(src[0]); + const float x1 = NE_FP16_TO_FP32(src[n_dims / 2]); + + dst_data[0] = NE_FP32_TO_FP16(x0 * cos_theta - x1 * sin_theta); + dst_data[n_dims / 2] = NE_FP32_TO_FP16(x0 * sin_theta + x1 * cos_theta); + } + } + } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); diff --git a/neural_speed/models/model_utils/gguf.h b/neural_speed/models/model_utils/gguf.h index 054c40d87..69d1b856b 100644 --- a/neural_speed/models/model_utils/gguf.h +++ b/neural_speed/models/model_utils/gguf.h @@ -239,27 +239,29 @@ enum llm_arch { LLM_ARCH_PHI3, }; -static std::map LLM_ARCH_NAMES = {{LLM_ARCH_LLAMA, "llama"}, - {LLM_ARCH_FALCON, "falcon"}, - {LLM_ARCH_GPT2, "gpt2"}, - {LLM_ARCH_GPTJ, "gptj"}, - {LLM_ARCH_GPTNEOX, "gptneox"}, - {LLM_ARCH_MPT, "mpt"}, - {LLM_ARCH_BAICHUAN, "baichuan"}, - {LLM_ARCH_STARCODER, "starcoder"}, - {LLM_ARCH_PERSIMMON, "persimmon"}, - {LLM_ARCH_REFACT, "refact"}, - {LLM_ARCH_BLOOM, "bloom"}, - {LLM_ARCH_STABLELM, "stablelm"}, - {LLM_ARCH_QWEN, "qwen"}, - {LLM_ARCH_CHATGLM, "chatglm"}, - {LLM_ARCH_CHATGLM2, "chatglm2"}, - {LLM_ARCH_CHATGLM3, "chatglm3"}, - {LLM_ARCH_PHI, "phi"}, - {LLM_ARCH_GEMMA, "gemma"}, - {LLM_ARCH_QWEN2, "qwen2"}, - {LLM_ARCH_GROK, "grok"}, - {LLM_ARCH_PHI3, "phi3"},}; +static std::map LLM_ARCH_NAMES = { + {LLM_ARCH_LLAMA, "llama"}, + {LLM_ARCH_FALCON, "falcon"}, + {LLM_ARCH_GPT2, "gpt2"}, + {LLM_ARCH_GPTJ, "gptj"}, + {LLM_ARCH_GPTNEOX, "gptneox"}, + {LLM_ARCH_MPT, "mpt"}, + {LLM_ARCH_BAICHUAN, "baichuan"}, + {LLM_ARCH_STARCODER, "starcoder"}, + {LLM_ARCH_PERSIMMON, "persimmon"}, + {LLM_ARCH_REFACT, "refact"}, + {LLM_ARCH_BLOOM, "bloom"}, + {LLM_ARCH_STABLELM, "stablelm"}, + {LLM_ARCH_QWEN, "qwen"}, + {LLM_ARCH_CHATGLM, "chatglm"}, + {LLM_ARCH_CHATGLM2, "chatglm2"}, + {LLM_ARCH_CHATGLM3, "chatglm3"}, + {LLM_ARCH_PHI, "phi"}, + {LLM_ARCH_GEMMA, "gemma"}, + {LLM_ARCH_QWEN2, "qwen2"}, + {LLM_ARCH_GROK, "grok"}, + {LLM_ARCH_PHI3, "phi3"}, +}; struct gguf_tensor_info { struct gguf_str name; diff --git a/neural_speed/models/model_utils/model_files.h b/neural_speed/models/model_utils/model_files.h index 5997e9b7e..444d28976 100644 --- a/neural_speed/models/model_utils/model_files.h +++ b/neural_speed/models/model_utils/model_files.h @@ -1423,7 +1423,7 @@ struct model_model_loader { } struct ne_tensor* get_tensor(const std::string& name, const std::vector& ne, ne_backend backend) { - auto it = tensors_map.name_to_idx.find(name); + auto it = tensors_map.name_to_idx.find(name); if (it == tensors_map.name_to_idx.end()) { throw format("%s: tensor '%s' is missing from model", __func__, name.c_str()); } diff --git a/neural_speed/models/phi/phi3.cpp b/neural_speed/models/phi/phi3.cpp index 7045792be..ef189b47b 100644 --- a/neural_speed/models/phi/phi3.cpp +++ b/neural_speed/models/phi/phi3.cpp @@ -37,7 +37,6 @@ #include "models/model_utils/model_utils.h" #include "models/model_utils/util.h" - // evaluate the transformer // // - lctx: model context @@ -122,7 +121,7 @@ static bool phi3_model_eval_internal(model_context* ctx, const model_input* inpu ne_set_name(embd, "embd"); for (int i = 0; i < batch_size; ++i) { memcpy(static_cast(embd->data) + i * N, (inputs + i)->tokens, N * ne_element_size(embd)); - // memcpy(static_cast(embd->data) + i * N, embd_input, N * ne_element_size(embd)); + // memcpy(static_cast(embd->data) + i * N, embd_input, N * ne_element_size(embd)); } struct ne_tensor* inpL = ne_get_rows(ctx0, model.others[0], embd); @@ -141,33 +140,37 @@ static bool phi3_model_eval_internal(model_context* ctx, const model_input* inpu } // compute QKV cur = ne_mul_mat(ctx0, model.layers[il].attn[0], cur); - struct ne_tensor* Qcur = ne_reshape_3d(ctx0, ne_cont(ctx0, ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd)),head_dim,n_head,N); - struct ne_tensor* Kcur = ne_reshape_3d(ctx0,ne_cont(ctx0, ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd)),head_dim,n_head,N); - struct ne_tensor* Vcur = ne_reshape_3d(ctx0,ne_cont(ctx0, ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd)),head_dim,n_head,N); + struct ne_tensor* Qcur = + ne_reshape_3d(ctx0, ne_cont(ctx0, ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 0 * sizeof(float) * n_embd)), + head_dim, n_head, N); + struct ne_tensor* Kcur = + ne_reshape_3d(ctx0, ne_cont(ctx0, ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 1 * sizeof(float) * n_embd)), + head_dim, n_head, N); + struct ne_tensor* Vcur = + ne_reshape_3d(ctx0, ne_cont(ctx0, ne_view_2d(ctx0, cur, n_embd, N, cur->nb[1], 2 * sizeof(float) * n_embd)), + head_dim, n_head, N); // using mode = 2 for GPT-NeoX mode // struct ne_tensor* Qcur_Part = ne_view_4d(ctx0, ne_permute(ctx0, Qcur, 0, 2, 1, 3), n_rot, n_head, N, 1, // Qcur->nb[1], Qcur->nb[2], Qcur->nb[3], 0); - if (hparams.max_seq_len > 4096){ - if(N <=4096) { + if (hparams.max_seq_len > 4096) { + if (N <= 4096) { Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 16, 0, hparams.freq_base, hparams.freq_scale); Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 16, 0, hparams.freq_base, hparams.freq_scale); - } - else { + } else { Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 17, 0, hparams.freq_base, hparams.freq_scale); Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 17, 0, hparams.freq_base, hparams.freq_scale); } + } else { + Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); + Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); } - else { - Qcur = ne_rope_inplace(ctx0, Qcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); - Kcur = ne_rope_inplace(ctx0, Kcur, n_past, n_rot, 2, 0, hparams.freq_base, hparams.freq_scale); - } - + // ne_build_forward_expand(&gf, Qcur_Part); ne_set_name(Qcur, "Qcur"); // struct ne_tensor* Kcur_Part = ne_view_4d(ctx0, ne_permute(ctx0, Kcur, 0, 2, 1, 3), n_rot, n_head, N, 1, // Kcur->nb[1], Kcur->nb[2], Kcur->nb[3], 0); - + // ne_build_forward_expand(&gf, Kcur_Part); ne_set_name(Kcur, "kcur"); const float attn_scale = 1.0f / sqrtf(static_cast(head_dim)); @@ -227,7 +230,7 @@ static bool phi3_model_eval_internal(model_context* ctx, const model_input* inpu ne_scale_inplace(ctx0, KQ, ne_new_f32(ctx0, 1.0f / sqrt(static_cast((n_embd) / n_head)))); // KQ_masked = mask_past(KQ_scaled) - struct ne_tensor* KQ_masked = ne_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past); + struct ne_tensor* KQ_masked = ne_diag_mask_inf(ctx0, KQ_scaled, n_past); // KQ = soft_max(KQ_masked) struct ne_tensor* KQ_soft_max = ne_soft_max_inplace(ctx0, KQ_masked); @@ -302,12 +305,17 @@ static bool phi3_model_eval_internal(model_context* ctx, const model_input* inpu cur = ne_rms_norm(ctx0, cur, hparams.norm_eps); cur = ne_mul(ctx0, cur, model.layers[il].norm[1]); } - struct ne_tensor* ffn_gate = ne_view_2d(ctx0,model.layers[il].ffn[1],n_embd, 8192, model.layers[il].ffn[1]->nb[1], 0 * sizeof(float) * n_embd * 8192); - struct ne_tensor* ffn_up = ne_view_2d(ctx0,model.layers[il].ffn[1],n_embd, 8192, model.layers[il].ffn[1]->nb[1], 1 * sizeof(float) * n_embd * 8192); + // size_t weight_size=ne_element_size(model.layers[il].ffn[1]); + // struct ne_tensor* ffn_gate = ne_cont(ctx0,ne_view_2d(ctx0,model.layers[il].ffn[1],n_embd, 8192, + // model.layers[il].ffn[1]->nb[1], 0 * weight_size * n_embd * 8192)); struct ne_tensor* ffn_up = + // ne_cont(ctx0,ne_view_2d(ctx0,model.layers[il].ffn[1],n_embd, 8192, model.layers[il].ffn[1]->nb[1], 1 * + // weight_size * n_embd * 8192)); { - struct ne_tensor* cur1 = ne_mul_mat(ctx0, ffn_gate, cur); - struct ne_tensor* cur2 = ne_mul_mat(ctx0, ffn_up, cur); - cur = ne_mul(ctx0, cur2, ne_silu(ctx0,cur1)); + struct ne_tensor* cur1 = ne_mul_mat(ctx0, model.layers[il].ffn[1], cur); + struct ne_tensor* cur_gate = ne_cont(ctx0, ne_view_2d(ctx0, cur1, cur1->ne[0] / 2, cur1->ne[1], cur1->nb[1], 0)); + struct ne_tensor* cur_up = + ne_cont(ctx0, ne_view_2d(ctx0, cur1, cur1->ne[0] / 2, cur1->ne[1], cur1->nb[1], cur1->nb[1] / 2)); + cur = ne_mul(ctx0, cur_up, ne_silu(ctx0, cur_gate)); cur = ne_mul_mat(ctx0, model.layers[il].ffn[0], cur); } diff --git a/neural_speed/models/phi/phi3_utils.cpp b/neural_speed/models/phi/phi3_utils.cpp index 052b3c647..5820fe8a7 100644 --- a/neural_speed/models/phi/phi3_utils.cpp +++ b/neural_speed/models/phi/phi3_utils.cpp @@ -49,7 +49,7 @@ void model_load_internal(const std::string& fname, model_archs arch, model_conte } void phi3::init(const char* path_model, model_context* ctx, int n_gpu_layer_, bool use_mmap_, bool use_mlock_, - bool vocab_only_) { + bool vocab_only_) { model_context& lctx = *ctx; n_gpu_layer = n_gpu_layer_; use_mmap = use_mmap_; @@ -131,16 +131,16 @@ void phi3::load(model_context* ctx, model_progress_callback progress_callback, v layer.attn[1] = ml->get_tensor(layers_i + ".attn_output.weight", {n_embd, n_embd}, backend); // ffn norm // layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - // ffn GEMM + // ffn GEMM layer.ffn[0] = ml->get_tensor(layers_i + ".ffn_down.weight", {8192, n_embd}, backend); layer.ffn[1] = ml->get_tensor(layers_i + ".ffn_up.weight", {n_embd, 2 * 8192}, backend); // ffn GEMM if (backend != NE_BACKEND_CPU) { - vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) - + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]); + vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]); } } - }else{ // ns_bin + } else { // ns_bin model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); model.others[1] = ml->get_tensor("model.norm.weight", {n_embd}, NE_BACKEND_CPU); model.others[2] = ml->get_tensor("lm_head.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); @@ -158,13 +158,13 @@ void phi3::load(model_context* ctx, model_progress_callback progress_callback, v layer.attn[1] = ml->get_tensor(layers_i + ".self_attn.o_proj.weight", {n_embd, n_embd}, backend); // ffn norm // layer.norm[1] = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend); - // ffn GEMM + // ffn GEMM layer.ffn[0] = ml->get_tensor(layers_i + ".mlp.down_proj.weight", {8192, n_embd}, backend); layer.ffn[1] = ml->get_tensor(layers_i + ".mlp.gate_up_proj.weight", {n_embd, 2 * 8192}, backend); // ffn GEMM if (backend != NE_BACKEND_CPU) { - vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) - + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]); + vram_total += ne_nbytes(layer.norm[0]) + ne_nbytes(layer.attn[0]) + ne_nbytes(layer.attn[1]) + + ne_nbytes(layer.norm[1]) + ne_nbytes(layer.ffn[0]) + ne_nbytes(layer.ffn[1]); } } } @@ -198,9 +198,9 @@ class phi3_quant_layer : public quant_layer_base { public: quant_params_internal get_layer_config(std::string layername, std::vector ne, ne_type type) override { bool quantize = layername.rfind("weight") == layername.size() - 6; // ends with 'weight'? - if (layername == "model.embed_tokens.weight") { + if (layername == "model.embed_tokens.weight" || layername == "output.weight" || layername == "lm_head.weight") { // special layer process, can be loaded by config file - return quant_params_internal(); // return q4_0 to cover the usage of getrow + return quant_params_internal{quant_bits::count}; // return q4_0 to cover the usage of getrow } quantize &= (ne.size() == 2); if (quantize) { diff --git a/neural_speed/models/phi/phi_utils.cpp b/neural_speed/models/phi/phi_utils.cpp index c864c94dd..eee50b252 100644 --- a/neural_speed/models/phi/phi_utils.cpp +++ b/neural_speed/models/phi/phi_utils.cpp @@ -111,7 +111,7 @@ void phi::load(model_context* ctx, model_progress_callback progress_callback, vo // PHI is set up so that if padding_idx is specified then offset the embedding ids by 2 // and adjust num_embeddings appropriately. Other models don't have this hack - + model.others[0] = ml->get_tensor("model.embed_tokens.weight", {n_embd, n_vocab}, NE_BACKEND_CPU); model.others[1] = ml->get_tensor("model.final_layernorm.weight", {n_embd}, NE_BACKEND_CPU); model.others[2] = ml->get_tensor("model.final_layernorm.bias", {n_embd}, NE_BACKEND_CPU);