Skip to content

Commit

Permalink
Merge pull request #113 from deepghs/dev/tagging
Browse files Browse the repository at this point in the history
dev(narugo): optimize the preprocessing of wd14 tagger
  • Loading branch information
narugo1992 authored Oct 30, 2024
2 parents 5f14d6a + bb62aad commit 74d4a2d
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 2 deletions.
11 changes: 9 additions & 2 deletions imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,15 +114,22 @@ def _mcut_threshold(probs) -> float:
return thresh


def _has_alpha_channel(image: Image.Image) -> bool:
return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
image = load_image(image, force_background='white', mode='RGB')
image = load_image(image, force_background=None, mode=None)
image_shape = image.size
max_dim = max(image_shape)
pad_left = (max_dim - image_shape[0]) // 2
pad_top = (max_dim - image_shape[1]) // 2

padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
padded_image.paste(image, (pad_left, pad_top))
if _has_alpha_channel(image):
padded_image.paste(image, (pad_left, pad_top), mask=image)
else:
padded_image.paste(image, (pad_left, pad_top))

if max_dim != target_size:
padded_image = padded_image.resize((target_size, target_size), Image.BICUBIC)
Expand Down
29 changes: 29 additions & 0 deletions test/tagging/test_wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,32 @@ def test_wd14_tags_no_overlap(self):
'breasts_apart': 0.35740798711776733, 'clitoris': 0.44502270221710205
}, abs=2e-2)
assert chars == pytest.approx({'surtr_(arknights)': 0.9957615733146667}, abs=2e-2)

def test_wd14_rgba(self):
rating, tags, chars = get_wd14_tags(get_testfile('nian.png'))
assert rating == pytest.approx({
'general': 0.013875722885131836, 'sensitive': 0.9790834188461304,
'questionable': 0.0004328787326812744, 'explicit': 0.00010639429092407227,
}, abs=2e-2)
assert tags == pytest.approx({
'1girl': 0.996912956237793, 'solo': 0.9690700769424438, 'long_hair': 0.9183608293533325,
'breasts': 0.5793432593345642, 'looking_at_viewer': 0.9029998779296875, 'smile': 0.7181373834609985,
'open_mouth': 0.5431916117668152, 'simple_background': 0.3519788384437561,
'long_sleeves': 0.7442969679832458, 'white_background': 0.6004813313484192, 'holding': 0.7325218319892883,
'navel': 0.9297535419464111, 'jewelry': 0.5435991287231445, 'standing': 0.8762419819831848,
'purple_eyes': 0.9269286394119263, 'tail': 0.8547350168228149, 'full_body': 0.9316157102584839,
'white_hair': 0.9207442402839661, 'braid': 0.37353646755218506, 'multicolored_hair': 0.6516135931015015,
'thighs': 0.451822429895401, ':d': 0.5130974054336548, 'red_hair': 0.5783762335777283,
'small_breasts': 0.3563075065612793, 'boots': 0.6243380308151245, 'open_clothes': 0.8822896480560303,
'horns': 0.965097188949585, 'shorts': 0.9586330056190491, 'shoes': 0.4847032427787781,
'socks': 0.47281092405319214, 'tongue': 0.9029147624969482, 'pointy_ears': 0.8633939623832703,
'belt': 0.4783763289451599, 'midriff': 0.9044876098632812, 'tongue_out': 0.9018264412879944,
'wide_sleeves': 0.7076666951179504, 'stomach': 0.891795814037323, 'streaked_hair': 0.6510426998138428,
'coat': 0.7965987324714661, 'crop_top': 0.6840215921401978, 'hand_on_own_hip': 0.5604047179222107,
'strapless': 0.950110137462616, 'short_shorts': 0.6481347680091858, 'bare_legs': 0.5356456637382507,
'white_footwear': 0.8399633169174194, 'transparent_background': 0.3643641471862793, ':p': 0.532076358795166,
'half_updo': 0.5155724883079529, 'open_coat': 0.8147380352020264, 'beads': 0.3977043032646179,
'white_shorts': 0.9007017612457275, 'white_coat': 0.8003122806549072, 'bandeau': 0.9671074151992798,
'tube_top': 0.9783295392990112, 'bead_bracelet': 0.3510066270828247, 'red_bandeau': 0.8741766214370728
}, abs=2e-2)
assert chars == pytest.approx({'nian_(arknights)': 0.9968841671943665}, abs=2e-2)

0 comments on commit 74d4a2d

Please sign in to comment.