diff --git a/include/devices/tfacc/fastllm-tfacc.h b/include/devices/tfacc/fastllm-tfacc.h index 54480c21..04ec2f85 100644 --- a/include/devices/tfacc/fastllm-tfacc.h +++ b/include/devices/tfacc/fastllm-tfacc.h @@ -37,7 +37,8 @@ namespace fastllm { fastllm::Data *weight, fastllm::Data *bias, std::vector *inputConfigs, uint8_t *uinput, float *output, - LinearExType exType); + LinearExType exType, + DataType outputType); void RunTfaccLinearF(int n, int m, int k, fastllm::Data *weight, fastllm::Data *bias, float *input, float *output, LinearExType exType, DataType dataType); diff --git a/src/devices/tfacc/fastllm-tfacc.cpp b/src/devices/tfacc/fastllm-tfacc.cpp index 1018785c..c7bcb23e 100644 --- a/src/devices/tfacc/fastllm-tfacc.cpp +++ b/src/devices/tfacc/fastllm-tfacc.cpp @@ -211,7 +211,8 @@ namespace fastllm { fastllm::Data *weight, fastllm::Data *bias, std::vector *inputConfigs, uint8_t *uinput, float *output, - LinearExType exType) { + LinearExType exType, + DataType outputType) { std::string linearType = "linear"; if (exType == LinearExType::ExSwiglu) { linearType = "linearSwiglu"; @@ -235,6 +236,7 @@ namespace fastllm { maxN = std::min(maxN, (int)(transLimit / (k * sizeof(float)))); // printf("maxN = %d\n", maxN); + int outputUnitSize = (outputType == DataType::FLOAT32 ? sizeof(float) : sizeof(uint16_t)); for (int baseN = 0; baseN < n; baseN += maxN) { // auto st0 = std::chrono::system_clock::now(); int curN = std::min(maxN, n - baseN); @@ -246,6 +248,7 @@ namespace fastllm { ((int32_t*)buf)[5] = weight->name.size(); ((int32_t*)buf)[6] = biasName.size(); ((int32_t*)buf)[7] = exType; + ((int32_t*)buf)[8] = outputType; volatile uint8_t *cur = (uint8_t*)buf + 10 * sizeof(int32_t); for (int i = 0; i < curN * group; i++) { @@ -269,9 +272,9 @@ namespace fastllm { auto pool = GetAlivePool(); - RunMultiThreadMemcpy(((uint8_t*) output) + baseN * outK * sizeof(int32_t), + RunMultiThreadMemcpy(((uint8_t*) output) + baseN * outK * outputUnitSize, (uint8_t*) result, - curN * outK * sizeof(int32_t), GetAlivePool()); + curN * outK * outputUnitSize, GetAlivePool()); // auto st3 = std::chrono::system_clock::now(); // if (n > 0) printf("n = %d, m = %d, k = %d, input = %f s, calc = %f s, output = %f. total = %f\n", n, m, k, GetSpan(st0, st1), GetSpan(st1, st2), GetSpan(st2, st3), GetSpan(st0, st3)); } diff --git a/src/devices/tfacc/tfaccdevice.cpp b/src/devices/tfacc/tfaccdevice.cpp index b8d264bb..ccce00c8 100644 --- a/src/devices/tfacc/tfaccdevice.cpp +++ b/src/devices/tfacc/tfaccdevice.cpp @@ -24,16 +24,19 @@ #include "utils.h" namespace fastllm { + extern FP16ToFP32Manager fp16tofp32; + extern void Float16ToFloat32(uint16_t *float16, float *float32, int len); + static TfaccClient tfaccClient; TfaccDevice::TfaccDevice() { this->deviceType = "tfacc"; this->ops["Linear"] = (BaseOperator *) (new TfaccLinearOp()); - this->ops["CatDirect"] = (BaseOperator *) (new TfaccCatDirectOp()); + /*this->ops["CatDirect"] = (BaseOperator *) (new TfaccCatDirectOp()); this->ops["Attention"] = (BaseOperator *) (new TfaccAttention()); this->ops["AttentionBatch"] = (BaseOperator *) (new TfaccAttentionBatchOp()); - this->ops["CatDirectBatch"] = (BaseOperator *) (new TfaccCatDirectBatchOp()); + this->ops["CatDirectBatch"] = (BaseOperator *) (new TfaccCatDirectBatchOp());*/ } bool TfaccDevice::Malloc(void **ret, size_t size) { @@ -176,7 +179,7 @@ namespace fastllm { ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); } else if (weight.dataType == DataType::INT8 || weight.dataType == DataType::INT4_NOZERO || weight.dataType == DataType::INT4_GROUP) { // auto st = std::chrono::system_clock::now(); - tfaccClient.RunTfaccLinearU(n, m, k, group, groupCnt, &weight, &bias, &inputConfigs, uinput.data(), outputData, exType); + tfaccClient.RunTfaccLinearU(n, m, k, group, groupCnt, &weight, &bias, &inputConfigs, uinput.data(), outputData, exType, output.dataType); // float spend = GetSpan(st, std::chrono::system_clock::now()); // float gops = (float)n * m * k / spend / 1e9; // inner = spend; @@ -184,7 +187,69 @@ namespace fastllm { } } } else if (input.dataType == DataType::FLOAT16 && output.dataType == DataType::FLOAT16) { - ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); + if (weight.dataType == DataType::FLOAT32 || weight.dataType == DataType::FLOAT16) { + ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); + } else if (weight.dataType == DataType::INT4 || + weight.dataType == DataType::INT4_NOZERO || + weight.dataType == DataType::INT4_GROUP || + weight.dataType == DataType::INT8) { + uint16_t *inputData = (uint16_t *) input.cpuData; + uint8_t *weightData = (uint8_t *) weight.cpuData; + uint16_t *outputData = (uint16_t *) output.cpuData; + float *biasData = bias.dims.size() > 0 ? (float *) bias.cpuData : nullptr; + weight.CalcWeightSum(); + + int group = weight.group, groupCnt = weight.groupCnt; + if (weight.dataType != DataType::INT4_GROUP) { + group = 1; + groupCnt = m; + } + + int outputLen = output.Count(0); + std::vector floatInputData; + floatInputData.resize(n * m); + Float16ToFloat32(inputData, floatInputData.data(), n * m); + + std::vector inputConfigs; + inputConfigs.resize(n * group); + std::vector uinput; + uinput.resize(n * m); + + if (n > 1) { + auto pool = GetAlivePool(); + int threadNum = pool->threads.size(); + int per = n / pool->threads.size(); + int cur = 0; + std::vector ops; + for (int i = 0; i < threadNum; i++) { + int end = (i == threadNum - 1 ? n : cur + per + (cur + per * (threadNum - i) < n)); + ops.push_back(new MultiThreadOnlineQuantizationOp( + floatInputData.data() + cur * m, uinput.data() + cur * m, inputConfigs.data() + cur * group, + end - cur, m, group, groupCnt, nullptr, nullptr, nullptr)); + cur = end; + } + for (int i = 0; i < threadNum; i++) { + pool->PushOp(i, ops[i]); + } + for (int i = 0; i < threadNum; i++) { + pool->Wait(i); + delete ops[i]; + } + } else { + MultiThreadOnlineQuantizationOp(floatInputData.data(), uinput.data(), inputConfigs.data(), n, m, group, groupCnt, nullptr, nullptr, nullptr).Run(); + } + + if (weight.dataType == DataType::INT4) { + ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); + } else if (weight.dataType == DataType::INT8 || weight.dataType == DataType::INT4_NOZERO || weight.dataType == DataType::INT4_GROUP) { +// auto st = std::chrono::system_clock::now(); + tfaccClient.RunTfaccLinearU(n, m, k, group, groupCnt, &weight, &bias, &inputConfigs, uinput.data(), (float*)outputData, exType, output.dataType); +// float spend = GetSpan(st, std::chrono::system_clock::now()); +// float gops = (float)n * m * k / spend / 1e9; +// inner = spend; +// if (n > 0) printf("n = %d, m = %d, k = %d, spend %f s, gops = %f (inner)\n", n, m, k, spend, gops); + } + } } else { ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); } diff --git a/third_party/tfacc/server b/third_party/tfacc/server index eb6a92e7..893bfbb6 100755 Binary files a/third_party/tfacc/server and b/third_party/tfacc/server differ