Skip to content

Commit

Permalink
优化graphllm的并行计算
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Jul 10, 2024
1 parent 8501b1a commit 7ca2153
Showing 1 changed file with 185 additions and 94 deletions.
279 changes: 185 additions & 94 deletions src/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,111 +220,202 @@ namespace fastllm {
}
} else {
int batch = seqLens.size(), total = 0;
std::vector <Data> curQs, curKs, curVs, curOutputs;
curQs.resize(batch);
curKs.resize(batch);
curVs.resize(batch);
curOutputs.resize(batch);
for (int b = 0; b < batch; b++) {
excutor.Run("Split", {
{"input", allDatas[op.datas.find("q")->second]}, {"output", &curQs[b]}
}, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}});
excutor.Run("Split", {
{"input", allDatas[op.datas.find("curk")->second]}, {"output", &curKs[b]}
}, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}});
excutor.Run("Split", {
{"input", allDatas[op.datas.find("curv")->second]}, {"output", &curVs[b]}
}, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}});
total += seqLens[b];
}
std::vector <int> axis = {0, 2, 1, 3};
Data axisData = Data(DataType::INT32PARAM, {(int)axis.size()});
axisData.Allocate();
for (int i = 0; i < axisData.Count(0); i++) {
((int32_t*)axisData.cpuData)[i] = axis[i];
}
for (int b = 0; b < batch; b++) {
excutor.Run("PermuteSelf", {
{"input", (Data*)&curQs[b]}, {"axis", &axisData}
}, {}, {});
curQs[b].Reshape({-1, curQs[b].dims[2], curQs[b].dims[3]});

excutor.Run("PermuteSelf", {
{"input", (Data*)&curKs[b]}, {"axis", &axisData}
}, {}, {});
curKs[b].Reshape({-1, curKs[b].dims[2], curKs[b].dims[3]});

excutor.Run("PermuteSelf", {
{"input", (Data*)&curVs[b]}, {"axis", &axisData}
}, {}, {});
curVs[b].Reshape({-1, curVs[b].dims[2], curVs[b].dims[3]});
bool all1 = true;
for (int i = 0; i < seqLens.size(); i++) {
if (seqLens[i] != 1) {
all1 = false;
break;
}
}

int unitLen = op.intParams.find("unitLen")->second;
for (int b = 0; b < batch; b++) {
for (int i = 0; i < 2; i++) {
auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second + "_" + std::to_string(b)];
auto cur = i == 0 ? &curKs[b] : &curVs[b];
while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1]))
|| (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) {
std::vector <int> newDims;
if (cache->Count(0) == 0 || cache->dims.size() == 0) {
newDims = std::vector <int> {cur->dims[0], ((cur->dims[1] - 1) / unitLen + 1) * unitLen, cur->dims[2]};
} else {
newDims = cache->dims;
newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen;

if (all1) {
std::vector <Data> curQs, curKs, curVs, curOutputs;
curQs.resize(batch);
curKs.resize(batch);
curVs.resize(batch);
curOutputs.resize(batch);
auto &q = *allDatas[op.datas.find("q")->second];
auto &k = *allDatas[op.datas.find("curk")->second];
auto &v = *allDatas[op.datas.find("curv")->second];

q.Reshape({-1, q.dims[2], q.dims[3]});
k.Reshape({-1, k.dims[2], k.dims[3]});
v.Reshape({-1, v.dims[2], v.dims[3]});
int embed_dim = q.dims[1] * v.dims[2];

std::vector <int> qdims = {q.dims[1], 1, q.dims[2]};
std::vector <uint64_t> qstrides = {(uint64_t)q.dims[2], (uint64_t)q.dims[2], 1};
std::vector <int> kdims = {k.dims[1], 1, k.dims[2]};
std::vector <uint64_t> kstrides = {(uint64_t)k.dims[2], (uint64_t)k.dims[2], 1};
std::vector <int> vdims = {v.dims[1], 1, v.dims[2]};
std::vector <uint64_t> vstrides = {(uint64_t)v.dims[2], (uint64_t)v.dims[2], 1};
for (int b = 0; b < batch; b++) {
curQs[b].dims = qdims;
curQs[b].strides = qstrides;
curQs[b].FakeFrom(q, b * q.strides[0] * q.unitSize);
curKs[b].dims = kdims;
curKs[b].strides = kstrides;
curKs[b].FakeFrom(k, b * k.strides[0] * k.unitSize);
curVs[b].dims = vdims;
curVs[b].strides = vstrides;
curVs[b].FakeFrom(v, b * v.strides[0] * v.unitSize);
}
total = batch;

int unitLen = op.intParams.find("unitLen")->second;
for (int i = 0; i < 2; i++) {
std::vector <Data*> caches, curs;
for (int b = 0; b < batch; b++) {
auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second + "_" + std::to_string(b)];
auto cur = i == 0 ? &curKs[b] : &curVs[b];
while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1]))
|| (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) {
std::vector <int> newDims;
if (cache->Count(0) == 0 || cache->dims.size() == 0) {
newDims = std::vector <int> {cur->dims[0], ((cur->dims[1] - 1) / unitLen + 1) * unitLen, cur->dims[2]};
} else {
newDims = cache->dims;
newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen;
}
cache->Expansion(newDims);
}
cache->Expansion(newDims);
}
excutor.Run("CatDirect", {
{"input0", cache}, {"input1", cur}
}, {}, {{"axis", 1}});
caches.push_back(cache);
curs.push_back(cur);
}
CatDirectBatch(caches, curs, 1);
}
}

for (int b = 0; b < batch; b++) {
std::string sb = "_" + std::to_string(b);
Data *k = allDatas[op.datas.find("k")->second + sb];
Data *v = allDatas[op.datas.find("v")->second + sb];
Data *mask = allDatas[op.datas.find("mask")->second + sb];
excutor.Run("Attention", {
{"q", (Data*)&curQs[b]}, {"k", k}, {"v", v},
{"mask", mask}, {"output", (Data*)&curOutputs[b]}
}, {{"scale", op.floatParams.find("scale")->second}},
{{"maskType", 0}});
}

for (int b = 0; b < batch; b++) {
std::vector <int> axis = {1, 0, 2};
auto &attenOutput = *allDatas[op.datas.find("output")->second];
attenOutput.dataType = q.dataType;
attenOutput.ToDevice(q.dataDevice);
attenOutput.Resize({1, batch, embed_dim});
attenOutput.Allocate();
std::vector <Data> curContextLayer;
std::vector <Data*> qs, keys, values, masks, contexts;
curContextLayer.resize(batch);
qs.resize(batch);
keys.resize(batch);
values.resize(batch);
masks.resize(batch);
contexts.resize(batch);

for (int b = 0; b < batch; b++) {
std::string sb = "_" + std::to_string(b);
qs[b] = (&curQs[b]);
keys[b] = allDatas[op.datas.find("k")->second + sb];
values[b] = allDatas[op.datas.find("v")->second + sb];
masks[b] = allDatas[op.datas.find("mask")->second + sb];
curContextLayer[b].FakeFrom(attenOutput, b * embed_dim * attenOutput.unitSize);
contexts[b] = (&curContextLayer[b]);
}
AttentionBatch(qs, keys, values, masks, contexts, qs[0]->dims[0] / values[0]->dims[0], op.floatParams.find("scale")->second, 1);
} else {
std::vector <Data> curQs, curKs, curVs, curOutputs;
curQs.resize(batch);
curKs.resize(batch);
curVs.resize(batch);
curOutputs.resize(batch);
for (int b = 0; b < batch; b++) {
excutor.Run("Split", {
{"input", allDatas[op.datas.find("q")->second]}, {"output", &curQs[b]}
}, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}});
excutor.Run("Split", {
{"input", allDatas[op.datas.find("curk")->second]}, {"output", &curKs[b]}
}, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}});
excutor.Run("Split", {
{"input", allDatas[op.datas.find("curv")->second]}, {"output", &curVs[b]}
}, {}, {{"axis", 1}, {"start", total}, {"end", total + seqLens[b]}});
total += seqLens[b];
}
std::vector <int> axis = {0, 2, 1, 3};
Data axisData = Data(DataType::INT32PARAM, {(int)axis.size()});
axisData.Allocate();
for (int i = 0; i < axisData.Count(0); i++) {
((int32_t*)axisData.cpuData)[i] = axis[i];
}
Data *output = (Data*)&curOutputs[b];
excutor.Run("PermuteSelf", {
{"input", output}, {"axis", &axisData}
}, {}, {});
output->Reshape({seqLens[b], 1, -1});
excutor.Run("PermuteSelf", {
{"input", output}, {"axis", &axisData}
}, {}, {});
}
for (int b = 0; b < batch; b++) {
excutor.Run("PermuteSelf", {
{"input", (Data*)&curQs[b]}, {"axis", &axisData}
}, {}, {});
curQs[b].Reshape({-1, curQs[b].dims[2], curQs[b].dims[3]});

auto lastOutput = allDatas[op.datas.find("output")->second];
for (int b = 0; b < batch; b++) {
Data *output = (Data*)&curOutputs[b];
if (b == 0) {
lastOutput->dataType = output->dataType;
std::vector <int> dims = output->dims;
dims[1] = 0;
lastOutput->Resize(dims);
dims[1] = total;
lastOutput->Expansion(dims);
excutor.Run("PermuteSelf", {
{"input", (Data*)&curKs[b]}, {"axis", &axisData}
}, {}, {});
curKs[b].Reshape({-1, curKs[b].dims[2], curKs[b].dims[3]});

excutor.Run("PermuteSelf", {
{"input", (Data*)&curVs[b]}, {"axis", &axisData}
}, {}, {});
curVs[b].Reshape({-1, curVs[b].dims[2], curVs[b].dims[3]});
}

int unitLen = op.intParams.find("unitLen")->second;
for (int b = 0; b < batch; b++) {
for (int i = 0; i < 2; i++) {
auto cache = allDatas[op.datas.find(i == 0 ? "k" : "v")->second + "_" + std::to_string(b)];
auto cur = i == 0 ? &curKs[b] : &curVs[b];
while ((cache->dims.size() == 0 && (cache->expansionDims.size() == 0 || cur->dims[1] > cache->expansionDims[1]))
|| (cache->dims.size() > 0 && cache->dims[1] + cur->dims[1] > cache->expansionDims[1])) {
std::vector <int> newDims;
if (cache->Count(0) == 0 || cache->dims.size() == 0) {
newDims = std::vector <int> {cur->dims[0], ((cur->dims[1] - 1) / unitLen + 1) * unitLen, cur->dims[2]};
} else {
newDims = cache->dims;
newDims[1] += ((cur->dims[1] - 1) / unitLen + 1) * unitLen;
}
cache->Expansion(newDims);
}
excutor.Run("CatDirect", {
{"input0", cache}, {"input1", cur}
}, {}, {{"axis", 1}});
}
}

for (int b = 0; b < batch; b++) {
std::string sb = "_" + std::to_string(b);
Data *k = allDatas[op.datas.find("k")->second + sb];
Data *v = allDatas[op.datas.find("v")->second + sb];
Data *mask = allDatas[op.datas.find("mask")->second + sb];
excutor.Run("Attention", {
{"q", (Data*)&curQs[b]}, {"k", k}, {"v", v},
{"mask", mask}, {"output", (Data*)&curOutputs[b]}
}, {{"scale", op.floatParams.find("scale")->second}},
{{"maskType", 0}});
}

for (int b = 0; b < batch; b++) {
std::vector <int> axis = {1, 0, 2};
Data axisData = Data(DataType::INT32PARAM, {(int)axis.size()});
axisData.Allocate();
for (int i = 0; i < axisData.Count(0); i++) {
((int32_t*)axisData.cpuData)[i] = axis[i];
}
Data *output = (Data*)&curOutputs[b];
excutor.Run("PermuteSelf", {
{"input", output}, {"axis", &axisData}
}, {}, {});
output->Reshape({seqLens[b], 1, -1});
excutor.Run("PermuteSelf", {
{"input", output}, {"axis", &axisData}
}, {}, {});
}

auto lastOutput = allDatas[op.datas.find("output")->second];
for (int b = 0; b < batch; b++) {
Data *output = (Data*)&curOutputs[b];
if (b == 0) {
lastOutput->dataType = output->dataType;
std::vector <int> dims = output->dims;
dims[1] = 0;
lastOutput->Resize(dims);
dims[1] = total;
lastOutput->Expansion(dims);
}
excutor.Run("CatDirect", {
{"input0", lastOutput}, {"input1", output}
}, {}, {{"axis", 1}});
}
excutor.Run("CatDirect", {
{"input0", lastOutput}, {"input1", output}
}, {}, {{"axis", 1}});
}
}
} else if (op.type == "SplitLastTokenStates") {
Expand Down

0 comments on commit 7ca2153

Please sign in to comment.