diff --git a/extension/llm/modules/README.md b/extension/llm/modules/README.md new file mode 100644 index 0000000000..3694f8b155 --- /dev/null +++ b/extension/llm/modules/README.md @@ -0,0 +1,14 @@ +## Export Friendly Modules + +Modules in this directory are: +* Extending `torch.nn.Module`. +* Guranteed to work out of the box with `torch.export.export()` and `torch.aot_compile()`. +* Guranteed to be able to work with ExecuTorch. + +All modules should be covered by unit tests to make sure they are: +1. giving the same output as the reference implementation in PyTorch or torchtune +2. export friendly +3. AOTI friendly +4. ExecuTorch friendly + +Notice that these modules are subject to change (may upstream to torchtune) so proceed with caution. diff --git a/extension/llm/modules/__init__.py b/extension/llm/modules/__init__.py new file mode 100644 index 0000000000..38245bf935 --- /dev/null +++ b/extension/llm/modules/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from ._position_embeddings import ( + replace_tile_positional_embedding, + TilePositionalEmbedding, +) + +__all__ = [ + "TilePositionalEmbedding", + "replace_tile_positional_embedding", +] diff --git a/extension/llm/modules/_position_embeddings.py b/extension/llm/modules/_position_embeddings.py new file mode 100644 index 0000000000..0c6a4f6ed9 --- /dev/null +++ b/extension/llm/modules/_position_embeddings.py @@ -0,0 +1,243 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# An torch.export() friendly version of torchtune's positional embeddings. +# Added torch._check() to make sure guards on symints are enforced. +# See https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/_position_embeddings.py + +import logging +from typing import Any, Dict, Tuple + +import torch +import torch.nn.functional as F +from torch import nn + +FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s" +logging.basicConfig(level=logging.INFO, format=FORMAT) + + +class TilePositionalEmbedding(nn.Module): + """ + Positional embedding for tiles, different for every tile, same for every token within a tile. + + Notice that tile is different from patch (token). For details, please check the documentation of + :class:`torchtune.modules.vision_transformer.VisionTransformer`. + + Args: + max_num_tiles (int): The maximum number of tiles an image can be divided into. + embed_dim (int): The dimensionality of each tile embedding. + """ + + def __init__( + self, + max_num_tiles: int, + embed_dim: int, + ): + super().__init__() + self.max_num_tiles = max_num_tiles + self.embed_dim = embed_dim + + scale = embed_dim**-0.5 + self.embedding = nn.Parameter( + scale * torch.randn(max_num_tiles, max_num_tiles, 1, embed_dim) + ) + self.gate = nn.Parameter(torch.zeros(1)) + + # Register load hook to interpolate positional embeddings + self._register_load_state_dict_pre_hook(self._load_state_dict_hook) + + # TODO: Switch to public method after 2.5 is stable + @torch.no_grad() + def _load_state_dict_hook( + self, + state_dict: Dict[str, Any], + prefix: str, + *args: Tuple[Any], + **kwargs: Dict[str, Any], + ): + """ + Interpolates positional embeddings to accomodate different number of tiles, + in case the model was instantiated with different + settings than the one you are loading the state dict from. + + For more info, check self._dynamic_resize function. + + Args: + state_dict (Dict[str, Any]): The state dict to load. + prefix (str): The prefix of the state dict. + *args (Tuple[Any]): Additional positional arguments. + **kwargs (Dict[str, Any]): Additional keyword arguments. + + Raises: + ValueError: if the shape of the loaded embedding is not compatible with the current embedding. + ValueError: if max_num_tiles_x, max_num_tiles_y are not equal. + ValueError: if after interpolation, the shape of the loaded embedding is not compatible with the current embedding. + """ + + embedding = state_dict.get(prefix + "embedding") + + if embedding is not None: + + # ckpt pos emb + ( + tgt_max_num_tiles_x, + tgt_max_num_tiles_y, + tgt_num_tokens, + tgt_emb, + ) = self.embedding.shape + + # instantiated pos emb + ( + inpt_max_num_tiles_x, + inpt_max_num_tiles_y, + inpt_num_tokens, + inpt_emb, + ) = state_dict[prefix + "embedding"].shape + + # sanity check + if inpt_num_tokens != tgt_num_tokens or inpt_emb != tgt_emb: + raise ValueError( + "Expected embedding shape to be (..., num_tokens, tgt_emb) to match" + f" but found shapes {self.embedding.shape} and {state_dict[prefix + 'embedding'].shape}" + ) + + if inpt_max_num_tiles_x != inpt_max_num_tiles_y: + raise ValueError( + "Expected max_num_tiles_x, max_num_tiles_y to be equal but found, but found" + f"(max_num_tiles_x, max_num_tiles_y, 1, embed_dim) = {self.embedding.shape}" + ) + + # resize ckpt to match instantiated shape + embedding_new = self._resize_position_embedding( + embedding, tgt_max_num_tiles=tgt_max_num_tiles_x + ) + + # update state dict + state_dict[prefix + "embedding"] = embedding_new + if embedding_new.shape != self.embedding.shape: + raise ValueError( + "Expected embedding shape and embedding_new.shape to match" + f" but found shapes {self.embedding.shape} and {embedding_new.shape}" + ) + + @staticmethod + def _resize_position_embedding( + embedding: torch.Tensor, tgt_max_num_tiles: int + ) -> torch.Tensor: + """ + Interpolates positional embeddings to accomodate a different max_num_tiles. These + are the only dimensions that changes during interpolation. + + Args: + embedding (torch.Tensor): torch.Tensor with shape (max_num_tiles, max_num_tiles, 1, embed_dim + tgt_max_num_tiles (int): The number of tiles to resize to. + + Returns: + torch.Tensor: The resized embedding. + + Example: + >>> import torch + >>> # create dummy embedding + >>> embedding = torch.arange(2*2*2*2).reshape(2, 2, 2, 2).float() + >>> resized_embed = _dynamic_resize(embedding, tgt_max_num_tiles=1) + >>> print(resized_embed.shape) + >>> torch.Size([1, 1, 2, 2]) + """ + # set max_num_tiles to the last dimension + embedding = embedding.permute(2, 3, 0, 1) + + embedding = F.interpolate( + embedding, + size=(tgt_max_num_tiles, tgt_max_num_tiles), + mode="bilinear", + align_corners=True, + ) + # permute to the original shape + embedding = embedding.permute(2, 3, 0, 1) + return embedding + + def forward(self, x: torch.Tensor, aspect_ratio: torch.Tensor) -> torch.Tensor: + """ + args: + x (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, n_tiles, n_tokens, embed_dim). + aspect_ratio (torch.Tensor): torch.Tensor with shape (bsz * n_imgs, 2), + representing the aspect ratio of the image before tile-cropping, e.g. (2,1). + returns: + torch.Tensor: The input tensor with added positional embeddings. + """ + bsz_and_n_imgs, n_tiles, n_tokens, embed_dim = x.shape + torch._check(n_tiles <= self.max_num_tiles) + + for batch_idx, (n_tiles_h, n_tiles_w) in enumerate(aspect_ratio): + # When we batch images, all are padded to the same amount of tiles. + # The aspect_ratio lets us know the non padded tiles for each image. + # We only add positional encoding to those. + n_tiles_h = n_tiles_h.item() + n_tiles_w = n_tiles_w.item() + + n_non_padded_tiles = int(n_tiles_h * n_tiles_w) + + # We get only the positional encoding for non padded tiles, + # i.e. n_tiles_h, n_tiles_w. + torch._check_is_size(n_tiles_h) + torch._check_is_size(n_tiles_w) + torch._check(n_tiles_h >= 1) + torch._check(n_tiles_w >= 1) + torch._check(n_tiles_h <= self.max_num_tiles) + torch._check(n_tiles_w <= self.max_num_tiles) + # TODO: Remove this once pytorch/pytorch#120288 is fixed + padded_embedding = F.pad(self.embedding, (0, 0, 0, 0, 0, 1, 0, 1)) + pos_embed = padded_embedding[:n_tiles_h, :n_tiles_w, :, :] + + # We need to do a clone here in order to make this model export + # friendly as the reshape is collapsing dim 0 and dim 1 into a + # single dim. + pos_embed = pos_embed.clone() + pos_embed = pos_embed.reshape(n_non_padded_tiles, 1, self.embed_dim) + + x = F.pad(x, (0, 0, 0, 0, 0, 1, 0, 0)) + torch._check_is_size(n_non_padded_tiles) + torch._check(n_non_padded_tiles < x.size(1)) + x[batch_idx, :n_non_padded_tiles, :, :] += pos_embed * self.gate.tanh() + x = x[:, :n_tiles, :, :] + + return x + + +def replace_tile_positional_embedding(model: nn.Module) -> nn.Module: + """ + Replace the tile positional embedding from torchtune with an export-friendly one. + Recursively searches the submodules of the model and replaces the tile positional embedding if found. + Args: + model (nn.Module): The model to replace the tile positional embedding in. + + Returns: + nn.Module: The model after replacing the tile positional embedding. + + """ + from torchtune.models.clip._position_embeddings import ( + TilePositionalEmbedding as TuneTilePositionalEmbedding, + ) + + for name, module in model.named_children(): + if isinstance(module, TuneTilePositionalEmbedding): + logging.info( + f"Replacing tile positional embedding in {name} with export-friendly one." + ) + max_num_tiles, _, _, embed_dim = module.embedding.shape + mod = TilePositionalEmbedding( + max_num_tiles=max_num_tiles, + embed_dim=embed_dim, + ) + mod.load_state_dict(module.state_dict()) + setattr( + model, + name, + mod, + ) + else: + replace_tile_positional_embedding(module) + return model diff --git a/extension/llm/modules/test/__init__.py b/extension/llm/modules/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/extension/llm/modules/test/test_position_embeddings.py b/extension/llm/modules/test/test_position_embeddings.py new file mode 100644 index 0000000000..cf4e7e7f05 --- /dev/null +++ b/extension/llm/modules/test/test_position_embeddings.py @@ -0,0 +1,118 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import tempfile +import unittest + +import torch +from executorch.exir import EdgeCompileConfig, to_edge +from executorch.extension.llm.modules import ( + replace_tile_positional_embedding, + TilePositionalEmbedding, +) +from executorch.runtime import Runtime +from torch._inductor.package import load_package, package_aoti +from torchtune.models.clip import TilePositionalEmbedding as TuneTilePositionalEmbedding + + +class TilePositionalEmbeddingTest(unittest.TestCase): + def setUp(self): + super().setUp() + self.tpe = TilePositionalEmbedding(4, 1280) + self.ref_tpe = TuneTilePositionalEmbedding(4, 1280) + self.x = torch.randn(1, 4, 1600, 1280) + self.aspect_ratio = torch.tensor([[1, 1]]) + num_tiles_dim = torch.export.Dim("num_tiles", min=1, max=4) + num_tokens = torch.export.Dim("num_tokens", min=1, max=1600) + + self.dynamic_shape = { + 0: 1, # batch + 1: num_tiles_dim, # num tiles + 2: num_tokens, # num tokens + 3: 1280, # embedding dim + } + + def test_tile_positional_embedding_smoke(self): + y = self.tpe(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_export(self): + + tpe_ep = torch.export.export( + self.tpe, + (self.x, self.aspect_ratio), + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + + y = tpe_ep.module()(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_aoti(self): + so = torch._export.aot_compile( + self.tpe, + args=(self.x, self.aspect_ratio), + options={"aot_inductor.package": True}, + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = package_aoti(os.path.join(tmpdir, "tpe.pt2"), so) + tpe_aoti = load_package(path) + + y = tpe_aoti(self.x, self.aspect_ratio) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y, ref_y)) + + def test_tile_positional_embedding_et(self): + tpe_ep = torch.export.export( + self.tpe, + (self.x, self.aspect_ratio), + dynamic_shapes=( + self.dynamic_shape, + None, + ), # assuming aspect ratio is static + ) + et_program = to_edge( + tpe_ep, + compile_config=EdgeCompileConfig( + _core_aten_ops_exception_list=[ + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_scalar.default, + torch.ops.aten._local_scalar_dense.default, + ] + ), + ).to_executorch() + runtime = Runtime.get() + program = runtime.load_program(et_program.buffer) + method = program.load_method("forward") + y = method.execute((self.x, self.aspect_ratio)) + ref_y = self.ref_tpe(self.x, self.aspect_ratio) + + self.assertTrue(torch.allclose(y[0], ref_y)) + + def test_replace_tile_positional_embedding(self): + class Module(torch.nn.Module): + def __init__(self): + super().__init__() + self.tpe = TuneTilePositionalEmbedding(4, 1280) + + def forward(self, x, aspect_ratio): + return self.tpe(x, aspect_ratio) + + m = Module() + m = replace_tile_positional_embedding(m) + self.assertTrue(isinstance(m.tpe, TilePositionalEmbedding)) diff --git a/pytest.ini b/pytest.ini index 3666c9c879..a5041504ae 100644 --- a/pytest.ini +++ b/pytest.ini @@ -38,6 +38,7 @@ addopts = # backends/xnnpack backends/xnnpack/test # extension/ + extension/llm/modules/test extension/pybindings/test # Runtime runtime