Skip to content

Commit

Permalink
Fix: num_logits_to_keep in composite models (huggingface#33168)
Browse files Browse the repository at this point in the history
* fix

* paligemma
  • Loading branch information
zucchini-nlp authored Sep 3, 2024
1 parent 5663026 commit 7ed9789
Show file tree
Hide file tree
Showing 24 changed files with 130 additions and 36 deletions.
6 changes: 4 additions & 2 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1169,14 +1169,16 @@ def prepare_inputs_for_generation(
batch_size=batch_size,
)

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
6 changes: 4 additions & 2 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,7 +1382,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1433,14 +1433,16 @@ def prepare_inputs_for_generation(
batch_size=batch_size,
)

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
6 changes: 4 additions & 2 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -1131,7 +1131,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1181,14 +1181,16 @@ def prepare_inputs_for_generation(
batch_size=batch_size,
)

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/models/gemma2/modeling_gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1148,14 +1148,17 @@ def prepare_inputs_for_generation(
cache_position=cache_position,
batch_size=batch_size,
)

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/idefics2/modeling_idefics2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1643,7 +1643,7 @@ def prepare_inputs_for_generation(
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
past_length = 0
Expand Down Expand Up @@ -1687,6 +1687,9 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

image_hidden_states = kwargs.get("image_hidden_states", None)
if image_hidden_states is not None:
pixel_values = None
Expand All @@ -1703,7 +1706,6 @@ def prepare_inputs_for_generation(
"pixel_values": pixel_values,
"pixel_attention_mask": pixel_attention_mask,
"image_hidden_states": image_hidden_states,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/jetmoe/modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ def prepare_inputs_for_generation(
output_router_logits=False,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand All @@ -1373,6 +1373,9 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
Expand All @@ -1381,7 +1384,6 @@ def prepare_inputs_for_generation(
"use_cache": use_cache,
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1244,7 +1244,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1295,14 +1295,16 @@ def prepare_inputs_for_generation(
batch_size=batch_size,
)

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
10 changes: 10 additions & 0 deletions src/transformers/models/llava/modeling_llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, LlavaCausalLMOutputWithPast]:
r"""
Args:
Expand All @@ -385,6 +386,12 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
Expand Down Expand Up @@ -518,6 +525,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)

logits = outputs[0]
Expand Down Expand Up @@ -558,6 +566,7 @@ def prepare_inputs_for_generation(
pixel_values=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
# Trigger the new behavior if we have more than image embeddings seq length tokens for images
Expand All @@ -572,6 +581,7 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
**kwargs,
)

Expand Down
9 changes: 9 additions & 0 deletions src/transformers/models/llava_next/modeling_llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, LlavaNextCausalLMOutputWithPast]:
r"""
Args:
Expand All @@ -729,6 +730,11 @@ def forward(
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
Expand Down Expand Up @@ -890,6 +896,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
)

logits = outputs[0]
Expand Down Expand Up @@ -931,6 +938,7 @@ def prepare_inputs_for_generation(
image_sizes=None,
attention_mask=None,
cache_position=None,
num_logits_to_keep=None,
**kwargs,
):
legacy_processing = (
Expand All @@ -944,6 +952,7 @@ def prepare_inputs_for_generation(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
cache_position=cache_position,
num_logits_to_keep=num_logits_to_keep,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
num_logits_to_keep: int = 0,
) -> Union[Tuple, LlavaNextVideoCausalLMOutputWithPast]:
r"""
Args:
Expand All @@ -778,6 +779,10 @@ def forward(
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Expand Down Expand Up @@ -973,6 +978,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
num_logits_to_keep=num_logits_to_keep,
)

logits = outputs[0]
Expand Down Expand Up @@ -1014,6 +1020,7 @@ def prepare_inputs_for_generation(
pixel_values_videos=None,
image_sizes=None,
attention_mask=None,
num_logits_to_keep=None,
**kwargs,
):
if past_key_values is not None:
Expand Down Expand Up @@ -1057,6 +1064,9 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids}

if "num_logits_to_keep" != None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1087,7 +1087,7 @@ def prepare_inputs_for_generation(
cache_position=None,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand Down Expand Up @@ -1115,14 +1115,16 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
6 changes: 4 additions & 2 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,7 +1348,7 @@ def prepare_inputs_for_generation(
output_router_logits=False,
position_ids=None,
use_cache=True,
num_logits_to_keep=0,
num_logits_to_keep=None,
**kwargs,
):
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
Expand All @@ -1373,6 +1373,9 @@ def prepare_inputs_for_generation(
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases

if num_logits_to_keep is not None:
model_inputs["num_logits_to_keep"] = num_logits_to_keep

model_inputs.update(
{
"position_ids": position_ids,
Expand All @@ -1381,7 +1384,6 @@ def prepare_inputs_for_generation(
"use_cache": use_cache,
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
"num_logits_to_keep": num_logits_to_keep,
}
)
return model_inputs
Expand Down
Loading

0 comments on commit 7ed9789

Please sign in to comment.