diff --git a/.ci/scripts/gather_test_models.py b/.ci/scripts/gather_test_models.py index 5f4fe7ab41..e22e196567 100755 --- a/.ci/scripts/gather_test_models.py +++ b/.ci/scripts/gather_test_models.py @@ -24,6 +24,7 @@ "ic4": "linux.12xlarge", "resnet50": "linux.12xlarge", "llava": "linux.12xlarge", + "llama3_2_vision_encoder": "linux.12xlarge", # This one causes timeout on smaller runner, the root cause is unclear (T161064121) "dl3": "linux.12xlarge", "emformer_join": "linux.12xlarge", diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 80f95af89e..d3f2a74f4d 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -18,6 +18,7 @@ "emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"), "llama2": ("llama", "Llama2Model"), "llama": ("llama", "Llama2Model"), + "llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"), "lstm": ("lstm", "LSTMModel"), "mobilebert": ("mobilebert", "MobileBertModelExample"), "mv2": ("mobilenet_v2", "MV2Model"), diff --git a/examples/models/llama3_2_vision/vision_encoder/__init__.py b/examples/models/llama3_2_vision/vision_encoder/__init__.py new file mode 100644 index 0000000000..f08fb2c260 --- /dev/null +++ b/examples/models/llama3_2_vision/vision_encoder/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import FlamingoVisionEncoderModel, VisionEncoderConfig + +__all__ = [ + "FlamingoVisionEncoderModel", + "VisionEncoderConfig", +] diff --git a/examples/models/llama3_2_vision/vision_encoder/model.py b/examples/models/llama3_2_vision/vision_encoder/model.py new file mode 100644 index 0000000000..1cfe74db76 --- /dev/null +++ b/examples/models/llama3_2_vision/vision_encoder/model.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from typing import Optional + +import torch + +from executorch.examples.models.model_base import EagerModelBase +from executorch.extension.llm.modules._position_embeddings import ( + replace_tile_positional_embedding, + replace_tiled_token_positional_embedding, +) +from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder + +max_seq_len = 8192 +in_channels = 3 +tile_size = 560 +max_num_tiles = 4 +# how many tokens per image generated by the vision encoder +tokens_per_image = 6404 +# how many images to cache in the kv cache in cross attention +kv_cache_image_num = 1 +# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention +encoder_max_seq_len = tokens_per_image * kv_cache_image_num + + +@dataclass +class VisionEncoderConfig: + patch_size: int = 14 + num_heads: int = 16 + clip_embed_dim: int = 1280 + clip_num_layers: int = 32 + clip_hidden_states: list[int] = field(default_factory=lambda: [3, 7, 15, 23, 30]) + decoder_embed_dim: int = 4096 + num_layers_projection: int = 8 + tile_size: int = 560 + max_num_tiles: int = 4 + in_channels: int = 3 + + +class FlamingoVisionEncoderModel(EagerModelBase): + def __init__(self, config: Optional[VisionEncoderConfig] = None): + super().__init__() + if config is None: + config = VisionEncoderConfig() + self.config = config + self.model = llama3_2_vision_encoder( + patch_size=config.patch_size, + num_heads=config.num_heads, + clip_embed_dim=config.clip_embed_dim, + clip_num_layers=config.clip_num_layers, + clip_hidden_states=config.clip_hidden_states, + decoder_embed_dim=config.decoder_embed_dim, + num_layers_projection=config.num_layers_projection, + tile_size=config.tile_size, + max_num_tiles=config.max_num_tiles, + in_channels=config.in_channels, + ) + self.model = replace_tile_positional_embedding(self.model) + self.model = replace_tiled_token_positional_embedding(self.model) + self.image = torch.randn( + 1, 1, 4, 3, self.config.tile_size, self.config.tile_size + ) + self.aspect_ratio = torch.tensor([[[1, 2]]]) + self.sample_inputs = ( + self.image, + self.aspect_ratio, + ) + + def get_eager_model(self, **kwargs): + return self.model + + def get_example_inputs(self): + return self.sample_inputs + + def get_dynamic_shapes(self): + dim = torch.export.Dim("num_tiles", min=1, max=self.config.max_num_tiles) + image_dynamic_dim = { + 0: 1, + 1: 1, + 2: dim, + 3: 3, + 4: self.config.tile_size, + 5: self.config.tile_size, + } + return (image_dynamic_dim, None) diff --git a/examples/models/llama3_2_vision/vision_encoder/test/__init__.py b/examples/models/llama3_2_vision/vision_encoder/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py b/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py new file mode 100644 index 0000000000..c2f1e77cee --- /dev/null +++ b/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh. +# Only test AOTI in this file +import os +import tempfile +import unittest + +import torch + +from executorch.examples.models.llama3_2_vision.vision_encoder import ( + FlamingoVisionEncoderModel, +) +from torch.testing import assert_close + + +class FlamingoVisionEncoderTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def test_flamingo_vision_encoder(self) -> None: + model = FlamingoVisionEncoderModel() + encoder = model.model + eager_res = encoder.forward(*model.get_example_inputs()) + + # AOTI + ep = torch.export.export( + encoder, + model.get_example_inputs(), + dynamic_shapes=model.get_dynamic_shapes(), + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = torch._inductor.aoti_compile_and_package( + ep, + model.get_example_inputs(), + package_path=os.path.join(tmpdir, "vision_encoder.pt2"), + ) + print(path) + encoder_aoti = torch._inductor.aoti_load_package(path) + + y = encoder_aoti(*model.get_example_inputs()) + assert_close(y, eager_res, rtol=1e-4, atol=1e-4) diff --git a/pytest.ini b/pytest.ini index a5041504ae..25204b7dc7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -16,7 +16,8 @@ addopts = devtools/ # examples examples/models/llama/tests - examples/models/llama3_2_vision/preprocess + examples/models/llama3_2_vision/preprocess/test + examples/models/llama3_2_vision/vision_encoder/test # examples/models/llava/test TODO: enable this # exir exir/_serialize/test