Skip to content

Commit

Permalink
Add support for uploading dataset with ambient light.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Aug 8, 2024
1 parent 7118afc commit 4fea242
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 19 deletions.
1 change: 1 addition & 0 deletions configs/upload_dataset_huggingface.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ lensless:
dir: null
ext: null # for example: .png, .jpg
eight_norm: False # save as 8-bit normalized image
ambient: False
downsample: null

lensed:
Expand Down
23 changes: 23 additions & 0 deletions configs/upload_tapecam_mirflickr_ambient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# python scripts/data/upload_dataset_huggingface.py -cn upload_tapecam_mirflickr_ambient
defaults:
- upload_dataset_huggingface
- _self_

repo_id: "Lensless/TapeCam-Mirflickr-Ambient"
n_files: null
test_size: 0.15
# -- to match TapeCam without ambient light
split: 100 # "first: first `nfiles*test_size` for test, `int`: test_size*split for test (interleaved) as if multimask with this many masks

lensless:
dir: data/100_samples
ambient: True
ext: ".png"

lensed:
dir: data/mirflickr/mirflickr
ext: ".jpg"

files:
psf: data/tape_psf.png
measurement_config: data/collect_dataset_background.yaml
92 changes: 73 additions & 19 deletions scripts/data/upload_dataset_huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import hydra
from hydra.utils import to_absolute_path
import time
import os
import glob
Expand Down Expand Up @@ -52,6 +53,9 @@ def upload_dataset(config):
config.lensed.ext is not None
), "Please provide a lensed file extension, e.g. .png, .jpg, .tiff"

config.lensless.dir = to_absolute_path(config.lensless.dir)
config.lensed.dir = to_absolute_path(config.lensed.dir)

# get masks
files_masks = []
n_masks = 0
Expand All @@ -70,6 +74,7 @@ def upload_dataset(config):

# get lensed files
files_lensed = glob.glob(os.path.join(config.lensed.dir, "*" + config.lensed.ext))
print(f"Number of lensed files: {len(files_lensed)}")

# only keep if in both
bn_lensless = [os.path.basename(f).split(".")[0] for f in files_lensless]
Expand All @@ -83,9 +88,20 @@ def upload_dataset(config):
os.path.join(config.lensless.dir, f + config.lensless.ext) for f in common_files
]
lensed_files = [os.path.join(config.lensed.dir, f + config.lensed.ext) for f in common_files]
background_files = []
if config.lensless.ambient:
# check that corresponding ambient files exist
for f in common_files:
ambient_bn = "black_background" + os.path.basename(f) + config.lensless.ext
ambient_f = os.path.join(config.lensless.dir, ambient_bn)
assert os.path.exists(ambient_f), f"File {ambient_f} does not exist."
background_files.append(ambient_f)

if config.lensless.downsample is not None:

if config.lensless.ambient:
raise NotImplementedError("Downsampling not implemented for ambient files.")

tmp_dir = config.lensless.dir + "_tmp"
os.makedirs(tmp_dir, exist_ok=True)

Expand All @@ -109,6 +125,9 @@ def downsample(f, output_dir):
# convert to normalized 8 bit
if config.lensless.eight_norm:

if config.lensless.ambient:
raise NotImplementedError("Normalized 8-bit not implemented for ambient files.")

tmp_dir = config.lensless.dir + "_tmp"
os.makedirs(tmp_dir, exist_ok=True)

Expand Down Expand Up @@ -150,17 +169,21 @@ def save_8bit(f, output_dir, normalize=True):
df_attr = {"mask_label": mask_labels}

# step 1: create Dataset objects
def create_dataset(lensless_files, lensed_files, df_attr=None):
def create_dataset(lensless_files, lensed_files, df_attr=None, ambient_files=None):
dataset_dict = {
"lensless": lensless_files,
"lensed": lensed_files,
}
if df_attr is not None:
# combine dictionaries
dataset_dict = {**dataset_dict, **df_attr}
if ambient_files is not None:
dataset_dict["ambient"] = ambient_files
dataset = Dataset.from_dict(dataset_dict)
dataset = dataset.cast_column("lensless", Image())
dataset = dataset.cast_column("lensed", Image())
if ambient_files is not None:
dataset = dataset.cast_column("ambient", Image())
return dataset

# train-test split
Expand All @@ -175,17 +198,23 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):
[lensless_files[i] for i in test_indices],
[lensed_files[i] for i in test_indices],
{k: [v[i] for i in test_indices] for k, v in df_attr.items()},
ambient_files=[background_files[i] for i in test_indices]
if config.lensless.ambient
else None,
)
train_dataset = create_dataset(
[lensless_files[i] for i in train_indices],
[lensed_files[i] for i in train_indices],
{k: [v[i] for i in train_indices] for k, v in df_attr.items()},
ambient_files=[background_files[i] for i in train_indices]
if config.lensless.ambient
else None,
)
elif isinstance(config.split, int):
n_test_split = int(test_size * config.split)

# get all indices
n_splits = len(lensless_files) // config.split
n_splits = np.ceil(len(lensless_files) / config.split).astype(int)
test_idx = np.array([])
for i in range(n_splits):
test_idx = np.append(test_idx, np.arange(n_test_split) + i * config.split)
Expand All @@ -197,10 +226,18 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):

# split dict into train-test
test_dataset = create_dataset(
[lensless_files[i] for i in test_idx], [lensed_files[i] for i in test_idx]
[lensless_files[i] for i in test_idx],
[lensed_files[i] for i in test_idx],
ambient_files=[background_files[i] for i in test_idx]
if config.lensless.ambient
else None,
)
train_dataset = create_dataset(
[lensless_files[i] for i in train_idx], [lensed_files[i] for i in train_idx]
[lensless_files[i] for i in train_idx],
[lensed_files[i] for i in train_idx],
ambient_files=[background_files[i] for i in train_idx]
if config.lensless.ambient
else None,
)

else:
Expand All @@ -212,9 +249,17 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):
else:
df_attr_test = None
df_attr_train = None
test_dataset = create_dataset(lensless_files[:n_test], lensed_files[:n_test], df_attr_test)
test_dataset = create_dataset(
lensless_files[:n_test],
lensed_files[:n_test],
df_attr_test,
ambient_files=background_files[:n_test] if config.lensless.ambient else None,
)
train_dataset = create_dataset(
lensless_files[n_test:], lensed_files[n_test:], df_attr_train
lensless_files[n_test:],
lensed_files[n_test:],
df_attr_train,
ambient_files=background_files[n_test:] if config.lensless.ambient else None,
)
print(f"Train size: {len(train_dataset)}")
print(f"Test size: {len(test_dataset)}")
Expand All @@ -230,7 +275,7 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):
# step 3: push to hub
if config.files is not None:
for f in config.files:
fp = config.files[f]
fp = to_absolute_path(config.files[f])
ext = os.path.splitext(fp)[1]
remote_fn = f"{f}{ext}"
upload_file(
Expand All @@ -241,18 +286,19 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):
token=hf_token,
)

# viewable version of file
img = cv2.imread(fp, cv2.IMREAD_UNCHANGED)
local_fp = f"{f}_viewable8bit.png"
remote_fn = f"{f}_viewable8bit.png"
save_image(img, local_fp, normalize=True)
upload_file(
path_or_fileobj=local_fp,
path_in_repo=remote_fn,
repo_id=repo_id,
repo_type="dataset",
token=hf_token,
)
# viewable version of file if it is an image
if ext in [".png", ".jpg", ".jpeg", ".tiff"]:
img = cv2.imread(fp, cv2.IMREAD_UNCHANGED)
local_fp = f"{f}_viewable8bit.png"
remote_fn = f"{f}_viewable8bit.png"
save_image(img, local_fp, normalize=True)
upload_file(
path_or_fileobj=local_fp,
path_in_repo=remote_fn,
repo_id=repo_id,
repo_type="dataset",
token=hf_token,
)

dataset_dict.push_to_hub(repo_id, token=hf_token)

Expand All @@ -271,6 +317,14 @@ def create_dataset(lensless_files, lensed_files, df_attr=None):
repo_type="dataset",
token=hf_token,
)
if config.lensless.ambient:
upload_file(
path_or_fileobj=background_files[0],
path_in_repo=f"ambient_example{config.lensless.ext}",
repo_id=repo_id,
repo_type="dataset",
token=hf_token,
)

for _mask_file in files_masks:
upload_file(
Expand Down

0 comments on commit 4fea242

Please sign in to comment.