Skip to content

Commit

Permalink
Images in Messages (pytorch#1504)
Browse files Browse the repository at this point in the history
  • Loading branch information
joecummings authored Sep 10, 2024
1 parent 6deeda9 commit eb92658
Show file tree
Hide file tree
Showing 15 changed files with 331 additions and 93 deletions.
2 changes: 2 additions & 0 deletions docs/source/api_ref_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -88,3 +88,5 @@ Miscellaneous helper functions used in modifying data.

validate_messages
truncate
load_image
format_content_with_images
Binary file added tests/assets/dog_on_skateboard.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,10 +145,14 @@ def tokenize_messages(
return tokenized_messages, mask

def __call__(self, sample: Mapping[str, Any]) -> Mapping[str, Any]:
messages = sample.pop("messages")
messages: List[Message] = sample.pop("messages")
images = []
for message in messages:
images += message.get_media()
tokens, mask = self.tokenize_messages(messages)
sample["tokens"] = tokens
sample["mask"] = mask
sample["images"] = images
return sample

@property
Expand Down
124 changes: 109 additions & 15 deletions tests/torchtune/data/test_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,20 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import os

import pytest
from PIL import Image

from tests.common import ASSETS
from torchtune.data import (
format_content_with_images,
Message,
PromptTemplate,
split_text_by_image_tag,
truncate,
validate_messages,
)
from torchtune.data._utils import _get_prompt_template
from torchtune.data._utils import _get_prompt_template, load_image
from torchtune.models.llama2 import Llama2ChatTemplate


Expand Down Expand Up @@ -98,47 +103,136 @@ def test_validate_messages():
validate_messages(messages)


def test_split_text_by_image_tag():
def test_format_content_with_images():
test_image_1 = Image.new(mode="RGB", size=(4, 4))
test_image_2 = Image.new(mode="RGB", size=(4, 4))
test_image_3 = Image.new(mode="RGB", size=(4, 4))

# Test single image tag in the middle
text = "hello <image>world"
assert split_text_by_image_tag(text, "<image>") == [
assert format_content_with_images(
text,
image_tag="<image>",
images=[test_image_1],
) == [
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "image", "content": test_image_1},
{"type": "text", "content": "world"},
]

# Test multiple image tags and image tag in beginning
text = "[image]hello [image]world"
assert split_text_by_image_tag(text, "[image]") == [
{"type": "image"},
assert format_content_with_images(
text,
image_tag="[image]",
images=[test_image_1, test_image_2],
) == [
{"type": "image", "content": test_image_1},
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "image", "content": test_image_2},
{"type": "text", "content": "world"},
]

# Test an image tag that is not present in the text
text = "hello world"
assert split_text_by_image_tag(text, "asdfghjkl;") == [
assert format_content_with_images(text, image_tag="asdfghjkl;", images=[]) == [
{"type": "text", "content": "hello world"}
]

# Test consecutive image tags
text = "<image><image>hello <image>world"
assert split_text_by_image_tag(text, "<image>") == [
{"type": "image"},
{"type": "image"},
assert format_content_with_images(
text,
image_tag="<image>",
images=[test_image_1, test_image_2, test_image_3],
) == [
{"type": "image", "content": test_image_1},
{"type": "image", "content": test_image_2},
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "image", "content": test_image_3},
{"type": "text", "content": "world"},
]

# Test image tag at the end
text = "hello <image>"
assert split_text_by_image_tag(text, "<image>") == [
assert format_content_with_images(
text,
image_tag="<image>",
images=[test_image_1],
) == [
{"type": "text", "content": "hello "},
{"type": "image"},
{"type": "image", "content": test_image_1},
]

# Test errors when the number of images does not match the number of image tags
text = "hello <image>world"
with pytest.raises(
ValueError,
match="does not match number of image tags",
):
format_content_with_images(
text, image_tag="<image>", images=[test_image_1, test_image_2]
)


def test_load_image(monkeypatch, tmp_path):
tmp_image = str(ASSETS / "dog_on_skateboard.jpg")

# Test loading from local file
image = load_image(tmp_image)
assert isinstance(image, Image.Image)
assert image.size == (580, 403)

# Test loading from remote file
# Mock the urlopen function to return a BytesIO object
def mock_urlopen(url):
return open(tmp_image, "rb")

monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
image = load_image("http://example.com/test_image.jpg")
assert isinstance(image, Image.Image)
assert image.size == (580, 403)

# Test that a ValueError is raised when the image path is invalid
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
load_image("invalid_path")

# Test a temporary file with invalid image data
image_path = tmp_path / "test_image.jpg"
with open(image_path, "w") as f:
f.write("Invalid image data")

# Test that a ValueError is raised when the image data is invalid
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
load_image(str(image_path))

# Test that a ValueError is raised when there is an HTTP error
# Mock the urlopen function to raise an exception
def mock_urlopen(url):
raise Exception("Failed to load image")

monkeypatch.setattr("urllib.request.urlopen", mock_urlopen)
with pytest.raises(ValueError, match="Failed to load image"):
load_image("http://example.com/test_image.jpg")

# Test that a ValueError is raised when there is an IO error
# Create a temporary file that cannot be read
image_path = tmp_path / "test_image.jpg"
with open(image_path, "w") as f:
f.write("Test data")
os.chmod(image_path, 0o000) # Remove read permissions
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
load_image(str(image_path))
os.chmod(image_path, 0o644) # Restore read permissions

# Test that a ValueError is raised with invalid image data is read
# Create a temporary file with invalid image data
image_path = tmp_path / "test_image.jpg"
with open(image_path, "wb") as f:
f.write(b"Invalid image data")
with pytest.raises(ValueError, match="Failed to open image as PIL.Image"):
load_image(str(image_path))


def test_get_prompt_template():
template = _get_prompt_template("torchtune.models.llama2.Llama2ChatTemplate")
Expand Down
18 changes: 14 additions & 4 deletions tests/torchtune/data/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# LICENSE file in the root directory of this source tree.

import pytest

from PIL import Image
from tests.test_utils import (
assert_dialogue_equal,
CHAT_SAMPLE,
Expand All @@ -26,17 +28,21 @@ def text_message(self):
return Message(role="user", content="hello world")

@pytest.fixture
def image_message(self):
def test_image(self):
return Image.new(mode="RGB", size=(4, 4))

@pytest.fixture
def image_message(self, test_image):
return Message(
role="user",
content=[
{"type": "text", "content": "hello"},
{"type": "image"},
{"type": "image", "content": test_image},
{"type": "text", "content": " world"},
],
)

def test_message_validation(self, text_message):
def test_message_validation(self, text_message, test_image):
message = text_message
assert message.role == "user"
assert message.content == [{"type": "text", "content": "hello world"}]
Expand All @@ -53,7 +59,7 @@ def test_message_validation(self, text_message):
):
message = Message(
role="user",
content=[{"type": "image"}],
content=[{"type": "image", "content": test_image}],
ipython=True,
)

Expand All @@ -69,6 +75,10 @@ def test_contains_media(self, text_message, image_message):
assert not text_message.contains_media
assert image_message.contains_media

def test_get_media(self, text_message, image_message, test_image):
assert text_message.get_media() == []
assert image_message.get_media() == [test_image]

def test_text_content(self, text_message, image_message):
assert text_message.text_content == "hello world"
assert image_message.text_content == "hello world"
Expand Down
27 changes: 23 additions & 4 deletions tests/torchtune/datasets/multimodal/test_llava_instruct_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from collections import Counter
from unittest.mock import patch

import PIL

import pytest
from datasets import Dataset

Expand All @@ -21,11 +23,22 @@ class TestLLaVAInstructDataset:
def tokenizer(self):
return DummyTokenizer()

@pytest.fixture
def test_image_pil(self):
return PIL.Image.new(mode="RGB", size=(4, 4))

@patch("torchtune.datasets._sft.load_dataset")
def test_label_no_masking(self, load_dataset, tokenizer):
@patch("torchtune.datasets.multimodal._llava_instruct.load_image")
def test_label_no_masking(
self, load_image, load_dataset, tokenizer, test_image_pil
):
"""
Test whether the input and the labels are correctly created when the input is not masked.
WARNING: careful with these mocks, they are applied in bottom up order
"""
# mock the call to load_image
load_image.return_value = test_image_pil

# mock the call to HF datasets
load_dataset.return_value = Dataset.from_list(
Expand Down Expand Up @@ -55,6 +68,7 @@ def test_label_no_masking(self, load_dataset, tokenizer):
model_transform=tokenizer,
train_on_input=True,
)

input, labels, images = ds[0]["tokens"], ds[0]["labels"], ds[0]["images"]

expected_count = {
Expand All @@ -76,13 +90,18 @@ def test_label_no_masking(self, load_dataset, tokenizer):

assert Counter(input) == expected_count
assert Counter(labels) == expected_count
assert images == "test_image.jpg"
assert images == [test_image_pil]

@patch("torchtune.datasets._sft.load_dataset")
def test_label_masking(self, load_dataset, tokenizer):
@patch("torchtune.datasets.multimodal._llava_instruct.load_image")
def test_label_masking(self, load_image, load_dataset, tokenizer, test_image_pil):
"""
Test whether the input and the labels are correctly created when the input is masked.
WARNING: careful with these mocks, they are applied in bottom up order
"""
# mock the call to load_image
load_image.return_value = test_image_pil

# mock the call to HF datasets
load_dataset.return_value = Dataset.from_list(
Expand Down Expand Up @@ -133,4 +152,4 @@ def test_label_masking(self, load_dataset, tokenizer):

assert Counter(input) == expected_count
assert labels.count(CROSS_ENTROPY_IGNORE_IDX) == 11
assert images == "test_image.jpg"
assert images == [test_image_pil]
Loading

0 comments on commit eb92658

Please sign in to comment.