Skip to content

Commit

Permalink
cogvlm支持float16
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Oct 24, 2024
1 parent e6a5833 commit d295968
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 5 deletions.
5 changes: 3 additions & 2 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3776,14 +3776,15 @@ bool FastllmCudaHalfAttention(const fastllm::Data &q, const fastllm::Data &k, co
half beta = __float2half_rn(0.0f), one = __float2half_rn(1.0f), hscale = __float2half_rn(scale);
if (q1 >= 1024 || (q1 > 1 && q1 != k1 && k1 >= 1024)) {
int alignQ1 = q1, alignK1 = k1;
int part = alignK1;
bool useFastAttn = getCudaInfos()->hasTensorCore && batch == 1 && (q2 == 128 && v2 == 128) && maskType == 0;
if (useFastAttn) {
alignQ1 = ((q1 - 1) / 128 + 1) * 128;
alignK1 = ((k1 - 1) / 128 + 1) * 128;
part = (alignK1 > 8192 ? 8192 : alignK1);
}

int part = (alignK1 > 8192 ? 8192 : alignK1);
half *qk = (half *) FastllmCudaMalloc(alignQ1 * part * sizeof(half));

cudaMemset(qk, 0, alignQ1 * part * sizeof(half));
auto fastllmCublasHandle = getFastllmCublasHandle();
cublasStatus_t status;
Expand Down
3 changes: 2 additions & 1 deletion src/models/basellm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1125,7 +1125,8 @@ printf("len = %d, spend = %f s. tokens / s = %f\n", (int)total, spend, (float)to
} else if (dataType == DataType::FLOAT16) {
AssertInFastLLM(this->model_struct == "chatglm" ||
this->model_struct == "llama" ||
this->model_struct == "graph",
this->model_struct == "graph" ||
this->model_struct == "cogvlm",
this->model_struct + " doesn't support float16");
} else {
ErrorInFastLLM("SetDataType Error: datatype should be float32 or float16");
Expand Down
8 changes: 6 additions & 2 deletions src/models/cogvlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ namespace fastllm {
int visionLayers = atoi(this->weight.dicts["vision_config.num_hidden_layers"].c_str());
int visionNumHeads = atoi(this->weight.dicts["vision_config.num_heads"].c_str());

ToDataType(x, this->dataType);
for (int i = 0; i < visionLayers; i++) {
std::string pre = "model.vision.transformer.layers." + std::to_string(i);
int B = x.dims[0], L = x.dims[1];
Expand Down Expand Up @@ -231,6 +232,7 @@ namespace fastllm {
AddTo(x, mlp);
}

ToDataType(x, DataType::FLOAT32);
Split(x, 1, 1, x.dims[1], y);
int gridSize = int(sqrt(y.dims[1]) + 1e-9);
y.Reshape({y.dims[0], gridSize, gridSize, y.dims[2]});
Expand Down Expand Up @@ -272,7 +274,8 @@ namespace fastllm {
startPos = 1;
endPos = 1;
}


ToDataType(x, this->dataType);
Data &hiddenStates = x;
Data attenInput, w1, w2, textW2, visionW2, w3;
Data* sinDataPtr = &sinData;
Expand Down Expand Up @@ -460,7 +463,8 @@ namespace fastllm {
Data norm, logit;
RMSNorm(*lastHiddenStates, this->weight["model.norm.weight"], rms_norm_eps, norm);
Linear(norm, this->weight["lm_head.weight"], Data(), logit);


ToDataType(logit, DataType::FLOAT32);
std::vector <int> lastRet;
Data topk;
TopK(logit, topk, 1);
Expand Down

0 comments on commit d295968

Please sign in to comment.