Skip to content

Commit

Permalink
Clean up example script.
Browse files Browse the repository at this point in the history
  • Loading branch information
ebezzam committed Sep 6, 2024
1 parent 189aaca commit 7937bce
Showing 1 changed file with 6 additions and 74 deletions.
80 changes: 6 additions & 74 deletions scripts/recon/hyperspectral.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@
import matplotlib.pyplot as plt
from lensless.utils.io import load_data
from lensless import (
GradientDescentUpdate,
GradientDescent,
NesterovGradientDescent,
FISTA,
HyperSpectralFISTA,
)


Expand Down Expand Up @@ -60,94 +57,29 @@ def gradient_descent(
data = np.expand_dims(data, axis=0)
data = np.expand_dims(data, axis=-1)

# apply gradient descent
from lensless import HyperSpectralFISTA

# apply FISTA
save = config["save"]
if save:
save = os.getcwd()

start_time = time.time()
recon = HyperSpectralFISTA(
psf,
mask,
# norm=None,
norm="ortho",
)
recon.set_data(data)
print(f"Setup time : {time.time() - start_time} s")

start_time = time.time()
res = recon.apply(
n_iter=500,
disp_iter=50,
save=save,
gamma=1.0,
plot=False,
)

if config.torch:
img = res[0].cpu().numpy()
else:
img = res[0]

if config["display"]["plot"]:
plt.show()
if save:
np.save(plib.Path(save) / "final_reconstruction.npy", img)
print(f"Files saved to : {save}")

raise ValueError

psf, data = load_data(
psf_fp=to_absolute_path(config.input.psf),
data_fp=to_absolute_path(config.input.data),
dtype=config.input.dtype,
downsample=config["preprocess"]["downsample"],
bayer=config["preprocess"]["bayer"],
blue_gain=config["preprocess"]["blue_gain"],
red_gain=config["preprocess"]["red_gain"],
plot=config["display"]["plot"],
flip=config["preprocess"]["flip"],
gamma=config["display"]["gamma"],
gray=config["preprocess"]["gray"],
single_psf=config["preprocess"]["single_psf"],
shape=config["preprocess"]["shape"],
use_torch=config.torch,
torch_device=config.torch_device,
)

disp = config["display"]["disp"]
if disp < 0:
disp = None

save = config["save"]
if save:
save = os.getcwd()

start_time = time.time()

if config["gradient_descent"]["method"] == GradientDescentUpdate.VANILLA:
recon = GradientDescent(psf)
elif config["gradient_descent"]["method"] == GradientDescentUpdate.NESTEROV:
recon = NesterovGradientDescent(
psf,
p=config["gradient_descent"]["nesterov"]["p"],
mu=config["gradient_descent"]["nesterov"]["mu"],
)
else:
recon = FISTA(
psf,
tk=config["gradient_descent"]["fista"]["tk"],
)

recon.set_data(data)
print(f"Setup time : {time.time() - start_time} s")

start_time = time.time()
res = recon.apply(
n_iter=config["gradient_descent"]["n_iter"],
disp_iter=disp,
save=save,
gamma=config["display"]["gamma"],
plot=config["display"]["plot"],
)
print(f"Processing time : {time.time() - start_time} s")

if config.torch:
Expand Down

0 comments on commit 7937bce

Please sign in to comment.