Skip to content

Commit

Permalink
fix: load_image replacement
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Sep 13, 2024
1 parent 0ce3f89 commit 70b6749
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 26 deletions.
6 changes: 0 additions & 6 deletions data/aligned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from data.utils import load_image
from data.image_folder import make_dataset
from PIL import Image
import tifffile


class AlignedDataset(BaseDataset):
Expand Down Expand Up @@ -34,11 +33,6 @@ def __init__(self, opt, phase, name=""):
"aligned dataset: domain A and domain B should have the same number of images"
)

if opt.data_image_bits > 8 and opt.model_input_nc > 1:
self.use_tiff = True # multi-channel images > 8bit
else:
self.use_tiff = False

def __getitem__(self, index):
"""Return a data point and its metadata information.
Expand Down
6 changes: 6 additions & 0 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import imgaug.augmenters as iaa
import os
import warnings
import tifffile


class BaseDataset(data.Dataset, ABC):
Expand Down Expand Up @@ -63,6 +64,11 @@ def __init__(self, opt, phase, name=""):
self.warning_mode = self.opt.warning_mode
self.set_dataset_dirs_and_dims()

if opt.data_image_bits > 8 and opt.model_input_nc > 1:
self.use_tiff = True # multi-channel images > 8bit
else:
self.use_tiff = False

@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new dataset-specific options, and rewrite default values for existing options.
Expand Down
31 changes: 24 additions & 7 deletions data/unaligned_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os.path
from data.base_dataset import BaseDataset, get_transform
from data.base_dataset import BaseDataset, get_transform, get_params
from data.utils import load_image
from data.image_folder import make_dataset, make_ref_path_list
from PIL import Image
Expand Down Expand Up @@ -35,8 +35,12 @@ def __init__(self, opt, phase, name=""):
self.A_size = len(self.A_img_paths) # get the size of dataset A
self.B_size = len(self.B_img_paths) # get the size of dataset B

self.transform_A = get_transform(self.opt, grayscale=(self.input_nc == 1))
self.transform_B = get_transform(self.opt, grayscale=(self.output_nc == 1))
if self.opt.data_image_bits == 8:
self.grayscale = self.input_nc == 1
else: # for > 8bit, no explicit conversion
self.grayscale = False

A = load_image(self.A_img_paths[0]) # temporarily load first image

self.header = ["img"]

Expand All @@ -56,11 +60,24 @@ def get_img(
B_label_cls,
index,
):
A_img = load_image(A_img_path)
B_img = load_image(B_img_path)
A_img = load_image(A_img_path, self.opt.data_image_bits, self.use_tiff)
B_img = load_image(B_img_path, self.opt.data_image_bits, self.use_tiff)

if self.use_tiff:
transform_params = get_params(self.opt, A_img[:2])
else:
transform_params = get_params(self.opt, A_img.size)

transform_A = get_transform(
self.opt, params=transform_params, grayscale=self.grayscale
)
transform_B = get_transform(
self.opt, params=transform_params, grayscale=self.grayscale
)

# apply image transformation
A = self.transform_A(A_img)
B = self.transform_B(B_img)
A = transform_A(A_img)
B = transform_B(B_img)

result = {
"A": A,
Expand Down
38 changes: 25 additions & 13 deletions data/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,34 @@
from PIL import Image


def load_image(img_path):
image = Image.open(img_path)
def load_image(img_path, img_bits=8, use_tiff=False):
if use_tiff:
img = tifffile.imread(img_path)
else:
img = Image.open(img_path)

# Define the color for transparency (e.g., transparent black)
transparent_black = (0, 0, 0, 0)
if img_bits == 8:
img = img.convert("RGB")

# Convert the image to RGBA mode if needed
image = image.convert("RGBA")
return img

# Create a new image with the specified color for transparency
transparent_color = Image.new("RGBA", image.size, transparent_black)

# Use alpha_composite to make the specified color transparent
result = Image.alpha_composite(transparent_color, image)
# def load_image(img_path):
# image = Image.open(img_path)

# Convert the result back to RGB mode
result_rgb = result.convert("RGB")
# # Define the color for transparency (e.g., transparent black)
# transparent_black = (0, 0, 0, 0)

return result_rgb
# # Convert the image to RGBA mode if needed
# image = image.convert("RGBA")

# # Create a new image with the specified color for transparency
# transparent_color = Image.new("RGBA", image.size, transparent_black)

# # Use alpha_composite to make the specified color transparent
# result = Image.alpha_composite(transparent_color, image)

# # Convert the result back to RGB mode
# result_rgb = result.convert("RGB")

# return result_rgb

0 comments on commit 70b6749

Please sign in to comment.