You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
My test environment is as follows:
########################################################
torch==1.4.0
torchvision==0.5.0
cuda==9.0/10.1
########################################################
My test code is as follows:
########################################################
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import torch
from warpctc_pytorch import CTCLoss as warpctc
from torch.nn import CTCLoss as pytorchctc
from torch.autograd import Variable
CHARS = ['京', '沪', '津', '渝', '冀', '晋', '蒙', '辽', '吉', '黑',
'苏', '浙', '皖', '闽', '赣', '鲁', '豫', '鄂', '湘', '粤',
'桂', '琼', '川', '贵', '云', '藏', '陕', '甘', '青', '宁',
'新',
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9',
'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'J', 'K',
'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'U', 'V',
'W', 'X', 'Y', 'Z', 'I', 'O', '-'
]
alphabet = "".join(CHARS) + 'ç'
alphabet_dict = {}
for i, char in enumerate(alphabet):
alphabet_dict[char] = i + 1
length = []
result = []
text = ["湘E269JY","冀PL3N67","川R63728F","津AD6849","苏SDFD45464"]
for str in text:
length.append(len(str))
for char in str:
# print(char)
index = alphabet_dict[char]
result.append(index)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
targets = torch.IntTensor(result)
targets_lengths = torch.IntTensor(length)
print(targets, targets_lengths)
########
T = 71
N = len(text)
print("N is",N)
C = len(alphabet)
outputs = torch.randn(T,N,C).to(device)
log_probs = outputs.log_softmax(2).detach().requires_grad_().to(device)
input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
print(input_lengths,input_lengths.shape)
########
#warp_ctc
criterion_warp = warpctc(blank=0,size_average=False,length_average=False).to(device)
loss_warp1 = criterion_warp1(outputs,targets,input_lengths,targets_lengths)
loss_warp2 = criterion_warp1(log_probs,targets,input_lengths,targets_lengths)
print(loss_warp1,loss_warp2)
#######
#pytorch ctc
criterion_none = pytorchctc(blank=0,reduction="none")
loss_pytorch = criterion_none(log_probs,targets,input_lengths,targets_lengths)
print(loss_pytorch)
########################################################
#The printed result is
######################################################## Q: Should the loss calculated by warp-ctc be the sum of pyorch's loss(reduction is "none")?
I find that 278.4382+289.9979+276.8926+278.0664+272.8851 is not equal to 828.2159,however 278.4382+276.8926+272.8851 is equal to 828.2159.
It maybe means only the even indexed data is used to calculate the loss by warp_ctc, no loss is calculated for odd indexed data. Is it a bug?
The text was updated successfully, but these errors were encountered:
Q: Should the loss calculated by warp-ctc be the sum of pyorch's loss(reduction is "none")?
A: yes
You should set revise your code "input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)" with "input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.int32)"
My test environment is as follows:
########################################################
torch==1.4.0
torchvision==0.5.0
cuda==9.0/10.1
########################################################
My test code is as follows:
########################################################
########################################################
#The printed result is
########################################################
Q: Should the loss calculated by warp-ctc be the sum of pyorch's loss(reduction is "none")?
I find that 278.4382+289.9979+276.8926+278.0664+272.8851 is not equal to 828.2159,however 278.4382+276.8926+272.8851 is equal to 828.2159.
It maybe means only the even indexed data is used to calculate the loss by warp_ctc, no loss is calculated for odd indexed data. Is it a bug?
The text was updated successfully, but these errors were encountered: