Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master'
Browse files Browse the repository at this point in the history
  • Loading branch information
jiewlmrh committed Jul 4, 2024
2 parents 1c9448f + dead943 commit f52960c
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 75 deletions.
2 changes: 1 addition & 1 deletion include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ namespace fastllm {
INT4_NOZERO = 8, // 不用zeroPoint的int4, floatValue = min + uint4Value * scale
INT4_GROUP = 9, // 不用zeroPoint的int4, floatValue = min + uint4Value * scale, 且使用分组量化
INT32PARAM = 100, // int32的参数,这种类型的数据永远存在CPU上
DATA_AUTO_NONE = 99999, DATA_AUTO_LINEAR, DATA_AUTO_EMBEDDING // 不确定类型
DATA_AUTO_NONE = 99999, DATA_AUTO_LINEAR, DATA_AUTO_EMBEDDING, DATA_AUTO_CONV
};

enum DataDevice {
Expand Down
48 changes: 43 additions & 5 deletions src/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,16 @@ namespace fastllm {
return std::unique_ptr<fastllm::basellm> (model);
}

template <typename T>
void TransposeSimple(T *pDst, T *pSrc, int dstStride, int srcStride, int n, int m) {
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
pDst[j * dstStride + i] = pSrc[i * srcStride + j];
}
}
}
extern void Transpose(float *pDst, float *pSrc, int dstStride, int srcStride, int n, int m);

struct SafeTensorItem {
std::string tensorName;
std::string fileName;
Expand Down Expand Up @@ -270,6 +280,23 @@ namespace fastllm {
fclose(fi);
}

void Transpose(DataType type) {
int n = intShape[0], m = intShape[1];
if (type == DataType::FLOAT32) {
float *temp = new float[len];
memcpy(temp, this->buffer, len * sizeof(float));
fastllm::Transpose((float*)this->buffer, temp, n, m, n, m);
delete[] temp;
} else if (type == DataType::FLOAT16 || type == DataType::BFLOAT16) {
uint16_t *temp = new uint16_t[len];
memcpy(temp, this->buffer, len * sizeof(uint16_t));
TransposeSimple((uint16_t*)this->buffer, temp, n, m, n, m);
delete[] temp;
} else {
ErrorInFastLLM("SafeTensorItem.Transpose: unsupport dtype " + std::to_string(type) + "\n");
}
}

void ClearBuffer() {
delete[] buffer;
buffer = nullptr;
Expand Down Expand Up @@ -304,7 +331,9 @@ namespace fastllm {
std::vector <std::string> GetSortedItemNames() {
std::vector <std::pair <std::pair <std::string, uint64_t>, std::string> > v;
for (auto &it : itmeDict) {
v.push_back(std::make_pair(std::make_pair(it.second.fileName, it.second.data_offsets[0]), it.first));
if (it.second.intShape.size() > 0 && it.second.dtype != "BOOL") {
v.push_back(std::make_pair(std::make_pair(it.second.fileName, it.second.data_offsets[0]), it.first));
}
}
std::sort(v.begin(), v.end());
std::vector <std::string> ret;
Expand Down Expand Up @@ -540,9 +569,15 @@ namespace fastllm {
auto dataType = it.second;
if (dataType >= DATA_AUTO_NONE) {
// AUTO类型
dataType = (dataType == DATA_AUTO_LINEAR) ? linearDataType : oriDataType;
dataType = (dataType == DATA_AUTO_LINEAR || dataType == DATA_AUTO_CONV) ? linearDataType : oriDataType;
}
if (it.second == DATA_AUTO_CONV) {
std::vector <int> realShape = tensor.intShape;
std::swap(realShape[0], realShape[1]);
model->weight.AddEmptyWeight(weightName, realShape, dataType);
} else {
model->weight.AddEmptyWeight(weightName, tensor.intShape, dataType);
}
model->weight.AddEmptyWeight(weightName, tensor.intShape, dataType);
}

totalBytes += tensor.bytes;
Expand Down Expand Up @@ -590,7 +625,7 @@ namespace fastllm {
auto dataType = it.second;
if (dataType >= DATA_AUTO_NONE) {
// AUTO类型
dataType = (dataType == DATA_AUTO_LINEAR) ? linearDataType : oriDataType;
dataType = (dataType == DATA_AUTO_LINEAR || dataType == DATA_AUTO_CONV) ? linearDataType : oriDataType;
}
if (tensor.dtype == "BF16" &&
(dataType == DataType::FLOAT16 || dataType == DataType::INT8 || dataType == DataType::INT4_GROUP || dataType == DataType::INT4_NOZERO)) {
Expand All @@ -601,12 +636,15 @@ namespace fastllm {
oriDataType = DataType::FLOAT16;
}
tensor.CreateBuffer(oriDataType);
if (it.second == DATA_AUTO_CONV) {
tensor.Transpose(oriDataType);
}
model->weight[weightName].CreateFromOriData(WeightType::AUTO, oriDataType, tensor.buffer, groupCnt);
tensor.ClearBuffer();
}

locker.lock();
printf("Convert %d \r", (++cnt) * 100 / (int)safeTensors.itmeDict.size());
printf("Convert %d \r", (++cnt) * 100 / (int)tensorMap.size());
fflush(stdout);
locker.unlock();
}
Expand Down
Loading

0 comments on commit f52960c

Please sign in to comment.