Skip to content

Commit

Permalink
add cudacat
Browse files Browse the repository at this point in the history
  • Loading branch information
黄宇扬 committed Oct 24, 2024
1 parent 2823624 commit e6a5833
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/devices/cuda/cudadevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ namespace fastllm {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaCatOp : CpuCatOp {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};

class CudaCatDirectOp : BaseOperator {
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams);
};
Expand Down
37 changes: 37 additions & 0 deletions src/devices/cuda/cudadevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ namespace fastllm {
this->ops["Linear"] = (BaseOperator*)(new CudaLinearOp());
this->ops["Conv2D"] = (BaseOperator*)(new CudaConv2DOp());
this->ops["Split"] = (BaseOperator*)(new CudaSplitOp());
this->ops["Cat"] = (BaseOperator*)(new CudaCatOp());
this->ops["CatDirect"] = (BaseOperator*)(new CudaCatDirectOp());
this->ops["MatMul"] = (BaseOperator*)(new CudaMatMulOp());
this->ops["MatMulTransB"] = (BaseOperator*)(new CudaMatMulTransBOp());
Expand Down Expand Up @@ -400,6 +401,42 @@ namespace fastllm {
(end - start) * inner * unitSize, outer);
}

void CudaCatOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input0 = *(datas.find("input0")->second);
Data &input1 = *(datas.find("input1")->second);
Data &output = *(datas.find("output")->second);

output.Allocate();

int axis = intParams.find("axis") != intParams.end() ? intParams.find("axis")->second : -1;
if (input0.dims.size() == 0 && input1.dims.size() > 0) {
output.CopyFrom(input1);
return;
}
if (input1.dims.size() == 0 && input0.dims.size() > 0) {
output.CopyFrom(input0);
return;
}

int dimsLen = input0.dims.size();
axis = (axis % dimsLen + dimsLen) % dimsLen;

int outer = output.Count(0) / output.Count(axis);
int input0Stride = input0.Count(axis);
int input1Stride = input1.Count(axis);
int outputStride = output.Count(axis);
int inner = input0.strides[axis];
int unitSize = input0.unitSize;

FastllmCudaMemcpy2DDeviceToDevice((uint8_t *) output.cudaData, outputStride * unitSize,
(uint8_t *) input0.cudaData, input0Stride * unitSize,
input0.dims[axis] * inner * unitSize, outer);
FastllmCudaMemcpy2DDeviceToDevice((uint8_t *) output.cudaData + input0.dims[axis] * inner * unitSize, outputStride * unitSize,
(uint8_t *) input1.cudaData, input1Stride * unitSize,
input1.dims[axis] * inner * unitSize, outer);
}

void CudaCatDirectOp::Run(const std::string &opType, const fastllm::DataDict &datas,
const fastllm::FloatDict &floatParams, const fastllm::IntDict &intParams) {
Data &input0 = *(datas.find("input0")->second);
Expand Down

0 comments on commit e6a5833

Please sign in to comment.