From 5c7963bf65f1f72df6d3a79ac844be54255abcc3 Mon Sep 17 00:00:00 2001 From: Siddharth Dalmia Date: Wed, 4 Apr 2018 22:29:54 -0400 Subject: [PATCH] Fix bug in backward pass It was mentioned in an old PR (https://github.com/SeanNaren/warp-ctc/pull/6/files), which now has conflicts. --- pytorch_binding/warpctc_pytorch/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_binding/warpctc_pytorch/__init__.py b/pytorch_binding/warpctc_pytorch/__init__.py index 532595b..475b530 100644 --- a/pytorch_binding/warpctc_pytorch/__init__.py +++ b/pytorch_binding/warpctc_pytorch/__init__.py @@ -43,7 +43,8 @@ def forward(ctx, acts, labels, act_lens, label_lens, size_average=False, @staticmethod def backward(ctx, grad_output): - return ctx.grads, None, None, None, None, None + grad_input = ctx.grads*grad_output.type_as(ctx.grads) + return grad_input, None, None, None, None, None class CTCLoss(Module):