diff --git a/pytorch_binding/setup.py b/pytorch_binding/setup.py index a63c3ea..9a97b87 100644 --- a/pytorch_binding/setup.py +++ b/pytorch_binding/setup.py @@ -7,7 +7,8 @@ from torch.utils.ffi import create_extension import torch -extra_compile_args = ['-std=c++11', '-fPIC'] +#extra_compile_args = ['-std=c++11', '-fPIC'] +extra_compile_args = ['-std=c99', '-fPIC'] warp_ctc_path = "../build" if torch.cuda.is_available() or "CUDA_HOME" in os.environ: diff --git a/pytorch_binding/src/binding.cpp b/pytorch_binding/src/binding.cpp index ca80dfc..cc338a9 100644 --- a/pytorch_binding/src/binding.cpp +++ b/pytorch_binding/src/binding.cpp @@ -20,6 +20,8 @@ extern "C" int cpu_ctc(THFloatTensor *probs, THIntTensor *label_sizes, THIntTensor *sizes, int minibatch_size, + int blanklabel_index, + int num_threads, THFloatTensor *costs) { float *probs_ptr = probs->storage->data + probs->storageOffset; @@ -38,7 +40,8 @@ extern "C" int cpu_ctc(THFloatTensor *probs, ctcOptions options; memset(&options, 0, sizeof(options)); options.loc = CTC_CPU; - options.num_threads = 0; // will use default number of threads + options.blank_label = blanklabel_index; + options.num_threads = num_threads; // will use given number of threads #if defined(CTC_DISABLE_OMP) || defined(APPLE) // have to use at least one @@ -68,6 +71,7 @@ extern "C" int cpu_ctc(THFloatTensor *probs, THIntTensor *label_sizes, THIntTensor *sizes, int minibatch_size, + int blanklabel_index, THFloatTensor *costs) { float *probs_ptr = probs->storage->data + probs->storageOffset; @@ -86,6 +90,7 @@ extern "C" int cpu_ctc(THFloatTensor *probs, ctcOptions options; memset(&options, 0, sizeof(options)); options.loc = CTC_GPU; + options.blank_label = blanklabel_index; options.stream = THCState_getCurrentStream(state); size_t gpu_size_bytes; diff --git a/pytorch_binding/tests/test_cpu.py b/pytorch_binding/tests/test_cpu.py index efc4d15..516ee1d 100755 --- a/pytorch_binding/tests/test_cpu.py +++ b/pytorch_binding/tests/test_cpu.py @@ -17,6 +17,8 @@ def test_simple(): label_sizes, sizes, minibatch_size, + 0, + 0, costs) print('CPU_cost: %f' % costs.sum()) @@ -40,6 +42,8 @@ def test_medium(multiplier): label_sizes, sizes, minibatch_size, + 0, + 0, costs) print('CPU_cost: %f' % costs.sum()) @@ -62,6 +66,8 @@ def test_empty_label(): label_sizes, sizes, minibatch_size, + 0, + 0, costs) print('CPU_cost: %f' % costs.sum()) diff --git a/pytorch_binding/tests/test_gpu.py b/pytorch_binding/tests/test_gpu.py index 9369ac7..d63cc5d 100755 --- a/pytorch_binding/tests/test_gpu.py +++ b/pytorch_binding/tests/test_gpu.py @@ -18,6 +18,8 @@ def test_simple(): label_sizes, sizes, minibatch_size, + 0, + 0, costs) print('CPU_cost: %f' % costs.sum()) probs = probs.clone().cuda() @@ -29,6 +31,7 @@ def test_simple(): label_sizes, sizes, minibatch_size, + 0, costs) print('GPU_cost: %f' % costs.sum()) print(grads.view(grads.size(0) * grads.size(1), grads.size(2))) @@ -54,6 +57,8 @@ def test_medium(multiplier): label_sizes, sizes, minibatch_size, + 0, + 0, costs) print('CPU_cost: %f' % costs.sum()) probs = probs.clone().cuda() @@ -65,6 +70,7 @@ def test_medium(multiplier): label_sizes, sizes, minibatch_size, + 0, costs) print('GPU_cost: %f' % costs.sum()) print(grads.view(grads.size(0) * grads.size(1), grads.size(2))) @@ -89,6 +95,8 @@ def test_empty_label(): label_sizes, sizes, minibatch_size, + 0, + 0, costs) print('CPU_cost: %f' % costs.sum()) probs = probs.clone().cuda() @@ -100,6 +108,7 @@ def test_empty_label(): label_sizes, sizes, minibatch_size, + 0, costs) print('GPU_cost: %f' % costs.sum()) print(grads.view(grads.size(0) * grads.size(1), grads.size(2))) diff --git a/pytorch_binding/warpctc_pytorch/__init__.py b/pytorch_binding/warpctc_pytorch/__init__.py index 5f7cf74..347b794 100644 --- a/pytorch_binding/warpctc_pytorch/__init__.py +++ b/pytorch_binding/warpctc_pytorch/__init__.py @@ -9,22 +9,18 @@ class _CTC(Function): @staticmethod - def forward(ctx, acts, labels, act_lens, label_lens, size_average=False, + def forward(ctx, acts, labels, act_lens, label_lens, blank_label=0, num_threads=0, size_average=False, length_average=False): is_cuda = True if acts.is_cuda else False acts = acts.contiguous() - loss_func = warp_ctc.gpu_ctc if is_cuda else warp_ctc.cpu_ctc grads = torch.zeros(acts.size()).type_as(acts) minibatch_size = acts.size(1) costs = torch.zeros(minibatch_size).cpu() - loss_func(acts, - grads, - labels, - label_lens, - act_lens, - minibatch_size, - costs) - + if is_cuda: + # num_threads will be negeleted in GPU mode + warp_ctc.gpu_ctc(acts, grads,labels, label_lens, act_lens, minibatch_size, blank_label, costs) + else: + warp_ctc.cpu_ctc(acts, grads,labels, label_lens, act_lens, minibatch_size, blank_label, num_threads, costs) costs = torch.FloatTensor([costs.sum()]) if length_average: