diff --git a/imgutils/data/image.py b/imgutils/data/image.py index d9e5354f91a..67b7ad6aa5c 100644 --- a/imgutils/data/image.py +++ b/imgutils/data/image.py @@ -181,10 +181,10 @@ def add_background_for_rgba(image: ImageTyping, background: str = 'white'): 'RGB' """ image = load_image(image, force_background=None, mode=None) - if has_alpha_channel(image): + try: ret_image = Image.new('RGBA', image.size, background) ret_image.paste(image, (0, 0), mask=image) - else: + except ValueError: ret_image = image if ret_image.mode != 'RGB': ret_image = ret_image.convert('RGB') diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index 454a860b55c..0d00aad1666 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -122,9 +122,9 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int): pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) - if has_alpha_channel(image): + try: padded_image.paste(image, (pad_left, pad_top), mask=image) - else: + except ValueError: padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: diff --git a/test/detect/test_text.py b/test/detect/test_text.py index e2092b5ecd1..600c9280f85 100644 --- a/test/detect/test_text.py +++ b/test/detect/test_text.py @@ -1,5 +1,6 @@ import pytest +from imgutils.detect import detection_similarity from imgutils.detect.text import _open_text_detect_model, detect_text from test.testings import get_testfile @@ -18,32 +19,34 @@ def test_detect_text(self): detections = detect_text(get_testfile('ml1.png')) assert len(detections) == 4 - values = [] - for bbox, label, score in detections: - assert label in {'text'} - values.append((bbox, int(score * 1000) / 1000)) - - assert values == pytest.approx([ - ((866, 45, 959, 69), 0.543), - ((222, 68, 313, 102), 0.543), - ((424, 82, 508, 113), 0.541), - ((691, 101, 776, 129), 0.471) - ]) + assert detection_similarity( + detections, + [ + ((866, 45, 959, 69), 'text', 0.543), + ((222, 68, 313, 102), 'text', 0.543), + ((424, 82, 508, 113), 'text', 0.541), + ((691, 101, 776, 129), 'text', 0.471) + ], + ) >= 0.9 def test_detect_text_without_resize(self): detections = detect_text(get_testfile('ml2.jpg'), max_area_size=None) assert len(detections) == 9 - values = [] - for bbox, label, score in detections: - assert label in {'text'} - values.append((bbox, int(score * 1000) / 1000)) - - assert values == pytest.approx([ - ((360, 218, 474, 250), 0.686), ((119, 218, 203, 240), 0.653), ((392, 47, 466, 76), 0.617), - ((593, 174, 666, 204), 0.616), ((179, 451, 672, 472), 0.591), ((633, 314, 747, 337), 0.59), - ((392, 369, 517, 386), 0.589), ((621, 81, 681, 102), 0.566), ((209, 92, 281, 122), 0.423), - ]) + assert detection_similarity( + detections, + [ + ((360, 218, 474, 250), 'text', 0.686), + ((119, 218, 203, 240), 'text', 0.653), + ((392, 47, 466, 76), 'text', 0.617), + ((593, 174, 666, 204), 'text', 0.616), + ((179, 451, 672, 472), 'text', 0.591), + ((633, 314, 747, 337), 'text', 0.59), + ((392, 369, 517, 386), 'text', 0.589), + ((621, 81, 681, 102), 'text', 0.566), + ((209, 92, 281, 122), 'text', 0.423), + ] + ) >= 0.9 def test_detect_text_none(self): assert detect_text(get_testfile('png_full.png')) == []