Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[llama-mm] Onboard torchtune vision encoder to ExecuTorch/AOTI #6807

Merged
merged 2 commits into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/scripts/gather_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"ic4": "linux.12xlarge",
"resnet50": "linux.12xlarge",
"llava": "linux.12xlarge",
"llama3_2_vision_encoder": "linux.12xlarge",
# This one causes timeout on smaller runner, the root cause is unclear (T161064121)
"dl3": "linux.12xlarge",
"emformer_join": "linux.12xlarge",
Expand Down
1 change: 1 addition & 0 deletions examples/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"emformer_join": ("emformer_rnnt", "EmformerRnntJoinerModel"),
"llama2": ("llama", "Llama2Model"),
"llama": ("llama", "Llama2Model"),
"llama3_2_vision_encoder": ("llama3_2_vision", "FlamingoVisionEncoderModel"),
"lstm": ("lstm", "LSTMModel"),
"mobilebert": ("mobilebert", "MobileBertModelExample"),
"mv2": ("mobilenet_v2", "MV2Model"),
Expand Down
12 changes: 12 additions & 0 deletions examples/models/llama3_2_vision/vision_encoder/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .model import FlamingoVisionEncoderModel, VisionEncoderConfig

__all__ = [
"FlamingoVisionEncoderModel",
"VisionEncoderConfig",
]
90 changes: 90 additions & 0 deletions examples/models/llama3_2_vision/vision_encoder/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Optional

import torch

from executorch.examples.models.model_base import EagerModelBase
from executorch.extension.llm.modules._position_embeddings import (
replace_tile_positional_embedding,
replace_tiled_token_positional_embedding,
)
from torchtune.models.llama3_2_vision._component_builders import llama3_2_vision_encoder

max_seq_len = 8192
in_channels = 3
tile_size = 560
max_num_tiles = 4
# how many tokens per image generated by the vision encoder
tokens_per_image = 6404
# how many images to cache in the kv cache in cross attention
kv_cache_image_num = 1
# maximum number of tokens generated by encoder and thus stored in the kv cache in cross attention
encoder_max_seq_len = tokens_per_image * kv_cache_image_num


@dataclass
class VisionEncoderConfig:
patch_size: int = 14
num_heads: int = 16
clip_embed_dim: int = 1280
clip_num_layers: int = 32
clip_hidden_states: list[int] = field(default_factory=lambda: [3, 7, 15, 23, 30])
decoder_embed_dim: int = 4096
num_layers_projection: int = 8
tile_size: int = 560
max_num_tiles: int = 4
in_channels: int = 3


class FlamingoVisionEncoderModel(EagerModelBase):
def __init__(self, config: Optional[VisionEncoderConfig] = None):
super().__init__()
if config is None:
config = VisionEncoderConfig()
self.config = config
self.model = llama3_2_vision_encoder(
patch_size=config.patch_size,
num_heads=config.num_heads,
clip_embed_dim=config.clip_embed_dim,
clip_num_layers=config.clip_num_layers,
clip_hidden_states=config.clip_hidden_states,
decoder_embed_dim=config.decoder_embed_dim,
num_layers_projection=config.num_layers_projection,
tile_size=config.tile_size,
max_num_tiles=config.max_num_tiles,
in_channels=config.in_channels,
)
self.model = replace_tile_positional_embedding(self.model)
self.model = replace_tiled_token_positional_embedding(self.model)
self.image = torch.randn(
1, 1, 4, 3, self.config.tile_size, self.config.tile_size
)
self.aspect_ratio = torch.tensor([[[1, 2]]])
self.sample_inputs = (
self.image,
self.aspect_ratio,
)

def get_eager_model(self, **kwargs):
return self.model

def get_example_inputs(self):
return self.sample_inputs

def get_dynamic_shapes(self):
dim = torch.export.Dim("num_tiles", min=1, max=self.config.max_num_tiles)
image_dynamic_dim = {
0: 1,
1: 1,
2: dim,
3: 3,
4: self.config.tile_size,
5: self.config.tile_size,
}
return (image_dynamic_dim, None)
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# Export and ExecuTorch tests for CLIP vision encoder are covered by test_models.sh.
# Only test AOTI in this file
import os
import tempfile
import unittest

import torch

from executorch.examples.models.llama3_2_vision.vision_encoder import (
FlamingoVisionEncoderModel,
)
from torch._inductor.package import package_aoti
from torch.testing import assert_close


class FlamingoVisionEncoderTest(unittest.TestCase):
def setUp(self) -> None:
super().setUp()

def test_flamingo_vision_encoder(self) -> None:
model = FlamingoVisionEncoderModel()
encoder = model.model
eager_res = encoder.forward(*model.get_example_inputs())

# AOTI
so = torch._export.aot_compile(
encoder,
model.get_example_inputs(),
options={"aot_inductor.package": True},
dynamic_shapes=model.get_dynamic_shapes(),
)
with tempfile.TemporaryDirectory() as tmpdir:
path = package_aoti(os.path.join(tmpdir, "vision_encoder.pt2"), so)
larryliu0820 marked this conversation as resolved.
Show resolved Hide resolved
print(path)
encoder_aoti = torch._inductor.aoti_load_package(path)

y = encoder_aoti(*model.get_example_inputs())
assert_close(y, eager_res, rtol=1e-4, atol=1e-4)
3 changes: 2 additions & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ addopts =
devtools/
# examples
examples/models/llama/tests
examples/models/llama3_2_vision/preprocess
examples/models/llama3_2_vision/preprocess/test
examples/models/llama3_2_vision/vision_encoder/test
# examples/models/llava/test TODO: enable this
# exir
exir/_serialize/test
Expand Down
Loading