diff --git a/s2s-ft/s2s_ft/modeling_decoding.py b/s2s-ft/s2s_ft/modeling_decoding.py index 89dec6565..c235be143 100644 --- a/s2s-ft/s2s_ft/modeling_decoding.py +++ b/s2s-ft/s2s_ft/modeling_decoding.py @@ -1656,7 +1656,7 @@ def get_dup_ngram_candidates(seq, n): forbid_word_mask = torch.tensor( buf_matrix, dtype=log_scores.dtype) forbid_word_mask = torch.reshape( - forbid_word_mask, [batch_size * K, 1, vocab_size]).cuda() + forbid_word_mask, [batch_size * K, 1, vocab_size]).to(input_ids.device) else: forbid_word_mask = None next_pos += 1