diff --git a/configs/upload_dataset_huggingface.yaml b/configs/upload_dataset_huggingface.yaml index 31ec9192..61c73b55 100644 --- a/configs/upload_dataset_huggingface.yaml +++ b/configs/upload_dataset_huggingface.yaml @@ -15,6 +15,7 @@ lensless: dir: null ext: null # for example: .png, .jpg eight_norm: False # save as 8-bit normalized image + downsample: null lensed: dir: null diff --git a/lensless/utils/io.py b/lensless/utils/io.py index a51feff5..47fd94f4 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -37,6 +37,7 @@ def load_image( shape=None, dtype=None, normalize=True, + bgr_input=True, ): """ Load image as numpy array. @@ -151,7 +152,7 @@ def load_image( ) else: - if len(img.shape) == 3: + if len(img.shape) == 3 and bgr_input: img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) original_dtype = img.dtype @@ -223,6 +224,7 @@ def load_psf( single_psf=False, shape=None, use_3d=False, + bgr_input=True, ): """ Load and process PSF for analysis or for reconstruction. @@ -296,6 +298,7 @@ def load_psf( blue_gain=blue_gain, red_gain=red_gain, nbits_out=nbits_out, + bgr_input=bgr_input, ) original_dtype = psf.dtype @@ -391,6 +394,7 @@ def load_data( torch=False, torch_device="cpu", normalize=False, + bgr_input=True, ): """ Load data for image reconstruction. @@ -471,6 +475,7 @@ def load_data( single_psf=single_psf, shape=shape, use_3d=use_3d, + bgr_input=bgr_input, ) # load and process raw measurement @@ -485,6 +490,7 @@ def load_data( return_float=return_float, shape=shape, normalize=normalize, + bgr_input=bgr_input, ) if data.shape != psf.shape: diff --git a/scripts/data/upload_dataset_huggingface.py b/scripts/data/upload_dataset_huggingface.py index 8edf598d..2b212d97 100644 --- a/scripts/data/upload_dataset_huggingface.py +++ b/scripts/data/upload_dataset_huggingface.py @@ -24,6 +24,8 @@ from lensless.utils.dataset import natural_sort from tqdm import tqdm from lensless.utils.io import save_image +import cv2 +from joblib import Parallel, delayed @hydra.main( @@ -82,24 +84,44 @@ def upload_dataset(config): ] lensed_files = [os.path.join(config.lensed.dir, f + config.lensed.ext) for f in common_files] + if config.lensless.downsample is not None: + + tmp_dir = config.lensless.dir + "_tmp" + os.makedirs(tmp_dir, exist_ok=True) + + def downsample(f, output_dir): + img = cv2.imread(f, cv2.IMREAD_UNCHANGED) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize( + img, + (0, 0), + fx=1 / config.lensless.downsample, + fy=1 / config.lensless.downsample, + interpolation=cv2.INTER_LINEAR, + ) + new_fp = os.path.join(output_dir, os.path.basename(f)) + new_fp = new_fp.split(".")[0] + config.lensless.ext + save_image(img, new_fp, normalize=False) + + Parallel(n_jobs=n_jobs)(delayed(downsample)(f, tmp_dir) for f in tqdm(lensless_files)) + lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}")) + # convert to normalized 8 bit if config.lensless.eight_norm: - import cv2 - from joblib import Parallel, delayed - tmp_dir = config.lensless.dir + "_tmp" os.makedirs(tmp_dir, exist_ok=True) # -- parallelize with joblib def save_8bit(f, output_dir, normalize=True): img = cv2.imread(f, cv2.IMREAD_UNCHANGED) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) new_fp = os.path.join(output_dir, os.path.basename(f)) - new_fp = new_fp.split(".")[0] + ".png" + new_fp = new_fp.split(".")[0] + config.lensless.ext save_image(img, new_fp, normalize=normalize) Parallel(n_jobs=n_jobs)(delayed(save_8bit)(f, tmp_dir) for f in tqdm(lensless_files)) - lensless_files = glob.glob(os.path.join(tmp_dir, "*png")) + lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}")) # check for attribute df_attr = None @@ -222,6 +244,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None): upload_file( path_or_fileobj=lensless_files[0], + # path_in_repo=f"lensless_example{config.lensless.ext}" if not config.lensless.eight_norm else f"lensless_example.png", path_in_repo=f"lensless_example{config.lensless.ext}", repo_id=repo_id, repo_type="dataset", @@ -248,7 +271,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None): print(f"Total time: {(time.time() - start_time) / 60} minutes") # delete PNG files - if config.lensless.eight_norm: + if config.lensless.eight_norm or config.lensless.downsample: os.system(f"rm -rf {tmp_dir}")