From c7741e3bfd07ae1097fcfcb163994440da11dc43 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Fri, 27 Apr 2018 20:47:12 -0700 Subject: [PATCH] cherry picked commits --- CMakeLists.txt | 85 +++++++++++++++++++------------- include/detail/cpu_ctc.h | 6 ++- include/detail/gpu_ctc.h | 3 ++ include/detail/gpu_ctc_kernels.h | 21 +++++++- include/detail/hostdevice.h | 17 +++++++ src/ctc_entrypoint.cpp | 1 - 6 files changed, 95 insertions(+), 38 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cdb4b3e..cf582d7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,24 +6,36 @@ ENDIF() project(ctc_release) -IF (NOT APPLE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp -O2") -ENDIF() - -IF (APPLE) - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2") - add_definitions(-DAPPLE) -ENDIF() - include_directories(include) FIND_PACKAGE(CUDA 6.5) +FIND_PACKAGE(Torch) + MESSAGE(STATUS "cuda found ${CUDA_FOUND}") +MESSAGE(STATUS "Torch found ${Torch_DIR}") -option(WITH_GPU "compile warp-ctc with cuda." ${CUDA_FOUND}) -option(WITH_OMP "compile warp-ctc with openmp." ON) +option(WITH_GPU "compile warp-ctc with CUDA." ${CUDA_FOUND}) +option(WITH_TORCH "compile warp-ctc with Torch." ${Torch_FOUND}) +option(WITH_OMP "compile warp-ctc with OpenMP." ON) +option(BUILD_TESTS "build warp-ctc unit tests." ON) +option(BUILD_SHARED "build warp-ctc shared library." ON) + +if(BUILD_SHARED) + set(WARPCTC_SHARED "SHARED") +else(BUILD_SHARED) + set(WARPCTC_SHARED "STATIC") +endif(BUILD_SHARED) + +# Set c++ flags +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O2") +if(APPLE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11") + add_definitions(-DAPPLE) +endif() -if(NOT WITH_OMP) +if(WITH_OMP AND NOT APPLE) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fopenmp") +else() add_definitions(-DCTC_DISABLE_OMP) endif() @@ -34,20 +46,22 @@ set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_35,code=sm_35") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_50,code=sm_50") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_52,code=sm_52") -IF (CUDA_VERSION GREATER 7.6) +IF (CUDA_VERSION VERSION_GREATER "7.6") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_60,code=sm_60") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_61,code=sm_61") set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_62,code=sm_62") ENDIF() -if (NOT APPLE) - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --std=c++11") - set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fopenmp") +IF ((CUDA_VERSION VERSION_GREATER "9.0") OR (CUDA_VERSION VERSION_EQUAL "9.0")) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -gencode arch=compute_70,code=sm_70") ENDIF() -FIND_PACKAGE(Torch) - -MESSAGE(STATUS "Torch found ${Torch_DIR}") +IF(NOT APPLE) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} --std=c++11") + if(WITH_OMP) + set(CUDA_NVCC_FLAGS "${CUDA_NVCC_FLAGS} -Xcompiler -fopenmp") + endif() +ENDIF() IF (APPLE) EXEC_PROGRAM(uname ARGS -v OUTPUT_VARIABLE DARWIN_VERSION) @@ -69,18 +83,21 @@ ENDIF() IF (WITH_GPU) MESSAGE(STATUS "Building shared library with GPU support") + MESSAGE(STATUS "NVCC_ARCH_FLAGS" ${CUDA_NVCC_FLAGS}) - CUDA_ADD_LIBRARY(warpctc SHARED src/ctc_entrypoint.cu src/reduce.cu) - IF (!Torch_FOUND) + CUDA_ADD_LIBRARY(warpctc ${WARPCTC_SHARED} src/ctc_entrypoint.cu src/reduce.cu) + IF (!WITH_TORCH) TARGET_LINK_LIBRARIES(warpctc ${CUDA_curand_LIBRARY}) ENDIF() - add_executable(test_cpu tests/test_cpu.cpp ) - TARGET_LINK_LIBRARIES(test_cpu warpctc) - SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") + if(BUILD_TESTS) + add_executable(test_cpu tests/test_cpu.cpp ) + TARGET_LINK_LIBRARIES(test_cpu warpctc) + SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") - cuda_add_executable(test_gpu tests/test_gpu.cu) - TARGET_LINK_LIBRARIES(test_gpu warpctc ${CUDA_curand_LIBRARY}) + cuda_add_executable(test_gpu tests/test_gpu.cu) + TARGET_LINK_LIBRARIES(test_gpu warpctc ${CUDA_curand_LIBRARY}) + endif(BUILD_TESTS) INSTALL(TARGETS warpctc RUNTIME DESTINATION "bin" @@ -89,7 +106,7 @@ IF (WITH_GPU) INSTALL(FILES include/ctc.h DESTINATION "include") - IF (Torch_FOUND) + IF (WITH_TORCH) MESSAGE(STATUS "Building Torch Bindings with GPU support") INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS} "${CUDA_TOOLKIT_ROOT_DIR}/samples/common/inc") INCLUDE_DIRECTORIES(${Torch_INSTALL_INCLUDE} ${Torch_INSTALL_INCLUDE}/TH ${Torch_INSTALL_INCLUDE}/THC) @@ -105,14 +122,12 @@ IF (WITH_GPU) ADD_TORCH_PACKAGE(warp_ctc "${src}" "${luasrc}") IF (APPLE) - TARGET_LINK_LIBRARIES(warp_ctc warpctc luajit luaT THC TH ${CUDA_curand_LIBRARY}) ELSE() TARGET_LINK_LIBRARIES(warp_ctc warpctc luajit luaT THC TH ${CUDA_curand_LIBRARY} gomp) ENDIF() ENDIF() - ELSE() MESSAGE(STATUS "Building shared library with no GPU support") @@ -120,11 +135,13 @@ ELSE() set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -O2") ENDIF() - ADD_LIBRARY(warpctc SHARED src/ctc_entrypoint.cpp) + ADD_LIBRARY(warpctc ${WARPCTC_SHARED} src/ctc_entrypoint.cpp) - add_executable(test_cpu tests/test_cpu.cpp ) - TARGET_LINK_LIBRARIES(test_cpu warpctc) - SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") + if(BUILD_TESTS) + add_executable(test_cpu tests/test_cpu.cpp ) + TARGET_LINK_LIBRARIES(test_cpu warpctc) + SET_TARGET_PROPERTIES(test_cpu PROPERTIES COMPILE_FLAGS "${CMAKE_CXX_FLAGS} --std=c++11") + endif(BUILD_TESTS) INSTALL(TARGETS warpctc RUNTIME DESTINATION "bin" @@ -133,7 +150,7 @@ ELSE() INSTALL(FILES include/ctc.h DESTINATION "include") - IF (Torch_FOUND) + IF (WITH_TORCH) MESSAGE(STATUS "Building Torch Bindings with no GPU support") add_definitions(-DTORCH_NOGPU) INCLUDE_DIRECTORIES(${Torch_INSTALL_INCLUDE} ${Torch_INSTALL_INCLUDE}/TH) diff --git a/include/detail/cpu_ctc.h b/include/detail/cpu_ctc.h index 8aae3a6..08621d6 100644 --- a/include/detail/cpu_ctc.h +++ b/include/detail/cpu_ctc.h @@ -163,6 +163,8 @@ template void CpuCTC::softmax(const ProbT* const activations, ProbT* probs, const int* const input_lengths) { + ProbT min_T = std::numeric_limits::min(); + #pragma omp parallel for for (int mb = 0; mb < minibatch_; ++mb) { for(int c = 0; c < input_lengths[mb]; ++c) { @@ -179,6 +181,9 @@ CpuCTC::softmax(const ProbT* const activations, ProbT* probs, for(int r = 0; r < alphabet_size_; ++r) { probs[r + col_offset] /= denom; + if (probs[r + col_offset] < min_T) { + probs[r + col_offset] = min_T; + } } } } @@ -226,7 +231,6 @@ ProbT CpuCTC::compute_alphas(const ProbT* probs, int repeats, int S, int const int* const s_inc, const int* const labels, ProbT* alphas) { - int start = (((S /2) + repeats - T) < 0) ? 0 : 1, end = S > 1 ? 2 : 1; diff --git a/include/detail/gpu_ctc.h b/include/detail/gpu_ctc.h index 0f1d239..2149d99 100644 --- a/include/detail/gpu_ctc.h +++ b/include/detail/gpu_ctc.h @@ -395,6 +395,9 @@ GpuCTC::compute_probs(const ProbT* const activations) { (ctc_helper::exponential(), probs_, denoms_, out_dim_, num_elements); + truncate_probs_kernel<<>> + (probs_, num_elements); + return CTC_STATUS_SUCCESS; } diff --git a/include/detail/gpu_ctc_kernels.h b/include/detail/gpu_ctc_kernels.h index cf6dba9..07412d0 100644 --- a/include/detail/gpu_ctc_kernels.h +++ b/include/detail/gpu_ctc_kernels.h @@ -88,8 +88,8 @@ template __global__ void compute_alpha_kernel (const ProbT* probs, const int *label_sizes, const int *utt_length, const int *repeats_in_labels, - const int *labels_without_blanks, const int *label_offsets, - int *labels_with_blanks, ProbT *alphas, + const int *labels_without_blanks, const int *label_offsets, + int *labels_with_blanks, ProbT *alphas, ProbT* nll_forward, int stride, int out_dim, int S_memoffset, int T_memoffset, int blank_label) { @@ -469,6 +469,23 @@ __global__ void compute_probs_kernel(Op f, ProbT* probs, } } +template +__global__ void truncate_probs_kernel(ProbT* probs, int count) { + + int idx = blockDim.x * blockIdx.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + ProbT min_T = numeric_limits::min(); +#pragma unroll + for(int i = 0; i < VT; i++) { + if (idx < count) { + if (min_T > probs[idx]) { + probs[idx] = min_T; + } + } + idx += stride; + } +} + template __global__ void prepare_stable_SM_kernel(Op f, ProbT* probs, const ProbT* const col_max, diff --git a/include/detail/hostdevice.h b/include/detail/hostdevice.h index 7bec1e0..3bc318c 100644 --- a/include/detail/hostdevice.h +++ b/include/detail/hostdevice.h @@ -5,3 +5,20 @@ #else #define HOSTDEVICE #endif + +// NOTE(dzhwinter) +// the warp primitive is different in cuda9(Volta) GPU. +// add a wrapper to compatible with cuda7 to cuda9 +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 +#define DEFAULT_MASK 0u +template +__forceinline__ __device__ T __shfl_down(T input, int delta) { + return __shfl_down_sync(DEFAULT_MASK, input, delta); +} + +template +__forceinline__ __device__ T __shfl_up(T input, int delta) { + return __shfl_up_sync(DEFAULT_MASK, input, delta); +} + +#endif diff --git a/src/ctc_entrypoint.cpp b/src/ctc_entrypoint.cpp index a68ef84..e1476d8 100644 --- a/src/ctc_entrypoint.cpp +++ b/src/ctc_entrypoint.cpp @@ -46,7 +46,6 @@ ctcStatus_t compute_ctc_loss(const float* const activations, float *costs, void *workspace, ctcOptions options) { - if (activations == nullptr || flat_labels == nullptr || label_lengths == nullptr ||