Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

upgrade to cuda9(Volta) GPU arch #118

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 51 additions & 34 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,36 @@ ENDIF()

project(ctc_release)

IF (NOT APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp -O2")
ENDIF()

IF (APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2")
add_definitions(-DAPPLE)
ENDIF()

include_directories(include)

FIND_PACKAGE(CUDA 6.5)
FIND_PACKAGE(Torch)

MESSAGE(STATUS "cuda found ${CUDA_FOUND}")
MESSAGE(STATUS "Torch found ${Torch_DIR}")

option(WITH_GPU "compile warp-ctc with cuda." ${CUDA_FOUND})
option(WITH_OMP "compile warp-ctc with openmp." ON)
option(WITH_GPU "compile warp-ctc with CUDA." ${CUDA_FOUND})
option(WITH_TORCH "compile warp-ctc with Torch." ${Torch_FOUND})
option(WITH_OMP "compile warp-ctc with OpenMP." ON)
option(BUILD_TESTS "build warp-ctc unit tests." ON)
option(BUILD_SHARED "build warp-ctc shared library." ON)

if(BUILD_SHARED)
set(WARPCTC_SHARED "SHARED")
else(BUILD_SHARED)
set(WARPCTC_SHARED "STATIC")
endif(BUILD_SHARED)

# Set c++ flags
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2")
if(APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11")
add_definitions(-DAPPLE)
endif()

if(NOT WITH_OMP)
if(WITH_OMP AND NOT APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp")
else()
add_definitions(-DCTC_DISABLE_OMP)
endif()

Expand All @@ -34,20 +46,22 @@ set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52")

IF (CUDA_VERSION GREATER 7.6)
IF (CUDA_VERSION VERSION_GREATER "7.6")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62")
ENDIF()

if (NOT APPLE)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --std=c++11")
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fopenmp")
IF ((CUDA_VERSION VERSION_GREATER "9.0") OR (CUDA_VERSION VERSION_EQUAL "9.0"))
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70")
ENDIF()

FIND_PACKAGE(Torch)

MESSAGE(STATUS "Torch found ${Torch_DIR}")
IF(NOT APPLE)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --std=c++11")
if(WITH_OMP)
set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fopenmp")
endif()
ENDIF()

IF (APPLE)
EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION)
Expand All @@ -69,18 +83,21 @@ ENDIF()
IF (WITH_GPU)

MESSAGE(STATUS "Building shared library with GPU support")
MESSAGE(STATUS "NVCC_ARCH_FLAGS" ${CUDA_NVCC_FLAGS})

CUDA_ADD_LIBRARY(warpctc SHARED src/ctc_entrypoint.cu src/reduce.cu)
IF (!Torch_FOUND)
CUDA_ADD_LIBRARY(warpctc ${WARPCTC_SHARED} src/ctc_entrypoint.cu src/reduce.cu)
IF (!WITH_TORCH)
TARGET_LINK_LIBRARIES(warpctc ${CUDA_curand_LIBRARY})
ENDIF()

add_executable(test_cpu tests/test_cpu.cpp )
TARGET_LINK_LIBRARIES(test_cpu warpctc)
SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11")
if(BUILD_TESTS)
add_executable(test_cpu tests/test_cpu.cpp )
TARGET_LINK_LIBRARIES(test_cpu warpctc)
SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11")

cuda_add_executable(test_gpu tests/test_gpu.cu)
TARGET_LINK_LIBRARIES(test_gpu warpctc ${CUDA_curand_LIBRARY})
cuda_add_executable(test_gpu tests/test_gpu.cu)
TARGET_LINK_LIBRARIES(test_gpu warpctc ${CUDA_curand_LIBRARY})
endif(BUILD_TESTS)

INSTALL(TARGETS warpctc
RUNTIME DESTINATION "bin"
Expand All @@ -89,7 +106,7 @@ IF (WITH_GPU)

INSTALL(FILES include/ctc.h DESTINATION "include")

IF (Torch_FOUND)
IF (WITH_TORCH)
MESSAGE(STATUS "Building Torch Bindings with GPU support")
INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS} "${CUDA_TOOLKIT_ROOT_DIR}/samples/common/inc")
INCLUDE_DIRECTORIES(${Torch_INSTALL_INCLUDE} ${Torch_INSTALL_INCLUDE}/TH ${Torch_INSTALL_INCLUDE}/THC)
Expand All @@ -105,26 +122,26 @@ IF (WITH_GPU)

ADD_TORCH_PACKAGE(warp_ctc "${src}" "${luasrc}")
IF (APPLE)

TARGET_LINK_LIBRARIES(warp_ctc warpctc luajit luaT THC TH ${CUDA_curand_LIBRARY})
ELSE()
TARGET_LINK_LIBRARIES(warp_ctc warpctc luajit luaT THC TH ${CUDA_curand_LIBRARY} gomp)
ENDIF()
ENDIF()


ELSE()
MESSAGE(STATUS "Building shared library with no GPU support")

if (NOT APPLE)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2")
ENDIF()

ADD_LIBRARY(warpctc SHARED src/ctc_entrypoint.cpp)
ADD_LIBRARY(warpctc ${WARPCTC_SHARED} src/ctc_entrypoint.cpp)

add_executable(test_cpu tests/test_cpu.cpp )
TARGET_LINK_LIBRARIES(test_cpu warpctc)
SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11")
if(BUILD_TESTS)
add_executable(test_cpu tests/test_cpu.cpp )
TARGET_LINK_LIBRARIES(test_cpu warpctc)
SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11")
endif(BUILD_TESTS)

INSTALL(TARGETS warpctc
RUNTIME DESTINATION "bin"
Expand All @@ -133,7 +150,7 @@ ELSE()

INSTALL(FILES include/ctc.h DESTINATION "include")

IF (Torch_FOUND)
IF (WITH_TORCH)
MESSAGE(STATUS "Building Torch Bindings with no GPU support")
add_definitions(-DTORCH_NOGPU)
INCLUDE_DIRECTORIES(${Torch_INSTALL_INCLUDE} ${Torch_INSTALL_INCLUDE}/TH)
Expand Down
6 changes: 5 additions & 1 deletion include/detail/cpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ template<typename ProbT>
void
CpuCTC<ProbT>::softmax(const ProbT* const activations, ProbT* probs,
const int* const input_lengths) {
ProbT min_T = std::numeric_limits<ProbT>::min();

#pragma omp parallel for
for (int mb = 0; mb < minibatch_; ++mb) {
for(int c = 0; c < input_lengths[mb]; ++c) {
Expand All @@ -179,6 +181,9 @@ CpuCTC<ProbT>::softmax(const ProbT* const activations, ProbT* probs,

for(int r = 0; r < alphabet_size_; ++r) {
probs[r + col_offset] /= denom;
if (probs[r + col_offset] < min_T) {
probs[r + col_offset] = min_T;
}
}
}
}
Expand Down Expand Up @@ -226,7 +231,6 @@ ProbT CpuCTC<ProbT>::compute_alphas(const ProbT* probs, int repeats, int S, int
const int* const s_inc,
const int* const labels,
ProbT* alphas) {

int start = (((S /2) + repeats - T) < 0) ? 0 : 1,
end = S > 1 ? 2 : 1;

Expand Down
3 changes: 3 additions & 0 deletions include/detail/gpu_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,9 @@ GpuCTC<ProbT>::compute_probs(const ProbT* const activations) {
(ctc_helper::exponential<ProbT>(), probs_,
denoms_, out_dim_, num_elements);

truncate_probs_kernel<ProbT, VT><<<grid_size, NT, 0, stream_>>>
(probs_, num_elements);

return CTC_STATUS_SUCCESS;
}

Expand Down
21 changes: 19 additions & 2 deletions include/detail/gpu_ctc_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ template<typename ProbT, int NT, int VT>
__global__
void compute_alpha_kernel (const ProbT* probs, const int *label_sizes,
const int *utt_length, const int *repeats_in_labels,
const int *labels_without_blanks, const int *label_offsets,
int *labels_with_blanks, ProbT *alphas,
const int *labels_without_blanks, const int *label_offsets,
int *labels_with_blanks, ProbT *alphas,
ProbT* nll_forward, int stride, int out_dim,
int S_memoffset, int T_memoffset, int blank_label) {

Expand Down Expand Up @@ -469,6 +469,23 @@ __global__ void compute_probs_kernel(Op f, ProbT* probs,
}
}

template <typename ProbT, int VT = 1>
__global__ void truncate_probs_kernel(ProbT* probs, int count) {

int idx = blockDim.x * blockIdx.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
ProbT min_T = numeric_limits<ProbT>::min();
#pragma unroll
for(int i = 0; i < VT; i++) {
if (idx < count) {
if (min_T > probs[idx]) {
probs[idx] = min_T;
}
}
idx += stride;
}
}

template <typename ProbT, int VT = 1, typename Op>
__global__ void prepare_stable_SM_kernel(Op f, ProbT* probs,
const ProbT* const col_max,
Expand Down
17 changes: 17 additions & 0 deletions include/detail/hostdevice.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,20 @@
#else
#define HOSTDEVICE
#endif

// NOTE(dzhwinter)
// the warp primitive is different in cuda9(Volta) GPU.
// add a wrapper to compatible with cuda7 to cuda9
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900
#define DEFAULT_MASK 0u
template<typename T>
__forceinline__ __device__ T __shfl_down(T input, int delta) {
return __shfl_down_sync(DEFAULT_MASK, input, delta);
}

template<typename T>
__forceinline__ __device__ T __shfl_up(T input, int delta) {
return __shfl_up_sync(DEFAULT_MASK, input, delta);
}

#endif
1 change: 0 additions & 1 deletion src/ctc_entrypoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ ctcStatus_t compute_ctc_loss(const float* const activations,
float *costs,
void *workspace,
ctcOptions options) {

if (activations == nullptr ||
flat_labels == nullptr ||
label_lengths == nullptr ||
Expand Down