Skip to content

Commit

Permalink
Add position ids in forward pass to opt model (huggingface#33121)
Browse files Browse the repository at this point in the history
* start working on adding position ids

* add docs

* Refactor modeling_biogpt.py and modeling_opt.py for code consistency

* fix 2 PR comments

* move position_ids to end of args

* remove trailing white space

* add comment with TODO

* bug fix gradient checkpointing

* fixup

* missed on position_ids

* remove _attention_to_position_ids and refactor embedding class

* remove redundent code

---------

Co-authored-by: Avishai Elmakies <[email protected]>
  • Loading branch information
avishaiElmakies and Avishai Elmakies authored Oct 7, 2024
1 parent 1bd604d commit 4953ddf
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
3 changes: 2 additions & 1 deletion src/transformers/models/biogpt/modeling_biogpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@
_CONFIG_FOR_DOC = "BioGptConfig"


# Copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt
# copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt
# TODO @ArthurZucker bring copied from back
class BioGptLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
Expand Down
63 changes: 53 additions & 10 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,21 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
def forward(
self,
attention_mask: torch.LongTensor,
past_key_values_length: int = 0,
position_ids: Optional[torch.LongTensor] = None,
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()

# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1

# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]
if position_ids is None:
position_ids = torch.cumsum(attention_mask, dim=1)
position_ids = (position_ids * attention_mask - 1).long()
# cut positions if `past_key_values_length` is > 0
position_ids = position_ids[:, past_key_values_length:]

return super().forward(positions + self.offset)
return super().forward(position_ids + self.offset)


class OPTAttention(nn.Module):
Expand Down Expand Up @@ -128,6 +132,8 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
# isn't needed in normal attention, but needed in flash attention so to keep the signature same
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

Expand Down Expand Up @@ -265,6 +271,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

Expand Down Expand Up @@ -346,6 +353,7 @@ def forward(
value_states,
attention_mask,
query_length,
position_ids=position_ids,
dropout=attn_dropout,
is_causal=self.is_causal,
use_top_left_mask=self._flash_attn_uses_top_left_mask,
Expand Down Expand Up @@ -392,6 +400,7 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
position_ids: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Expand Down Expand Up @@ -419,6 +428,7 @@ def forward(
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
past_key_value=past_key_value,
position_ids=position_ids,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
Expand Down Expand Up @@ -561,6 +571,11 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
"""


Expand Down Expand Up @@ -627,6 +642,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -674,6 +690,11 @@ def forward(
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand Down Expand Up @@ -724,7 +745,13 @@ def forward(
attention_mask, input_shape, inputs_embeds, past_key_values_length
)

pos_embeds = self.embed_positions(attention_mask, past_key_values_length)
if position_ids is None:
position_ids = torch.cumsum(attention_mask, dim=1)
position_ids = (position_ids * attention_mask - 1).long()
# cut positions if `past_key_values_length` is > 0
position_ids = position_ids[:, past_key_values_length:]

pos_embeds = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)

if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)
Expand Down Expand Up @@ -773,11 +800,13 @@ def forward(
None,
output_attentions,
use_cache,
position_ids,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=causal_attention_mask,
position_ids=position_ids,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
past_key_value=past_key_value,
output_attentions=output_attentions,
Expand Down Expand Up @@ -851,6 +880,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
Expand All @@ -863,6 +893,7 @@ def forward(
decoder_outputs = self.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -927,6 +958,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
Expand Down Expand Up @@ -982,6 +1014,11 @@ def forward(
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`. for padding use -1.
[What are position IDs?](../glossary#position-ids)
Returns:
Expand Down Expand Up @@ -1012,6 +1049,7 @@ def forward(
outputs = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
Expand Down Expand Up @@ -1047,7 +1085,7 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, position_ids=None, **kwargs
):
if past_key_values is not None:
past_length = past_key_values[0][0].shape[2]
Expand All @@ -1072,6 +1110,7 @@ def prepare_inputs_for_generation(
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"position_ids": position_ids,
}
)
return model_inputs
Expand Down Expand Up @@ -1131,6 +1170,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand All @@ -1144,6 +1184,7 @@ def forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down Expand Up @@ -1248,6 +1289,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
position_ids: Optional[torch.LongTensor] = None,
) -> Union[Tuple, QuestionAnsweringModelOutput]:
r"""
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Expand Down Expand Up @@ -1298,6 +1340,7 @@ def forward(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down

0 comments on commit 4953ddf

Please sign in to comment.