From bde1985c377cc55f0f1d62546961b4cb8a55e2f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=BB=84=E5=AE=87=E6=89=AC?= Date: Thu, 6 Jun 2024 19:00:58 +0800 Subject: [PATCH] =?UTF-8?q?main=E6=94=AF=E6=8C=81=E8=AE=BE=E7=BD=AEact=20t?= =?UTF-8?q?ype?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- main.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/main.cpp b/main.cpp index 7af06475..5a82b95e 100644 --- a/main.cpp +++ b/main.cpp @@ -18,6 +18,7 @@ struct RunConfig { bool lowMemMode = false; // 是否使用低内存模式 fastllm::DataType dtype = fastllm::DataType::FLOAT16; + fastllm::DataType atype = fastllm::DataType::FLOAT32; int groupCnt = -1; }; @@ -74,6 +75,11 @@ void ParseArgs(int argc, char **argv, RunConfig &config, fastllm::GenerationConf fastllm::AssertInFastLLM(dataTypeDict.find(dtypeStr) != dataTypeDict.end(), "Unsupport data type: " + dtypeStr); config.dtype = dataTypeDict[dtypeStr]; + } else if (sargv[i] == "--atype") { + std::string atypeStr = sargv[++i]; + fastllm::AssertInFastLLM(dataTypeDict.find(atypeStr) != dataTypeDict.end(), + "Unsupport act type: " + atypeStr); + config.atype = dataTypeDict[atypeStr]; } else { Usage(); exit(-1); @@ -91,7 +97,10 @@ int main(int argc, char **argv) { fastllm::SetLowMemMode(config.lowMemMode); bool isHFDir = access((config.path + "/config.json").c_str(), R_OK) == 0 || access((config.path + "config.json").c_str(), R_OK) == 0; auto model = !isHFDir ? fastllm::CreateLLMModelFromFile(config.path) : fastllm::CreateLLMModelFromHF(config.path, config.dtype, config.groupCnt); - model->SetSaveHistoryChat(true); + if (config.atype != fastllm::DataType::FLOAT32) { + model->SetDataType(config.atype); + } + model->SetSaveHistoryChat(true); for (auto &it : config.eosToken) { generationConfig.stop_token_ids.insert(model->weight.tokenizer.GetTokenId(it));