Skip to content

Commit

Permalink
fixed coverage size
Browse files Browse the repository at this point in the history
  • Loading branch information
l-k-11235 committed Oct 4, 2024
1 parent ff3914c commit 5e8b3ff
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions eole/predict/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,19 +371,20 @@ def advance(self, log_probs, attn):
# update global state (step == 1)
if self._cov_pen: # coverage penalty
self._prev_penalty = torch.zeros_like(self.topk_log_probs)
self._coverage = current_attn
self._coverage = torch.zeros_like(current_attn[:, 0, :])
self.src_length = current_attn.size(1)
else:
self.alive_attn = self.alive_attn[self.select_indices]
self.alive_attn = torch.cat([self.alive_attn, current_attn], 1)
# update global state (step > 1)
if self._cov_pen:
self._coverage = self._coverage[self.select_indices]
self._coverage += current_attn
self._coverage += current_attn[:, 0, : self.src_length]
self._prev_penalty = self.global_scorer.cov_penalty(
self._coverage, beta=self.global_scorer.beta
).view(_B, self.beam_size)

if self._vanilla_cov_pen:
if self._vanilla_cov_pen and step > 1:
# shape: (batch_size x beam_size, 1)
cov_penalty = self.global_scorer.cov_penalty(
self._coverage, beta=self.global_scorer.beta
Expand Down

0 comments on commit 5e8b3ff

Please sign in to comment.