Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doubts about the difference between pytorch's own ctcloss and warp-ctc #191

Open
2017ZYS opened this issue Jun 10, 2021 · 2 comments
Open

Comments

@2017ZYS
Copy link

2017ZYS commented Jun 10, 2021

My test environment is as follows:
########################################################
torch==1.4.0
torchvision==0.5.0
cuda==9.0/10.1
########################################################

My test code is as follows:
########################################################

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from warpctc_pytorch import CTCLoss as warpctc
from torch.nn import CTCLoss as pytorchctc
from torch.autograd import Variable

CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
         '苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
         '桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
         '新',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
         'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
         'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
         'W', 'X', 'Y', 'Z', 'I', 'O', '-'
         ]
alphabet = "".join(CHARS) + 'ç'  
alphabet_dict = {}

for i, char in enumerate(alphabet):
    alphabet_dict[char] = i + 1
length = []
result = []
text = ["湘E269JY","冀PL3N67","川R63728F","津AD6849","苏SDFD45464"]

for str in text:
    length.append(len(str))
    for char in str:
        # print(char)
        index = alphabet_dict[char]
        result.append(index)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
targets = torch.IntTensor(result)
targets_lengths = torch.IntTensor(length)
print(targets, targets_lengths)
########
T = 71
N = len(text)
print("N is",N)
C = len(alphabet)
outputs = torch.randn(T,N,C).to(device)
log_probs = outputs.log_softmax(2).detach().requires_grad_().to(device)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
print(input_lengths,input_lengths.shape)
########
#warp_ctc
criterion_warp = warpctc(blank=0,size_average=False,length_average=False).to(device)
loss_warp1 = criterion_warp1(outputs,targets,input_lengths,targets_lengths)
loss_warp2 = criterion_warp1(log_probs,targets,input_lengths,targets_lengths)
print(loss_warp1,loss_warp2)
#######
#pytorch ctc
criterion_none = pytorchctc(blank=0,reduction="none")
loss_pytorch = criterion_none(log_probs,targets,input_lengths,targets_lengths)
print(loss_pytorch)

########################################################
#The printed result is

tensor([19, 46, 34, 38, 41, 50, 64,  5, 55, 52, 35, 54, 38, 39, 23, 57, 38, 35,
        39, 34, 40, 47,  3, 42, 45, 38, 40, 36, 41, 11, 58, 45, 47, 45, 36, 37,
        36, 38, 36], dtype=torch.int32)  tensor([ 7,  7,  8,  7, 10], dtype=torch.int32)
N is 5
tensor([71, 71, 71, 71, 71]) torch.Size([5])
tensor([828.2159]) tensor([828.2159], grad_fn=<_CTCBackward>)
tensor([278.4382, 289.9979, 276.8926, 278.0664, 272.8851], device='cuda:0',grad_fn=<CudnnCtcLossBackward>)

########################################################
Q: Should the loss calculated by warp-ctc be the sum of pyorch's loss(reduction is "none")?
I find that 278.4382+289.9979+276.8926+278.0664+272.8851 is not equal to 828.2159,however 278.4382+276.8926+272.8851 is equal to 828.2159.
It maybe means only the even indexed data is used to calculate the loss by warp_ctc, no loss is calculated for odd indexed data. Is it a bug?

@DYJNG
Copy link

DYJNG commented Jun 18, 2021

Same question !
Look forward to your reply...

@Zhang-O
Copy link

Zhang-O commented Feb 24, 2022

Q: Should the loss calculated by warp-ctc be the sum of pyorch's loss(reduction is "none")?
A: yes
You should set revise your code "input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)" with "input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.int32)"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants