Skip to content

Commit

Permalink
Option to downsample data before HF upload. (#128)
Browse files Browse the repository at this point in the history
* Add option for BGR conversion.

* Add downsample option before Hugging Face upload.
  • Loading branch information
ebezzam authored May 2, 2024
1 parent 410510a commit d44ab57
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 7 deletions.
1 change: 1 addition & 0 deletions configs/upload_dataset_huggingface.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ lensless:
dir: null
ext: null # for example: .png, .jpg
eight_norm: False # save as 8-bit normalized image
downsample: null

lensed:
dir: null
Expand Down
8 changes: 7 additions & 1 deletion lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def load_image(
shape=None,
dtype=None,
normalize=True,
bgr_input=True,
):
"""
Load image as numpy array.
Expand Down Expand Up @@ -151,7 +152,7 @@ def load_image(
)

else:
if len(img.shape) == 3:
if len(img.shape) == 3 and bgr_input:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

original_dtype = img.dtype
Expand Down Expand Up @@ -223,6 +224,7 @@ def load_psf(
single_psf=False,
shape=None,
use_3d=False,
bgr_input=True,
):
"""
Load and process PSF for analysis or for reconstruction.
Expand Down Expand Up @@ -296,6 +298,7 @@ def load_psf(
blue_gain=blue_gain,
red_gain=red_gain,
nbits_out=nbits_out,
bgr_input=bgr_input,
)

original_dtype = psf.dtype
Expand Down Expand Up @@ -391,6 +394,7 @@ def load_data(
torch=False,
torch_device="cpu",
normalize=False,
bgr_input=True,
):
"""
Load data for image reconstruction.
Expand Down Expand Up @@ -471,6 +475,7 @@ def load_data(
single_psf=single_psf,
shape=shape,
use_3d=use_3d,
bgr_input=bgr_input,
)

# load and process raw measurement
Expand All @@ -485,6 +490,7 @@ def load_data(
return_float=return_float,
shape=shape,
normalize=normalize,
bgr_input=bgr_input,
)

if data.shape != psf.shape:
Expand Down
35 changes: 29 additions & 6 deletions scripts/data/upload_dataset_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from lensless.utils.dataset import natural_sort
from tqdm import tqdm
from lensless.utils.io import save_image
import cv2
from joblib import Parallel, delayed


@hydra.main(
Expand Down Expand Up @@ -82,24 +84,44 @@ def upload_dataset(config):
]
lensed_files = [os.path.join(config.lensed.dir, f + config.lensed.ext) for f in common_files]

if config.lensless.downsample is not None:

tmp_dir = config.lensless.dir + "_tmp"
os.makedirs(tmp_dir, exist_ok=True)

def downsample(f, output_dir):
img = cv2.imread(f, cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(
img,
(0, 0),
fx=1 / config.lensless.downsample,
fy=1 / config.lensless.downsample,
interpolation=cv2.INTER_LINEAR,
)
new_fp = os.path.join(output_dir, os.path.basename(f))
new_fp = new_fp.split(".")[0] + config.lensless.ext
save_image(img, new_fp, normalize=False)

Parallel(n_jobs=n_jobs)(delayed(downsample)(f, tmp_dir) for f in tqdm(lensless_files))
lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}"))

# convert to normalized 8 bit
if config.lensless.eight_norm:

import cv2
from joblib import Parallel, delayed

tmp_dir = config.lensless.dir + "_tmp"
os.makedirs(tmp_dir, exist_ok=True)

# -- parallelize with joblib
def save_8bit(f, output_dir, normalize=True):
img = cv2.imread(f, cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
new_fp = os.path.join(output_dir, os.path.basename(f))
new_fp = new_fp.split(".")[0] + ".png"
new_fp = new_fp.split(".")[0] + config.lensless.ext
save_image(img, new_fp, normalize=normalize)

Parallel(n_jobs=n_jobs)(delayed(save_8bit)(f, tmp_dir) for f in tqdm(lensless_files))
lensless_files = glob.glob(os.path.join(tmp_dir, "*png"))
lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}"))

# check for attribute
df_attr = None
Expand Down Expand Up @@ -222,6 +244,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):

upload_file(
path_or_fileobj=lensless_files[0],
# path_in_repo=f"lensless_example{config.lensless.ext}" if not config.lensless.eight_norm else f"lensless_example.png",
path_in_repo=f"lensless_example{config.lensless.ext}",
repo_id=repo_id,
repo_type="dataset",
Expand All @@ -248,7 +271,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):
print(f"Total time: {(time.time() - start_time) / 60} minutes")

# delete PNG files
if config.lensless.eight_norm:
if config.lensless.eight_norm or config.lensless.downsample:
os.system(f"rm -rf {tmp_dir}")


Expand Down

0 comments on commit d44ab57

Please sign in to comment.