Skip to content

Commit

Permalink
Llama-Vision: Enable tracing, refactor generation code (#15005)
Browse files Browse the repository at this point in the history
- Refactored Llama-Vision demos
  - Implemented prefill/decode wrapper in `vision_generator.py`
  - Use new generator wrapper in all demos
- Added `simple_vision_demo.py` for easy testing and e2e perf
measurement
- Refactored Llama cross attention tests 
  - Added support for batch>1 xattn cache generation
- Enable tracing in Llama-Vision
  • Loading branch information
cglagovichTT authored Nov 14, 2024
1 parent a8da990 commit 758f8c9
Show file tree
Hide file tree
Showing 16 changed files with 1,662 additions and 557 deletions.
82 changes: 38 additions & 44 deletions models/demos/llama3/demo/multimodal_demo_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,21 @@
from PIL import Image as PIL_Image
from termcolor import cprint

from models.demos.llama3.demo.multimodal_demo_text import create_multimodal_model
import llama_models.llama3.reference_impl.generation as llama_reference_generation
import pytest
import os
import ttnn

import llama_models.llama3.reference_impl.generation as llama_reference_generation
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import ImageMedia, UserMessage

from pkg_resources import resource_filename

IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/"))

import torch
import pytest
import os
import ttnn
from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision
from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model


@pytest.mark.parametrize(
Expand All @@ -36,39 +38,36 @@
"target",
("tt", "cpu"),
)
@pytest.mark.parametrize(
"warmup_iters",
(0, 1),
)
def test_llama_multimodal_demo_chat(
mesh_device,
target,
warmup_iters,
temperature: float = 0.5,
top_p: float = 0.9,
max_seq_len: int = 512,
max_batch_size: int = 4,
max_batch_size: int = 1,
max_gen_len: Optional[int] = 200,
model_parallel_size: Optional[int] = None,
):
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
ckpt_dir = os.environ["LLAMA_DIR"]
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")

logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'")
generator = llama_reference_generation.Llama.build(
ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)

if target == "tt":
if target == "cpu":
generator = llama_reference_generation.Llama.build(
ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)
else:
logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices")
model = create_multimodal_model(generator.args, mesh_device)
generator.model = model
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)

# image understanding
dialogs = []
Expand All @@ -85,26 +84,21 @@ def test_llama_multimodal_demo_chat(
)
],
]
# text only
dialogs += [
[UserMessage(content="what is the recipe of mayonnaise in two sentences?")],
]

print(f"Running text completion on {target}")
for _ in range(warmup_iters + 1):
for dialog in dialogs:
result = generator.chat_completion(
dialog,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)
for dialog in dialogs:
result = generator.chat_completion(
dialog,
max_gen_len=max_gen_len,
temperature=temperature,
top_p=top_p,
)

for msg in dialog:
print(f"{msg.role.capitalize()}: {msg.content}\n")
for msg in dialog:
print(f"{msg.role.capitalize()}: {msg.content}\n")

out_message = result.generation
print(f"> {out_message.role.capitalize()}: {out_message.content}")
for t in out_message.tool_calls:
print(f" Tool call: {t.tool_name} ({t.arguments})")
print("\n==================================\n")
out_message = result.generation
print(f"> {out_message.role.capitalize()}: {out_message.content}")
for t in out_message.tool_calls:
print(f" Tool call: {t.tool_name} ({t.arguments})")
print("\n==================================\n")
64 changes: 25 additions & 39 deletions models/demos/llama3/demo/multimodal_demo_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,22 @@
from PIL import Image as PIL_Image
from termcolor import cprint

import llama_models.llama3.reference_impl.generation as llama_reference_generation
import pytest
import os
import ttnn

import llama_models.llama3.reference_impl.generation as llama_reference_generation
from llama_models.llama3.api.datatypes import ImageMedia
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.llama3.api.chat_format import ChatFormat


from pkg_resources import resource_filename

IMG_PATH = Path(resource_filename("llama_models", "scripts/resources/"))

import torch
import pytest
import os
import ttnn


def create_multimodal_model(model_args, mesh_device, dtype=ttnn.bfloat16):
from models.demos.llama3.tt.multimodal.llama_vision_model import CrossAttentionTransformer
from models.demos.llama3.tt.model_config import TtModelArgs

tt_model_args = TtModelArgs(mesh_device)
checkpoint = torch.load(tt_model_args.consolidated_weights_path, map_location="cpu", weights_only=True)
model = CrossAttentionTransformer(
model_args,
mesh_device,
checkpoint,
weight_cache_path=tt_model_args.weight_cache_path(dtype),
dtype=dtype,
configuration=tt_model_args,
)
model.setup_cache(model_args.max_batch_size, torch.float32)
return model
from models.demos.llama3.demo.simple_vision_demo import create_multimodal_model
from models.demos.llama3.tt.multimodal.vision_generator import LlamaVision


@pytest.mark.parametrize(
Expand All @@ -64,28 +50,30 @@ def test_llama_multimodal_demo_text(
temperature: float = 0.5,
top_p: float = 0.9,
max_seq_len: int = 512,
max_batch_size: int = 4,
max_batch_size: int = 1,
max_gen_len: Optional[int] = 200,
model_parallel_size: Optional[int] = None,
):
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
ckpt_dir = os.environ["LLAMA_DIR"]
tokenizer_path = str(Path(ckpt_dir) / "tokenizer.model")

logger.info(f"Creating reference model from checkpoint in '{ckpt_dir}'")
generator = llama_reference_generation.Llama.build(
ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)

if target == "tt":
if target == "cpu":
generator = llama_reference_generation.Llama.build(
ckpt_dir,
tokenizer_path=tokenizer_path,
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
model_parallel_size=model_parallel_size,
)
else:
logger.info(f"Creating TT model on {len(mesh_device.get_devices())} devices")
model = create_multimodal_model(generator.args, mesh_device)
generator.model = model
mesh_device.enable_program_cache()
mesh_device.enable_async(True)
model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len)
tokenizer = Tokenizer(model_path=tokenizer_path)
formatter = ChatFormat(tokenizer)
generator = LlamaVision(model, model_args, mesh_device, tokenizer=tokenizer, formatter=formatter)

with open(IMG_PATH / "dog.jpg", "rb") as f:
img = PIL_Image.open(f).convert("RGB")
Expand All @@ -100,8 +88,6 @@ def test_llama_multimodal_demo_text(
clutter = PIL_Image.open(f).convert("RGB")

interleaved_contents = [
# text only
"The color of the sky is blue but sometimes it can also be",
# image understanding
[ImageMedia(image=img), "If I had to write a haiku for this one"],
[ImageMedia(image=img2), "Couting the number of individual spaghetti strands in this image"],
Expand Down
Loading

0 comments on commit 758f8c9

Please sign in to comment.