diff --git a/README.rst b/README.rst index 532e0940..428fd07e 100644 --- a/README.rst +++ b/README.rst @@ -168,8 +168,7 @@ directory): pip install -r rpi_requirements.txt # test on-device camera capture (after setting up the camera) - source lensless_env/bin/activate - python scripts/measure/on_device_capture.py + (lensless_env) python scripts/measure/on_device_capture.py You may still need to manually install ``numpy`` and/or ``scipy`` with ``pip`` in case libraries (e.g. ``libopenblas.so.0``) cannot be detected. diff --git a/configs/analyze_dataset.yaml b/configs/analyze_dataset.yaml index 53d6a130..475f452c 100644 --- a/configs/analyze_dataset.yaml +++ b/configs/analyze_dataset.yaml @@ -1,9 +1,11 @@ +# python scripts/measure/analyze_measured_dataset.py hydra: job: chdir: True # change to output folder dataset_path: null -desired_range: [150, 254] +desired_range: [150, 255] +saturation_percent: 0.05 delete_bad: False n_files: null start_idx: null diff --git a/configs/collect_dataset.yaml b/configs/collect_dataset.yaml index c898ab2b..0fa87fa2 100644 --- a/configs/collect_dataset.yaml +++ b/configs/collect_dataset.yaml @@ -24,6 +24,7 @@ min_level: 200 max_tries: 6 masks: null # for multi-mask measurements +recon: null # parameters for reconstruction (for debugging purposes, not recommended to do during actual measurement as it will significantly increase the time) # -- display parameters display: @@ -41,10 +42,12 @@ display: capture: skip: False # to test looping over displaying images - config_pause: 2 + config_pause: 3 iso: 100 res: null down: 4 exposure: 0.02 # min exposure awb_gains: [1.9, 1.2] # red, blue - # awb_gains: null \ No newline at end of file + # awb_gains: null + fact_increase: 2 # multiplicative factor to increase exposure + fact_decrease: 1.5 \ No newline at end of file diff --git a/configs/collect_mirflickr_fza.yaml b/configs/collect_mirflickr_fza.yaml new file mode 100644 index 00000000..bd6de24f --- /dev/null +++ b/configs/collect_mirflickr_fza.yaml @@ -0,0 +1,32 @@ +# python scripts/measure/collect_dataset_on_device.py -cn collect_mirflickr_fza +defaults: + - collect_dataset + - _self_ + +input_dir: /mnt/mirflickr/all +min_level: 170 + +# FOR TESTING PURPOSE +output_dir: data/fza_test # RPi won't have enough memory for full measured dataset +max_tries: 0 +recon: + psf: data/psf/tape_rgb.png # TODO: set correct PSF + n_iter: 10 +# # FOR FINAL MEASUREMENT +# max_tries: 2 +# output_dir: /mnt/mirflickr/fza_10K + +# files to measure +n_files: 25000 + +# -- display parameters +display: + image_res: [900, 1200] + vshift: -26 + brightness: 90 + delay: 1 + +capture: + down: 8 + exposure: 0.7 + awb_gains: [1.6, 1.2] # red, blue, TODO for your mask \ No newline at end of file diff --git a/lensless/recon/recon.py b/lensless/recon/recon.py index ca172ff8..aaf7093d 100644 --- a/lensless/recon/recon.py +++ b/lensless/recon/recon.py @@ -494,9 +494,9 @@ def _get_numpy_data(self, data): def apply( self, n_iter=None, - disp_iter=10, + disp_iter=-1, plot_pause=0.2, - plot=True, + plot=False, save=False, gamma=None, ax=None, diff --git a/lensless/utils/image.py b/lensless/utils/image.py index 4b159fd2..a36a0f61 100644 --- a/lensless/utils/image.py +++ b/lensless/utils/image.py @@ -221,6 +221,7 @@ def get_max_val(img, nbits=None): def bayer2rgb_cc( img, nbits, + down=None, blue_gain=None, red_gain=None, black_level=RPI_HQ_CAMERA_BLACK_LEVEL, @@ -269,6 +270,10 @@ def bayer2rgb_cc( # demosaic Bayer data img = cv2.cvtColor(img, cv2.COLOR_BayerRG2RGB) + # downsample + if down is not None: + img = resize(img[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC)[0] + # correction img = img - black_level if red_gain: @@ -277,6 +282,7 @@ def bayer2rgb_cc( img[:, :, 2] *= blue_gain img = img / (2**nbits - 1 - black_level) img[img > 1] = 1 + img = (img.reshape(-1, 3, order="F") @ ccm.T).reshape(img.shape, order="F") img[img < 0] = 0 img[img > 1] = 1 diff --git a/lensless/utils/io.py b/lensless/utils/io.py index 62fd7f2b..3beafbbf 100644 --- a/lensless/utils/io.py +++ b/lensless/utils/io.py @@ -588,31 +588,39 @@ def save_image(img, fp, max_val=255, normalize=True): img_tmp = img.copy() - if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8: - img_tmp = img_tmp.astype(np.float32) - if normalize: + + if img_tmp.dtype == np.uint16 or img_tmp.dtype == np.uint8: + img_tmp = img_tmp.astype(np.float32) + img_tmp -= img_tmp.min() img_tmp /= img_tmp.max() - else: - normalized = False - if img_tmp.min() < 0: - img_tmp -= img_tmp.min() - normalize = True - if img_tmp.max() > 1: - img_tmp /= img_tmp.max() - normalize = True - if normalized: - print(f"Warning (out of range): {fp} normalizing data to [0, 1]") - - if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32: img_tmp *= max_val img_tmp = img_tmp.astype(np.uint8) - # RGB + else: + + if img_tmp.dtype == np.float64 or img_tmp.dtype == np.float32: + # check within [0, 1] and convert to uint8 + + normalized = False + if img_tmp.min() < 0: + img_tmp -= img_tmp.min() + normalized = True + if img_tmp.max() > 1: + img_tmp /= img_tmp.max() + normalized = True + if normalized: + print(f"Warning (out of range): {fp} normalizing data to [0, 1]") + img_tmp *= max_val + img_tmp = img_tmp.astype(np.uint8) + + # save if len(img_tmp.shape) == 3 and img_tmp.shape[2] == 3: + # RGB img_tmp = Image.fromarray(img_tmp) else: + # grayscale img_tmp = Image.fromarray(img_tmp.squeeze()) img_tmp.save(fp) diff --git a/rpi_requirements.txt b/rpi_requirements.txt index 6c7bf14b..01ccda68 100644 --- a/rpi_requirements.txt +++ b/rpi_requirements.txt @@ -5,4 +5,5 @@ matplotlib>=3.4.2 hydra-code paramiko numpy>=1.24.2 -scipy>=1.6.0 \ No newline at end of file +scipy>=1.6.0 +git+https://github.com/ebezzam/slm-controller.git \ No newline at end of file diff --git a/scripts/measure/analyze_measured_dataset.py b/scripts/measure/analyze_measured_dataset.py index 6137e176..2d5c7050 100644 --- a/scripts/measure/analyze_measured_dataset.py +++ b/scripts/measure/analyze_measured_dataset.py @@ -15,6 +15,19 @@ import matplotlib.pyplot as plt import time import tqdm +import re + + +def convert(text): + return int(text) if text.isdigit() else text.lower() + + +def alphanum_key(key): + return [convert(c) for c in re.split("([0-9]+)", key)] + + +def natural_sort(arr): + return sorted(arr, key=alphanum_key) @hydra.main(version_base=None, config_path="../../configs", config_name="analyze_dataset") @@ -24,13 +37,14 @@ def analyze_dataset(config): desired_range = config.desired_range delete_bad = config.delete_bad start_idx = config.start_idx + saturation_percent = config.saturation_percent assert ( folder is not None ), "Must specify folder to analyze in config or through command line (folder=PATH)." # get all PNG files in folder - files = sorted(glob.glob(os.path.join(folder, "*.png"))) + files = natural_sort(glob.glob(os.path.join(folder, "*.png"))) print("Found {} files".format(len(files))) if start_idx is not None: files = files[start_idx:] @@ -48,10 +62,9 @@ def analyze_dataset(config): im = np.array(Image.open(fn)) max_val = im.max() max_vals.append(max_val) + saturation_ratio = np.sum(im >= desired_range[1]) / im.size - # if out of desired range, print filename - if max_val < desired_range[0] or max_val > desired_range[1]: - # print("File {} has max value {}".format(fn, max_val)) + if max_val < desired_range[0]: n_bad_files += 1 bad_files.append(fn) @@ -61,6 +74,28 @@ def analyze_dataset(config): else: print("File {} has max value {}".format(fn, max_val)) + elif saturation_ratio > saturation_percent: + n_bad_files += 1 + bad_files.append(fn) + + if delete_bad: + os.remove(fn) + print("REMOVED file {}".format(fn)) + else: + print("File {} has saturation ratio {}".format(fn, saturation_ratio)) + + # # if out of desired range, print filename + # if max_val < desired_range[0] or saturation_ratio > saturation_percent: + # # print("File {} has max value {}".format(fn, max_val)) + # n_bad_files += 1 + # bad_files.append(fn) + + # if delete_bad: + # os.remove(fn) + # print("REMOVED file {}".format(fn)) + # else: + # print("File {} has max value {}".format(fn, max_val)) + proc_time = time.time() - start_time print("Went through {} files in {:.2f} seconds".format(len(files), proc_time)) print( @@ -69,6 +104,14 @@ def analyze_dataset(config): ) ) + # plot histogram + output_folder = os.getcwd() + output_fp = os.path.join(output_folder, "max_vals.png") + plt.hist(max_vals, bins=100) + plt.savefig(output_fp) + + print("Saved histogram to {}".format(output_fp)) + # command line input on whether to delete bad files if not delete_bad: response = None @@ -80,14 +123,6 @@ def analyze_dataset(config): else: print("Not deleting bad files") - # plot histogram - output_folder = os.getcwd() - output_fp = os.path.join(output_folder, "max_vals.png") - plt.hist(max_vals, bins=100) - plt.savefig(output_fp) - - print("Saved histogram to {}".format(output_fp)) - if __name__ == "__main__": analyze_dataset() diff --git a/scripts/measure/collect_dataset_on_device.py b/scripts/measure/collect_dataset_on_device.py index 76393b9c..96c7c727 100644 --- a/scripts/measure/collect_dataset_on_device.py +++ b/scripts/measure/collect_dataset_on_device.py @@ -14,6 +14,7 @@ import numpy as np import hydra +from hydra.utils import to_absolute_path import time import os import pathlib as plib @@ -67,8 +68,8 @@ def collect_dataset(config): start_idx = config.start_idx if os.path.exists(output_dir): files = list(plib.Path(output_dir).glob(f"*.{config.output_file_ext}")) - start_idx = len(files) - print("\nNumber of completed measurements :", start_idx) + n_completed_files = len(files) + print("\nNumber of completed measurements :", n_completed_files) output_dir = plib.Path(output_dir) if config.masks is not None: mask_dir = plib.Path(output_dir) / "masks" @@ -89,6 +90,26 @@ def collect_dataset(config): mask_vals = np.random.uniform(0, 1, config.masks.shape) np.save(mask_fp, mask_vals) + recon = None + if config.recon is not None: + print("Initializing ADMM recon...") + # initialize ADMM reconstruction + from lensless import ADMM + from lensless.utils.io import load_psf + + psf, bg = load_psf( + fp=to_absolute_path(config.recon.psf), + downsample=config.capture.down, # assume full resolution PSF + return_bg=True, + ) + + print("PSF shape: ", psf.shape) + recon = ADMM(psf, n_iter=config.recon.n_iter) + + recon_dir = plib.Path(output_dir) / "recon" + recon_dir.mkdir(exist_ok=True) + print("Finished initializing ADMM recon.") + # assert input directory exists assert os.path.exists(input_dir) @@ -234,8 +255,8 @@ def collect_dataset(config): # -- take picture max_pixel_val = 0 - fact_increase = 2 - fact_decrease = 1.5 + fact_increase = config.capture.fact_increase + fact_decrease = config.capture.fact_decrease n_tries = 0 camera.shutter_speed = init_shutter_speed @@ -256,6 +277,7 @@ def collect_dataset(config): # convert to RGB output = bayer2rgb_cc( output_bayer, + down=down, nbits=12, blue_gain=float(g[1]), red_gain=float(g[0]), @@ -264,10 +286,10 @@ def collect_dataset(config): nbits_out=8, ) - if down: - output = resize( - output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC - )[0] + # if down: + # output = resize( + # output[None, ...], factor=1 / down, interpolation=cv2.INTER_CUBIC + # )[0] # save image save_image(output, output_fp, normalize=False) @@ -289,32 +311,56 @@ def collect_dataset(config): elif max_pixel_val > MAX_LEVEL: - # decrease exposure - current_shutter_speed = int(current_shutter_speed / fact_decrease) - camera.shutter_speed = current_shutter_speed - time.sleep(config.capture.config_pause) - print(f"decreasing shutter speed to {current_shutter_speed}") - - # # decrease screen brightness - # current_screen_brightness = current_screen_brightness - 10 - # screen_res = np.array(config.display.screen_res) - # hshift = config.display.hshift - # vshift = config.display.vshift - # pad = config.display.pad - # brightness = current_screen_brightness - # display_image_path = config.display.output_fp - # rot90 = config.display.rot90 - # os.system( - # f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" - # ) - # print(f"decreasing screen brightness to {current_screen_brightness}") - - # time.sleep(config.display.delay) + if current_shutter_speed > 13098: # TODO: minimum for RPi HQ + # decrease exposure + current_shutter_speed = int(current_shutter_speed / fact_decrease) + camera.shutter_speed = current_shutter_speed + time.sleep(config.capture.config_pause) + print(f"decreasing shutter speed to {current_shutter_speed}") + + else: + + # decrease screen brightness + current_screen_brightness = current_screen_brightness - 10 + screen_res = np.array(config.display.screen_res) + hshift = config.display.hshift + vshift = config.display.vshift + pad = config.display.pad + brightness = current_screen_brightness + display_image_path = config.display.output_fp + rot90 = config.display.rot90 + + display_command = f"python scripts/measure/prep_display_image.py --fp {_file} --output_path {display_image_path} --screen_res {screen_res[0]} {screen_res[1]} --hshift {hshift} --vshift {vshift} --pad {pad} --brightness {brightness} --rot90 {rot90}" + if config.display.landscape: + display_command += " --landscape" + if config.display.image_res is not None: + display_command += f" --image_res {config.display.image_res[0]} {config.display.image_res[1]}" + # print(display_command) + os.system(display_command) + + time.sleep(config.display.delay) exposure_vals.append(current_shutter_speed / 1e6) brightness_vals.append(current_screen_brightness) n_tries_vals.append(n_tries) + if recon is not None: + + # normalize and remove background + output = output.astype(np.float32) + output /= output.max() + output -= bg + output = np.clip(output, a_min=0, a_max=output.max()) + + # set data + output = output[np.newaxis, :, :, :] + recon.set_data(output) + + # reconstruct and save + res = recon.apply() + recon_fp = recon_dir / output_fp.name + save_image(res, recon_fp) + # check if runtime is exceeded if config.runtime: proc_time = time.time() - start_time