From 163709ea970a294c8f9d8e11ff11c65967a02c92 Mon Sep 17 00:00:00 2001 From: Ravindra Marella Date: Tue, 29 Aug 2023 05:06:24 +0530 Subject: [PATCH] Add ROCm support --- CMakeLists.txt | 40 +++++++++++++++++++++++++++++++++++++++- README.md | 8 ++++++++ setup.py | 2 +- 3 files changed, 48 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 307476f..d7e3364 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,16 +5,18 @@ set(CT_INSTRUCTIONS "avx2" CACHE STRING "avx2 | avx | basic") option(CT_CUBLAS "Use cuBLAS" OFF) option(CT_CUDA_FORCE_DMMV "use dmmv instead of mmvq CUDA kernels" OFF) -option(CT_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) +option(CT_CUDA_F16 "use 16 bit floats for some calculations" OFF) set(CT_CUDA_MMQ_Y "64" CACHE STRING "y tile size for mmq CUDA kernels") set(CT_CUDA_DMMV_X "32" CACHE STRING "x stride for dmmv CUDA kernels") set(CT_CUDA_MMV_Y "1" CACHE STRING "y block size for mmv CUDA kernels") set(CT_CUDA_KQUANTS_ITER "2" CACHE STRING "iters/thread per block for Q2_K/Q6_K") +option(CT_HIPBLAS "Use hipBLAS" OFF) option(CT_METAL "Use Metal" OFF) message(STATUS "CT_INSTRUCTIONS: ${CT_INSTRUCTIONS}") message(STATUS "CT_CUBLAS: ${CT_CUBLAS}") +message(STATUS "CT_HIPBLAS: ${CT_HIPBLAS}") message(STATUS "CT_METAL: ${CT_METAL}") set(BUILD_SHARED_LIBS ON) @@ -162,6 +164,42 @@ if (CT_CUBLAS) endif() endif() +if (CT_HIPBLAS) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) + + if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") + endif() + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + find_package(hip) + find_package(hipblas) + find_package(rocblas) + + if (${hipblas_FOUND} AND ${hip_FOUND}) + message(STATUS "HIP and hipBLAS found") + set_source_files_properties(models/ggml/ggml-cuda.cu PROPERTIES LANGUAGE CXX) + + target_sources(ctransformers PRIVATE models/ggml/ggml-cuda.cu) + target_link_libraries(ctransformers PRIVATE hip::device hip::host roc::rocblas roc::hipblas) + + target_compile_definitions(ctransformers PRIVATE GGML_USE_HIPBLAS) + target_compile_definitions(ctransformers PRIVATE GGML_USE_CUBLAS) + target_compile_definitions(ctransformers PRIVATE GGML_CUDA_MMQ_Y=${CT_CUDA_MMQ_Y}) + target_compile_definitions(ctransformers PRIVATE GGML_CUDA_DMMV_X=${CT_CUDA_DMMV_X}) + target_compile_definitions(ctransformers PRIVATE GGML_CUDA_MMV_Y=${CT_CUDA_MMV_Y}) + target_compile_definitions(ctransformers PRIVATE K_QUANTS_PER_ITERATION=${CT_CUDA_KQUANTS_ITER}) + target_compile_definitions(ctransformers PRIVATE CC_TURING=1000000000) + if (CT_CUDA_FORCE_DMMV) + target_compile_definitions(ctransformers PRIVATE GGML_CUDA_FORCE_DMMV) + endif() + else() + message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") + endif() +endif() + if (CT_METAL) find_library(FOUNDATION_LIBRARY Foundation REQUIRED) find_library(METAL_FRAMEWORK Metal REQUIRED) diff --git a/README.md b/README.md index 5d9e1b2..3754a55 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,14 @@ Install CUDA libraries using: pip install ctransformers[cuda] ``` +#### ROCm + +To enable ROCm support, install the `ctransformers` package using: + +```sh +CT_HIPBLAS=1 pip install ctransformers --no-binary ctransformers +``` + #### Metal To enable Metal support, install the `ctransformers` package using: diff --git a/setup.py b/setup.py index e495bb3..ad5726c 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,7 @@ from skbuild import setup cmake_args = [] - for key in ["CT_INSTRUCTIONS", "CT_CUBLAS", "CT_METAL"]: + for key in ["CT_INSTRUCTIONS", "CT_CUBLAS", "CT_HIPBLAS", "CT_METAL"]: value = os.environ.get(key) if value: cmake_args.append(f"-D{key}={value}")