From 7937bcec6422f238b8c349d68d90687e1f5805be Mon Sep 17 00:00:00 2001 From: Eric Bezzam Date: Fri, 6 Sep 2024 16:25:24 +0000 Subject: [PATCH] Clean up example script. --- scripts/recon/hyperspectral.py | 80 +++------------------------------- 1 file changed, 6 insertions(+), 74 deletions(-) diff --git a/scripts/recon/hyperspectral.py b/scripts/recon/hyperspectral.py index a3766c89..d5d2f6d3 100644 --- a/scripts/recon/hyperspectral.py +++ b/scripts/recon/hyperspectral.py @@ -16,10 +16,7 @@ import matplotlib.pyplot as plt from lensless.utils.io import load_data from lensless import ( - GradientDescentUpdate, - GradientDescent, - NesterovGradientDescent, - FISTA, + HyperSpectralFISTA, ) @@ -60,13 +57,12 @@ def gradient_descent( data = np.expand_dims(data, axis=0) data = np.expand_dims(data, axis=-1) - # apply gradient descent - from lensless import HyperSpectralFISTA - + # apply FISTA save = config["save"] if save: save = os.getcwd() + start_time = time.time() recon = HyperSpectralFISTA( psf, mask, @@ -74,6 +70,9 @@ def gradient_descent( norm="ortho", ) recon.set_data(data) + print(f"Setup time : {time.time() - start_time} s") + + start_time = time.time() res = recon.apply( n_iter=500, disp_iter=50, @@ -81,73 +80,6 @@ def gradient_descent( gamma=1.0, plot=False, ) - - if config.torch: - img = res[0].cpu().numpy() - else: - img = res[0] - - if config["display"]["plot"]: - plt.show() - if save: - np.save(plib.Path(save) / "final_reconstruction.npy", img) - print(f"Files saved to : {save}") - - raise ValueError - - psf, data = load_data( - psf_fp=to_absolute_path(config.input.psf), - data_fp=to_absolute_path(config.input.data), - dtype=config.input.dtype, - downsample=config["preprocess"]["downsample"], - bayer=config["preprocess"]["bayer"], - blue_gain=config["preprocess"]["blue_gain"], - red_gain=config["preprocess"]["red_gain"], - plot=config["display"]["plot"], - flip=config["preprocess"]["flip"], - gamma=config["display"]["gamma"], - gray=config["preprocess"]["gray"], - single_psf=config["preprocess"]["single_psf"], - shape=config["preprocess"]["shape"], - use_torch=config.torch, - torch_device=config.torch_device, - ) - - disp = config["display"]["disp"] - if disp < 0: - disp = None - - save = config["save"] - if save: - save = os.getcwd() - - start_time = time.time() - - if config["gradient_descent"]["method"] == GradientDescentUpdate.VANILLA: - recon = GradientDescent(psf) - elif config["gradient_descent"]["method"] == GradientDescentUpdate.NESTEROV: - recon = NesterovGradientDescent( - psf, - p=config["gradient_descent"]["nesterov"]["p"], - mu=config["gradient_descent"]["nesterov"]["mu"], - ) - else: - recon = FISTA( - psf, - tk=config["gradient_descent"]["fista"]["tk"], - ) - - recon.set_data(data) - print(f"Setup time : {time.time() - start_time} s") - - start_time = time.time() - res = recon.apply( - n_iter=config["gradient_descent"]["n_iter"], - disp_iter=disp, - save=save, - gamma=config["display"]["gamma"], - plot=config["display"]["plot"], - ) print(f"Processing time : {time.time() - start_time} s") if config.torch: