Skip to content

Commit

Permalink
Add ROCm support
Browse files Browse the repository at this point in the history
  • Loading branch information
marella committed Aug 28, 2023
1 parent a71db2c commit 163709e
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 2 deletions.
40 changes: 39 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,18 @@ set(CT_INSTRUCTIONS "avx2" CACHE STRING "avx2 | avx | basic")

option(CT_CUBLAS "Use cuBLAS" OFF)
option(CT_CUDA_FORCE_DMMV "use dmmv instead of mmvq CUDA kernels" OFF)
option(CT_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
option(CT_CUDA_F16 "use 16 bit floats for some calculations" OFF)
set(CT_CUDA_MMQ_Y "64" CACHE STRING "y tile size for mmq CUDA kernels")
set(CT_CUDA_DMMV_X "32" CACHE STRING "x stride for dmmv CUDA kernels")
set(CT_CUDA_MMV_Y "1" CACHE STRING "y block size for mmv CUDA kernels")
set(CT_CUDA_KQUANTS_ITER "2" CACHE STRING "iters/thread per block for Q2_K/Q6_K")

option(CT_HIPBLAS "Use hipBLAS" OFF)
option(CT_METAL "Use Metal" OFF)

message(STATUS "CT_INSTRUCTIONS: ${CT_INSTRUCTIONS}")
message(STATUS "CT_CUBLAS: ${CT_CUBLAS}")
message(STATUS "CT_HIPBLAS: ${CT_HIPBLAS}")
message(STATUS "CT_METAL: ${CT_METAL}")

set(BUILD_SHARED_LIBS ON)
Expand Down Expand Up @@ -162,6 +164,42 @@ if (CT_CUBLAS)
endif()
endif()

if (CT_HIPBLAS)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)

if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
endif()
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
endif()

find_package(hip)
find_package(hipblas)
find_package(rocblas)

if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")
set_source_files_properties(models/ggml/ggml-cuda.cu PROPERTIES LANGUAGE CXX)

target_sources(ctransformers PRIVATE models/ggml/ggml-cuda.cu)
target_link_libraries(ctransformers PRIVATE hip::device hip::host roc::rocblas roc::hipblas)

target_compile_definitions(ctransformers PRIVATE GGML_USE_HIPBLAS)
target_compile_definitions(ctransformers PRIVATE GGML_USE_CUBLAS)
target_compile_definitions(ctransformers PRIVATE GGML_CUDA_MMQ_Y=${CT_CUDA_MMQ_Y})
target_compile_definitions(ctransformers PRIVATE GGML_CUDA_DMMV_X=${CT_CUDA_DMMV_X})
target_compile_definitions(ctransformers PRIVATE GGML_CUDA_MMV_Y=${CT_CUDA_MMV_Y})
target_compile_definitions(ctransformers PRIVATE K_QUANTS_PER_ITERATION=${CT_CUDA_KQUANTS_ITER})
target_compile_definitions(ctransformers PRIVATE CC_TURING=1000000000)
if (CT_CUDA_FORCE_DMMV)
target_compile_definitions(ctransformers PRIVATE GGML_CUDA_FORCE_DMMV)
endif()
else()
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
endif()
endif()

if (CT_METAL)
find_library(FOUNDATION_LIBRARY Foundation REQUIRED)
find_library(METAL_FRAMEWORK Metal REQUIRED)
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,14 @@ Install CUDA libraries using:
pip install ctransformers[cuda]
```

#### ROCm

To enable ROCm support, install the `ctransformers` package using:

```sh
CT_HIPBLAS=1 pip install ctransformers --no-binary ctransformers
```

#### Metal

To enable Metal support, install the `ctransformers` package using:
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from skbuild import setup

cmake_args = []
for key in ["CT_INSTRUCTIONS", "CT_CUBLAS", "CT_METAL"]:
for key in ["CT_INSTRUCTIONS", "CT_CUBLAS", "CT_HIPBLAS", "CT_METAL"]:
value = os.environ.get(key)
if value:
cmake_args.append(f"-D{key}={value}")
Expand Down

0 comments on commit 163709e

Please sign in to comment.