diff --git a/src/model.cpp b/src/model.cpp index 4b06059..eb1ad48 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -733,9 +733,13 @@ namespace fastllm { int outDim = loraTensors->itmeDict[loraB].intShape[0]; int lora = loraTensors->itmeDict[loraA].intShape[0]; - AssertInFastLLM(loraTensors->itmeDict[loraA].dtype == "F32" && - loraTensors->itmeDict[loraB].dtype == "F32", - "Lora error: lora's dtype should be F32."); + AssertInFastLLM((loraTensors->itmeDict[loraA].dtype == "F32" || + loraTensors->itmeDict[loraA].dtype == "F16" || + loraTensors->itmeDict[loraA].dtype == "BF16") && + (loraTensors->itmeDict[loraB].dtype == "F32" || + loraTensors->itmeDict[loraB].dtype == "F16" || + loraTensors->itmeDict[loraB].dtype == "BF16"), + "Lora error: lora's dtype should be F32 or F16 or BF16."); loraTensors->itmeDict[loraA].CreateBuffer(DataType::FLOAT32); loraTensors->itmeDict[loraB].CreateBuffer(DataType::FLOAT32); float *weightA = (float*)loraTensors->itmeDict[loraA].buffer;