Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Option to downsample data before HF upload. #128

Merged
merged 4 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/upload_dataset_huggingface.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ lensless:
dir: null
ext: null # for example: .png, .jpg
eight_norm: False # save as 8-bit normalized image
downsample: null

lensed:
dir: null
Expand Down
8 changes: 7 additions & 1 deletion lensless/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def load_image(
shape=None,
dtype=None,
normalize=True,
bgr_input=True,
):
"""
Load image as numpy array.
Expand Down Expand Up @@ -151,7 +152,7 @@ def load_image(
)

else:
if len(img.shape) == 3:
if len(img.shape) == 3 and bgr_input:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

original_dtype = img.dtype
Expand Down Expand Up @@ -223,6 +224,7 @@ def load_psf(
single_psf=False,
shape=None,
use_3d=False,
bgr_input=True,
):
"""
Load and process PSF for analysis or for reconstruction.
Expand Down Expand Up @@ -296,6 +298,7 @@ def load_psf(
blue_gain=blue_gain,
red_gain=red_gain,
nbits_out=nbits_out,
bgr_input=bgr_input,
)

original_dtype = psf.dtype
Expand Down Expand Up @@ -391,6 +394,7 @@ def load_data(
torch=False,
torch_device="cpu",
normalize=False,
bgr_input=True,
):
"""
Load data for image reconstruction.
Expand Down Expand Up @@ -471,6 +475,7 @@ def load_data(
single_psf=single_psf,
shape=shape,
use_3d=use_3d,
bgr_input=bgr_input,
)

# load and process raw measurement
Expand All @@ -485,6 +490,7 @@ def load_data(
return_float=return_float,
shape=shape,
normalize=normalize,
bgr_input=bgr_input,
)

if data.shape != psf.shape:
Expand Down
35 changes: 29 additions & 6 deletions scripts/data/upload_dataset_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from lensless.utils.dataset import natural_sort
from tqdm import tqdm
from lensless.utils.io import save_image
import cv2
from joblib import Parallel, delayed


@hydra.main(
Expand Down Expand Up @@ -82,24 +84,44 @@ def upload_dataset(config):
]
lensed_files = [os.path.join(config.lensed.dir, f + config.lensed.ext) for f in common_files]

if config.lensless.downsample is not None:

tmp_dir = config.lensless.dir + "_tmp"
os.makedirs(tmp_dir, exist_ok=True)

def downsample(f, output_dir):
img = cv2.imread(f, cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(
img,
(0, 0),
fx=1 / config.lensless.downsample,
fy=1 / config.lensless.downsample,
interpolation=cv2.INTER_LINEAR,
)
new_fp = os.path.join(output_dir, os.path.basename(f))
new_fp = new_fp.split(".")[0] + config.lensless.ext
save_image(img, new_fp, normalize=False)

Parallel(n_jobs=n_jobs)(delayed(downsample)(f, tmp_dir) for f in tqdm(lensless_files))
lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}"))

# convert to normalized 8 bit
if config.lensless.eight_norm:

import cv2
from joblib import Parallel, delayed

tmp_dir = config.lensless.dir + "_tmp"
os.makedirs(tmp_dir, exist_ok=True)

# -- parallelize with joblib
def save_8bit(f, output_dir, normalize=True):
img = cv2.imread(f, cv2.IMREAD_UNCHANGED)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
new_fp = os.path.join(output_dir, os.path.basename(f))
new_fp = new_fp.split(".")[0] + ".png"
new_fp = new_fp.split(".")[0] + config.lensless.ext
save_image(img, new_fp, normalize=normalize)

Parallel(n_jobs=n_jobs)(delayed(save_8bit)(f, tmp_dir) for f in tqdm(lensless_files))
lensless_files = glob.glob(os.path.join(tmp_dir, "*png"))
lensless_files = glob.glob(os.path.join(tmp_dir, f"*{config.lensless.ext[1:]}"))

# check for attribute
df_attr = None
Expand Down Expand Up @@ -222,6 +244,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):

upload_file(
path_or_fileobj=lensless_files[0],
# path_in_repo=f"lensless_example{config.lensless.ext}" if not config.lensless.eight_norm else f"lensless_example.png",
path_in_repo=f"lensless_example{config.lensless.ext}",
repo_id=repo_id,
repo_type="dataset",
Expand All @@ -248,7 +271,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):
print(f"Total time: {(time.time() - start_time) / 60} minutes")

# delete PNG files
if config.lensless.eight_norm:
if config.lensless.eight_norm or config.lensless.downsample:
os.system(f"rm -rf {tmp_dir}")


Expand Down
Loading