Skip to content

Commit

Permalink
[llama-mm] Onboard Llama3.2 mm vision encoder
Browse files Browse the repository at this point in the history
Summary: Add llama3.2 mm vision encoder to examples/models.

We need to do a module swapping for TilePositionEmbedding to make sure
vision encoder is exportable.

Test Plan: Unit tests.

Reviewers:

Subscribers:

Tasks:

Tags:

ghstack-source-id: 4c2a30e6d6b5932a972c34778fea8b3152372e58
Pull Request resolved: #6653
  • Loading branch information
larryliu0820 committed Nov 5, 2024
1 parent 95ffb45 commit d2425d6
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 1 deletion.
1 change: 1 addition & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
"llama2": ("llama", "Llama2Model"),
"llama": ("llama", "Llama2Model"),
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
"lstm": ("lstm", "LSTMModel"),
"mobilebert": ("mobilebert", "MobileBertModelExample"),
"mv2": ("mobilenet_v2", "MV2Model"),
Expand Down
12 changes: 12 additions & 0 deletions examples/models/llama3_2_vision/vision_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import FlamingoVisionEncoderModel, VisionEncoderConfig

__all__ = [
"FlamingoVisionEncoderModel",
"VisionEncoderConfig",
]
85 changes: 85 additions & 0 deletions examples/models/llama3_2_vision/vision_encoder/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field

import torch

from executorch.examples.models.model_base import EagerModelBase
from executorch.extension.llm.modules._position_embeddings import (
replace_tile_positional_embedding,
)
from torchtune.models.flamingo._component_builders import flamingo_vision_encoder

max_seq_len = 8192
in_channels = 3
tile_size = 560
max_num_tiles = 4
# how many tokens per image generated by the vision encoder
tokens_per_image = 6404
# how many images to cache in the kv cache in cross attention
kv_cache_image_num = 1
# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention
encoder_max_seq_len = tokens_per_image * kv_cache_image_num


@dataclass
class VisionEncoderConfig:
patch_size: int = 14
num_heads: int = 16
clip_embed_dim: int = 1280
clip_num_layers: int = 32
clip_hidden_states: list[int] = field(default_factory=lambda: [3, 7, 15, 23, 30])
decoder_embed_dim: int = 4096
num_layers_projection: int = 8
tile_size: int = 560
max_num_tiles: int = 4
in_channels: int = 3


class FlamingoVisionEncoderModel(EagerModelBase):
def __init__(self, config: VisionEncoderConfig = VisionEncoderConfig()):
super().__init__()
self.config = config
self.model = flamingo_vision_encoder(
patch_size=config.patch_size,
num_heads=config.num_heads,
clip_embed_dim=config.clip_embed_dim,
clip_num_layers=config.clip_num_layers,
clip_hidden_states=config.clip_hidden_states,
decoder_embed_dim=config.decoder_embed_dim,
num_layers_projection=config.num_layers_projection,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
in_channels=config.in_channels,
)
self.image = torch.randn(
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
)
self.aspect_ratio = torch.tensor([[[1, 2]]])
self.sample_inputs = (
self.image,
self.aspect_ratio,
)

def get_eager_model(self, **kwargs):
self.model = replace_tile_positional_embedding(self.model)
return self.model

def get_example_inputs(self):
return self.sample_inputs

def get_dynamic_shapes(self):
dim = torch.export.Dim("num_tiles", min=1, max=self.config.max_num_tiles)
image_dynamic_dim = {
0: 1,
1: 1,
2: dim,
3: 3,
4: self.config.tile_size,
5: self.config.tile_size,
}
return (image_dynamic_dim, None)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
# Only test AOTI in this file
import os
import tempfile
import unittest

import torch

from executorch.examples.models.llama3_2_vision.vision_encoder import (
FlamingoVisionEncoderModel,
VisionEncoderConfig,
)
from torch._inductor.package import load_package, package_aoti


class FlamingoVisionEncoderTest(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_flamingo_vision_encoder(self) -> None:
model = FlamingoVisionEncoderModel(VisionEncoderConfig())
encoder = model.model
eager_res = encoder.forward(*model.get_example_inputs())

# AOTI
so = torch._export.aot_compile(
encoder,
model.get_example_inputs(),
options={"aot_inductor.package": True},
dynamic_shapes=model.get_dynamic_shapes(),
)
with tempfile.TemporaryDirectory() as tmpdir:
path = package_aoti(os.path.join(tmpdir, "vision_encoder.pt2"), so)
print(path)
encoder_aoti = load_package(path)

y = encoder_aoti(*model.get_example_inputs())

self.assertTrue(torch.allclose(y, eager_res))
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ addopts =
devtools/
# examples
examples/models/llama/tests
examples/models/llama3_2_vision/preprocess
examples/models/llama3_2_vision/preprocess/test
examples/models/llama3_2_vision/vision_encoder/test
# examples/models/llava/test TODO: enable this
# exir
exir/_serialize/test
Expand Down

0 comments on commit d2425d6

Please sign in to comment.