From edf847f7e070dceb749ef45c30361a9f3c6e3759 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Thu, 2 Nov 2023 12:22:50 +0100 Subject: [PATCH] Codecleanup (#2497) * some code cleanup --- onmt/bin/translate.py | 6 ++- onmt/tests/test_beam_search.py | 16 +++--- onmt/translate/beam_search.py | 90 ++++++++++++++----------------- onmt/translate/decode_strategy.py | 11 ++-- onmt/translate/greedy_search.py | 40 ++++++-------- onmt/translate/translator.py | 24 ++------- 6 files changed, 78 insertions(+), 109 deletions(-) diff --git a/onmt/bin/translate.py b/onmt/bin/translate.py index aae076ec85..0cd2176d4c 100644 --- a/onmt/bin/translate.py +++ b/onmt/bin/translate.py @@ -54,7 +54,11 @@ def main(): opt = parser.parse_args() if opt.profile: - with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + profile_memory=True, + with_stack=True, + ) as prof: with record_function("Translate"): translate(opt) print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=40)) diff --git a/onmt/tests/test_beam_search.py b/onmt/tests/test_beam_search.py index 4b0e4df8ad..fadd5dca93 100644 --- a/onmt/tests/test_beam_search.py +++ b/onmt/tests/test_beam_search.py @@ -414,7 +414,7 @@ def test_beam_returns_attn_with_correct_length(self): False, ) device_init = torch.zeros(1, 1) - _, _, inp_lens, _ = beam.initialize(device_init, inp_lens) + _, _, _ = beam.initialize(device_init, inp_lens) # inp_lens is tiled in initialize, reassign to make attn match for i in range(min_length + 2): # non-interesting beams are going to get dummy values @@ -465,7 +465,9 @@ def test_beam_returns_attn_with_correct_length(self): self.assertEqual(len(beam.attention[b]), 2) for k in range(2): # second dim is cut down to the non-padded src length - self.assertEqual(beam.attention[b][k].shape[-1], inp_lens[b]) + self.assertEqual( + beam.attention[b][k].shape[-1], beam.src_len[b] + ) # first dim is equal to the time of death # (beam 0 died at current step - adjust for SOS) self.assertEqual(beam.attention[b][0].shape[0], i + 1) @@ -756,7 +758,7 @@ def test_beam_lm_increase_src_len(self): ) device_init = torch.zeros(1, 1) src_len = torch.randint(0, 30, (self.BATCH_SZ,)) - fn_map_state, _, _, _ = beam.initialize(device_init, src_len) + fn_map_state, _, _ = beam.initialize(device_init, src_len) expected_beam_scores = self.init_step(beam, 1) expected_beam_scores = self.first_step(beam, expected_beam_scores, 1) expected_beam_scores = self.second_step(beam, expected_beam_scores, 1) @@ -787,13 +789,9 @@ def test_beam_lm_update_src_len_when_finished(self): ) device_init = torch.zeros(1, 1) src_len = torch.randint(0, 30, (self.BATCH_SZ,)) - fn_map_state, _, _, _ = beam.initialize(device_init, src_len) + fn_map_state, _, _ = beam.initialize(device_init, src_len) self.init_step(beam, 1) self.finish_first_beam_step(beam) n_steps = beam.alive_seq.shape[-1] - 1 - # I think this all test is unnecessary because - # 1) I removed self.src_len reindexing in remove_finished_batches() of LM case - # 2) this is already done in the translator _translate_batch_with_strategy() - # self.assertTrue(beam.src_len.equal(n_steps + fn_map_state(src_len[1:], dim=0))) - self.assertTrue(beam.src_len.equal(n_steps + fn_map_state(src_len, dim=0))) + self.assertTrue(beam.src_len.equal(n_steps + fn_map_state(src_len[1:], dim=0))) diff --git a/onmt/translate/beam_search.py b/onmt/translate/beam_search.py index f9f13a89d8..43865f1467 100644 --- a/onmt/translate/beam_search.py +++ b/onmt/translate/beam_search.py @@ -113,10 +113,8 @@ def __init__( def initialize(self, *args, **kwargs): raise NotImplementedError - def initialize_(self, enc_out, src_len, src_map, device, target_prefix): - super(BeamSearchBase, self).initialize( - enc_out, src_len, src_map, device, target_prefix - ) + def initialize_(self, enc_out, src_map, device, target_prefix): + super(BeamSearchBase, self).initialize(device, target_prefix) self.best_scores = [-1e10 for _ in range(self.batch_size)] self._beam_offset = torch.arange( 0, @@ -184,9 +182,9 @@ def _pick(self, log_probs): return topk_scores, topk_ids def beams_non_finished(self, i, predictions, attention, step): - b = self._batch_offset[i] if any(self.is_finished_list[i]): + b = self._batch_offset[i] # Store finished hypotheses for this example in the batch. for j in [ k for k, fin in enumerate(self.is_finished_list[i]) if fin @@ -198,7 +196,7 @@ def beams_non_finished(self, i, predictions, attention, step): ( self.topk_scores[i, j], predictions[i, j, 1:], # Ignore start_token. - attention[i, j, :, : self.src_len[b]] + attention[i, j, :, : self.src_len[i]] if attention is not None else None, ) @@ -208,23 +206,25 @@ def beams_non_finished(self, i, predictions, attention, step): self.hypotheses[b], key=lambda x: x[0], reverse=True ) - # End condition is the top beam finished and we can return - # n_best hypotheses. - if self.ratio > 0: - pred_len = self.src_len[b] * self.ratio - finish_flag = ( - (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b] - ) or all(self.is_finished_list[i]) - else: - # early stop when top beam is finished - finish_flag = self.is_finished_list[i][0] - - if finish_flag and len(self.hypotheses[b]) >= self.n_best: - for score, pred, attn in self.hypotheses[b][: self.n_best]: - self.scores[b].append(score) - self.predictions[b].append(pred) # ``(batch, n_best,)`` - self.attention[b].append(attn if attn is not None else []) - return False + # End condition is the top beam finished and we can return + # n_best hypotheses. + if self.ratio > 0: + pred_len = self.src_len[i] * self.ratio + finish_flag = ( + (self.topk_scores[i, 0] / pred_len) <= self.best_scores[b] + ) or all(self.is_finished_list[i]) + else: + # early stop when top beam is finished + finish_flag = self.is_finished_list[i][0] + + if finish_flag and len(self.hypotheses[b]) >= self.n_best: + for score, pred, attn in self.hypotheses[b][: self.n_best]: + self.scores[b].append(score) + self.predictions[b].append(pred) # ``(batch, n_best,)`` + self.attention[b].append(attn if attn is not None else []) + return False + else: + return True else: return True @@ -263,6 +263,17 @@ def update_finished(self): _B_new, _B_old, non_finished, predictions, attention, step ) + # reset the selection for the next step + self.select_indices = self._batch_index.view(_B_new * self.beam_size) + # assert torch.equal( + # self.src_len[self.select_indices], + # self.src_len.view(_B_old, self.beam_size)[non_finished].view( + # _B_new * self.beam_size + # ), + # ) + self.src_len = self.src_len[self.select_indices] + self.maybe_update_target_prefix(self.select_indices) + def remove_finished_batches( self, _B_new, _B_old, non_finished, predictions, attention, step ): @@ -270,16 +281,14 @@ def remove_finished_batches( self._batch_offset = self._batch_offset[non_finished] # here we combine two slections in one # self.topk_log_probs = self.topk_log_probs[non_finished] - # self._batch_index = self._batch_index[non_finished] + # self._batch_index = self._batch_index.index_select(0, non_finished) self.topk_log_probs, self._batch_index = torch.unbind( torch.stack([self.topk_log_probs, self._batch_index], dim=2)[non_finished], dim=2, ) self._batch_index = self._batch_index.to(torch.long) - self.select_indices = self._batch_index.view(_B_new * self.beam_size) self.alive_seq = predictions[non_finished].view(-1, self.alive_seq.size(-1)) - self.maybe_update_target_prefix(self.select_indices) if self.alive_attn is not None: inp_seq_len = self.alive_attn.size(-1) self.alive_attn = attention[non_finished].view( @@ -339,7 +348,7 @@ def advance(self, log_probs, attn): self.select_indices = self._batch_index.view(_B * self.beam_size) self.topk_ids %= vocab_size - # Append last prediction. + # Append last prediction to reordered alive sequence self.alive_seq = torch.cat( [ self.alive_seq[self.select_indices], @@ -399,11 +408,9 @@ def initialize( if device is None: device = self.get_device_from_enc_out(enc_out) - super(BeamSearch, self).initialize_( - enc_out, self.src_len, src_map, device, target_prefix - ) + super(BeamSearch, self).initialize_(enc_out, src_map, device, target_prefix) - return fn_map_state, enc_out, self.src_len, src_map + return fn_map_state, enc_out, src_map class BeamSearchLM(BeamSearchBase): @@ -423,13 +430,12 @@ def initialize(self, src, src_len, src_map=None, device=None, target_prefix=None super(BeamSearchLM, self).initialize_( None, - self.src_len, src_map=src_map, device=device, target_prefix=target_prefix, ) - return fn_map_state, src, self.src_len, src_map + return fn_map_state, src, src_map def advance(self, log_probs, attn): super(BeamSearchLM, self).advance(log_probs, attn) @@ -438,24 +444,6 @@ def advance(self, log_probs, attn): # and therefore needs to follow the generation self.src_len += 1 - def remove_finished_batches( - self, _B_new, _B_old, non_finished, predictions, attention, step - ): - super(BeamSearchLM, self).remove_finished_batches( - _B_new, _B_old, non_finished, predictions, attention, step - ) - - # in LM task src_len is associated with currently generated src - # and therefore needs to follow the generation - # VN 24/10/2023 given the usage of src_len in update_finished() - # I think this is incorrect therefore commenting - # indexing needs to be aligned to original batch indexing - """ - self.src_len = self.src_len.view(_B_old, self.beam_size)[non_finished].view( - _B_new * self.beam_size - ) - """ - class GNMTGlobalScorer(object): """NMT re-ranking. diff --git a/onmt/translate/decode_strategy.py b/onmt/translate/decode_strategy.py index 86a33ed1c2..ec39710b20 100644 --- a/onmt/translate/decode_strategy.py +++ b/onmt/translate/decode_strategy.py @@ -48,10 +48,10 @@ class DecodeStrategy(object): alive_seq (LongTensor): Shape ``(B x parallel_paths, step)``. This sequence grows in the ``step`` axis on each call to :func:``advance()``. - is_finished (ByteTensor or NoneType): Shape ``(B, parallel_paths)``. + is_finished (ByteTensor or NoneType): Shape ``(B, parallel_paths)``. Initialized to ``None``. alive_attn (FloatTensor or NoneType): If tensor, shape is - ``(step, B x parallel_paths, inp_seq_len)``, where ``inp_seq_len`` + ``(B x parallel_paths, step, inp_seq_len)``, where ``inp_seq_len`` is the (max) length of the input sequence. target_prefix (LongTensor or NoneType): If tensor, shape is ``(B x parallel_paths, prefix_seq_len)``, where ``prefix_seq_len`` @@ -133,14 +133,13 @@ def fn_map_state(state, dim=0): src_map = tile(src_map, self.beam_size, dim=0) self.src_len = tile(src_len, self.beam_size) + if target_prefix is not None: target_prefix = tile(target_prefix, self.beam_size, dim=0) return fn_map_state, enc_out, src_map, target_prefix - def initialize( - self, enc_out, src_len, src_map=None, device=None, target_prefix=None - ): + def initialize(self, device=None, target_prefix=None): """DecodeStrategy subclasses should override :func:`initialize()`. `initialize` should be called before all actions. @@ -177,7 +176,7 @@ def initialize( self.min_length += min(prefix_non_pad) - 1 self.target_prefix = target_prefix # NOTE: forced prefix words - return None, enc_out, src_len, src_map + return None def __len__(self): return self.alive_seq.shape[1] diff --git a/onmt/translate/greedy_search.py b/onmt/translate/greedy_search.py index 0d016df454..8a5707ffa8 100644 --- a/onmt/translate/greedy_search.py +++ b/onmt/translate/greedy_search.py @@ -168,9 +168,7 @@ def initialize( if device is None: device = self.get_device_from_enc_out(enc_out) - super(GreedySearch, self).initialize( - enc_out, src_len, src_map, device, target_prefix - ) + super(GreedySearch, self).initialize(device, target_prefix) self.select_indices = torch.arange( self.batch_size * self.beam_size, dtype=torch.long, device=device ) @@ -180,7 +178,7 @@ def initialize( self.beams_scores = torch.zeros( (self.batch_size * self.beam_size, 1), dtype=torch.float, device=device ) - return fn_map_state, enc_out, self.src_len, src_map + return fn_map_state, enc_out, src_map @property def current_predictions(self): @@ -245,7 +243,7 @@ def advance(self, log_probs, attn): if self.alive_attn is None: self.alive_attn = attn else: - self.alive_attn = torch.cat([self.alive_attn, attn], 0) + self.alive_attn = torch.cat([self.alive_attn, attn], 1) self.ensure_max_length() def update_finished(self): @@ -262,7 +260,7 @@ def update_finished(self): score = self.beams_scores[b, 0] / length_penalty pred = self.alive_seq[b, 1:] attention = ( - self.alive_attn[:, b, : self.src_len[b]] + self.alive_attn[b, :, : self.src_len[b]] if self.alive_attn is not None else [] ) @@ -279,8 +277,9 @@ def update_finished(self): is_alive = ~self.is_finished.view(-1) self.alive_seq = self.alive_seq[is_alive] self.beams_scores = self.beams_scores[is_alive] + self.src_len = self.src_len[is_alive] if self.alive_attn is not None: - self.alive_attn = self.alive_attn[:, is_alive] + self.alive_attn = self.alive_attn[is_alive] self.select_indices = is_alive.nonzero(as_tuple=False).view(-1) self.original_batch_idx = self.original_batch_idx[is_alive] self.maybe_update_target_prefix(self.select_indices) @@ -289,18 +288,6 @@ def update_finished(self): class GreedySearchLM(GreedySearch): def update_finished(self): super(GreedySearchLM, self).update_finished() - self.update_src_len() - - def update_src_len(self): - is_alive = ~self.is_finished.view(-1) - self.src_len = self.src_len[is_alive] - - def advance(self, log_probs, attn): - super(GreedySearchLM, self).advance(log_probs, attn) - - # in LM task src_len is associated with currently generated src - # and therefore needs to follow the generation - self.src_len += 1 def initialize(self, src, src_len, src_map=None, device=None, target_prefix=None): """Initialize for decoding.""" @@ -308,8 +295,15 @@ def initialize(self, src, src_len, src_map=None, device=None, target_prefix=None if device is None: device = src.device - (fn_map_state, _, self.src_len, src_map) = super( - GreedySearchLM, self - ).initialize(None, src_len, src_map, device, target_prefix) + (fn_map_state, _, src_map) = super(GreedySearchLM, self).initialize( + None, src_len, src_map, device, target_prefix + ) + + return fn_map_state, src, src_map + + def advance(self, log_probs, attn): + super(GreedySearchLM, self).advance(log_probs, attn) - return fn_map_state, src, self.src_len, src_map + # in LM task src_len is associated with currently generated src + # and therefore needs to follow the generation + self.src_len += 1 diff --git a/onmt/translate/translator.py b/onmt/translate/translator.py index 41c607bf34..be6903bcaf 100644 --- a/onmt/translate/translator.py +++ b/onmt/translate/translator.py @@ -714,9 +714,6 @@ def report_results( gold_score, batch, batch_size, - src, - src_len, - use_src_map, decode_strategy, ): results = { @@ -901,7 +898,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): # (2) prep decode_strategy. Possibly repeat src objects. src_map = batch["src_map"] if use_src_map else None target_prefix = batch["tgt"] if self.tgt_file_prefix else None - (fn_map_state, enc_out, src_len_tiled, src_map,) = decode_strategy.initialize( + (fn_map_state, enc_out, src_map) = decode_strategy.initialize( enc_out, src_len, src_map, target_prefix=target_prefix ) @@ -916,7 +913,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): decoder_input, enc_out, batch, - src_len=src_len_tiled, + src_len=decode_strategy.src_len, src_map=src_map, step=step, batch_offset=decode_strategy.batch_offset, @@ -941,8 +938,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): else: enc_out = enc_out[select_indices] - src_len_tiled = src_len_tiled[select_indices] - if src_map is not None: src_map = src_map[select_indices] @@ -953,9 +948,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): gold_score, batch, batch_size, - src, - src_len, - use_src_map, decode_strategy, ) @@ -1107,7 +1099,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): # (3) prep decode_strategy. Possibly repeat src objects. src_map = batch["src_map"] if use_src_map else None - (fn_map_state, src, src_len_tiled, src_map,) = decode_strategy.initialize( + (fn_map_state, src, src_map) = decode_strategy.initialize( src, src_len, src_map, @@ -1124,7 +1116,7 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): decoder_input, None, batch, - src_len=src_len_tiled.clone(), + src_len=decode_strategy.src_len, src_map=src_map, step=step if step == 0 else step + src_len[0].item(), batch_offset=decode_strategy.batch_offset, @@ -1143,13 +1135,10 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): decode_strategy.update_finished() if decode_strategy.done: break - select_indices = decode_strategy.select_indices - src_len_tiled += 1 + if any_finished: # Reorder states. - src_len_tiled = src_len_tiled[select_indices] - if src_map is not None: src_map = src_map[select_indices] @@ -1161,9 +1150,6 @@ def _translate_batch_with_strategy(self, batch, decode_strategy): gold_score, batch, batch_size, - src, - src_len, - use_src_map, decode_strategy, )