Skip to content

Commit

Permalink
支持f16, bf16的lora
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Sep 26, 2024
1 parent 3d61022 commit 7502e4c
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -733,9 +733,13 @@ namespace fastllm {
int outDim = loraTensors->itmeDict[loraB].intShape[0];
int lora = loraTensors->itmeDict[loraA].intShape[0];

AssertInFastLLM(loraTensors->itmeDict[loraA].dtype == "F32" &&
loraTensors->itmeDict[loraB].dtype == "F32",
"Lora error: lora's dtype should be F32.");
AssertInFastLLM((loraTensors->itmeDict[loraA].dtype == "F32" ||
loraTensors->itmeDict[loraA].dtype == "F16" ||
loraTensors->itmeDict[loraA].dtype == "BF16") &&
(loraTensors->itmeDict[loraB].dtype == "F32" ||
loraTensors->itmeDict[loraB].dtype == "F16" ||
loraTensors->itmeDict[loraB].dtype == "BF16"),
"Lora error: lora's dtype should be F32 or F16 or BF16.");
loraTensors->itmeDict[loraA].CreateBuffer(DataType::FLOAT32);
loraTensors->itmeDict[loraB].CreateBuffer(DataType::FLOAT32);
float *weightA = (float*)loraTensors->itmeDict[loraA].buffer;
Expand Down

0 comments on commit 7502e4c

Please sign in to comment.