Skip to content

Commit

Permalink
支持gemma2
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Dec 11, 2024
1 parent a51dea3 commit 60871ce
Show file tree
Hide file tree
Showing 9 changed files with 199 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/devices/cpu/cpudevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CpuAddOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CpuAddToOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand Down
4 changes: 4 additions & 0 deletions include/devices/cuda/cudadevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaAddOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaMulOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand Down
1 change: 1 addition & 0 deletions include/devices/cuda/fastllm-cuda.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output);\
bool FastllmCudaGelu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output);
bool FastllmCudaAdd(const fastllm::Data &input, float v, fastllm::Data &output);
bool FastllmCudaMul(const fastllm::Data &input, float v, fastllm::Data &output);
bool FastllmCudaSoftmax(const fastllm::Data &input, fastllm::Data &output, int axis);
bool FastllmCudaAddTo(fastllm::Data &input0, const fastllm::Data &input1, float alpha);
Expand Down
2 changes: 2 additions & 0 deletions include/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace fastllm {

void Update();

void Add(ComputeGraphNode &input, float v, ComputeGraphNode &output); // output = input + v
void AddTo(ComputeGraphNode &input0, ComputeGraphNode &input1, float alpha = 1.0); // input0 += input1 * alpha
void Cat(ComputeGraphNode &input0, ComputeGraphNode &input1, int axis, ComputeGraphNode &output);
void DataTypeAs(ComputeGraphNode &input, ComputeGraphNode &input1); // 将input的dataType设成和input1一样
Expand All @@ -61,6 +62,7 @@ namespace fastllm {
ComputeGraphNode &original, ComputeGraphNode &mask, ComputeGraphNode &output,
ComputeGraphNode &seqLens,
float scale, int maskType, int unitLen); // 融合的attention
void Gelu(ComputeGraphNode &input, ComputeGraphNode &output);
void Linear(ComputeGraphNode &input, ComputeGraphNode &weight, ComputeGraphNode &bias, ComputeGraphNode &output);
void LlamaRotatePosition2D(ComputeGraphNode &input, ComputeGraphNode &positionIds, ComputeGraphNode &sinData, ComputeGraphNode &cosData, int rotaryDim); // 2D position for llama
void Mul(ComputeGraphNode &input, float v, ComputeGraphNode &output); // output = input * v
Expand Down
28 changes: 28 additions & 0 deletions src/devices/cpu/cpudevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace fastllm {
this->ops["Swiglu"] = (BaseOperator*)(new CpuSwigluOp());
this->ops["Mul"] = (BaseOperator*)(new CpuMulOp());
this->ops["MulTo"] = (BaseOperator*)(new CpuMulToOp());
this->ops["Add"] = (BaseOperator*)(new CpuAddOp());
this->ops["AddTo"] = (BaseOperator*)(new CpuAddToOp());
this->ops["AttentionMask"] = (BaseOperator*)(new CpuAttentionMaskOp());
this->ops["AttentionExtendedMask"] = (BaseOperator*)(new CpuAttentionExtendedMaskOp());
Expand Down Expand Up @@ -3612,6 +3613,33 @@ namespace fastllm {
}
}

void CpuAddOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
output.Allocate();

float v = floatParams.find("v") != floatParams.end() ? floatParams.find("v")->second : 1.0;
AssertInFastLLM(input.dataType == DataType::FLOAT32 || input.dataType == DataType::FLOAT16,
"Add error: Data's type should be float32 or float16.\n");

int len = input.Count(0);

if (input.dataType == DataType::FLOAT32) {
float *inputData = (float *) input.cpuData;
float *outputData = (float *) output.cpuData;
for (int i = 0; i < len; i++) {
outputData[i] = inputData[i] + v;
}
} else if (input.dataType == DataType::FLOAT16) {
uint16_t *inputData = (uint16_t *) input.cpuData;
uint16_t *outputData = (uint16_t *) output.cpuData;
for (int i = 0; i < len; i++) {
outputData[i] = float_to_half(fp16tofp32.dict[inputData[i]] + v);
}
}
}

void CpuMulToOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input0 = *(datas.find("input0")->second);
Expand Down
14 changes: 14 additions & 0 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace fastllm {
this->ops["GeluNew"] = (BaseOperator*)(new CudaGeluNewOp());
this->ops["Silu"] = (BaseOperator*)(new CudaSiluOp());
this->ops["Swiglu"] = (BaseOperator*)(new CudaSwigluOp());
this->ops["Add"] = (BaseOperator*)(new CudaAddOp());
this->ops["Mul"] = (BaseOperator*)(new CudaMulOp());
this->ops["AddTo"] = (BaseOperator*)(new CudaAddToOp());
this->ops["MulTo"] = (BaseOperator*)(new CudaMulToOp());
Expand Down Expand Up @@ -700,6 +701,19 @@ namespace fastllm {
FastllmCudaSilu(input, output);
}

void CudaAddOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Data &output = *(datas.find("output")->second);
output.Allocate();

float v = floatParams.find("v") != floatParams.end() ? floatParams.find("v")->second : 1.0;
AssertInFastLLM(input.dataType == DataType::FLOAT32 ||
input.dataType == DataType::FLOAT16,
"Mul error: Data's type should be float32 or float16.\n");
FastllmCudaAdd(input, v, output);
}

void CudaMulOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input = *(datas.find("input")->second);
Expand Down
35 changes: 35 additions & 0 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,24 @@ __global__ void FastllmSwigluKernel(half* __restrict__ a, half* __restrict__ b,
}
}

__global__ void FastllmAddKernel(float* a, float *b, float v, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
b[idx] = a[idx] + v;
}
}

__global__ void FastllmAddKernel(half* a, half *b, half v, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
#ifdef CUDA_NO_TENSOR_CORE
b[idx] = __float2half(__half2float(a[idx]) + __half2float(v));
#else
b[idx] = __hadd(a[idx], v);
#endif
}
}

__global__ void FastllmMulKernel(float* a, float *b, float v, int len) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < len) {
Expand Down Expand Up @@ -3267,6 +3285,23 @@ bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output) {
return true;
}

bool FastllmCudaAdd(const fastllm::Data &input, float v, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int threadPerBlock = std::min(256, len);

if (input.dataType == fastllm::DataType::FLOAT32) {
FastllmAddKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, v, len);
} else {
FastllmAddKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>((half*)cudaInput, (half*)cudaOutput, __float2half_rn(v), len);
}

FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
return true;
}

bool FastllmCudaMul(const fastllm::Data &input, float v, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
Expand Down
16 changes: 16 additions & 0 deletions src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,14 @@ namespace fastllm {
);
}

void ComputeGraph::Add(ComputeGraphNode &input, float v, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Add",
{{"input", input.name}, {"output", output.name}},
{{"v", v}}, {})
);
}

void ComputeGraph::Mul(ComputeGraphNode &input, float v, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Mul",
Expand All @@ -733,6 +741,14 @@ namespace fastllm {
);
}

void ComputeGraph::Gelu(ComputeGraphNode &input, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Gelu",
{{"input", input.name}, {"output", output.name}},
{}, {})
);
}

void ComputeGraph::Silu(ComputeGraphNode &input, ComputeGraphNode &output) {
this->ops.push_back (
ComputeGraphOp("Silu",
Expand Down
95 changes: 95 additions & 0 deletions src/models/graph/gemma2.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "graphllm.h"

namespace fastllm {
class Gemma2GraphModelConfig : GraphLLMModelConfig {
public:
void InitParams(GraphLLMModel *model) {
model->rotary_dim = atoi(model->weight.dicts["head_dim"].c_str());
}

std::map <std::string, std::vector <std::pair <std::string, DataType> > >
GetTensorMap(GraphLLMModel *model, const std::vector <std::string> &tensorNames) {
std::map <std::string, std::vector <std::pair <std::string, DataType> > > ret;
std::string embeddingName = "model.embed_tokens.weight";
std::string logitsName = "lm_head.weight";
std::set <std::string> linearNames = {
".self_attn.q_proj.weight", ".self_attn.k_proj.weight", ".self_attn.v_proj.weight", ".self_attn.o_proj.weight",
".mlp.gate_proj.weight", ".mlp.up_proj.weight", ".mlp.down_proj.weight"
};
ret[embeddingName].push_back(std::make_pair(embeddingName, DataType::DATA_AUTO_EMBEDDING));
for (int i = 0; i < model->block_cnt; i++) {
std::string pre = "model.layers." + std::to_string(i);
for (auto &it : linearNames) {
ret[pre + it].push_back(std::make_pair(pre + it, DataType::DATA_AUTO_LINEAR));
}
}
for (auto &name : tensorNames) {
if (ret[name].size() == 0) {
ret[name].push_back(std::make_pair(name, DataType::DATA_AUTO_NONE));
}
}
if (ret.find(logitsName) == ret.end()) {
ret[embeddingName].push_back(std::make_pair(logitsName, DataType::DATA_AUTO_LINEAR));
} else {
ret[logitsName][0].second = DataType::DATA_AUTO_LINEAR;
}
return ret;
}

void BuildGraph(GraphLLMModel *model) {
int head_dim = atoi(model->weight.dicts["head_dim"].c_str());
int query_pre_attn_scalar = atoi(model->weight.dicts["query_pre_attn_scalar"].c_str());

auto &graph = *(model->GetGraph());
std::map <std::string, ComputeGraphNode> wNodes;
for (auto &it : model->weight.weight) {
wNodes[it.first] = ComputeGraphNode(it.first);
}
ComputeGraphNode inputIds("inputIds"), positionIds("positionIds"), attentionMask("attentionMask"), atype("atype"), sin("sin"), cos("cos"), seqLens("seqLens");
ComputeGraphNode hiddenStates("hiddenStates"), attenInput("attenInput"), attenOutput("attenOutput"), attenLastOutput("attenLastOutput");
ComputeGraphNode q("q"), k("k"), v("v"), w1("w1"), w2("w2"), w3("w3"), lastTokensStates("lastTokensStates"), logits("logits");
ComputeGraphNode rmsNormWeight("rmsNormWeight");
graph.Embedding(inputIds, wNodes["model.embed_tokens.weight"], hiddenStates);
graph.Mul(hiddenStates, sqrt(model->embed_dim), hiddenStates);
graph.DataTypeAs(hiddenStates, atype);
for (int i = 0; i < model->block_cnt; i++) {
std::string pre = "model.layers." + std::to_string(i);
ComputeGraphNode pastKey("pastKey." + std::to_string(i)), pastValue("pastValue." + std::to_string(i));
graph.Add(wNodes[pre + ".input_layernorm.weight"], 1.0f, rmsNormWeight);
graph.RMSNorm(hiddenStates, rmsNormWeight, model->rms_norm_eps, attenInput);
graph.Linear(attenInput, wNodes[pre + ".self_attn.q_proj.weight"], wNodes[pre + ".self_attn.q_proj.bias"], q);
graph.Linear(attenInput, wNodes[pre + ".self_attn.k_proj.weight"], wNodes[pre + ".self_attn.k_proj.bias"], k);
graph.Linear(attenInput, wNodes[pre + ".self_attn.v_proj.weight"], wNodes[pre + ".self_attn.v_proj.bias"], v);
graph.ExpandHead(q, head_dim);
graph.ExpandHead(k, head_dim);
graph.ExpandHead(v, head_dim);
graph.LlamaRotatePosition2D(q, positionIds, sin, cos, model->rotary_dim);
graph.LlamaRotatePosition2D(k, positionIds, sin, cos, model->rotary_dim);
graph.FusedAttention(q, pastKey, pastValue, k, v, attenInput, attentionMask, attenOutput, seqLens, 1.0 / sqrt(query_pre_attn_scalar), 0, 128);
graph.Linear(attenOutput, wNodes[pre + ".self_attn.o_proj.weight"], wNodes[pre + ".self_attn.o_proj.bias"], attenLastOutput);
graph.Add(wNodes[pre + ".post_attention_layernorm.weight"], 1.0f, rmsNormWeight);
graph.RMSNorm(attenLastOutput, rmsNormWeight, model->rms_norm_eps, attenOutput);
graph.AddTo(hiddenStates, attenOutput);
graph.Add(wNodes[pre + ".pre_feedforward_layernorm.weight"], 1.0f, rmsNormWeight);
graph.RMSNorm(hiddenStates, rmsNormWeight, model->rms_norm_eps, attenInput);
graph.Linear(attenInput, wNodes[pre + ".mlp.gate_proj.weight"], wNodes[pre + ".mlp.gate_proj.bias"], w1);
graph.Linear(attenInput, wNodes[pre + ".mlp.up_proj.weight"], wNodes[pre + ".mlp.up_proj.bias"], w3);
graph.Gelu(w1, w1);
graph.MulTo(w1, w3);
graph.Linear(w1, wNodes[pre + ".mlp.down_proj.weight"], wNodes[pre + ".mlp.down_proj.bias"], w2);
graph.Add(wNodes[pre + ".post_feedforward_layernorm.weight"], 1.0f, rmsNormWeight);
graph.RMSNorm(w2, rmsNormWeight, model->rms_norm_eps, w1);
graph.AddTo(hiddenStates, w1);
}

graph.SplitLastTokenStates(hiddenStates, seqLens, lastTokensStates);
graph.Add(wNodes["model.norm.weight"], 1.0f, rmsNormWeight);
graph.RMSNorm(lastTokensStates, rmsNormWeight, model->rms_norm_eps, lastTokensStates);
graph.Linear(lastTokensStates, wNodes["lm_head.weight"], wNodes["lm_head.bias"], logits);

OptimizeComputeGraph(graph, model->weight);
graph.Update();
}
};
REGISTERGRAPHMODELCONFIG(gemma2, Gemma2GraphModelConfig)
}

0 comments on commit 60871ce

Please sign in to comment.