diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index 8c965c0..a5dd138 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -71,7 +71,7 @@ def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device, self.hypotheses = [[] for _ in range(batch_size)] # beam state - self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8) + self.top_beam_finished = torch.zeros([batch_size], dtype=torch.bool) self._batch_offset = torch.arange(batch_size, dtype=torch.long) self._beam_offset = torch.arange( 0, batch_size * beam_size, step=beam_size, dtype=torch.long,