From bb62aad0441499ed409d6d8c09c4216b1061ea22 Mon Sep 17 00:00:00 2001 From: narugo1992 Date: Wed, 30 Oct 2024 17:37:35 +0800 Subject: [PATCH] dev(narugo): optimize the preprocessing of wd14 tag preprocessing --- imgutils/tagging/wd14.py | 11 +++++++++-- test/tagging/test_wd14.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/imgutils/tagging/wd14.py b/imgutils/tagging/wd14.py index 43b2145223c..7823b5b8bc3 100644 --- a/imgutils/tagging/wd14.py +++ b/imgutils/tagging/wd14.py @@ -114,15 +114,22 @@ def _mcut_threshold(probs) -> float: return thresh +def _has_alpha_channel(image: Image.Image) -> bool: + return any(band in {'A', 'a', 'P'} for band in image.getbands()) + + def _prepare_image_for_tagging(image: ImageTyping, target_size: int): - image = load_image(image, force_background='white', mode='RGB') + image = load_image(image, force_background=None, mode=None) image_shape = image.size max_dim = max(image_shape) pad_left = (max_dim - image_shape[0]) // 2 pad_top = (max_dim - image_shape[1]) // 2 padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) - padded_image.paste(image, (pad_left, pad_top)) + if _has_alpha_channel(image): + padded_image.paste(image, (pad_left, pad_top), mask=image) + else: + padded_image.paste(image, (pad_left, pad_top)) if max_dim != target_size: padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC) diff --git a/test/tagging/test_wd14.py b/test/tagging/test_wd14.py index 92f3b4ac9c5..cbaf33efa11 100644 --- a/test/tagging/test_wd14.py +++ b/test/tagging/test_wd14.py @@ -130,3 +130,32 @@ def test_wd14_tags_no_overlap(self): 'breasts_apart': 0.35740798711776733, 'clitoris': 0.44502270221710205 }, abs=2e-2) assert chars == pytest.approx({'surtr_(arknights)': 0.9957615733146667}, abs=2e-2) + + def test_wd14_rgba(self): + rating, tags, chars = get_wd14_tags(get_testfile('nian.png')) + assert rating == pytest.approx({ + 'general': 0.013875722885131836, 'sensitive': 0.9790834188461304, + 'questionable': 0.0004328787326812744, 'explicit': 0.00010639429092407227, + }, abs=2e-2) + assert tags == pytest.approx({ + '1girl': 0.996912956237793, 'solo': 0.9690700769424438, 'long_hair': 0.9183608293533325, + 'breasts': 0.5793432593345642, 'looking_at_viewer': 0.9029998779296875, 'smile': 0.7181373834609985, + 'open_mouth': 0.5431916117668152, 'simple_background': 0.3519788384437561, + 'long_sleeves': 0.7442969679832458, 'white_background': 0.6004813313484192, 'holding': 0.7325218319892883, + 'navel': 0.9297535419464111, 'jewelry': 0.5435991287231445, 'standing': 0.8762419819831848, + 'purple_eyes': 0.9269286394119263, 'tail': 0.8547350168228149, 'full_body': 0.9316157102584839, + 'white_hair': 0.9207442402839661, 'braid': 0.37353646755218506, 'multicolored_hair': 0.6516135931015015, + 'thighs': 0.451822429895401, ':d': 0.5130974054336548, 'red_hair': 0.5783762335777283, + 'small_breasts': 0.3563075065612793, 'boots': 0.6243380308151245, 'open_clothes': 0.8822896480560303, + 'horns': 0.965097188949585, 'shorts': 0.9586330056190491, 'shoes': 0.4847032427787781, + 'socks': 0.47281092405319214, 'tongue': 0.9029147624969482, 'pointy_ears': 0.8633939623832703, + 'belt': 0.4783763289451599, 'midriff': 0.9044876098632812, 'tongue_out': 0.9018264412879944, + 'wide_sleeves': 0.7076666951179504, 'stomach': 0.891795814037323, 'streaked_hair': 0.6510426998138428, + 'coat': 0.7965987324714661, 'crop_top': 0.6840215921401978, 'hand_on_own_hip': 0.5604047179222107, + 'strapless': 0.950110137462616, 'short_shorts': 0.6481347680091858, 'bare_legs': 0.5356456637382507, + 'white_footwear': 0.8399633169174194, 'transparent_background': 0.3643641471862793, ':p': 0.532076358795166, + 'half_updo': 0.5155724883079529, 'open_coat': 0.8147380352020264, 'beads': 0.3977043032646179, + 'white_shorts': 0.9007017612457275, 'white_coat': 0.8003122806549072, 'bandeau': 0.9671074151992798, + 'tube_top': 0.9783295392990112, 'bead_bracelet': 0.3510066270828247, 'red_bandeau': 0.8741766214370728 + }, abs=2e-2) + assert chars == pytest.approx({'nian_(arknights)': 0.9968841671943665}, abs=2e-2)