Skip to content

Commit

Permalink
Merge branch 'vision-language' into hwoo-image-cli
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-hwoo committed Jul 15, 2024
2 parents 68bacca + 92b2f3d commit a06edf9
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import base64
from enum import Enum, auto
from io import BytesIO
from typing import Optional, cast

import numpy as np
from PIL import Image


class ImageFormat(Enum):
JPEG = auto()
PNG = auto()


class SyntheticImageGenerator:
def __init__(
self,
image_width_mean: int,
image_height_mean: int,
image_width_stddev: int,
image_height_stddev: int,
image_format: ImageFormat = ImageFormat.PNG,
rng: Optional[np.random.Generator] = None,
):
self._image_width_mean = image_width_mean
self._image_height_mean = image_height_mean
self._image_width_stddev = image_width_stddev
self._image_height_stddev = image_height_stddev
self.image_format = image_format
self.rng = cast(np.random.Generator, rng or np.random.default_rng())

def __iter__(self):
return self

def _sample_random_positive_integer(self, mean: int, stddev: int) -> int:
while True:
n = int(self.rng.normal(mean, stddev))
if n > 0:
break
return n

def _get_next_image(self):
width = self._sample_random_positive_integer(
self._image_width_mean, self._image_width_stddev
)
height = self._sample_random_positive_integer(
self._image_height_mean, self._image_height_stddev
)
shape = width, height, 3
noise = self.rng.integers(0, 256, shape, dtype=np.uint8)
return Image.fromarray(noise)

def _encode(self, image):
buffered = BytesIO()
image.save(buffered, format=self.image_format.name)
data = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/{self.image_format.name.lower()};base64,{data}"

def __next__(self) -> str:
image = self._get_next_image()
base64_string = self._encode(image)
return base64_string
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import base64
from io import BytesIO

import numpy as np
import pytest
from genai_perf.llm_inputs.synthetic_image_generator import (
ImageFormat,
SyntheticImageGenerator,
)
from PIL import Image


def decode_image(base64_string):
_, data = base64_string.split(",")
decoded_data = base64.b64decode(data)
return Image.open(BytesIO(decoded_data))


@pytest.mark.parametrize(
"expected_image_size",
[
(100, 100),
(200, 200),
],
)
def test_different_image_size(expected_image_size):
expected_width, expected_height = expected_image_size
sut = SyntheticImageGenerator(
image_width_mean=expected_width,
image_height_mean=expected_height,
image_width_stddev=0,
image_height_stddev=0,
)

base64_string = next(sut)
image = decode_image(base64_string)

assert image.size == expected_image_size, "image not resized to the target size"


def test_negative_size_is_not_selected():
sut = SyntheticImageGenerator(
image_width_mean=-1,
image_height_mean=-1,
image_width_stddev=10,
image_height_stddev=10,
)

# exception is raised, when PIL.Image.resize is called with negative values
next(sut)


def test_generator_deterministic():
IMAGE_SIZE = 100, 100
STDDEV = 100, 100
SEED = 44
rng1 = np.random.default_rng(seed=SEED)
rng2 = np.random.default_rng(seed=SEED)
sut1 = SyntheticImageGenerator(*IMAGE_SIZE, *STDDEV, rng=rng1)
sut2 = SyntheticImageGenerator(*IMAGE_SIZE, *STDDEV, rng=rng2)

for _, img1, img2 in zip(range(5), sut1, sut2):
assert img1 == img2, "generator is nondererministic"


@pytest.mark.parametrize("image_format", [ImageFormat.PNG, ImageFormat.JPEG])
def test_base64_encoding_with_different_formats(image_format):
IMAGE_SIZE = 100, 100
STDDEV = 100, 100
sut = SyntheticImageGenerator(*IMAGE_SIZE, *STDDEV, image_format=image_format)

base64String = next(sut)

base64prefix = f"data:image/{image_format.name.lower()};base64,"
assert base64String.startswith(base64prefix), "unexpected prefix"
data = base64String[len(base64prefix) :]

# test if generator encodes to base64
img_data = base64.b64decode(data)
img_bytes = BytesIO(img_data)
# test if an image is encoded
image = Image.open(img_bytes)

assert image.format == image_format.name

0 comments on commit a06edf9

Please sign in to comment.