-
Notifications
You must be signed in to change notification settings - Fork 234
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Synthetic image generator #751
Changes from 21 commits
a8656f8
a5b6dbc
9630704
7915dd7
6176dc4
9f8a426
5178b27
09e81a8
7f5d573
c4e7c35
5673a51
d64cd27
6d5b4ea
edad485
935da0b
d8712eb
b5d4b64
fb6e982
287edba
af0a93a
e0b43fd
1ef4f71
ae66dc3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import base64 | ||
from enum import Enum, auto | ||
from io import BytesIO | ||
from pathlib import Path | ||
from typing import Generator, Optional, Tuple, cast | ||
Check notice Code scanning / CodeQL Unused import Note
Import of 'Generator' is not used.
Import of 'Tuple' is not used. Import of 'cast' is not used. |
||
|
||
import numpy as np | ||
from genai_perf.exceptions import GenAIPerfException | ||
Check notice Code scanning / CodeQL Unused import Note
Import of 'GenAIPerfException' is not used.
|
||
from PIL import Image | ||
|
||
|
||
class ImageFormat(Enum): | ||
JPEG = auto() | ||
PNG = auto() | ||
|
||
|
||
class SyntheticImageGenerator: | ||
def __init__( | ||
self, | ||
image_width_mean: int, | ||
image_height_mean: int, | ||
image_width_stddev: int, | ||
image_height_stddev: int, | ||
image_format: ImageFormat = ImageFormat.PNG, | ||
rng: Optional[np.random.Generator] = None, | ||
): | ||
self._image_width_mean = image_width_mean | ||
self._image_height_mean = image_height_mean | ||
self._image_width_stddev = image_width_stddev | ||
self._image_height_stddev = image_height_stddev | ||
self.image_format = image_format | ||
self.rng = rng or np.random.default_rng() | ||
|
||
def __iter__(self): | ||
return self | ||
|
||
def _sample_random_positive_integer(self, mean: int, stddev: int) -> int: | ||
while True: | ||
n = int(self.rng.normal(mean, stddev)) | ||
if n > 0: | ||
break | ||
return n | ||
|
||
def _random_resize(self, image): | ||
width = self._sample_random_positive_integer( | ||
self._image_width_mean, self._image_width_stddev | ||
) | ||
height = self._sample_random_positive_integer( | ||
self._image_height_mean, self._image_height_stddev | ||
) | ||
return image.resize((width, height)) | ||
|
||
def _get_next_image(self): | ||
return Image.new("RGB", (100, 100), color="white") | ||
|
||
def _encode(self, image): | ||
buffered = BytesIO() | ||
image.save(buffered, format=self.image_format.name) | ||
data = base64.b64encode(buffered.getvalue()).decode("utf-8") | ||
return f"data:image/{self.image_format.name.lower()};base64,{data}" | ||
|
||
def __next__(self) -> str: | ||
image = self._get_next_image() | ||
image = self._random_resize(image) | ||
base64_string = self._encode(image) | ||
return base64_string |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import base64 | ||
from io import BytesIO | ||
from pathlib import Path | ||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'Path' is not used.
|
||
from unittest.mock import patch | ||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'patch' is not used.
|
||
|
||
import numpy as np | ||
import pytest | ||
from genai_perf.exceptions import GenAIPerfException | ||
Check notice Code scanning / CodeQL Unused import Note test
Import of 'GenAIPerfException' is not used.
|
||
from genai_perf.llm_inputs.synthetic_image_generator import ( | ||
ImageFormat, | ||
SyntheticImageGenerator, | ||
) | ||
from PIL import Image | ||
|
||
|
||
def decode_image(base64_string): | ||
_, data = base64_string.split(",") | ||
decoded_data = base64.b64decode(data) | ||
return Image.open(BytesIO(decoded_data)) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"expected_image_size", | ||
[ | ||
(100, 100), | ||
(200, 200), | ||
], | ||
) | ||
def test_different_image_size(expected_image_size): | ||
expected_width, expected_height = expected_image_size | ||
sut = SyntheticImageGenerator( | ||
image_width_mean=expected_width, | ||
image_height_mean=expected_height, | ||
image_width_stddev=0, | ||
image_height_stddev=0, | ||
) | ||
|
||
base64_string = next(sut) | ||
image = decode_image(base64_string) | ||
|
||
assert image.size == expected_image_size, "image not resized to the target size" | ||
|
||
|
||
def test_negative_size_is_not_selected(): | ||
sut = SyntheticImageGenerator( | ||
image_width_mean=-1, | ||
image_height_mean=-1, | ||
image_width_stddev=10, | ||
image_height_stddev=10, | ||
) | ||
|
||
# exception is raised, when PIL.Image.resize is called with negative values | ||
next(sut) | ||
|
||
|
||
def test_generator_deterministic(): | ||
IMAGE_SIZE = 100, 100 | ||
STDDEV = 100, 100 | ||
SEED = 44 | ||
rng1 = np.random.default_rng(seed=SEED) | ||
rng2 = np.random.default_rng(seed=SEED) | ||
sut1 = SyntheticImageGenerator(*IMAGE_SIZE, *STDDEV, rng=rng1) | ||
sut2 = SyntheticImageGenerator(*IMAGE_SIZE, *STDDEV, rng=rng2) | ||
|
||
for _, img1, img2 in zip(range(5), sut1, sut2): | ||
assert img1 == img2, "generator is nondererministic" | ||
|
||
|
||
@pytest.mark.parametrize("image_format", [ImageFormat.PNG, ImageFormat.JPEG]) | ||
def test_base64_encoding_with_different_formats(image_format): | ||
IMAGE_SIZE = 100, 100 | ||
STDDEV = 100, 100 | ||
sut = SyntheticImageGenerator(*IMAGE_SIZE, *STDDEV, image_format=image_format) | ||
|
||
base64String = next(sut) | ||
|
||
base64prefix = f"data:image/{image_format.name.lower()};base64," | ||
assert base64String.startswith(base64prefix), "unexpected prefix" | ||
data = base64String[len(base64prefix) :] | ||
|
||
# test if generator encodes to base64 | ||
img_data = base64.b64decode(data) | ||
img_bytes = BytesIO(img_data) | ||
# test if an image is encoded | ||
image = Image.open(img_bytes) | ||
|
||
assert image.format == image_format.name |
Check notice
Code scanning / CodeQL
Unused import Note