Skip to content

Commit

Permalink
Merge pull request #178 from ysk24ok/add_CTCLoss_test
Browse files Browse the repository at this point in the history
Add a test for CTCLoss
  • Loading branch information
SeanNaren authored Oct 1, 2020
2 parents 33a97b2 + fb6c4b4 commit e2609d8
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
14 changes: 14 additions & 0 deletions pytorch_binding/tests/test_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,5 +69,19 @@ def test_empty_label():
print('CPU_cost: %f' % costs.sum())


def test_CTCLoss():
probs = torch.FloatTensor([[
[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]
]]).transpose(0, 1).contiguous()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True)

ctc_loss = warp_ctc.CTCLoss()
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()


if __name__ == '__main__':
pytest.main([__file__])
15 changes: 15 additions & 0 deletions pytorch_binding/tests/test_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,5 +111,20 @@ def test_empty_label():
print(grads.view(grads.size(0) * grads.size(1), grads.size(2)))


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU")
def test_CTCLoss():
probs = torch.FloatTensor([[
[0.1, 0.6, 0.1, 0.1, 0.1], [0.1, 0.1, 0.6, 0.1, 0.1]
]]).transpose(0, 1).contiguous().cuda()
labels = torch.IntTensor([1, 2])
label_sizes = torch.IntTensor([2])
probs_sizes = torch.IntTensor([2])
probs.requires_grad_(True)

ctc_loss = warp_ctc.CTCLoss()
cost = ctc_loss(probs, labels, probs_sizes, label_sizes)
cost.backward()


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit e2609d8

Please sign in to comment.