Skip to content

Commit

Permalink
Merge pull request #116 from deepghs/dev/bg
Browse files Browse the repository at this point in the history
dev(narugo): optimize background loading safefy
  • Loading branch information
narugo1992 authored Oct 30, 2024
2 parents d39b451 + 146f940 commit 902ab07
Show file tree
Hide file tree
Showing 16 changed files with 252 additions and 52 deletions.
7 changes: 7 additions & 0 deletions docs/source/api_doc/data/image.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,10 @@ load_images
.. autofunction:: load_images


has_alpha_channel
------------------------------

.. autofunction:: has_alpha_channel



98 changes: 91 additions & 7 deletions imgutils/data/image.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,77 @@
"""
This module provides utility functions for image processing and manipulation using the PIL (Python Imaging Library) library.
It includes functions for loading images from various sources, handling multiple images, adding backgrounds to RGBA images,
and checking for alpha channels. The module is designed to simplify common image-related tasks in Python applications.
Key features:
- Loading images from different sources (file paths, binary data, file-like objects)
- Handling multiple images at once
- Adding backgrounds to RGBA images
- Checking for alpha channels in images
This module is particularly useful for applications that require image preprocessing or manipulation before further processing or analysis.
"""

from os import PathLike
from typing import Union, BinaryIO, List, Tuple, Optional

from PIL import Image

__all__ = [
'ImageTyping', 'load_image',
'MultiImagesTyping', 'load_images',
'ImageTyping',
'load_image',
'MultiImagesTyping',
'load_images',
'add_background_for_rgba',
'has_alpha_channel',
]


def _is_readable(obj):
"""
Check if an object is readable (has 'read' and 'seek' methods).
:param obj: The object to check for readability.
:type obj: Any
:return: True if the object is readable, False otherwise.
:rtype: bool
"""
return hasattr(obj, 'read') and hasattr(obj, 'seek')


ImageTyping = Union[str, PathLike, bytes, bytearray, BinaryIO, Image.Image]
MultiImagesTyping = Union[ImageTyping, List[ImageTyping], Tuple[ImageTyping, ...]]


def _has_alpha_channel(image: Image.Image) -> bool:
return any(band in {'A', 'a', 'P'} for band in image.getbands())
def has_alpha_channel(image: Image.Image) -> bool:
"""
Determine if the given Pillow image object has an alpha channel (transparency)
:param image: Pillow image object
:type image: Image.Image
:return: Boolean, True if it has an alpha channel, False otherwise
:rtype: bool
"""
# Get the image mode
mode = image.mode

# Modes that directly include an alpha channel
if mode in ('RGBA', 'LA', 'PA'):
return True

if getattr(image, 'palette'):
# Check if there's a transparent palette
try:
image.palette.getcolor((0, 0, 0, 0))
return True # cannot find a line to trigger this
except ValueError:
pass

# For other modes, check if 'transparency' key exists in image info
return 'transparency' in image.info


def load_image(image: ImageTyping, mode=None, force_background: Optional[str] = 'white'):
Expand All @@ -43,6 +95,16 @@ def load_image(image: ImageTyping, mode=None, force_background: Optional[str] =
:return: The loaded and transformed image.
:rtype: Image.Image
:raises TypeError: If the provided image type is not supported.
:example:
>>> from PIL import Image
>>> img = load_image('path/to/image.png', mode='RGB', force_background='white')
>>> isinstance(img, Image.Image)
True
>>> img.mode
'RGB'
"""
if isinstance(image, (str, PathLike, bytes, bytearray, BinaryIO)) or _is_readable(image):
image = Image.open(image)
Expand All @@ -51,7 +113,7 @@ def load_image(image: ImageTyping, mode=None, force_background: Optional[str] =
else:
raise TypeError(f'Unknown image type - {image!r}.')

if _has_alpha_channel(image) and force_background is not None:
if has_alpha_channel(image) and force_background is not None:
image = add_background_for_rgba(image, force_background)

if mode is not None and image.mode != mode:
Expand Down Expand Up @@ -79,6 +141,14 @@ def load_images(images: MultiImagesTyping, mode=None, force_background: Optional
:return: A list of loaded and transformed images.
:rtype: List[Image.Image]
:example:
>>> img_paths = ['path/to/image1.png', 'path/to/image2.jpg']
>>> loaded_images = load_images(img_paths, mode='RGB')
>>> len(loaded_images)
2
>>> all(isinstance(img, Image.Image) for img in loaded_images)
True
"""
if not isinstance(images, (list, tuple)):
images = [images]
Expand All @@ -102,6 +172,20 @@ def add_background_for_rgba(image: ImageTyping, background: str = 'white'):
:return: The image with the added background, converted to RGB.
:rtype: Image.Image
:example:
>>> from PIL import Image
>>> rgba_image = Image.new('RGBA', (100, 100), (255, 0, 0, 128))
>>> rgb_image = add_background_for_rgba(rgba_image, background='blue')
>>> rgb_image.mode
'RGB'
"""
from .layer import istack
return istack(background, image).convert('RGB')
image = load_image(image, force_background=None, mode=None)
try:
ret_image = Image.new('RGBA', image.size, background)
ret_image.paste(image, (0, 0), mask=image)
except ValueError:
ret_image = image
if ret_image.mode != 'RGB':
ret_image = ret_image.convert('RGB')
return ret_image
19 changes: 3 additions & 16 deletions imgutils/generic/enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,13 @@
import numpy as np
from PIL import Image

from ..data import ImageTyping, load_image
from ..data import ImageTyping, load_image, has_alpha_channel

__all__ = [
'ImageEnhancer',
]


def _has_alpha_channel(image: Image.Image) -> bool:
"""
Check if the image has an alpha channel.
:param image: The image to check.
:type image: Image.Image
:return: True if the image has an alpha channel, False otherwise.
:rtype: bool
"""
return any(band in {'A', 'a', 'P'} for band in image.getbands())


class ImageEnhancer:
"""
Enhances images by applying various processing techniques.
Expand Down Expand Up @@ -103,10 +90,10 @@ def process(self, image: ImageTyping):
:rtype: Image.Image
"""
image = load_image(image, mode=None, force_background=None)
mode = 'RGBA' if _has_alpha_channel(image) else 'RGB'
mode = 'RGBA' if has_alpha_channel(image) else 'RGB'
image = load_image(image, mode=mode, force_background=None)
input_array = (np.array(image).astype(np.float32) / 255.0).transpose((2, 0, 1))
if _has_alpha_channel(image):
if has_alpha_channel(image):
output_array = self._process_rgba(input_array)
else:
output_array = self._process_rgb(input_array)
Expand Down
10 changes: 3 additions & 7 deletions imgutils/tagging/wd14.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .format import remove_underline
from .overlap import drop_overlap_tags
from ..data import load_image, ImageTyping
from ..data import load_image, ImageTyping, has_alpha_channel
from ..utils import open_onnx_model, vreplace

SWIN_MODEL_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2"
Expand Down Expand Up @@ -114,10 +114,6 @@ def _mcut_threshold(probs) -> float:
return thresh


def _has_alpha_channel(image: Image.Image) -> bool:
return any(band in {'A', 'a', 'P'} for band in image.getbands())


def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
image = load_image(image, force_background=None, mode=None)
image_shape = image.size
Expand All @@ -126,9 +122,9 @@ def _prepare_image_for_tagging(image: ImageTyping, target_size: int):
pad_top = (max_dim - image_shape[1]) // 2

padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255))
if _has_alpha_channel(image):
try:
padded_image.paste(image, (pad_left, pad_top), mask=image)
else:
except ValueError:
padded_image.paste(image, (pad_left, pad_top))

if max_dim != target_size:
Expand Down
125 changes: 124 additions & 1 deletion test/data/test_image.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest
from PIL import Image

from imgutils.data import load_image
from imgutils.data import load_image, has_alpha_channel, add_background_for_rgba
from test.testings import get_testfile

_FILENAME = get_testfile('6125785.png')
Expand All @@ -23,3 +23,126 @@ def test_load_image(self, image_, result, image_diff):
assert load_image(image_, force_background=None) is image_
else:
assert image_diff(load_image(image_), result, throw_exception=False) < 1e-2

@pytest.mark.parametrize(['color'], [
('white',),
('green',),
('red',),
('blue',),
('black',),
])
def test_load_image_bg_rgba(self, image_diff, color):
image = load_image(get_testfile('nian.png'), force_background=color, mode='RGB')
expected = Image.open(get_testfile(f'nian_bg_{color}.png'))
assert image_diff(image, expected, throw_exception=False) < 1e-2

@pytest.mark.parametrize(['color'], [
('white',),
('green',),
('red',),
('blue',),
('black',),
])
def test_add_background_for_rgba_rgba(self, image_diff, color):
image = add_background_for_rgba(get_testfile('nian.png'), background=color)
assert image.mode == 'RGB'
expected = Image.open(get_testfile(f'nian_bg_{color}.png'))
assert image_diff(image, expected, throw_exception=False) < 1e-2

@pytest.mark.parametrize(['color'], [
('white',),
('green',),
('red',),
('blue',),
('black',),
])
def test_load_image_bg_rgb(self, image_diff, color):
image = load_image(get_testfile('mostima_post.jpg'), force_background=color, mode='RGB')
expected = Image.open(get_testfile(f'mostima_post_bg_{color}.png'))
assert image_diff(image, expected, throw_exception=False) < 1e-2

@pytest.mark.parametrize(['color'], [
('white',),
('green',),
('red',),
('blue',),
('black',),
])
def test_add_backround_for_rgba_rgb(self, image_diff, color):
image = add_background_for_rgba(get_testfile('mostima_post.jpg'), background=color)
assert image.mode == 'RGB'
expected = Image.open(get_testfile(f'mostima_post_bg_{color}.png'))
assert image_diff(image, expected, throw_exception=False) < 1e-2


@pytest.fixture
def rgba_image():
img = Image.new('RGBA', (10, 10), (255, 0, 0, 128))
return img


@pytest.fixture
def rgb_image():
img = Image.new('RGB', (10, 10), (255, 0, 0))
return img


@pytest.fixture
def la_image():
img = Image.new('LA', (10, 10), (128, 128))
return img


@pytest.fixture
def l_image():
img = Image.new('L', (10, 10), 128)
return img


@pytest.fixture
def p_image_with_transparency():
width, height = 200, 200
image = Image.new('P', (width, height))

palette = []
for i in range(256):
palette.extend((i, i, i)) # 灰度调色板

palette[:3] = (0, 0, 0) # 黑色
image.info['transparency'] = 0

image.putpalette(palette)
return image


@pytest.fixture
def p_image_without_transparency():
img = Image.new('P', (10, 10))
palette = [255, 0, 0, 255, 0, 0] # No transparent color
img.putpalette(palette)
return img


@pytest.mark.unittest
class TestHasAlphaChannel:
def test_rgba_image(self, rgba_image):
assert has_alpha_channel(rgba_image)

def test_rgb_image(self, rgb_image):
assert not has_alpha_channel(rgb_image)

def test_la_image(self, la_image):
assert has_alpha_channel(la_image)

def test_l_image(self, l_image):
assert not has_alpha_channel(l_image)

def test_p_image_with_transparency(self, p_image_with_transparency):
assert has_alpha_channel(p_image_with_transparency)

def test_p_image_without_transparency(self, p_image_without_transparency):
assert not has_alpha_channel(p_image_without_transparency)

def test_pa_image(self):
pa_image = Image.new('PA', (10, 10))
assert has_alpha_channel(pa_image)
Loading

0 comments on commit 902ab07

Please sign in to comment.