-
Notifications
You must be signed in to change notification settings - Fork 347
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
578 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
// | ||
// Created by huangyuyang on 8/2/24. | ||
// | ||
|
||
#include "fastllm.h" | ||
|
||
std::vector <long long> FastllmCudaGetFreeSizes(); | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
void FastllmMultiCudaSetDevice(std::vector <int> ids); | ||
|
||
bool FastllmMultiCudaHalfMatMul(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k); | ||
bool FastllmMultiCudaMatMul(const fastllm::Data &input, fastllm::Data &weight, const fastllm::Data &bias, fastllm::Data &output, int n, int m, int k); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
// | ||
// Created by huangyuyang on 8/2/24. | ||
// | ||
|
||
#ifndef FASTLLM_MULTICUDADEVICE_H | ||
#define FASTLLM_MULTICUDADEVICE_H | ||
|
||
#include "device.h" | ||
|
||
namespace fastllm { | ||
class MultiCudaDevice : BaseDevice { | ||
public: | ||
MultiCudaDevice (); | ||
|
||
bool Malloc (void **ret, size_t size); // 分配尺寸为size的空间 | ||
bool Free(void *ret); // 释放ret | ||
|
||
bool CopyDataToCPU(void *dst, void *src, size_t size); | ||
bool CopyDataFromCPU(void *dst, void *src, size_t size); | ||
}; | ||
|
||
class MultiCudaLinearOp : CudaLinearOp { | ||
void Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams); | ||
}; | ||
} | ||
|
||
#endif //FASTLLM_MULTICUDADEVICE_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
// | ||
// Created by huangyuyang on 8/2/24. | ||
// | ||
|
||
#include "devices/cpu/cpudevice.h" | ||
#include "devices/cuda/cudadevice.h" | ||
#include "devices/cuda/fastllm-cuda.cuh" | ||
#include "devices/multicuda/multicudadevice.h" | ||
|
||
#include "fastllm-multicuda.cuh" | ||
|
||
#include "utils.h" | ||
|
||
namespace fastllm { | ||
MultiCudaDevice::MultiCudaDevice() { | ||
this->deviceType = "multicuda"; | ||
|
||
this->ops["Linear"] = (BaseOperator*)(new MultiCudaLinearOp()); | ||
} | ||
|
||
bool MultiCudaDevice::Malloc(void **ret, size_t size) { | ||
*ret = FastllmCudaMalloc(size); | ||
return true; | ||
} | ||
|
||
bool MultiCudaDevice::Free(void *ret) { | ||
FastllmCudaFree(ret); | ||
return true; | ||
} | ||
|
||
bool MultiCudaDevice::CopyDataFromCPU(void *dst, void *src, size_t size) { | ||
FastllmCudaCopyFromHostToDevice(dst, src, size); | ||
return true; | ||
} | ||
|
||
bool MultiCudaDevice::CopyDataToCPU(void *dst, void *src, size_t size) { | ||
FastllmCudaCopyFromDeviceToHost(dst, src, size); | ||
return true; | ||
} | ||
|
||
void MultiCudaLinearOp::Run(const std::string &opType, const DataDict &datas, const FloatDict &floatParams, const IntDict &intParams) { | ||
// auto st = std::chrono::system_clock::now(); | ||
Data &input = *(datas.find("input")->second); | ||
Data &output = *(datas.find("output")->second); | ||
Data &weight = *(datas.find("weight")->second); | ||
Data &bias = *(datas.find("bias")->second); | ||
|
||
output.Allocate(); | ||
int n = input.Count(0) / input.dims.back(); | ||
int m = input.dims.back(); | ||
int k = output.dims.back(); | ||
|
||
if (input.dataType == DataType::FLOAT16) { | ||
if (weight.dataType == DataType::FLOAT16 || | ||
weight.dataType == DataType::INT8 || | ||
weight.dataType == DataType::INT4_NOZERO || | ||
weight.dataType == DataType::INT4_GROUP) { | ||
FastllmMultiCudaHalfMatMul(input, weight, bias, output, n, m, k); | ||
} else { | ||
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); | ||
} | ||
} else if (input.dataType == DataType::FLOAT32) { | ||
if (weight.dataType == DataType::FLOAT32) { | ||
FastllmCudaMatMulFloat32(input, weight, bias, output, n, m, k); | ||
} else if (weight.dataType == DataType::FLOAT16 || | ||
weight.dataType == DataType::INT8 || | ||
weight.dataType == DataType::INT4_NOZERO || | ||
weight.dataType == DataType::INT4_GROUP) { | ||
FastllmMultiCudaMatMul(input, weight, bias, output, n, m, k); | ||
} else if (weight.dataType == DataType::INT4) { | ||
FastllmCudaMatMulFloatInt4(input, weight, bias, output, n, m, k); | ||
} else { | ||
ErrorInFastLLM("Linear error: unsupport weight's dataType.\n"); | ||
} | ||
} else { | ||
ErrorInFastLLM("Linear error: unsupport input's dataType.\n"); | ||
} | ||
// float spend = GetSpan(st, std::chrono::system_clock::now()); | ||
// float gops = (float)n * m * k / spend / 1e9; | ||
// printf("n = %d, m = %d, k = %d, spend %f s, gops = %f\n", n, m, k, spend, gops); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters