From d2425d6a49749d0dad36db27e4de8582c0b48bfe Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Mon, 4 Nov 2024 17:26:48 -0800 Subject: [PATCH] [llama-mm] Onboard Llama3.2 mm vision encoder Summary: Add llama3.2 mm vision encoder to examples/models. We need to do a module swapping for TilePositionEmbedding to make sure vision encoder is exportable. Test Plan: Unit tests. Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 4c2a30e6d6b5932a972c34778fea8b3152372e58 Pull Request resolved: https://github.com/pytorch/executorch/pull/6653 --- examples/models/__init__.py | 1 + .../vision_encoder/__init__.py | 12 +++ .../llama3_2_vision/vision_encoder/model.py | 85 +++++++++++++++++++ .../vision_encoder/test/__init__.py | 0 .../test/test_vision_encoder.py | 45 ++++++++++ pytest.ini | 3 +- 6 files changed, 145 insertions(+), 1 deletion(-) create mode 100644 examples/models/llama3_2_vision/vision_encoder/__init__.py create mode 100644 examples/models/llama3_2_vision/vision_encoder/model.py create mode 100644 examples/models/llama3_2_vision/vision_encoder/test/__init__.py create mode 100644 examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py diff --git a/examples/models/__init__.py b/examples/models/__init__.py index 80f95af89e..d3f2a74f4d 100644 --- a/examples/models/__init__.py +++ b/examples/models/__init__.py @@ -18,6 +18,7 @@ "emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"), "llama2": ("llama", "Llama2Model"), "llama": ("llama", "Llama2Model"), + "llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"), "lstm": ("lstm", "LSTMModel"), "mobilebert": ("mobilebert", "MobileBertModelExample"), "mv2": ("mobilenet_v2", "MV2Model"), diff --git a/examples/models/llama3_2_vision/vision_encoder/__init__.py b/examples/models/llama3_2_vision/vision_encoder/__init__.py new file mode 100644 index 0000000000..f08fb2c260 --- /dev/null +++ b/examples/models/llama3_2_vision/vision_encoder/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .model import FlamingoVisionEncoderModel, VisionEncoderConfig + +__all__ = [ + "FlamingoVisionEncoderModel", + "VisionEncoderConfig", +] diff --git a/examples/models/llama3_2_vision/vision_encoder/model.py b/examples/models/llama3_2_vision/vision_encoder/model.py new file mode 100644 index 0000000000..56a6733071 --- /dev/null +++ b/examples/models/llama3_2_vision/vision_encoder/model.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + +import torch + +from executorch.examples.models.model_base import EagerModelBase +from executorch.extension.llm.modules._position_embeddings import ( + replace_tile_positional_embedding, +) +from torchtune.models.flamingo._component_builders import flamingo_vision_encoder + +max_seq_len = 8192 +in_channels = 3 +tile_size = 560 +max_num_tiles = 4 +# how many tokens per image generated by the vision encoder +tokens_per_image = 6404 +# how many images to cache in the kv cache in cross attention +kv_cache_image_num = 1 +# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention +encoder_max_seq_len = tokens_per_image * kv_cache_image_num + + +@dataclass +class VisionEncoderConfig: + patch_size: int = 14 + num_heads: int = 16 + clip_embed_dim: int = 1280 + clip_num_layers: int = 32 + clip_hidden_states: list[int] = field(default_factory=lambda: [3, 7, 15, 23, 30]) + decoder_embed_dim: int = 4096 + num_layers_projection: int = 8 + tile_size: int = 560 + max_num_tiles: int = 4 + in_channels: int = 3 + + +class FlamingoVisionEncoderModel(EagerModelBase): + def __init__(self, config: VisionEncoderConfig = VisionEncoderConfig()): + super().__init__() + self.config = config + self.model = flamingo_vision_encoder( + patch_size=config.patch_size, + num_heads=config.num_heads, + clip_embed_dim=config.clip_embed_dim, + clip_num_layers=config.clip_num_layers, + clip_hidden_states=config.clip_hidden_states, + decoder_embed_dim=config.decoder_embed_dim, + num_layers_projection=config.num_layers_projection, + tile_size=config.tile_size, + max_num_tiles=config.max_num_tiles, + in_channels=config.in_channels, + ) + self.image = torch.randn( + 1, 1, 4, 3, self.config.tile_size, self.config.tile_size + ) + self.aspect_ratio = torch.tensor([[[1, 2]]]) + self.sample_inputs = ( + self.image, + self.aspect_ratio, + ) + + def get_eager_model(self, **kwargs): + self.model = replace_tile_positional_embedding(self.model) + return self.model + + def get_example_inputs(self): + return self.sample_inputs + + def get_dynamic_shapes(self): + dim = torch.export.Dim("num_tiles", min=1, max=self.config.max_num_tiles) + image_dynamic_dim = { + 0: 1, + 1: 1, + 2: dim, + 3: 3, + 4: self.config.tile_size, + 5: self.config.tile_size, + } + return (image_dynamic_dim, None) diff --git a/examples/models/llama3_2_vision/vision_encoder/test/__init__.py b/examples/models/llama3_2_vision/vision_encoder/test/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py b/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py new file mode 100644 index 0000000000..89800aa244 --- /dev/null +++ b/examples/models/llama3_2_vision/vision_encoder/test/test_vision_encoder.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh. +# Only test AOTI in this file +import os +import tempfile +import unittest + +import torch + +from executorch.examples.models.llama3_2_vision.vision_encoder import ( + FlamingoVisionEncoderModel, + VisionEncoderConfig, +) +from torch._inductor.package import load_package, package_aoti + + +class FlamingoVisionEncoderTest(unittest.TestCase): + def setUp(self) -> None: + super().setUp() + + def test_flamingo_vision_encoder(self) -> None: + model = FlamingoVisionEncoderModel(VisionEncoderConfig()) + encoder = model.model + eager_res = encoder.forward(*model.get_example_inputs()) + + # AOTI + so = torch._export.aot_compile( + encoder, + model.get_example_inputs(), + options={"aot_inductor.package": True}, + dynamic_shapes=model.get_dynamic_shapes(), + ) + with tempfile.TemporaryDirectory() as tmpdir: + path = package_aoti(os.path.join(tmpdir, "vision_encoder.pt2"), so) + print(path) + encoder_aoti = load_package(path) + + y = encoder_aoti(*model.get_example_inputs()) + + self.assertTrue(torch.allclose(y, eager_res)) diff --git a/pytest.ini b/pytest.ini index a5041504ae..25204b7dc7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -16,7 +16,8 @@ addopts = devtools/ # examples examples/models/llama/tests - examples/models/llama3_2_vision/preprocess + examples/models/llama3_2_vision/preprocess/test + examples/models/llama3_2_vision/vision_encoder/test # examples/models/llava/test TODO: enable this # exir exir/_serialize/test