Skip to content
This repository has been archived by the owner on Aug 30, 2024. It is now read-only.

Commit

Permalink
enable phi3
Browse files Browse the repository at this point in the history
Signed-off-by: intellinjun <[email protected]>
  • Loading branch information
intellinjun committed May 8, 2024
1 parent a47533a commit 5d58dec
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 123 deletions.
175 changes: 106 additions & 69 deletions neural_speed/core/ne_layers.c
Original file line number Diff line number Diff line change
Expand Up @@ -8965,56 +8965,24 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
const int64_t mode = ((int32_t*)src1->data)[ROPE_MODE_IDX];
const int64_t prompt_size = ((int32_t*)src1->data)[ROPE_PROMPTSIZE_IDX];
const int64_t n_keep = ((int32_t*)src1->data)[ROPE_NKEEP_IDX];
const float longfactor[48] = {1.0299999713897705,1.0499999523162842,1.0499999523162842,1.0799999237060547,1.2299998998641968,1.2299998998641968,1.2999999523162842,1.4499999284744263,
1.5999999046325684,
1.6499998569488525,
1.8999998569488525,
2.859999895095825,
3.68999981880188,
5.419999599456787,
5.489999771118164,
5.489999771118164,
9.09000015258789,
11.579999923706055,
15.65999984741211,
15.769999504089355,
15.789999961853027,
18.360000610351562,
21.989999771118164,
23.079999923706055,
30.009998321533203,
32.35000228881836,
32.590003967285156,
35.56000518798828,
39.95000457763672,
53.840003967285156,
56.20000457763672,
57.95000457763672,
59.29000473022461,
59.77000427246094,
59.920005798339844,
61.190006256103516,
61.96000671386719,
62.50000762939453,
63.3700065612793,
63.48000717163086,
63.48000717163086,
63.66000747680664,
63.850006103515625,
64.08000946044922,
64.760009765625,
64.80001068115234,
64.81001281738281,
64.81001281738281
};
const float shortfactor[48] = {1.04999995, 1.04999995, 1.04999995, 1.10000002, 1.10000002, 1.14999998,
1.20000005, 1.25000000, 1.29999995, 1.35000002, 1.50000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.04999995, 2.04999995, 2.04999995, 2.09999990,
2.09999990, 2.09999990, 2.15000010, 2.15000010, 2.34999990, 2.54999995,
2.59999990, 2.59999990, 2.75000000, 2.84999990, 2.84999990, 2.95000005};
const float longfactor[48] = {
1.0299999713897705, 1.0499999523162842, 1.0499999523162842, 1.0799999237060547, 1.2299998998641968,
1.2299998998641968, 1.2999999523162842, 1.4499999284744263, 1.5999999046325684, 1.6499998569488525,
1.8999998569488525, 2.859999895095825, 3.68999981880188, 5.419999599456787, 5.489999771118164,
5.489999771118164, 9.09000015258789, 11.579999923706055, 15.65999984741211, 15.769999504089355,
15.789999961853027, 18.360000610351562, 21.989999771118164, 23.079999923706055, 30.009998321533203,
32.35000228881836, 32.590003967285156, 35.56000518798828, 39.95000457763672, 53.840003967285156,
56.20000457763672, 57.95000457763672, 59.29000473022461, 59.77000427246094, 59.920005798339844,
61.190006256103516, 61.96000671386719, 62.50000762939453, 63.3700065612793, 63.48000717163086,
63.48000717163086, 63.66000747680664, 63.850006103515625, 64.08000946044922, 64.760009765625,
64.80001068115234, 64.81001281738281, 64.81001281738281};
const float shortfactor[48] = {1.04999995, 1.04999995, 1.04999995, 1.10000002, 1.10000002, 1.14999998, 1.20000005,
1.25000000, 1.29999995, 1.35000002, 1.50000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.04999995, 2.04999995, 2.04999995,
2.09999990, 2.09999990, 2.09999990, 2.15000010, 2.15000010, 2.34999990, 2.54999995,
2.59999990, 2.59999990, 2.75000000, 2.84999990, 2.84999990, 2.95000005};
assert(n_past >= 0);

NE_TENSOR_UNARY_OP_LOCALS;
Expand Down Expand Up @@ -9047,8 +9015,8 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
const bool skip = mode & 1;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const bool is_phi_short = mode==16? true : false;
const bool is_phi_long = mode==17? true : false;
const bool is_phi_short = mode == 16 ? true : false;
const bool is_phi_long = mode == 17 ? true : false;
const bool is_shift = n_keep >= 0;
const bool use_yarn = ((mode & 0x8) != 0);
NE_ASSERT(("RoPE shift not supported!", !is_shift));
Expand Down Expand Up @@ -9091,7 +9059,7 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
dst_data[n_dims] = x2 * cos_block_theta - x3 * sin_block_theta;
dst_data[n_dims / 2 * 3] = x2 * sin_block_theta + x3 * cos_block_theta;
}
} else if(is_phi_short){
} else if (is_phi_short) {
// TODO: this is probably wrong, but I can't figure it out ..
// ref:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
Expand All @@ -9103,11 +9071,11 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
float cur_rot = inv_ndims * ic - ib;

float cos_theta, sin_theta;
float tmp_theta_base=theta_base/shortfactor[ic/2];
float tmp_theta_base = theta_base / shortfactor[ic / 2];
rope_yarn(tmp_theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta,
&sin_theta);
cos_theta *=scale_factor;
sin_theta *=scale_factor;
cos_theta *= scale_factor;
sin_theta *= scale_factor;
theta_base *= theta_scale;

const int64_t i0 = ib * n_dims + ic / 2;
Expand All @@ -9120,10 +9088,9 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,

dst_data[0] = x0 * cos_theta - x1 * sin_theta;
dst_data[n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
}
}
}
else if(is_phi_long){
}
}
} else if (is_phi_long) {
// TODO: this is probably wrong, but I can't figure it out ..
// ref:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
Expand All @@ -9136,11 +9103,11 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
float cur_rot = inv_ndims * ic - ib;

float cos_theta, sin_theta;
float tmp_theta_base=theta_base / longfactor[ic/2];
rope_yarn(tmp_theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta,
float tmp_theta_base = theta_base / longfactor[ic / 2];
rope_yarn(theta_base, freq_scale, corr_dims, (int)cur_rot, ext_factor, attn_factor, &cos_theta,
&sin_theta);
cos_theta *=scale_factor;
sin_theta *=scale_factor;
cos_theta *= scale_factor;
sin_theta *= scale_factor;
theta_base *= theta_scale;

const int64_t i0 = ib * n_dims + ic / 2;
Expand All @@ -9153,9 +9120,9 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,

dst_data[0] = x0 * cos_theta - x1 * sin_theta;
dst_data[n_dims / 2] = x0 * sin_theta + x1 * cos_theta;
}
}
}else if (!is_neox) {
}
}
} else if (!is_neox) {
// printf("theta_base = %ld, freq_scale %.4f, ne0 %d\n", p, freq_scale, ne0);
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
float cos_theta, sin_theta;
Expand All @@ -9177,7 +9144,7 @@ static void ne_compute_forward_rope_f32(const struct ne_compute_params* params,
// ref:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
theta_base = theta_base * freq_scale;

for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
// simplified from `(ib * n_dims + ic) * inv_ndims`
Expand Down Expand Up @@ -9237,6 +9204,24 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
const size_t nb1 = dst->nb[1];
const size_t nb2 = dst->nb[2];
const size_t nb3 = dst->nb[3];
const float longfactor[48] = {
1.0299999713897705, 1.0499999523162842, 1.0499999523162842, 1.0799999237060547, 1.2299998998641968,
1.2299998998641968, 1.2999999523162842, 1.4499999284744263, 1.5999999046325684, 1.6499998569488525,
1.8999998569488525, 2.859999895095825, 3.68999981880188, 5.419999599456787, 5.489999771118164,
5.489999771118164, 9.09000015258789, 11.579999923706055, 15.65999984741211, 15.769999504089355,
15.789999961853027, 18.360000610351562, 21.989999771118164, 23.079999923706055, 30.009998321533203,
32.35000228881836, 32.590003967285156, 35.56000518798828, 39.95000457763672, 53.840003967285156,
56.20000457763672, 57.95000457763672, 59.29000473022461, 59.77000427246094, 59.920005798339844,
61.190006256103516, 61.96000671386719, 62.50000762939453, 63.3700065612793, 63.48000717163086,
63.48000717163086, 63.66000747680664, 63.850006103515625, 64.08000946044922, 64.760009765625,
64.80001068115234, 64.81001281738281, 64.81001281738281};
const float shortfactor[48] = {1.04999995, 1.04999995, 1.04999995, 1.10000002, 1.10000002, 1.14999998, 1.20000005,
1.25000000, 1.29999995, 1.35000002, 1.50000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.00000000,
2.00000000, 2.00000000, 2.00000000, 2.00000000, 2.04999995, 2.04999995, 2.04999995,
2.09999990, 2.09999990, 2.09999990, 2.15000010, 2.15000010, 2.34999990, 2.54999995,
2.59999990, 2.59999990, 2.75000000, 2.84999990, 2.84999990, 2.95000005};

NE_ASSERT(nb0 == sizeof(ne_fp16_t));

Expand Down Expand Up @@ -9266,6 +9251,8 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
const bool skip = mode & 1;
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
const bool is_phi_short = mode == 16 ? true : false;
const bool is_phi_long = mode == 17 ? true : false;
NE_ASSERT(("glm mode RoPE is not implemented!", !is_glm));
const bool is_shift = n_keep >= 0;
NE_ASSERT(("shift RoPE is only implemented for the vanilla mode", !is_shift || !(is_glm || is_neox || skip)));
Expand Down Expand Up @@ -9317,8 +9304,58 @@ static void ne_compute_forward_rope_f16(const struct ne_compute_params* params,
if (ir > ir1) break;

float theta = freq_scale * (float)p;
float scale_factor = 1.1902380714238083;
if (is_phi_short) {
// TODO: this is probably wrong, but I can't figure it out ..
// ref:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
float tmp_theta = theta / shortfactor[ic / 2];
const float cos_theta = scale_factor * cosf(tmp_theta);
const float sin_theta = scale_factor * sinf(tmp_theta);

if (!is_neox) {
theta *= theta_scale;

const int64_t i0 = ib * n_dims + ic / 2;

const ne_fp16_t* const src =
(ne_fp16_t*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
ne_fp16_t* dst_data = (ne_fp16_t*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);

const float x0 = NE_FP16_TO_FP32(src[0]);
const float x1 = NE_FP16_TO_FP32(src[n_dims / 2]);

dst_data[0] = NE_FP32_TO_FP16(x0 * cos_theta - x1 * sin_theta);
dst_data[n_dims / 2] = NE_FP32_TO_FP16(x0 * sin_theta + x1 * cos_theta);
}
}
} else if (is_phi_long) {
// TODO: this is probably wrong, but I can't figure it out ..
// ref:
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py#LL251C1-L294C28
for (int64_t ib = 0; ib < ne0 / n_dims; ++ib) {
for (int64_t ic = 0; ic < n_dims; ic += 2) {
float tmp_theta = theta / longfactor[ic / 2];
const float cos_theta = scale_factor * cosf(tmp_theta);
const float sin_theta = scale_factor * sinf(tmp_theta);

theta *= theta_scale;

const int64_t i0 = ib * n_dims + ic / 2;

const ne_fp16_t* const src =
(ne_fp16_t*)((char*)src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01 + i0 * nb00);
ne_fp16_t* dst_data = (ne_fp16_t*)((char*)dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1 + i0 * nb0);

const float x0 = NE_FP16_TO_FP32(src[0]);
const float x1 = NE_FP16_TO_FP32(src[n_dims / 2]);

dst_data[0] = NE_FP32_TO_FP16(x0 * cos_theta - x1 * sin_theta);
dst_data[n_dims / 2] = NE_FP32_TO_FP16(x0 * sin_theta + x1 * cos_theta);
}
}
} else if (!is_neox) {
for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
const float cos_theta = cosf(theta);
const float sin_theta = sinf(theta);
Expand Down
44 changes: 23 additions & 21 deletions neural_speed/models/model_utils/gguf.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,27 +239,29 @@ enum llm_arch {
LLM_ARCH_PHI3,
};

static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {{LLM_ARCH_LLAMA, "llama"},
{LLM_ARCH_FALCON, "falcon"},
{LLM_ARCH_GPT2, "gpt2"},
{LLM_ARCH_GPTJ, "gptj"},
{LLM_ARCH_GPTNEOX, "gptneox"},
{LLM_ARCH_MPT, "mpt"},
{LLM_ARCH_BAICHUAN, "baichuan"},
{LLM_ARCH_STARCODER, "starcoder"},
{LLM_ARCH_PERSIMMON, "persimmon"},
{LLM_ARCH_REFACT, "refact"},
{LLM_ARCH_BLOOM, "bloom"},
{LLM_ARCH_STABLELM, "stablelm"},
{LLM_ARCH_QWEN, "qwen"},
{LLM_ARCH_CHATGLM, "chatglm"},
{LLM_ARCH_CHATGLM2, "chatglm2"},
{LLM_ARCH_CHATGLM3, "chatglm3"},
{LLM_ARCH_PHI, "phi"},
{LLM_ARCH_GEMMA, "gemma"},
{LLM_ARCH_QWEN2, "qwen2"},
{LLM_ARCH_GROK, "grok"},
{LLM_ARCH_PHI3, "phi3"},};
static std::map<llm_arch, std::string> LLM_ARCH_NAMES = {
{LLM_ARCH_LLAMA, "llama"},
{LLM_ARCH_FALCON, "falcon"},
{LLM_ARCH_GPT2, "gpt2"},
{LLM_ARCH_GPTJ, "gptj"},
{LLM_ARCH_GPTNEOX, "gptneox"},
{LLM_ARCH_MPT, "mpt"},
{LLM_ARCH_BAICHUAN, "baichuan"},
{LLM_ARCH_STARCODER, "starcoder"},
{LLM_ARCH_PERSIMMON, "persimmon"},
{LLM_ARCH_REFACT, "refact"},
{LLM_ARCH_BLOOM, "bloom"},
{LLM_ARCH_STABLELM, "stablelm"},
{LLM_ARCH_QWEN, "qwen"},
{LLM_ARCH_CHATGLM, "chatglm"},
{LLM_ARCH_CHATGLM2, "chatglm2"},
{LLM_ARCH_CHATGLM3, "chatglm3"},
{LLM_ARCH_PHI, "phi"},
{LLM_ARCH_GEMMA, "gemma"},
{LLM_ARCH_QWEN2, "qwen2"},
{LLM_ARCH_GROK, "grok"},
{LLM_ARCH_PHI3, "phi3"},
};

struct gguf_tensor_info {
struct gguf_str name;
Expand Down
2 changes: 1 addition & 1 deletion neural_speed/models/model_utils/model_files.h
Original file line number Diff line number Diff line change
Expand Up @@ -1423,7 +1423,7 @@ struct model_model_loader {
}

struct ne_tensor* get_tensor(const std::string& name, const std::vector<uint32_t>& ne, ne_backend backend) {
auto it = tensors_map.name_to_idx.find(name);
auto it = tensors_map.name_to_idx.find(name);
if (it == tensors_map.name_to_idx.end()) {
throw format("%s: tensor '%s' is missing from model", __func__, name.c_str());
}
Expand Down
Loading

0 comments on commit 5d58dec

Please sign in to comment.