Skip to content

Commit

Permalink
use alive_attn and select_indices only if return attn
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Oct 7, 2024
1 parent 5e8b3ff commit 642db27
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions eole/predict/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,21 +365,26 @@ def advance(self, log_probs, attn):
self.maybe_update_forbidden_tokens()

if self.return_attention or self._cov_pen:
current_attn = attn[self.select_indices]
if self.return_attention:
current_attn = attn[self.select_indices]
if step == 1:
self.alive_attn = current_attn
if self.return_attention:
self.alive_attn = current_attn
# update global state (step == 1)
if self._cov_pen: # coverage penalty
self._prev_penalty = torch.zeros_like(self.topk_log_probs)
self._coverage = torch.zeros_like(current_attn[:, 0, :])
self.src_length = current_attn.size(1)
self._coverage = torch.zeros(
(_B * self.beam_size, attn.size(1)), device=attn.device
)
self.src_length = attn.size(1)
else:
self.alive_attn = self.alive_attn[self.select_indices]
self.alive_attn = torch.cat([self.alive_attn, current_attn], 1)
if self.return_attention:
self.alive_attn = self.alive_attn[self.select_indices]
self.alive_attn = torch.cat([self.alive_attn, current_attn], 1)
# update global state (step > 1)
if self._cov_pen:
self._coverage = self._coverage[self.select_indices]
self._coverage += current_attn[:, 0, : self.src_length]
self._coverage += attn[:, 0, : self.src_length]
self._prev_penalty = self.global_scorer.cov_penalty(
self._coverage, beta=self.global_scorer.beta
).view(_B, self.beam_size)
Expand Down

0 comments on commit 642db27

Please sign in to comment.