Skip to content

Commit

Permalink
tfacc 支持float16 * int8
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Aug 21, 2024
1 parent 3a748e7 commit 29b8fa2
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 8 deletions.
3 changes: 2 additions & 1 deletion include/devices/tfacc/fastllm-tfacc.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ namespace fastllm {
fastllm::Data *weight, fastllm::Data *bias,
std::vector <LowBitConfig> *inputConfigs,
uint8_t *uinput, float *output,
LinearExType exType);
LinearExType exType,
DataType outputType);

void RunTfaccLinearF(int n, int m, int k, fastllm::Data *weight, fastllm::Data *bias,
float *input, float *output, LinearExType exType, DataType dataType);
Expand Down
9 changes: 6 additions & 3 deletions src/devices/tfacc/fastllm-tfacc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,8 @@ namespace fastllm {
fastllm::Data *weight, fastllm::Data *bias,
std::vector <LowBitConfig> *inputConfigs,
uint8_t *uinput, float *output,
LinearExType exType) {
LinearExType exType,
DataType outputType) {
std::string linearType = "linear";
if (exType == LinearExType::ExSwiglu) {
linearType = "linearSwiglu";
Expand All @@ -235,6 +236,7 @@ namespace fastllm {
maxN = std::min(maxN, (int)(transLimit / (k * sizeof(float))));

// printf("maxN = %d\n", maxN);
int outputUnitSize = (outputType == DataType::FLOAT32 ? sizeof(float) : sizeof(uint16_t));
for (int baseN = 0; baseN < n; baseN += maxN) {
// auto st0 = std::chrono::system_clock::now();
int curN = std::min(maxN, n - baseN);
Expand All @@ -246,6 +248,7 @@ namespace fastllm {
((int32_t*)buf)[5] = weight->name.size();
((int32_t*)buf)[6] = biasName.size();
((int32_t*)buf)[7] = exType;
((int32_t*)buf)[8] = outputType;

volatile uint8_t *cur = (uint8_t*)buf + 10 * sizeof(int32_t);
for (int i = 0; i < curN * group; i++) {
Expand All @@ -269,9 +272,9 @@ namespace fastllm {

auto pool = GetAlivePool();

RunMultiThreadMemcpy(((uint8_t*) output) + baseN * outK * sizeof(int32_t),
RunMultiThreadMemcpy(((uint8_t*) output) + baseN * outK * outputUnitSize,
(uint8_t*) result,
curN * outK * sizeof(int32_t), GetAlivePool());
curN * outK * outputUnitSize, GetAlivePool());
// auto st3 = std::chrono::system_clock::now();
// if (n > 0) printf("n = %d, m = %d, k = %d, input = %f s, calc = %f s, output = %f. total = %f\n", n, m, k, GetSpan(st0, st1), GetSpan(st1, st2), GetSpan(st2, st3), GetSpan(st0, st3));
}
Expand Down
73 changes: 69 additions & 4 deletions src/devices/tfacc/tfaccdevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@
#include "utils.h"

namespace fastllm {
extern FP16ToFP32Manager fp16tofp32;
extern void Float16ToFloat32(uint16_t *float16, float *float32, int len);

static TfaccClient tfaccClient;

TfaccDevice::TfaccDevice() {
this->deviceType = "tfacc";
this->ops["Linear"] = (BaseOperator *) (new TfaccLinearOp());
this->ops["CatDirect"] = (BaseOperator *) (new TfaccCatDirectOp());
/*this->ops["CatDirect"] = (BaseOperator *) (new TfaccCatDirectOp());
this->ops["Attention"] = (BaseOperator *) (new TfaccAttention());
this->ops["AttentionBatch"] = (BaseOperator *) (new TfaccAttentionBatchOp());
this->ops["CatDirectBatch"] = (BaseOperator *) (new TfaccCatDirectBatchOp());
this->ops["CatDirectBatch"] = (BaseOperator *) (new TfaccCatDirectBatchOp());*/
}

bool TfaccDevice::Malloc(void **ret, size_t size) {
Expand Down Expand Up @@ -176,15 +179,77 @@ namespace fastllm {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
} else if (weight.dataType == DataType::INT8 || weight.dataType == DataType::INT4_NOZERO || weight.dataType == DataType::INT4_GROUP) {
// auto st = std::chrono::system_clock::now();
tfaccClient.RunTfaccLinearU(n, m, k, group, groupCnt, &weight, &bias, &inputConfigs, uinput.data(), outputData, exType);
tfaccClient.RunTfaccLinearU(n, m, k, group, groupCnt, &weight, &bias, &inputConfigs, uinput.data(), outputData, exType, output.dataType);
// float spend = GetSpan(st, std::chrono::system_clock::now());
// float gops = (float)n * m * k / spend / 1e9;
// inner = spend;
// if (n > 0) printf("n = %d, m = %d, k = %d, spend %f s, gops = %f (inner)\n", n, m, k, spend, gops);
}
}
} else if (input.dataType == DataType::FLOAT16 && output.dataType == DataType::FLOAT16) {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
if (weight.dataType == DataType::FLOAT32 || weight.dataType == DataType::FLOAT16) {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
} else if (weight.dataType == DataType::INT4 ||
weight.dataType == DataType::INT4_NOZERO ||
weight.dataType == DataType::INT4_GROUP ||
weight.dataType == DataType::INT8) {
uint16_t *inputData = (uint16_t *) input.cpuData;
uint8_t *weightData = (uint8_t *) weight.cpuData;
uint16_t *outputData = (uint16_t *) output.cpuData;
float *biasData = bias.dims.size() > 0 ? (float *) bias.cpuData : nullptr;
weight.CalcWeightSum();

int group = weight.group, groupCnt = weight.groupCnt;
if (weight.dataType != DataType::INT4_GROUP) {
group = 1;
groupCnt = m;
}

int outputLen = output.Count(0);
std::vector <float> floatInputData;
floatInputData.resize(n * m);
Float16ToFloat32(inputData, floatInputData.data(), n * m);

std::vector<LowBitConfig> inputConfigs;
inputConfigs.resize(n * group);
std::vector<uint8_t> uinput;
uinput.resize(n * m);

if (n > 1) {
auto pool = GetAlivePool();
int threadNum = pool->threads.size();
int per = n / pool->threads.size();
int cur = 0;
std::vector<fastllm::MultiThreadOnlineQuantizationOp*> ops;
for (int i = 0; i < threadNum; i++) {
int end = (i == threadNum - 1 ? n : cur + per + (cur + per * (threadNum - i) < n));
ops.push_back(new MultiThreadOnlineQuantizationOp(
floatInputData.data() + cur * m, uinput.data() + cur * m, inputConfigs.data() + cur * group,
end - cur, m, group, groupCnt, nullptr, nullptr, nullptr));
cur = end;
}
for (int i = 0; i < threadNum; i++) {
pool->PushOp(i, ops[i]);
}
for (int i = 0; i < threadNum; i++) {
pool->Wait(i);
delete ops[i];
}
} else {
MultiThreadOnlineQuantizationOp(floatInputData.data(), uinput.data(), inputConfigs.data(), n, m, group, groupCnt, nullptr, nullptr, nullptr).Run();
}

if (weight.dataType == DataType::INT4) {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
} else if (weight.dataType == DataType::INT8 || weight.dataType == DataType::INT4_NOZERO || weight.dataType == DataType::INT4_GROUP) {
// auto st = std::chrono::system_clock::now();
tfaccClient.RunTfaccLinearU(n, m, k, group, groupCnt, &weight, &bias, &inputConfigs, uinput.data(), (float*)outputData, exType, output.dataType);
// float spend = GetSpan(st, std::chrono::system_clock::now());
// float gops = (float)n * m * k / spend / 1e9;
// inner = spend;
// if (n > 0) printf("n = %d, m = %d, k = %d, spend %f s, gops = %f (inner)\n", n, m, k, spend, gops);
}
}
} else {
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n");
}
Expand Down
Binary file modified third_party/tfacc/server
Binary file not shown.

0 comments on commit 29b8fa2

Please sign in to comment.