diff --git a/include/fastllm.h b/include/fastllm.h index fc052e86..4305a51a 100644 --- a/include/fastllm.h +++ b/include/fastllm.h @@ -285,6 +285,8 @@ namespace fastllm { Data (const Data &ori); // 深拷贝 + void CreateFromOriData(WeightType weightType, DataType oriDataType, uint8_t *oriData, int groupCnt = -1); // 从oriData中创建 + void CopyFrom(const Data &ori); // 复制 void FakeFrom(const Data &ori, size_t offset); // 将data指针指向ori的data + offset,delete时不销毁 @@ -461,6 +463,8 @@ namespace fastllm { void AddAdapterDict(const std::string &name, const std::string &key, const std::string &value); + void AddEmptyWeight(const std::string &key, const std::vector &dims, fastllm::DataType dataType); + void AddWeight(const std::string &key, const std::vector &dims, DataType dataType, WeightType weightType, DataType oriDataType, uint8_t *oriData, int groupCnt = -1); // 插入一个权重 diff --git a/src/fastllm.cpp b/src/fastllm.cpp index 60b6e3d1..8fe74484 100644 --- a/src/fastllm.cpp +++ b/src/fastllm.cpp @@ -321,6 +321,298 @@ namespace fastllm { std::memcpy(this->cpuData, ori.cpuData, this->GetBytes()); } + struct BF16ToFP16Manager { + float dict[65536]; + + BF16ToFP16Manager() { + for (uint16_t i = 0; i < 65535; i++) { + uint32_t x = (i << 16); + dict[i] = float_to_half(*((float*)&x)); + } + } + } bf16tofp16; + + struct BF16ToFP32Manager { + float dict[65536]; + + BF16ToFP32Manager() { + for (uint16_t i = 0; i < 65535; i++) { + uint32_t x = (i << 16); + dict[i] = *((float*)&x); + } + } + } bf16tofp32; + + struct MultiThreadGroupQuantizationBF16Op : MultiThreadBaseOp { + int st, end, m; + uint16_t *bf; + uint8_t *u8; + LowBitConfig *configs; + int bit; + int group, groupCnt; + + MultiThreadGroupQuantizationBF16Op (int st, int end, int m, + uint16_t *bf, uint8_t *u8, LowBitConfig *configs, int bit, int group, int groupCnt) : + st(st), end(end), m(m), bf(bf), u8(u8), configs(configs), bit(bit), group(group), groupCnt(groupCnt) {} + + void Run() { + int type = (bit == 4) ? 1 : 0; + for (int i = st; i < end; i++) { + for (int g = 0; g < group; g++) { + int cid = i * group + g; + int groupStart = g * groupCnt; + int groupEnd = std::min((g + 1) * groupCnt, m); + + float minValue = 1e9, maxValue = -1e9; + for (int j = groupStart; j < groupEnd; j++) { + minValue = std::min(minValue, bf16tofp32.dict[bf[i * m + j]]); + maxValue = std::max(maxValue, bf16tofp32.dict[bf[i * m + j]]); + } + if (bit == 8) { + configs[cid] = LowBitConfig(minValue, maxValue, 8, type); + for (int j = groupStart; j < groupEnd; j++) { + u8[i * m + j] = configs[cid].quantization(bf16tofp32.dict[bf[i * m + j]]); + } + } else { + configs[cid] = LowBitConfig(minValue, maxValue, 4, type); + for (int j = groupStart; j < groupEnd; j++) { + int id = (i * m + j) / 2; + uint8_t value = configs[cid].quantization(bf16tofp32.dict[bf[i * m + j]]); + if ((i * m + j) % 2) { + u8[id] = (u8[id] & 0xF0) | value; + } else { + u8[id] = (u8[id] & 0xF) | (value << 4); + } + } + } + } + } + } + }; + + struct MultiThreadGroupQuantizationOp : MultiThreadBaseOp { + int st, end, m; + float *f; + uint8_t *u8; + LowBitConfig *configs; + int bit; + int group, groupCnt; + + MultiThreadGroupQuantizationOp (int st, int end, int m, + float *f, uint8_t *u8, LowBitConfig *configs, int bit, int group, int groupCnt) : + st(st), end(end), m(m), f(f), u8(u8), configs(configs), bit(bit), group(group), groupCnt(groupCnt) {} + + void Run() { + int type = (bit == 4) ? 1 : 0; + for (int i = st; i < end; i++) { + for (int g = 0; g < group; g++) { + int cid = i * group + g; + int groupStart = g * groupCnt; + int groupEnd = std::min((g + 1) * groupCnt, m); + + float minValue = 1e9, maxValue = -1e9; + for (int j = groupStart; j < groupEnd; j++) { + minValue = std::min(minValue, f[i * m + j]); + maxValue = std::max(maxValue, f[i * m + j]); + } + if (bit == 8) { + configs[cid] = LowBitConfig(minValue, maxValue, 8, type); + for (int j = groupStart; j < groupEnd; j++) { + u8[i * m + j] = configs[cid].quantization(f[i * m + j]); + } + } else { + configs[cid] = LowBitConfig(minValue, maxValue, 4, type); + for (int j = groupStart; j < groupEnd; j++) { + int id = (i * m + j) / 2; + uint8_t value = configs[cid].quantization(f[i * m + j]); + if ((i * m + j) % 2) { + u8[id] = (u8[id] & 0xF0) | value; + } else { + u8[id] = (u8[id] & 0xF) | (value << 4); + } + } + } + } + } + } + }; + + struct MultiThreadPerChannelQuantizationBF16Op : MultiThreadBaseOp { + int st, end, m; + uint16_t *bf; + uint8_t *u8; + LowBitConfig *configs; + int bit; + + MultiThreadPerChannelQuantizationBF16Op (int st, int end, int m, + uint16_t *bf, uint8_t *u8, LowBitConfig *configs, int bit) : + st(st), end(end), m(m), bf(bf), u8(u8), configs(configs), bit(bit) {} + + void Run() { + int type = (bit == 4) ? 1 : 0; + for (int i = st; i < end; i++) { + float minValue = 1e9, maxValue = -1e9; + for (int j = 0; j < m; j++) { + minValue = std::min(minValue, bf16tofp32.dict[bf[i * m + j]]); + maxValue = std::max(maxValue, bf16tofp32.dict[bf[i * m + j]]); + } + if (bit == 8) { + configs[i] = LowBitConfig(minValue, maxValue, 8, type); + for (int j = 0; j < m; j++) { + u8[i * m + j] = configs[i].quantization(bf16tofp32.dict[bf[i * m + j]]); + } + } else { + configs[i] = LowBitConfig(minValue, maxValue, 4, type); + for (int j = 0; j < m; j++) { + int id = (i * m + j) / 2; + uint8_t value = configs[i].quantization(bf16tofp32.dict[bf[i * m + j]]); + if ((i * m + j) % 2) { + u8[id] = (u8[id] & 0xF0) | value; + } else { + u8[id] = (u8[id] & 0xF) | (value << 4); + } + } + } + } + } + }; + + struct MultiThreadPerChannelQuantizationOp : MultiThreadBaseOp { + int st, end, m; + float *f; + uint8_t *u8; + LowBitConfig *configs; + int bit; + + MultiThreadPerChannelQuantizationOp (int st, int end, int m, + float *f, uint8_t *u8, LowBitConfig *configs, int bit) : + st(st), end(end), m(m), f(f), u8(u8), configs(configs), bit(bit) {} + + void Run() { + int type = (bit == 4) ? 1 : 0; + for (int i = st; i < end; i++) { + float minValue = 1e9, maxValue = -1e9; + for (int j = 0; j < m; j++) { + minValue = std::min(minValue, f[i * m + j]); + maxValue = std::max(maxValue, f[i * m + j]); + } + if (bit == 8) { + configs[i] = LowBitConfig(minValue, maxValue, 8, type); + for (int j = 0; j < m; j++) { + u8[i * m + j] = configs[i].quantization(f[i * m + j]); + } + } else { + configs[i] = LowBitConfig(minValue, maxValue, 4, type); + for (int j = 0; j < m; j++) { + int id = (i * m + j) / 2; + uint8_t value = configs[i].quantization(f[i * m + j]); + if ((i * m + j) % 2) { + u8[id] = (u8[id] & 0xF0) | value; + } else { + u8[id] = (u8[id] & 0xF) | (value << 4); + } + } + } + } + } + }; + + void Data::CreateFromOriData(WeightType weightType, DataType oriDataType, uint8_t *oriData, int groupCnt) { + auto &data = *this; + data.weightType = weightType; + data.UpdateUnitSize(); + data.Allocate(); + if (dataType == oriDataType) { + memcpy(data.cpuData, oriData, data.GetBytes()); + } else if (oriDataType == DataType::BFLOAT16 + && dataType == DataType::FLOAT16) { + uint16_t *a = (uint16_t*)data.cpuData; + uint16_t *b = (uint16_t*)oriData; + int len = data.Count(0); + for (int i = 0; i < len; i++) { + a[i] = bf16tofp16.dict[b[i]]; + } + } else if (oriDataType == DataType::FLOAT32 + && dataType == DataType::FLOAT16) { + uint16_t *a = (uint16_t*)data.cpuData; + float *b = (float*)oriData; + int len = data.Count(0); + for (int i = 0; i < len; i++) { + a[i] = float_to_half(b[i]); + } + } else if ((oriDataType == DataType::FLOAT32 || oriDataType == DataType::BFLOAT16) + && dataType == DataType::INT4_GROUP) { + int bit = (dataType == DataType::INT4_GROUP) ? 4 : 8; + int type = (bit == 4) ? 1 : 0; + int k = data.dims[0], m = data.dims[1]; + if (groupCnt == -1) { + groupCnt = 128; + } + int group = (m - 1) / groupCnt + 1; + std::vector configs; + std::vector uDatas; + configs.resize(k * group); + + int bytes = k * m; + if (bit == 4) { + bytes = (k * m + 1) / 2; + } + uDatas.resize(bytes); + if (oriDataType == DataType::FLOAT32) { + (MultiThreadGroupQuantizationOp(0, k, m, (float*)oriData, uDatas.data(), configs.data(), bit, group, groupCnt)).Run(); + } else if (oriDataType == DataType::BFLOAT16) { + (MultiThreadGroupQuantizationBF16Op(0, k, m, (uint16_t*)oriData, uDatas.data(), configs.data(), bit, group, groupCnt)).Run(); + } + data.perChannelAxis = 0; + data.perChannelsConfigs.resize(k * group); + data.group = group; + data.groupCnt = groupCnt; + data.zeros.resize(k * group); + data.scales.resize(k * group); + data.mins.resize(k * group); + for (int i = 0; i < k * group; i++) { + data.perChannelsConfigs[i] = LowBitConfig(configs[i].min, configs[i].max, bit, type); + data.mins[i] = data.perChannelsConfigs[i].min; + data.zeros[i] = data.perChannelsConfigs[i].zeroPoint; + data.scales[i] = data.perChannelsConfigs[i].scale; + } + memcpy((uint8_t*)data.cpuData, (uint8_t*)uDatas.data(), bytes); + } else if ((oriDataType == DataType::FLOAT32 || oriDataType == DataType::BFLOAT16) && + (dataType == DataType::INT8 || dataType == DataType::INT4_NOZERO)) { + int bit = (dataType == DataType::INT4_NOZERO) ? 4 : 8; + int type = (bit == 4) ? 1 : 0; + int k = data.dims[0], m = data.dims[1]; + std::vector configs; + std::vector uDatas; + configs.resize(k); + + int bytes = k * m; + if (bit == 4) { + bytes = (k * m + 1) / 2; + } + uDatas.resize(bytes); + if (oriDataType == DataType::FLOAT32) { + (MultiThreadPerChannelQuantizationOp(0, k, m, (float *) oriData, uDatas.data(), configs.data(), bit)).Run(); + } else if (oriDataType == DataType::BFLOAT16) { + (MultiThreadPerChannelQuantizationBF16Op(0, k, m, (uint16_t *) oriData, uDatas.data(), configs.data(), bit)).Run(); + } + data.perChannelAxis = 0; + data.perChannelsConfigs.resize(k); + data.zeros.resize(k); + data.scales.resize(k); + data.mins.resize(k); + for (int i = 0; i < k; i++) { + data.perChannelsConfigs[i] = LowBitConfig(configs[i].min, configs[i].max, bit, type); + data.mins[i] = data.perChannelsConfigs[i].min; + data.zeros[i] = data.perChannelsConfigs[i].zeroPoint; + data.scales[i] = data.perChannelsConfigs[i].scale; + } + memcpy((uint8_t*)data.cpuData, (uint8_t*)uDatas.data(), bytes); + } else { + ErrorInFastLLM("wrong data type"); + } + } + uint64_t Data::Count(int i) const { if (i >= this->dims.size()) { return 1; @@ -1764,93 +2056,6 @@ namespace fastllm { return; } - struct MultiThreadGroupQuantizationOp : MultiThreadBaseOp { - int st, end, m; - float *f; - uint8_t *u8; - LowBitConfig *configs; - int bit; - int group, groupCnt; - - MultiThreadGroupQuantizationOp (int st, int end, int m, - float *f, uint8_t *u8, LowBitConfig *configs, int bit, int group, int groupCnt) : - st(st), end(end), m(m), f(f), u8(u8), configs(configs), bit(bit), group(group), groupCnt(groupCnt) {} - - void Run() { - int type = (bit == 4) ? 1 : 0; - for (int i = st; i < end; i++) { - for (int g = 0; g < group; g++) { - int cid = i * group + g; - int groupStart = g * groupCnt; - int groupEnd = std::min((g + 1) * groupCnt, m); - - float minValue = 1e9, maxValue = -1e9; - for (int j = groupStart; j < groupEnd; j++) { - minValue = std::min(minValue, f[i * m + j]); - maxValue = std::max(maxValue, f[i * m + j]); - } - if (bit == 8) { - configs[cid] = LowBitConfig(minValue, maxValue, 8, type); - for (int j = groupStart; j < groupEnd; j++) { - u8[i * m + j] = configs[cid].quantization(f[i * m + j]); - } - } else { - configs[cid] = LowBitConfig(minValue, maxValue, 4, type); - for (int j = groupStart; j < groupEnd; j++) { - int id = (i * m + j) / 2; - uint8_t value = configs[cid].quantization(f[i * m + j]); - if ((i * m + j) % 2) { - u8[id] = (u8[id] & 0xF0) | value; - } else { - u8[id] = (u8[id] & 0xF) | (value << 4); - } - } - } - } - } - } - }; - - struct MultiThreadPerChannelQuantizationOp : MultiThreadBaseOp { - int st, end, m; - float *f; - uint8_t *u8; - LowBitConfig *configs; - int bit; - - MultiThreadPerChannelQuantizationOp (int st, int end, int m, - float *f, uint8_t *u8, LowBitConfig *configs, int bit) : - st(st), end(end), m(m), f(f), u8(u8), configs(configs), bit(bit) {} - - void Run() { - int type = (bit == 4) ? 1 : 0; - for (int i = st; i < end; i++) { - float minValue = 1e9, maxValue = -1e9; - for (int j = 0; j < m; j++) { - minValue = std::min(minValue, f[i * m + j]); - maxValue = std::max(maxValue, f[i * m + j]); - } - if (bit == 8) { - configs[i] = LowBitConfig(minValue, maxValue, 8, type); - for (int j = 0; j < m; j++) { - u8[i * m + j] = configs[i].quantization(f[i * m + j]); - } - } else { - configs[i] = LowBitConfig(minValue, maxValue, 4, type); - for (int j = 0; j < m; j++) { - int id = (i * m + j) / 2; - uint8_t value = configs[i].quantization(f[i * m + j]); - if ((i * m + j) % 2) { - u8[id] = (u8[id] & 0xF0) | value; - } else { - u8[id] = (u8[id] & 0xF) | (value << 4); - } - } - } - } - } - }; - void WeightMap::SaveLowBitModel(const std::string &fileName, int bit) { AssertInFastLLM(fileName != "", "Error: output's name shouldn't be empty.\n"); AssertInFastLLM(bit == 0 || bit == 4 || bit == 8 || bit == 16, "Error: only support 16 bit or 8 bit or 4 bit model.\n"); @@ -2113,6 +2318,11 @@ namespace fastllm { } } + void WeightMap::AddEmptyWeight(const std::string &key, const std::vector &dims, fastllm::DataType dataType) { + this->weight[key] = Data(dataType, dims); + this->weight[key].name = std::string(key); + } + void WeightMap::AddWeight(const std::string &key, const std::vector &dims, fastllm::DataType dataType, fastllm::WeightType weightType, fastllm::DataType oriDataType, uint8_t *oriData, int groupCnt) { if (weightType == WeightType::AUTO) { diff --git a/src/model.cpp b/src/model.cpp index 800ce4ea..1f9e0723 100644 --- a/src/model.cpp +++ b/src/model.cpp @@ -456,16 +456,77 @@ namespace fastllm { } // 4. 读取权重 + auto tensors = safeTensors.GetSortedItemNames(); int cur = 0; - for (auto &weightName : safeTensors.GetSortedItemNames()) { + long long totalBytes = 0; + for (auto &weightName : tensors) { auto &tensor = safeTensors.itmeDict[weightName]; - tensor.CreateBuffer(DataType::FLOAT32); - model->weight.AddWeight(weightName, tensor.intShape, linearDataType, WeightType::AUTO, DataType::FLOAT32, tensor.buffer, groupCnt); - tensor.ClearBuffer(); + auto oriDataType = DataType::FLOAT32; + auto weightType = model->weight.GetWeightType(weightName); + auto dataType = (weightType == WeightType::EMBEDDING || weightType == WeightType::NONE) ? oriDataType : linearDataType; + model->weight.AddEmptyWeight(weightName, tensor.intShape, dataType); + totalBytes += tensor.bytes; - printf("Load (%d / %d) \r", (++cur), (int)safeTensors.itmeDict.size()); + printf("Load %d \r", (++cur) * 100 / (int)safeTensors.itmeDict.size()); fflush(stdout); } + + std::vector threads; + int threadNum = std::min(16, std::max(4, (int)GetAlivePool()->threads.size())); + int per = tensors.size() / threadNum; + std::mutex locker; + int cnt = 0; + + std::vector > parts; + int start = 0; + for (int i = 0; i < threadNum; i++) { + int cur = start; + long long now = 0; + while (true) { + if (now * threadNum >= totalBytes || start >= tensors.size()) { + break; + } + now += safeTensors.itmeDict[tensors[start]].bytes; + start++; + } + parts.push_back(std::make_pair(cur, start)); + } + parts.back().second = tensors.size(); + while (parts.size() < threadNum) { + parts.push_back(std::make_pair(-1, -1)); + } + + for (int i = 0; i < threadNum; i++) { + int st = per * i, end = (i == threadNum - 1) ? tensors.size() : per * (i + 1); + threads.push_back( + new std::thread([&](int st, int end) { + for (int i = st; i < end; i++) { + auto &weightName = tensors[i]; + auto &tensor = safeTensors.itmeDict[weightName]; + auto oriDataType = DataType::FLOAT32; + auto weightType = model->weight.GetWeightType(weightName); + auto dataType = (weightType == WeightType::EMBEDDING || weightType == WeightType::NONE) ? oriDataType : linearDataType; + + if (tensor.dtype == "BF16" && + (dataType == DataType::FLOAT16 || dataType == DataType::INT8 || dataType == DataType::INT4_GROUP || dataType == DataType::INT4_NOZERO)) { + oriDataType = DataType::BFLOAT16; + } + tensor.CreateBuffer(oriDataType); + model->weight[weightName].CreateFromOriData(weightType, oriDataType, tensor.buffer, groupCnt); + tensor.ClearBuffer(); + locker.lock(); + printf("Convert %d \r", (++cnt) * 100 / (int)safeTensors.itmeDict.size()); + fflush(stdout); + locker.unlock(); + } + }, parts[i].first, parts[i].second) + ); + } + for (int i = 0; i < threads.size(); i++) { + threads[i]->join(); + delete threads[i]; + } + printf("\n"); fflush(stdout);