Skip to content

Commit

Permalink
Update on "[llama-mm] Add export-friendly tile position embedding"
Browse files Browse the repository at this point in the history
Summary:

Before we make a decision on whether torchtune takes this
export-friendly version of `TilePositionEmbedding`, we put it under
`extension/llm` so that users can start to use it.

Added unit tests to make sure the behavior is the same as the reference
implementation in torchtune and export/AOTI/ET all working properly.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]
  • Loading branch information
larryliu0820 committed Nov 5, 2024
1 parent df66f00 commit b0d9e3f
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion extension/llm/modules/_position_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,13 @@ def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor:
torch._check(n_tiles_w >= 1)
torch._check(n_tiles_h <= self.max_num_tiles)
torch._check(n_tiles_w <= self.max_num_tiles)
# TODO: Remove this once pytorch/pytorch#120288 is fixed
padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1))
pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :]

# Add pos encoding to the non padded tiles.
# We need to do a clone here in order to make this model export
# friendly as the reshape is collapsing dim 0 and dim 1 into a
# single dim.
pos_embed = pos_embed.clone()
pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim)

Expand Down

0 comments on commit b0d9e3f

Please sign in to comment.