Skip to content

Commit

Permalink
[llama-mm] Onboard torchtune vision encoder to ExecuTorch/AOTI (#6807)
Browse files Browse the repository at this point in the history
[llama-mm] Onboard torchtune vision encoder to ExecuTorch

Summary: As titled. This PR adds `llama3_2_vision_encoder` to
`examples/models/llama3_2_vision/vision_encoder` and add CI jobs.

Test Plan: Rely on newly added CI jobs

Reviewers:

Subscribers:

Tasks:

Tags:

[ghstack-poisoned]

(cherry picked from commit a6cfb03)
  • Loading branch information
larryliu0820 authored and pytorchbot committed Nov 13, 2024
1 parent ee32ea3 commit e714576
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 1 deletion.
1 change: 1 addition & 0 deletions .ci/scripts/gather_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"ic4": "linux.12xlarge",
"resnet50": "linux.12xlarge",
"llava": "linux.12xlarge",
"llama3_2_vision_encoder": "linux.12xlarge",
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
"dl3": "linux.12xlarge",
"emformer_join": "linux.12xlarge",
Expand Down
1 change: 1 addition & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
"llama2": ("llama", "Llama2Model"),
"llama": ("llama", "Llama2Model"),
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
"lstm": ("lstm", "LSTMModel"),
"mobilebert": ("mobilebert", "MobileBertModelExample"),
"mv2": ("mobilenet_v2", "MV2Model"),
Expand Down
12 changes: 12 additions & 0 deletions examples/models/llama3_2_vision/vision_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import FlamingoVisionEncoderModel, VisionEncoderConfig

__all__ = [
"FlamingoVisionEncoderModel",
"VisionEncoderConfig",
]
90 changes: 90 additions & 0 deletions examples/models/llama3_2_vision/vision_encoder/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Optional

import torch

from executorch.examples.models.model_base import EagerModelBase
from executorch.extension.llm.modules._position_embeddings import (
replace_tile_positional_embedding,
replace_tiled_token_positional_embedding,
)
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder

max_seq_len = 8192
in_channels = 3
tile_size = 560
max_num_tiles = 4
# how many tokens per image generated by the vision encoder
tokens_per_image = 6404
# how many images to cache in the kv cache in cross attention
kv_cache_image_num = 1
# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention
encoder_max_seq_len = tokens_per_image * kv_cache_image_num


@dataclass
class VisionEncoderConfig:
patch_size: int = 14
num_heads: int = 16
clip_embed_dim: int = 1280
clip_num_layers: int = 32
clip_hidden_states: list[int] = field(default_factory=lambda: [3, 7, 15, 23, 30])
decoder_embed_dim: int = 4096
num_layers_projection: int = 8
tile_size: int = 560
max_num_tiles: int = 4
in_channels: int = 3


class FlamingoVisionEncoderModel(EagerModelBase):
def __init__(self, config: Optional[VisionEncoderConfig] = None):
super().__init__()
if config is None:
config = VisionEncoderConfig()
self.config = config
self.model = llama3_2_vision_encoder(
patch_size=config.patch_size,
num_heads=config.num_heads,
clip_embed_dim=config.clip_embed_dim,
clip_num_layers=config.clip_num_layers,
clip_hidden_states=config.clip_hidden_states,
decoder_embed_dim=config.decoder_embed_dim,
num_layers_projection=config.num_layers_projection,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
in_channels=config.in_channels,
)
self.model = replace_tile_positional_embedding(self.model)
self.model = replace_tiled_token_positional_embedding(self.model)
self.image = torch.randn(
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
)
self.aspect_ratio = torch.tensor([[[1, 2]]])
self.sample_inputs = (
self.image,
self.aspect_ratio,
)

def get_eager_model(self, **kwargs):
return self.model

def get_example_inputs(self):
return self.sample_inputs

def get_dynamic_shapes(self):
dim = torch.export.Dim("num_tiles", min=1, max=self.config.max_num_tiles)
image_dynamic_dim = {
0: 1,
1: 1,
2: dim,
3: 3,
4: self.config.tile_size,
5: self.config.tile_size,
}
return (image_dynamic_dim, None)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
# Only test AOTI in this file
import os
import tempfile
import unittest

import torch

from executorch.examples.models.llama3_2_vision.vision_encoder import (
FlamingoVisionEncoderModel,
)
from torch.testing import assert_close


class FlamingoVisionEncoderTest(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_flamingo_vision_encoder(self) -> None:
model = FlamingoVisionEncoderModel()
encoder = model.model
eager_res = encoder.forward(*model.get_example_inputs())

# AOTI
ep = torch.export.export(
encoder,
model.get_example_inputs(),
dynamic_shapes=model.get_dynamic_shapes(),
)
with tempfile.TemporaryDirectory() as tmpdir:
path = torch._inductor.aoti_compile_and_package(
ep,
model.get_example_inputs(),
package_path=os.path.join(tmpdir, "vision_encoder.pt2"),
)
print(path)
encoder_aoti = torch._inductor.aoti_load_package(path)

y = encoder_aoti(*model.get_example_inputs())
assert_close(y, eager_res, rtol=1e-4, atol=1e-4)
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ addopts =
devtools/
# examples
examples/models/llama/tests
examples/models/llama3_2_vision/preprocess
examples/models/llama3_2_vision/preprocess/test
examples/models/llama3_2_vision/vision_encoder/test
# examples/models/llava/test TODO: enable this
# exir
exir/_serialize/test
Expand Down

0 comments on commit e714576

Please sign in to comment.