Skip to content

Commit

Permalink
Merge pull request #379 from wangyumu/iluvatar-gpu-support
Browse files Browse the repository at this point in the history
支持国产GPU卡
  • Loading branch information
ztxz16 authored Dec 26, 2023
2 parents 82a3035 + 734afd7 commit a03448d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 14 deletions.
10 changes: 10 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,21 @@ project(fastllm LANGUAGES CXX)
option(USE_CUDA "use cuda" OFF)

option(PY_API "python api" OFF)

option(USE_MMAP "use mmap" OFF)

option(USE_SENTENCEPIECE "use sentencepiece" OFF)

option(USE_IVCOREX "use iluvatar corex gpu" OFF)

message(STATUS "USE_CUDA: ${USE_CUDA}")

message(STATUS "PYTHON_API: ${PY_API}")

message(STATUS "USE_SENTENCEPIECE: ${USE_SENTENCEPIECE}")

message(STATUS "USE_IVCOREX: ${USE_IVCOREX}")

set(CMAKE_BUILD_TYPE "Release")

if (CMAKE_CXX_COMPILER_ID STREQUAL "Clang")
Expand Down Expand Up @@ -55,6 +60,11 @@ if (USE_CUDA)
set(CMAKE_CUDA_ARCHITECTURES "native")
endif()

if (USE_IVCOREX)
set(FASTLLM_LINKED_LIBS ${FASTLLM_LINKED_LIBS} cudart)
set(CMAKE_CUDA_ARCHITECTURES ${IVCOREX_ARCH})
endif()

if (PY_API)
set(PYBIND third_party/pybind11)
add_subdirectory(${PYBIND})
Expand Down
2 changes: 1 addition & 1 deletion include/fastllm.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ namespace fastllm {
uint8_t quantization(const float &realNumber) const {
if (type == 0) {
return (uint8_t) (std::min((double) ((1 << bit) - 1),
std::max(realNumber / scale + zeroPoint + 0.5, 0.0)));
(double) std::max(realNumber / scale + zeroPoint + 0.5, 0.0)));
} else {
return (uint8_t) (std::max(0.f, std::min(15.f, (realNumber - min) / scale + 0.5f)));
}
Expand Down
26 changes: 13 additions & 13 deletions src/devices/cuda/fastllm-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ bool FastllmCudaMatMulFloatInt8(const fastllm::Data &input, fastllm::Data &weigh
cublasStatus_t status;

int len = n * m;
int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaFp16Input, len);

len = k * m;
Expand Down Expand Up @@ -1300,7 +1300,7 @@ bool FastllmCudaMatMulFloatInt4NoZero(const fastllm::Data &input, fastllm::Data
cublasStatus_t status;

int len = n * m;
int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaFp16Input,
len);

Expand Down Expand Up @@ -1428,7 +1428,7 @@ bool FastllmCudaMatMulFloat16(const fastllm::Data &input, fastllm::Data &weight,
cublasStatus_t status;

int len = n * m;
int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmCudaFloat2HalfKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaFp16Input,
len);

Expand Down Expand Up @@ -1716,7 +1716,7 @@ bool FastllmCudaGeluNew(const fastllm::Data &input, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmGeluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len);
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
Expand All @@ -1727,7 +1727,7 @@ bool FastllmCudaSilu(const fastllm::Data &input, fastllm::Data &output) {
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmSiluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len);
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
Expand All @@ -1740,7 +1740,7 @@ bool FastllmCudaSwiglu(const fastllm::Data &input, fastllm::Data &output) {
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int spatial = input.Count(input.dims.size() - 1), mid = spatial / 2;

int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmSwigluKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, len, spatial, mid);

FastllmCudaFinishInput(input, cudaInput);
Expand All @@ -1752,7 +1752,7 @@ bool FastllmCudaMul(const fastllm::Data &input, float v, fastllm::Data &output)
int len = input.Count(0);
float *cudaInput = (float *) FastllmCudaPrepareInput(input);
float *cudaOutput = (float *) FastllmCudaPrepareOutput(output);
int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmMulKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaInput, cudaOutput, v, len);
FastllmCudaFinishInput(input, cudaInput);
FastllmCudaFinishOutput(output, cudaOutput);
Expand All @@ -1764,7 +1764,7 @@ bool FastllmCudaAddTo(fastllm::Data &input0, const fastllm::Data &input1, float
float *cudaData = (float *) FastllmCudaPrepareInput(input0);
float *input1Data = (float *) FastllmCudaPrepareInput(input1);

int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmAddToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len);
FastllmCudaFinishInput(input1, input1Data);
FastllmCudaFinishOutput(input0, cudaData);
Expand All @@ -1776,7 +1776,7 @@ bool FastllmCudaMulTo(fastllm::Data &input0, const fastllm::Data &input1, float
float *cudaData = (float *) FastllmCudaPrepareInput(input0);
float *input1Data = (float *) FastllmCudaPrepareInput(input1);

int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmMulToKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock>>>(cudaData, input1Data, alpha, len);
FastllmCudaFinishInput(input1, input1Data);
FastllmCudaFinishOutput(input0, cudaData);
Expand Down Expand Up @@ -2005,7 +2005,7 @@ bool FastllmCudaPermute(fastllm::Data &input, const std::vector<int> &axis) {

int *cudaTemp = (int *) FastllmCudaMalloc(temp.size() * sizeof(int));
cudaMemcpy(cudaTemp, temp.data(), temp.size() * sizeof(int), cudaMemcpyHostToDevice);
int threadPerBlock = min(256, len);
int threadPerBlock = std::min(256, len);
FastllmPermuteKernel <<< (len - 1) / threadPerBlock + 1, threadPerBlock >>>((float *) input.cudaData,
tempData, cudaTemp,
(int) axis.size(), len);
Expand Down Expand Up @@ -2228,7 +2228,7 @@ bool FastllmCudaRotatePosition2D(fastllm::Data &data, const fastllm::Data &posit
int spatial = data.Count(2);
int len = data.dims[0], bs = data.dims[1];
int n = data.dims[2], m = data.dims[3];
FastllmRotatePosition2DKernel <<< outer * 2 * n, min(rotaryDim, m / 4) >>> (cudaData, cudaPositionIds, cudaSin, cudaCos,
FastllmRotatePosition2DKernel <<< outer * 2 * n, std::min(rotaryDim, m / 4) >>> (cudaData, cudaPositionIds, cudaSin, cudaCos,
len, bs, spatial, n, m,
(int)positionIds.dims.back(), (int)sinData.dims[1], rotaryDim);

Expand All @@ -2251,7 +2251,7 @@ bool FastllmCudaNearlyRotatePosition2D(fastllm::Data &data, const fastllm::Data
int spatial = data.Count(2);
int len = data.dims[0], bs = data.dims[1];
int n = data.dims[2], m = data.dims[3];
FastllmNearlyRotatePosition2DKernel <<< outer * n, min(rotaryDim, m / 4) >>> (cudaData, cudaPositionIds, cudaSin, cudaCos,
FastllmNearlyRotatePosition2DKernel <<< outer * n, std::min(rotaryDim, m / 4) >>> (cudaData, cudaPositionIds, cudaSin, cudaCos,
len, bs, spatial, n, m,
(int)positionIds.dims.back(), (int)sinData.dims[1], rotaryDim);

Expand All @@ -2273,7 +2273,7 @@ bool FastllmCudaLlamaRotatePosition2D(fastllm::Data &data, const fastllm::Data &
int spatial = data.Count(2);
int bs = data.dims[0], len = data.dims[1];
int n = data.dims[2], m = data.dims[3];
FastllmLlamaRotatePosition2DKernel <<< outer * n, min(rotaryDim, m / 2) >>> (cudaData, cudaPositionIds, cudaSin, cudaCos,
FastllmLlamaRotatePosition2DKernel <<< outer * n, std::min(rotaryDim, m / 2) >>> (cudaData, cudaPositionIds, cudaSin, cudaCos,
len, bs, spatial, n, m,
(int)positionIds.dims.back(), (int)sinData.dims[1], rotaryDim);

Expand Down

0 comments on commit a03448d

Please sign in to comment.