Skip to content

Commit

Permalink
完成minicpm3的graph形式
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Dec 9, 2024
1 parent f711b32 commit 5e27fb7
Show file tree
Hide file tree
Showing 8 changed files with 183 additions and 4 deletions.
4 changes: 4 additions & 0 deletions include/devices/cuda/cudadevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaRepeatOp : CpuRepeatOp {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaCatOp : CpuCatOp {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand Down
1 change: 1 addition & 0 deletions include/devices/cuda/fastllm-cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void FastllmCudaMemcpy2DDeviceToDevice(void * dst, size_t dpitch, const void *
void FastllmCudaMemcpy2DDeviceToDeviceBatch(void ** dsts, size_t * dpitchs, void ** srcs,
size_t * spitchs, size_t *widths, size_t * heights,
int batch);
void FastllmCudaRepeat(void *input, void *output, int outer, int repeatTimes, int inputStride, int outputStride0, int outputStride1, int copyLen);

bool FastllmFloatToHalf(void *a, void *b, int len);
bool FastllmHalfToFloat(void *a, void *b, int len);
Expand Down
3 changes: 3 additions & 0 deletions include/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ namespace fastllm {
void Update();

void AddTo(ComputeGraphNode &input0, ComputeGraphNode &input1, float alpha = 1.0); // input0 += input1 * alpha
void Cat(ComputeGraphNode &input0, ComputeGraphNode &input1, int axis, ComputeGraphNode &output);
void DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1); // 将input的dataType设成和input1一样
void Embedding(ComputeGraphNode &input, ComputeGraphNode &weight, ComputeGraphNode &output);
void ExpandHead(ComputeGraphNode &input, int headDim);
Expand All @@ -62,7 +63,9 @@ namespace fastllm {
float scale, int maskType, int unitLen); // 融合的attention
void Linear(ComputeGraphNode &input, ComputeGraphNode &weight, ComputeGraphNode &bias, ComputeGraphNode &output);
void LlamaRotatePosition2D(ComputeGraphNode &input, ComputeGraphNode &positionIds, ComputeGraphNode &sinData, ComputeGraphNode &cosData, int rotaryDim); // 2D position for llama
void Mul(ComputeGraphNode &input, float v, ComputeGraphNode &output); // output = input * v
void MulTo(ComputeGraphNode &input0, ComputeGraphNode &input1); // input0 *= input1
void Repeat(ComputeGraphNode &input, int axis, int repeatTimes, ComputeGraphNode &output);
void RMSNorm(ComputeGraphNode &input, ComputeGraphNode &weight, float eps, ComputeGraphNode &output);
void Silu(ComputeGraphNode &input, ComputeGraphNode &output);
void Split(ComputeGraphNode &input, int axis, int start, int end, ComputeGraphNode &output);
Expand Down
22 changes: 22 additions & 0 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace fastllm {
this->ops["Linear"] = (BaseOperator*)(new CudaLinearOp());
this->ops["Conv2D"] = (BaseOperator*)(new CudaConv2DOp());
this->ops["Split"] = (BaseOperator*)(new CudaSplitOp());
this->ops["Repeat"] = (BaseOperator*)(new CudaRepeatOp());
this->ops["Cat"] = (BaseOperator*)(new CudaCatOp());
this->ops["CatDirect"] = (BaseOperator*)(new CudaCatDirectOp());
this->ops["MatMul"] = (BaseOperator*)(new CudaMatMulOp());
Expand Down Expand Up @@ -401,6 +402,27 @@ namespace fastllm {
(end - start) * inner * unitSize, outer);
}

void CudaRepeatOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
int axis = intParams.find("axis") != intParams.end() ? intParams.find("axis")->second : -1;
int repeatTimes = intParams.find("repeatTimes") != intParams.end() ? intParams.find("repeatTimes")->second : 1;
int dimsLen = input.dims.size();
axis = (axis % dimsLen + dimsLen) % dimsLen;

output.Allocate();

int outer = output.Count(0) / output.Count(axis);
int inputStride = input.Count(axis);
int outputStride = output.Count(axis);
int channels = input.dims[axis];
int inner = input.strides[axis];
int unitSize = input.unitSize;

FastllmCudaRepeat(input.cudaData, output.cudaData, outer, repeatTimes, inputStride * unitSize, outputStride * unitSize, channels * inner * unitSize, channels * inner * unitSize);
}

void CudaCatOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input0 = *(datas.find("input0")->second);
Expand Down
15 changes: 15 additions & 0 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3165,6 +3165,21 @@ __global__ void FastllmMemcpyBatchKernel (uint8_t** pointer) {
}
}

template <int THREAD_PER_BLOCK>
__global__ void FastllmRepeatKernel (void *inputOri, void *outputOri, int outer, int repeatTimes, int inputStride, int outputStride0, int outputStride1, int copyLen) {
int id = blockIdx.x;
int i = id / repeatTimes, j = id % repeatTimes;
uint8_t *output = (uint8_t*)outputOri + i * outputStride0 + j * outputStride1;
uint8_t *input = (uint8_t*)inputOri + i * inputStride;
for (int x = threadIdx.x; x < copyLen; x += THREAD_PER_BLOCK) {
output[x] = input[x];
}
}

void FastllmCudaRepeat(void *input, void *output, int outer, int repeatTimes, int inputStride, int outputStride0, int outputStride1, int copyLen) {
FastllmRepeatKernel <256> <<< outer * repeatTimes, 256 >>> (input, output, outer, repeatTimes, inputStride, outputStride0, outputStride1, copyLen);
}

void FastllmCudaMemcpy2DDeviceToDeviceBatch(void ** dsts, size_t * dpitchs, void ** srcs,
size_t * spitchs, size_t *widths, size_t * heights,
int batch) {
Expand Down
28 changes: 26 additions & 2 deletions src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ namespace fastllm {
CatDirectBatch(pastKeys[layerId], pointersK, 1);
CatDirectBatch(pastValues[layerId], pointersV, 1);

int q0 = q.dims[2], k0 = k.dims[2], dims = q.dims[3];
int q0 = q.dims[2], k0 = k.dims[2], dims = q.dims[3], vdims = v.dims[3];
q.Reshape({batch, maxLen, q0, dims});
PermuteSelf(q, {0, 2, 1, 3});
q.Reshape({batch * q0, maxLen, -1});
Expand All @@ -415,7 +415,7 @@ namespace fastllm {
PermuteSelf(k, {0, 2, 1, 3});
k.Reshape({batch * k0, maxLen, -1});

v.Reshape({batch, maxLen, k0, dims});
v.Reshape({batch, maxLen, k0, vdims});
PermuteSelf(v, {0, 2, 1, 3});
v.Reshape({batch * k0, maxLen, -1});

Expand Down Expand Up @@ -693,6 +693,14 @@ namespace fastllm {
);
}

void ComputeGraph::Cat(ComputeGraphNode &input0, ComputeGraphNode &input1, int axis, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Cat",
{{"input0", input0.name}, {"input1", input1.name}, {"output", output.name}},
{}, {{"axis", axis}})
);
}

void ComputeGraph::DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1) {
this->ops.push_back (
ComputeGraphOp("DataTypeAs",
Expand All @@ -701,6 +709,14 @@ namespace fastllm {
);
}

void ComputeGraph::Mul(ComputeGraphNode &input, float v, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Mul",
{{"input", input.name}, {"output", output.name}},
{{"v", v}}, {})
);
}

void ComputeGraph::MulTo(ComputeGraphNode &input0, ComputeGraphNode &input1) {
this->ops.push_back (
ComputeGraphOp("MulTo",
Expand All @@ -709,6 +725,14 @@ namespace fastllm {
);
}

void ComputeGraph::Repeat(ComputeGraphNode &input, int axis, int repeatTimes, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Repeat",
{{"input", input.name}, {"output", output.name}},
{}, {{"axis", axis}, {"repeatTimes", repeatTimes}})
);
}

void ComputeGraph::Silu(ComputeGraphNode &input, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Silu",
Expand Down
2 changes: 0 additions & 2 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,6 @@ namespace fastllm {
model->model_type = "phi3";
} else if (modelType=="minicpm") {
model = new MiniCpmModel();
} else if (modelType=="minicpm3") {
model = new MiniCpm3Model();
} else if (modelType == "qwen") {
model = (basellm *) (new QWenModel());
model->weight.tokenizer.type = Tokenizer::TokenizerType::QWEN;
Expand Down
112 changes: 112 additions & 0 deletions src/models/graph/minicpm3.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include "graphllm.h"

namespace fastllm {
class Minicpm3GraphModelConfig : GraphLLMModelConfig {
public:
void InitParams(GraphLLMModel *model) {
model->rotary_dim = atoi(model->weight.dicts["qk_rope_head_dim"].c_str());
}

std::map <std::string, std::vector <std::pair <std::string, DataType> > >
GetTensorMap(GraphLLMModel *model, const std::vector <std::string> &tensorNames) {
std::map <std::string, std::vector <std::pair <std::string, DataType> > > ret;
std::string embeddingName = "model.embed_tokens.weight";
std::string logitsName = "lm_head.weight";
std::set <std::string> linearNames = {
".self_attn.q_a_proj.weight",
".self_attn.q_b_proj.weight",
".self_attn.kv_a_proj_with_mqa.weight",
".self_attn.kv_b_proj.weight",
".self_attn.o_proj.weight",
".mlp.gate_proj.weight",
".mlp.up_proj.weight",
".mlp.down_proj.weight"
};
ret[embeddingName].push_back(std::make_pair(embeddingName, DataType::DATA_AUTO_EMBEDDING));
for (int i = 0; i < model->block_cnt; i++) {
std::string pre = "model.layers." + std::to_string(i);
for (auto &it : linearNames) {
ret[pre + it].push_back(std::make_pair(pre + it, DataType::DATA_AUTO_LINEAR));
}
}
for (auto &name : tensorNames) {
if (ret[name].size() == 0) {
ret[name].push_back(std::make_pair(name, DataType::DATA_AUTO_NONE));
}
}
if (ret.find(logitsName) == ret.end()) {
ret[embeddingName].push_back(std::make_pair(logitsName, DataType::DATA_AUTO_LINEAR));
} else {
ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR;
}
return ret;
}

void BuildGraph(GraphLLMModel *model) {
int qk_nope_head_dim = atoi(model->weight.dicts["qk_nope_head_dim"].c_str());
int qk_rope_head_dim = atoi(model->weight.dicts["qk_rope_head_dim"].c_str());
int kv_lora_rank = atoi(model->weight.dicts["kv_lora_rank"].c_str());
int v_head_dim = atoi(model->weight.dicts["v_head_dim"].c_str());
float scale_depth = atof(model->weight.dicts["scale_depth"].c_str());
float attention_scale = scale_depth / std::sqrt(model->block_cnt);
int dim_model_base = atoi(model->weight.dicts["dim_model_base"].c_str());
float rms_scale = 1.f / (model->embed_dim / dim_model_base);

auto &graph = *(model->GetGraph());
std::map <std::string, ComputeGraphNode> wNodes;
for (auto &it : model->weight.weight) {
wNodes[it.first] = ComputeGraphNode(it.first);
}
ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos"), seqLens("seqLens");
ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput");
ComputeGraphNode qa("qa"), qa_norm("qa_norm"), qb("qb"), q_nope("q_nope"), q_rope("q_rope"), qkv("qkv"), kva("kva"), kv_norm("kv_norm"), kvb("kvb"), compressed_kv("compressed_kv"), k_rope("k_rope"), k_nope("k_nope"), q("q"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits");
ComputeGraphNode k_rope_expand("k_rope_expand");
graph.Embedding(inputIds, wNodes["model.embed_tokens.weight"], hiddenStates);
graph.Mul(hiddenStates, atof(model->weight.dicts["scale_emb"].c_str()), hiddenStates);
graph.DataTypeAs(hiddenStates, atype);
for (int i = 0; i < model->block_cnt; i++) {
std::string pre = "model.layers." + std::to_string(i);
ComputeGraphNode pastKey("pastKey." + std::to_string(i)), pastValue("pastValue." + std::to_string(i));
graph.RMSNorm(hiddenStates, wNodes[pre + ".input_layernorm.weight"], model->rms_norm_eps, attenInput);
graph.Linear(attenInput, wNodes[pre + ".self_attn.q_a_proj.weight"], wNodes[pre + ".self_attn.q_a_proj.bias"], qa);
graph.RMSNorm(qa, wNodes[pre + ".self_attn.q_a_layernorm.weight"], model->rms_norm_eps, qa_norm);
graph.Linear(qa_norm, wNodes[pre + ".self_attn.q_b_proj.weight"], wNodes[pre + ".self_attn.q_b_proj.bias"], qb);
graph.ExpandHead(qb, qk_nope_head_dim + qk_rope_head_dim);
graph.Split(qb, -1, 0, qk_nope_head_dim, q_nope);
graph.Split(qb, -1, qk_nope_head_dim, qk_nope_head_dim + qk_rope_head_dim, q_rope);
graph.Linear(attenInput, wNodes[pre + ".self_attn.kv_a_proj_with_mqa.weight"], wNodes[pre + ".self_attn.kv_a_proj_with_mqa.bias"], kva);
graph.Split(kva, -1, 0, kv_lora_rank, compressed_kv);
graph.Split(kva, -1, kv_lora_rank, kv_lora_rank + qk_rope_head_dim, k_rope);
graph.ExpandHead(k_rope, qk_rope_head_dim);
graph.RMSNorm(compressed_kv, wNodes[pre + ".self_attn.kv_a_layernorm.weight"], model->rms_norm_eps, kv_norm);
graph.Linear(kv_norm, wNodes[pre + ".self_attn.kv_b_proj.weight"], wNodes[pre + ".self_attn.kv_b_proj.bias"], kvb);
graph.ExpandHead(kvb, qk_nope_head_dim + v_head_dim);
graph.Split(kvb, -1, 0, qk_nope_head_dim, k_nope);
graph.Split(kvb, -1, qk_nope_head_dim, qk_nope_head_dim + v_head_dim, v);
graph.LlamaRotatePosition2D(q_rope, positionIds, sin, cos, model->rotary_dim);
graph.LlamaRotatePosition2D(k_rope, positionIds, sin, cos, model->rotary_dim);
graph.Cat(q_nope, q_rope, -1, q);
graph.Repeat(k_rope, 2, model->num_attention_heads, k_rope_expand);
graph.Cat(k_nope, k_rope_expand, -1, k);
graph.FusedAttention(q, pastKey, pastValue, k, v, attenInput, attentionMask, attenOutput, seqLens, 1.0 / sqrt(v_head_dim), 0, 128);
graph.Linear(attenOutput, wNodes[pre + ".self_attn.o_proj.weight"], wNodes[pre + ".self_attn.o_proj.weight"], attenLastOutput);
graph.AddTo(hiddenStates, attenLastOutput, attention_scale);
graph.RMSNorm(hiddenStates, wNodes[pre + ".post_attention_layernorm.weight"], model->rms_norm_eps, attenInput);
graph.Linear(attenInput, wNodes[pre + ".mlp.gate_proj.weight"], wNodes[pre + ".mlp.gate_proj.bias"], w1);
graph.Linear(attenInput, wNodes[pre + ".mlp.up_proj.weight"], wNodes[pre + ".mlp.up_proj.bias"], w3);
graph.Silu(w1, w1);
graph.MulTo(w1, w3);
graph.Linear(w1, wNodes[pre + ".mlp.down_proj.weight"], wNodes[pre + ".mlp.down_proj.bias"], w2);
graph.AddTo(hiddenStates, w2, attention_scale);
}

graph.SplitLastTokenStates(hiddenStates, seqLens, lastTokensStates);
graph.RMSNorm(lastTokensStates, wNodes["model.norm.weight"], model->rms_norm_eps, lastTokensStates);
graph.Mul(lastTokensStates, rms_scale, lastTokensStates);
graph.Linear(lastTokensStates, wNodes["lm_head.weight"], wNodes["lm_head.bias"], logits);
OptimizeComputeGraph(graph, model->weight);
graph.Update();
}
};
REGISTERGRAPHMODELCONFIG(minicpm3, Minicpm3GraphModelConfig)
}

0 comments on commit 5e27fb7

Please sign in to comment.