diff --git a/extension/llm/modules/_position_embeddings.py b/extension/llm/modules/_position_embeddings.py index adcd90c999..0c6a4f6ed9 100644 --- a/extension/llm/modules/_position_embeddings.py +++ b/extension/llm/modules/_position_embeddings.py @@ -188,10 +188,13 @@ def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: torch._check(n_tiles_w >= 1) torch._check(n_tiles_h <= self.max_num_tiles) torch._check(n_tiles_w <= self.max_num_tiles) + # TODO: Remove this once pytorch/pytorch#120288 is fixed padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] - # Add pos encoding to the non padded tiles. + # We need to do a clone here in order to make this model export + # friendly as the reshape is collapsing dim 0 and dim 1 into a + # single dim. pos_embed = pos_embed.clone() pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim)