Skip to content

Commit

Permalink
Codecleanup (OpenNMT#2497)
Browse files Browse the repository at this point in the history
* some code cleanup
  • Loading branch information
vince62s authored Nov 2, 2023
1 parent be13d12 commit edf847f
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 109 deletions.
6 changes: 5 additions & 1 deletion onmt/bin/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def main():
opt = parser.parse_args()
if opt.profile:

with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
profile_memory=True,
with_stack=True,
) as prof:
with record_function("Translate"):
translate(opt)
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=40))
Expand Down
16 changes: 7 additions & 9 deletions onmt/tests/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def test_beam_returns_attn_with_correct_length(self):
False,
)
device_init = torch.zeros(1, 1)
_, _, inp_lens, _ = beam.initialize(device_init, inp_lens)
_, _, _ = beam.initialize(device_init, inp_lens)
# inp_lens is tiled in initialize, reassign to make attn match
for i in range(min_length + 2):
# non-interesting beams are going to get dummy values
Expand Down Expand Up @@ -465,7 +465,9 @@ def test_beam_returns_attn_with_correct_length(self):
self.assertEqual(len(beam.attention[b]), 2)
for k in range(2):
# second dim is cut down to the non-padded src length
self.assertEqual(beam.attention[b][k].shape[-1], inp_lens[b])
self.assertEqual(
beam.attention[b][k].shape[-1], beam.src_len[b]
)
# first dim is equal to the time of death
# (beam 0 died at current step - adjust for SOS)
self.assertEqual(beam.attention[b][0].shape[0], i + 1)
Expand Down Expand Up @@ -756,7 +758,7 @@ def test_beam_lm_increase_src_len(self):
)
device_init = torch.zeros(1, 1)
src_len = torch.randint(0, 30, (self.BATCH_SZ,))
fn_map_state, _, _, _ = beam.initialize(device_init, src_len)
fn_map_state, _, _ = beam.initialize(device_init, src_len)
expected_beam_scores = self.init_step(beam, 1)
expected_beam_scores = self.first_step(beam, expected_beam_scores, 1)
expected_beam_scores = self.second_step(beam, expected_beam_scores, 1)
Expand Down Expand Up @@ -787,13 +789,9 @@ def test_beam_lm_update_src_len_when_finished(self):
)
device_init = torch.zeros(1, 1)
src_len = torch.randint(0, 30, (self.BATCH_SZ,))
fn_map_state, _, _, _ = beam.initialize(device_init, src_len)
fn_map_state, _, _ = beam.initialize(device_init, src_len)
self.init_step(beam, 1)
self.finish_first_beam_step(beam)

n_steps = beam.alive_seq.shape[-1] - 1
# I think this all test is unnecessary because
# 1) I removed self.src_len reindexing in remove_finished_batches() of LM case
# 2) this is already done in the translator _translate_batch_with_strategy()
# self.assertTrue(beam.src_len.equal(n_steps + fn_map_state(src_len[1:], dim=0)))
self.assertTrue(beam.src_len.equal(n_steps + fn_map_state(src_len, dim=0)))
self.assertTrue(beam.src_len.equal(n_steps + fn_map_state(src_len[1:], dim=0)))
90 changes: 39 additions & 51 deletions onmt/translate/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,8 @@ def __init__(
def initialize(self, *args, **kwargs):
raise NotImplementedError

def initialize_(self, enc_out, src_len, src_map, device, target_prefix):
super(BeamSearchBase, self).initialize(
enc_out, src_len, src_map, device, target_prefix
)
def initialize_(self, enc_out, src_map, device, target_prefix):
super(BeamSearchBase, self).initialize(device, target_prefix)
self.best_scores = [-1e10 for _ in range(self.batch_size)]
self._beam_offset = torch.arange(
0,
Expand Down Expand Up @@ -184,9 +182,9 @@ def _pick(self, log_probs):
return topk_scores, topk_ids

def beams_non_finished(self, i, predictions, attention, step):
b = self._batch_offset[i]

if any(self.is_finished_list[i]):
b = self._batch_offset[i]
# Store finished hypotheses for this example in the batch.
for j in [
k for k, fin in enumerate(self.is_finished_list[i]) if fin
Expand All @@ -198,7 +196,7 @@ def beams_non_finished(self, i, predictions, attention, step):
(
self.topk_scores[i, j],
predictions[i, j, 1:], # Ignore start_token.
attention[i, j, :, : self.src_len[b]]
attention[i, j, :, : self.src_len[i]]
if attention is not None
else None,
)
Expand All @@ -208,23 +206,25 @@ def beams_non_finished(self, i, predictions, attention, step):
self.hypotheses[b], key=lambda x: x[0], reverse=True
)

# End condition is the top beam finished and we can return
# n_best hypotheses.
if self.ratio > 0:
pred_len = self.src_len[b] * self.ratio
finish_flag = (
(self.topk_scores[i, 0] / pred_len) <= self.best_scores[b]
) or all(self.is_finished_list[i])
else:
# early stop when top beam is finished
finish_flag = self.is_finished_list[i][0]

if finish_flag and len(self.hypotheses[b]) >= self.n_best:
for score, pred, attn in self.hypotheses[b][: self.n_best]:
self.scores[b].append(score)
self.predictions[b].append(pred) # ``(batch, n_best,)``
self.attention[b].append(attn if attn is not None else [])
return False
# End condition is the top beam finished and we can return
# n_best hypotheses.
if self.ratio > 0:
pred_len = self.src_len[i] * self.ratio
finish_flag = (
(self.topk_scores[i, 0] / pred_len) <= self.best_scores[b]
) or all(self.is_finished_list[i])
else:
# early stop when top beam is finished
finish_flag = self.is_finished_list[i][0]

if finish_flag and len(self.hypotheses[b]) >= self.n_best:
for score, pred, attn in self.hypotheses[b][: self.n_best]:
self.scores[b].append(score)
self.predictions[b].append(pred) # ``(batch, n_best,)``
self.attention[b].append(attn if attn is not None else [])
return False
else:
return True
else:
return True

Expand Down Expand Up @@ -263,23 +263,32 @@ def update_finished(self):
_B_new, _B_old, non_finished, predictions, attention, step
)

# reset the selection for the next step
self.select_indices = self._batch_index.view(_B_new * self.beam_size)
# assert torch.equal(
# self.src_len[self.select_indices],
# self.src_len.view(_B_old, self.beam_size)[non_finished].view(
# _B_new * self.beam_size
# ),
# )
self.src_len = self.src_len[self.select_indices]
self.maybe_update_target_prefix(self.select_indices)

def remove_finished_batches(
self, _B_new, _B_old, non_finished, predictions, attention, step
):
# Remove finished batches for the next step.
self._batch_offset = self._batch_offset[non_finished]
# here we combine two slections in one
# self.topk_log_probs = self.topk_log_probs[non_finished]
# self._batch_index = self._batch_index[non_finished]
# self._batch_index = self._batch_index.index_select(0, non_finished)
self.topk_log_probs, self._batch_index = torch.unbind(
torch.stack([self.topk_log_probs, self._batch_index], dim=2)[non_finished],
dim=2,
)
self._batch_index = self._batch_index.to(torch.long)
self.select_indices = self._batch_index.view(_B_new * self.beam_size)
self.alive_seq = predictions[non_finished].view(-1, self.alive_seq.size(-1))

self.maybe_update_target_prefix(self.select_indices)
if self.alive_attn is not None:
inp_seq_len = self.alive_attn.size(-1)
self.alive_attn = attention[non_finished].view(
Expand Down Expand Up @@ -339,7 +348,7 @@ def advance(self, log_probs, attn):
self.select_indices = self._batch_index.view(_B * self.beam_size)
self.topk_ids %= vocab_size

# Append last prediction.
# Append last prediction to reordered alive sequence
self.alive_seq = torch.cat(
[
self.alive_seq[self.select_indices],
Expand Down Expand Up @@ -399,11 +408,9 @@ def initialize(
if device is None:
device = self.get_device_from_enc_out(enc_out)

super(BeamSearch, self).initialize_(
enc_out, self.src_len, src_map, device, target_prefix
)
super(BeamSearch, self).initialize_(enc_out, src_map, device, target_prefix)

return fn_map_state, enc_out, self.src_len, src_map
return fn_map_state, enc_out, src_map


class BeamSearchLM(BeamSearchBase):
Expand All @@ -423,13 +430,12 @@ def initialize(self, src, src_len, src_map=None, device=None, target_prefix=None

super(BeamSearchLM, self).initialize_(
None,
self.src_len,
src_map=src_map,
device=device,
target_prefix=target_prefix,
)

return fn_map_state, src, self.src_len, src_map
return fn_map_state, src, src_map

def advance(self, log_probs, attn):
super(BeamSearchLM, self).advance(log_probs, attn)
Expand All @@ -438,24 +444,6 @@ def advance(self, log_probs, attn):
# and therefore needs to follow the generation
self.src_len += 1

def remove_finished_batches(
self, _B_new, _B_old, non_finished, predictions, attention, step
):
super(BeamSearchLM, self).remove_finished_batches(
_B_new, _B_old, non_finished, predictions, attention, step
)

# in LM task src_len is associated with currently generated src
# and therefore needs to follow the generation
# VN 24/10/2023 given the usage of src_len in update_finished()
# I think this is incorrect therefore commenting
# indexing needs to be aligned to original batch indexing
"""
self.src_len = self.src_len.view(_B_old, self.beam_size)[non_finished].view(
_B_new * self.beam_size
)
"""


class GNMTGlobalScorer(object):
"""NMT re-ranking.
Expand Down
11 changes: 5 additions & 6 deletions onmt/translate/decode_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ class DecodeStrategy(object):
alive_seq (LongTensor): Shape ``(B x parallel_paths, step)``.
This sequence grows in the ``step`` axis on each call to
:func:``advance()``.
is_finished (ByteTensor or NoneType): Shape ``(B, parallel_paths)``.
is_finished (ByteTensor or NoneType): Shape ``(B, parallel_paths)``.
Initialized to ``None``.
alive_attn (FloatTensor or NoneType): If tensor, shape is
``(step, B x parallel_paths, inp_seq_len)``, where ``inp_seq_len``
``(B x parallel_paths, step, inp_seq_len)``, where ``inp_seq_len``
is the (max) length of the input sequence.
target_prefix (LongTensor or NoneType): If tensor, shape is
``(B x parallel_paths, prefix_seq_len)``, where ``prefix_seq_len``
Expand Down Expand Up @@ -133,14 +133,13 @@ def fn_map_state(state, dim=0):
src_map = tile(src_map, self.beam_size, dim=0)

self.src_len = tile(src_len, self.beam_size)

if target_prefix is not None:
target_prefix = tile(target_prefix, self.beam_size, dim=0)

return fn_map_state, enc_out, src_map, target_prefix

def initialize(
self, enc_out, src_len, src_map=None, device=None, target_prefix=None
):
def initialize(self, device=None, target_prefix=None):
"""DecodeStrategy subclasses should override :func:`initialize()`.
`initialize` should be called before all actions.
Expand Down Expand Up @@ -177,7 +176,7 @@ def initialize(
self.min_length += min(prefix_non_pad) - 1

self.target_prefix = target_prefix # NOTE: forced prefix words
return None, enc_out, src_len, src_map
return None

def __len__(self):
return self.alive_seq.shape[1]
Expand Down
40 changes: 17 additions & 23 deletions onmt/translate/greedy_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,7 @@ def initialize(
if device is None:
device = self.get_device_from_enc_out(enc_out)

super(GreedySearch, self).initialize(
enc_out, src_len, src_map, device, target_prefix
)
super(GreedySearch, self).initialize(device, target_prefix)
self.select_indices = torch.arange(
self.batch_size * self.beam_size, dtype=torch.long, device=device
)
Expand All @@ -180,7 +178,7 @@ def initialize(
self.beams_scores = torch.zeros(
(self.batch_size * self.beam_size, 1), dtype=torch.float, device=device
)
return fn_map_state, enc_out, self.src_len, src_map
return fn_map_state, enc_out, src_map

@property
def current_predictions(self):
Expand Down Expand Up @@ -245,7 +243,7 @@ def advance(self, log_probs, attn):
if self.alive_attn is None:
self.alive_attn = attn
else:
self.alive_attn = torch.cat([self.alive_attn, attn], 0)
self.alive_attn = torch.cat([self.alive_attn, attn], 1)
self.ensure_max_length()

def update_finished(self):
Expand All @@ -262,7 +260,7 @@ def update_finished(self):
score = self.beams_scores[b, 0] / length_penalty
pred = self.alive_seq[b, 1:]
attention = (
self.alive_attn[:, b, : self.src_len[b]]
self.alive_attn[b, :, : self.src_len[b]]
if self.alive_attn is not None
else []
)
Expand All @@ -279,8 +277,9 @@ def update_finished(self):
is_alive = ~self.is_finished.view(-1)
self.alive_seq = self.alive_seq[is_alive]
self.beams_scores = self.beams_scores[is_alive]
self.src_len = self.src_len[is_alive]
if self.alive_attn is not None:
self.alive_attn = self.alive_attn[:, is_alive]
self.alive_attn = self.alive_attn[is_alive]
self.select_indices = is_alive.nonzero(as_tuple=False).view(-1)
self.original_batch_idx = self.original_batch_idx[is_alive]
self.maybe_update_target_prefix(self.select_indices)
Expand All @@ -289,27 +288,22 @@ def update_finished(self):
class GreedySearchLM(GreedySearch):
def update_finished(self):
super(GreedySearchLM, self).update_finished()
self.update_src_len()

def update_src_len(self):
is_alive = ~self.is_finished.view(-1)
self.src_len = self.src_len[is_alive]

def advance(self, log_probs, attn):
super(GreedySearchLM, self).advance(log_probs, attn)

# in LM task src_len is associated with currently generated src
# and therefore needs to follow the generation
self.src_len += 1

def initialize(self, src, src_len, src_map=None, device=None, target_prefix=None):
"""Initialize for decoding."""

if device is None:
device = src.device

(fn_map_state, _, self.src_len, src_map) = super(
GreedySearchLM, self
).initialize(None, src_len, src_map, device, target_prefix)
(fn_map_state, _, src_map) = super(GreedySearchLM, self).initialize(
None, src_len, src_map, device, target_prefix
)

return fn_map_state, src, src_map

def advance(self, log_probs, attn):
super(GreedySearchLM, self).advance(log_probs, attn)

return fn_map_state, src, self.src_len, src_map
# in LM task src_len is associated with currently generated src
# and therefore needs to follow the generation
self.src_len += 1
Loading

0 comments on commit edf847f

Please sign in to comment.