Skip to content

Commit

Permalink
Fix remaining issues in beam score calculation (huggingface#27808)
Browse files Browse the repository at this point in the history
* Fix issues in add and is_done for BeamHypotheses

* make newly added arguments optional for better compatibility

* Directly use cur_len as generated_len, add note for retrocompatibility

* update test expectation

* make cur_len represents the length of the entire sequence including the decoder prompt

* remove redundant if/else in testing
  • Loading branch information
VsonicV authored Dec 8, 2023
1 parent 3ac9945 commit b31905d
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 34 deletions.
57 changes: 29 additions & 28 deletions src/transformers/generation/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,8 +224,8 @@ def process(
group_index: Optional[int] = 0,
decoder_prompt_len: Optional[int] = 0,
) -> Dict[str, torch.Tensor]:
# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps) // self.num_beam_groups

if not (batch_size == (input_ids.shape[0] // self.group_size)):
Expand Down Expand Up @@ -279,15 +279,11 @@ def process(
else:
beam_index = None

# skip the corner case where the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue

self._beam_hyps[batch_group_idx].add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
generated_len=cur_len - decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
Expand All @@ -308,7 +304,7 @@ def process(

# Check if we are done so that we can save a pad step if all(done)
self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
next_scores[batch_idx].max().item(), cur_len
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)

return UserDict(
Expand Down Expand Up @@ -348,7 +344,8 @@ def finalize(
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)

# select the best hypotheses
sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
Expand Down Expand Up @@ -560,8 +557,8 @@ def process(
indicating to which beam the next tokens shall be added.
"""

# add up to the length which the next_scores is calculated on
cur_len = input_ids.shape[-1] - decoder_prompt_len + 1
# add up to the length which the next_scores is calculated on (including decoder prompt)
cur_len = input_ids.shape[-1] + 1
batch_size = len(self._beam_hyps)
if not (batch_size == (input_ids.shape[0] // self.group_size)):
if self.num_beam_groups > 1:
Expand Down Expand Up @@ -617,16 +614,11 @@ def process(
else:
beam_index = None

# skip the corner case where the only constraint token is
# eos_token and the very first generated token is eos_token
if decoder_prompt_len == input_ids.shape[-1]:
continue

beam_hyp.add(
input_ids[batch_beam_idx].clone(),
next_score.item(),
beam_indices=beam_index,
decoder_prompt_len=decoder_prompt_len,
generated_len=cur_len - decoder_prompt_len,
)
else:
# add next predicted token since it is not eos_token
Expand Down Expand Up @@ -660,7 +652,7 @@ def process(

# Check if we are done so that we can save a pad step if all(done)
self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
next_scores[batch_idx].max().item(), cur_len
next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
)

return UserDict(
Expand Down Expand Up @@ -846,9 +838,8 @@ def finalize(
completes_constraint = self.check_completes_constraints(final_tokens.cpu().tolist())
if completes_constraint:
beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
beam_hyp.add(
final_tokens, final_score, beam_indices=beam_index, decoder_prompt_len=decoder_prompt_len
)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
ids_collect.append(beam_id)

# due to overly complex constraints or other factors, sometimes we can't gaurantee a successful
Expand All @@ -859,7 +850,8 @@ def finalize(
batch_beam_idx = batch_idx * self.num_beams + beam_id
final_score = final_beam_scores[batch_beam_idx].item()
final_tokens = input_ids[batch_beam_idx]
beam_hyp.add(final_tokens, final_score, decoder_prompt_len=decoder_prompt_len)
generated_len = final_tokens.shape[-1] - decoder_prompt_len
beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
if len(ids_collect) >= self.num_beam_hyps_to_keep:
break

Expand Down Expand Up @@ -956,12 +948,17 @@ def add(
hyp: torch.LongTensor,
sum_logprobs: float,
beam_indices: Optional[torch.LongTensor] = None,
decoder_prompt_len: Optional[int] = 0,
generated_len: Optional[int] = None,
):
"""
Add a new hypothesis to the list.
"""
score = sum_logprobs / ((hyp.shape[-1] - decoder_prompt_len) ** self.length_penalty)
if generated_len is not None:
score = sum_logprobs / (generated_len**self.length_penalty)
# This 'else' case exists for retrocompatibility
else:
score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)

if len(self) < self.num_beams or score > self.worst_score:
self.beams.append((score, hyp, beam_indices))
if len(self) > self.num_beams:
Expand All @@ -971,7 +968,7 @@ def add(
else:
self.worst_score = min(score, self.worst_score)

def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: Optional[int] = 0) -> bool:
"""
If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
one in the heap, then we are done with this sentence.
Expand All @@ -987,7 +984,7 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
# when `length_penalty` is positive. See the discussion below for more details.
# https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
elif self.early_stopping is False:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
# `"never"`: compute the best possible score, depending on the signal of `length_penalty`
Expand All @@ -996,9 +993,13 @@ def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
# abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
# its max this way
if self.length_penalty > 0.0:
highest_attainable_score = best_sum_logprobs / self.max_length**self.length_penalty
if self.max_length <= decoder_prompt_len:
raise ValueError("max_length is not larger than decoder prompt length")
highest_attainable_score = (
best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
)
# the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
else:
highest_attainable_score = best_sum_logprobs / cur_len**self.length_penalty
highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
ret = self.worst_score >= highest_attainable_score
return ret
6 changes: 1 addition & 5 deletions tests/generation/test_framework_agnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,11 +633,7 @@ def test_eos_token_id_int_and_list_beam_search(self):
"do_sample": False,
"num_beams": 3,
}
if is_pt:
expectation = 20
else:
# TODO (joao): fix me
expectation = 13
expectation = 13

tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
text = """Hello, my dog is cute and"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ def generate_step(pixel_values):

preds, scores = generate_step(pixel_values)

EXPECTED_SCORES = np.array([-0.64145195])
EXPECTED_SCORES = np.array([-0.5956343])
max_diff = np.amax(np.abs(scores - EXPECTED_SCORES))
self.assertLessEqual(max_diff, 1e-4)

Expand Down

0 comments on commit b31905d

Please sign in to comment.