diff --git a/main_script_optim.py b/main_script_optim.py deleted file mode 100644 index 58f353b..0000000 --- a/main_script_optim.py +++ /dev/null @@ -1,329 +0,0 @@ -from simca import load_yaml_config -from simca.CassiSystemOptim import CassiSystemOptim -from simca.CassiSystem import CassiSystem -import numpy as np -import snoop -import matplotlib.pyplot as plt -import matplotlib.animation as anim -#import matplotlib -import torch -import time, datetime -import os -from pprint import pprint -from simca.cost_functions import evaluate_slit_scanning_straightness, evaluate_center, evaluate_mean_lighting, evaluate_max_lighting -from simca.functions_optim import optim_smile, optim_width - -#matplotlib.use('Agg') -config_dataset = load_yaml_config("simca/configs/dataset.yml") -config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_max_center.yml") -config_patterns = load_yaml_config("simca/configs/pattern.yml") -config_acquisition = load_yaml_config("simca/configs/acquisition.yml") - -dataset_name = "washington_mall" - -test = "SMILE" - -algo = "ADAM" - -if test=="SMILE": - config_system = load_yaml_config("simca/configs/cassi_system_simple_optim.yml") - aspect = 0.2 -elif test=="EQUAL_LIGHT" or test=="MAX_CENTER": - config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_max_center.yml") - aspect = 1 -elif test == "SMILE_mono": - config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_smile_mono.yml") - aspect = 1 -elif test == "MAX_LIGHT": - config_system = load_yaml_config("simca/configs/cassi_system_simple_optim.yml") - aspect = 1 - -config_system = load_yaml_config("simca/configs/cassi_system_optim_optics_full_triplet.yml") -#config_system = load_yaml_config("simca/configs/cassi_system_satur_test.yml") - -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - -if __name__ == '__main__': - time_start = time.time() - # Initialize the CASSI system - cassi_system = CassiSystemOptim(system_config=config_system) - - # time0 = time.time() - # DATASET : Load the hyperspectral dataset - list_dataset_data = cassi_system.load_dataset(dataset_name, config_dataset["datasets directory"]) - """ plt.imshow(np.sum(list_dataset_data[0], axis=2)) - plt.show() """ - - # Loop beginning if optics optim. - cassi_system.update_optical_model(system_config=config_system) - X_vec_out, Y_vec_out = cassi_system.propagate_coded_aperture_grid() - sigma = 1.5 - - cassi_system.X_coordinates_propagated_coded_aperture = cassi_system.X_coordinates_propagated_coded_aperture.to(device) - cassi_system.Y_coordinates_propagated_coded_aperture = cassi_system.Y_coordinates_propagated_coded_aperture.to(device) - cassi_system.X_detector_coordinates_grid = cassi_system.X_detector_coordinates_grid.to(device) - cassi_system.Y_detector_coordinates_grid = cassi_system.Y_detector_coordinates_grid.to(device) - - num_iterations = 1000 # Define num_iterations as needed - - first_pos = 0.38 - last_pos = 0.8 - step_pos = 0.2/1 - - pattern_pos = [0.68, 0.58, 0.48, 0.38] - pos_slit_detector_list = [20/145, 40/145, 60/145, 80/145] - image_counter = 0 - - - - gen_randint = torch.Generator() - gen_randint.manual_seed(2009) - list_of_rand_seeds = torch.randint(low=0, high=2009, size=(100,), generator=gen_randint) - for seed_i in range(0, len(list_of_rand_seeds)): - seed = int(list_of_rand_seeds[seed_i]) - - pattern_pos = [0.76] - pattern_pos = [0.2] - pattern_pos = [0.68] - pos_slit_detector_list = [50/145] # 64 if pixel_size_Y=200 - pos_slit_detector_list = [33/145] - patterns1 = [] - corrected_patterns1 = [] - smile_positions = [] - corrected_smile_positions = [] - patterns2 = [] - start_width_values = [] - width_values = [] - cubes1 = [] - corrected_cubes1 = [] - cubes2 = [] - acquisitions = [] - - prev_position = None - - position = pattern_pos[0] - pos_slit_detector = pos_slit_detector_list[0] - - image_counter += 1 - print(f"===== Start of image acquisition {image_counter} =====") - max_iter_cnt = 25 - - cassi_system = CassiSystemOptim(system_config=config_system) - cassi_system.device = device - - # time0 = time.time() - # DATASET : Load the hyperspectral dataset - cassi_system.load_dataset(dataset_name, config_dataset["datasets directory"]) - - # Loop beginning if optics optim. - cassi_system.update_optical_model(system_config=config_system) - X_vec_out, Y_vec_out = cassi_system.propagate_coded_aperture_grid() - sigma = 1.5 - - cassi_system.X_coordinates_propagated_coded_aperture = cassi_system.X_coordinates_propagated_coded_aperture.to(device) - cassi_system.Y_coordinates_propagated_coded_aperture = cassi_system.Y_coordinates_propagated_coded_aperture.to(device) - cassi_system.X_detector_coordinates_grid = cassi_system.X_detector_coordinates_grid.to(device) - cassi_system.Y_detector_coordinates_grid = cassi_system.Y_detector_coordinates_grid.to(device) - - # Adjust the learning rate - if algo == "LBFGS": - lr = 0.005 # default: 0.05 - elif algo == "ADAM": - lr = 0.005 # default: 0.005 - - """ if position == 0.5: - pos_slit_detector = 0.41 - elif position == 0.7: - pos_slit_detector = 0.124 """ - - data = np.load(f"./results/24-03-04_19h33/results.npz") - if data is None: - cassi_system = optim_smile(cassi_system, position, pos_slit_detector, sigma, device, algo, lr, num_iterations, max_iter_cnt, prev_position = prev_position, plot_frequency=None) - - pattern = cassi_system.pattern.detach().to('cpu').numpy() - cube = cassi_system.filtering_cube.detach().to('cpu').numpy() - - patterns1.append(pattern) - cubes1.append(cube) - start_position = cassi_system.array_x_positions.detach().to('cpu').numpy() - smile_positions.append(start_position) - - diffs = np.diff(start_position) - diffs_ind = np.nonzero(diffs)[0] - pos_middle = start_position[diffs_ind.min()+1:diffs_ind.max()+1] - poly_coeffs = np.polyfit(np.linspace(1,2, len(pos_middle)), pos_middle, deg = 2) - poly = np.poly1d(poly_coeffs) - start_position[diffs_ind.min()+1:diffs_ind.max()+1] = poly(np.linspace(1,2, len(pos_middle))) - - corrected_smile_positions.append(start_position) - - start_position = torch.tensor(start_position) - - cassi_system.array_x_positions.data = start_position - cassi_system.generate_custom_slit_pattern() - cassi_system.generate_filtering_cube() - - pattern = cassi_system.pattern.detach().to('cpu').numpy() - cube = cassi_system.filtering_cube.detach().to('cpu').numpy() - corrected_patterns1.append(pattern) - corrected_cubes1.append(cube) - - prev_position = (cassi_system.array_x_positions.detach()-position) - else: - patterns1.append(data["patterns_smile"][0]) - cubes1.append(data["cubes_smile"][0]) - smile_positions.append(data["smile_positions"][0]) - - corrected_smile_positions.append(data["corrected_smile_positions"][0]) - start_position = torch.tensor(data["corrected_smile_positions"][0]) - - corrected_patterns1.append(data["corrected_patterns_smile"][0]) - corrected_cubes1.append(data["corrected_cubes_smile"][0]) - - prev_position = (start_position - position) - - # Adjust the learning rate - if algo == "LBFGS": - lr = 0.1 # default: 0.05 - elif algo == "ADAM": - lr = 0.01 # default: 0.01 - - target = 100000 - max_iter_cnt = 25 - - #num_iterations = 1 - gen = torch.Generator() - gen.manual_seed(seed) - start_width = torch.rand(size=(1,cassi_system.system_config["detector"]["number of pixels along Y"]), generator=gen)/3 - start_width_values.append(start_width.detach().to('cpu').numpy()) - - # Create first histogram - cassi_system.generate_custom_pattern_parameters_slit_width(nb_slits=1, nb_rows=cassi_system.system_config["coded aperture"]["number of pixels along Y"], start_width = start_width) - cassi_system.generate_custom_slit_pattern_width(start_pattern = "corrected", start_position = start_position) - cassi_system.generate_filtering_cube() - cassi_system.filtering_cube = cassi_system.filtering_cube.to(device) - acq = cassi_system.image_acquisition(use_psf=False, chunck_size=cassi_system.system_config["detector"]["number of pixels along Y"]).detach().to('cpu').numpy() - fig_first_histo = plt.figure() - #plt.imshow(torch.sum(cassi_system.dataset, dim=2)) - #plt.imshow(acq, aspect=0.2) - plt.hist(acq[acq>100].flatten(), bins=100) - - # Run optimization - cassi_system = optim_width(cassi_system, start_position, target, cassi_system.system_config["coded aperture"]["number of pixels along Y"], start_width, device, algo, lr, num_iterations, max_iter_cnt, plot_frequency = None) - - pattern = cassi_system.pattern.detach().to('cpu').numpy() - cube = cassi_system.filtering_cube.detach().to('cpu').numpy() - acquisition = cassi_system.measurement.detach().to('cpu').numpy() - - patterns2.append(pattern) - cubes2.append(cube) - acquisitions.append(acquisition) - width_values.append(cassi_system.array_x_positions.detach().to('cpu').numpy()) - - - - #print(torch.std(cassi_system.measurement.detach())) - - print(f"Exec time: {time.time() - time_start}s") - - fig1 = plt.figure() - im1 = plt.imshow(patterns1[0], animated = True, aspect=aspect) - plt.colorbar() - - fig1bis = plt.figure() - im1bis = plt.imshow(corrected_patterns1[0], animated = True, aspect=aspect) - plt.colorbar() - - fig2 = plt.figure() - im2 = plt.imshow(cubes1[0][:,:,cubes1[0].shape[2]//2], animated = True, aspect=aspect) - plt.colorbar() - - fig2bis = plt.figure() - im2bis = plt.imshow(corrected_cubes1[0][:,:,corrected_cubes1[0].shape[2]//2], animated = True, aspect=aspect) - plt.colorbar() - - fig3 = plt.figure() - im3 = plt.imshow(patterns2[0], animated = True, aspect=aspect) - plt.colorbar() - - fig4 = plt.figure() - im4 = plt.imshow(cubes2[0][:,:,cubes2[0].shape[2]//2], animated = True, aspect=aspect) - plt.colorbar() - - fig5 = plt.figure() - im5 = plt.imshow(np.clip(acquisitions[0], 1, None), animated = True, aspect=aspect, cmap="gray", norm="log") - plt.colorbar() - - def update1(i): - im1.set_array(patterns1[i]) - return im1, - def update1bis(i): - im1bis.set_array(corrected_patterns1[i]) - return im1bis, - def update2(i): - im2.set_array(cubes1[i][:,:,cubes1[0].shape[2]//2]) - return im2, - def update2bis(i): - im2bis.set_array(corrected_cubes1[i][:,:,corrected_cubes1[0].shape[2]//2]) - return im2bis, - def update3(i): - im3.set_array(patterns2[i]) - return im3, - def update4(i): - im4.set_array(cubes2[i][:,:,cubes2[0].shape[2]//2]) - return im4, - def update5(i): - im5.set_array(acquisitions[i]) - return im5, - - animation_fig1 = anim.FuncAnimation(fig1, update1, frames=len(patterns1), interval = 1000, repeat=True) - animation_fig1bis = anim.FuncAnimation(fig1bis, update1bis, frames=len(corrected_patterns1), interval = 1000, repeat=True) - animation_fig2 = anim.FuncAnimation(fig2, update2, frames=len(cubes1), interval = 1000, repeat=True) - animation_fig2bis = anim.FuncAnimation(fig2bis, update2bis, frames=len(corrected_cubes1), interval = 1000, repeat=True) - animation_fig3 = anim.FuncAnimation(fig3, update3, frames=len(patterns2), interval = 1000, repeat=True) - animation_fig4 = anim.FuncAnimation(fig4, update4, frames=len(cubes2), interval = 1000, repeat=True) - animation_fig5 = anim.FuncAnimation(fig5, update5, frames=len(acquisitions), interval = 1000, repeat=True) - - #print("Var: ", np.var(acquisitions[0][int(pos_slit_detector*145)-2:int(pos_slit_detector*145)+2].flatten())) - print("Var: ", np.var(acquisitions[0][acquisitions[0]>100].flatten())) - - - fig6 = plt.figure() - plt.hist(acquisitions[0][acquisitions[0]>100].flatten(), bins=100) - - plt.show() - - - - - folder = datetime.datetime.now().strftime('%y-%m-%d_%Hh%M') - os.makedirs(f"./results/{folder}") - - fig6.savefig(f"./results/{folder}/histo.png") - fig_first_histo.savefig(f"./results/{folder}/first_histo.png") - fig1.savefig(f"./results/{folder}/patterns_smile.png") - fig1bis.savefig(f"./results/{folder}/corrected_patterns_smile.png") - fig2.savefig(f"./results/{folder}/cubes_smile.png") - fig2bis.savefig(f"./results/{folder}/corrected_cubes_smile.png") - fig3.savefig(f"./results/{folder}/patterns_width.png") - fig4.savefig(f"./results/{folder}/cubes_width.png") - fig5.savefig(f"./results/{folder}/acquisitions.png") - #animation_fig1.save(f"./results/{folder}/patterns_smile.gif") - #animation_fig1bis.save(f"./results/{folder}/corrected_patterns_smile.gif") - #animation_fig2.save(f"./results/{folder}/cubes_smile.gif") - #animation_fig2bis.save(f"./results/{folder}/corrected_cubes_smile.gif") - #animation_fig3.save(f"./results/{folder}/patterns_width.gif") - #animation_fig4.save(f"./results/{folder}/cubes_width.gif") - #animation_fig5.save(f"./results/{folder}/acquisitions.gif") - - """np.savez(f"./results/{folder}/results.npz", smile_positions=np.stack(smile_positions, axis=0), patterns_smile=np.stack(patterns1, axis=0), cubes_smile = np.stack(cubes1, axis=0), - corrected_smile_positions=np.stack(corrected_smile_positions, axis=0), corrected_patterns_smile=np.stack(corrected_patterns1, axis=0), corrected_cubes_smile=np.stack(corrected_cubes1, axis=0), - start_width_values=np.stack(start_width_values, axis=0), width_values=np.stack(width_values, axis=0), patterns_width=np.stack(patterns2, axis=0), cubes_width = np.stack(cubes2, axis=0), - acquisitions=np.stack(acquisitions, axis=0), - variance=np.var(acquisitions[0][acquisitions[0]>500].flatten()))""" - np.savez(f"./results/{folder}/results.npz", smile_positions=np.stack(smile_positions, axis=0), patterns_smile=np.stack(patterns1, axis=0), - corrected_smile_positions=np.stack(corrected_smile_positions, axis=0), corrected_patterns_smile=np.stack(corrected_patterns1, axis=0), - start_width_values=np.stack(start_width_values, axis=0), width_values=np.stack(width_values, axis=0), patterns_width=np.stack(patterns2, axis=0), - acquisitions=np.stack(acquisitions, axis=0), - variance=np.var(acquisitions[0][acquisitions[0]>500].flatten())) - #np.savez(f"./results/{folder}/results.npz", patterns_smile=np.stack(patterns1, axis=0), cubes_smile = np.stack(cubes1, axis=0)) diff --git a/optimization_modules_with_resnet_v2.py b/optimization_modules_with_resnet_v2.py deleted file mode 100755 index 62081bf..0000000 --- a/optimization_modules_with_resnet_v2.py +++ /dev/null @@ -1,385 +0,0 @@ -import pytorch_lightning as pl -import torch -import torch.nn as nn -from simca.CassiSystem_lightning import CassiSystemOptim -from MST.simulation.train_code.architecture import * -from simca import load_yaml_config -import matplotlib.pyplot as plt -import torchvision -import numpy as np -from simca.functions_acquisition import * -from piqa import SSIM -from torch.utils.tensorboard import SummaryWriter -import io -import torchvision.transforms as transforms -from PIL import Image -import segmentation_models_pytorch as smp -import torch.nn.functional as F - -class UnetModel(nn.Module): - def __init__(self,encoder_name="resnet18",encoder_weights="",in_channels=1,classes=2,index=0): - super().__init__() - self.i= index - self.model= smp.Unet(encoder_name= encoder_name, in_channels=in_channels,encoder_weights=encoder_weights,classes=classes,activation='sigmoid') - def forward(self,x): - x= self.model(x) - return x - -class JointReconstructionModule_V3(pl.LightningModule): - def __init__(self, recon_lightning_module, log_dir="tb_logs",resnet_checkpoint=None): - super().__init__() - - self.reconstruction_module = recon_lightning_module - self.mask_generation = UnetModel(classes=1,encoder_weights=None,in_channels=1) - - if resnet_checkpoint is not None: - # Load the weights from the checkpoint into self.seg_model - checkpoint = torch.load(resnet_checkpoint, map_location=self.device) - # Adjust the keys - adjusted_state_dict = {key.replace('mask_generation.', ''): value - for key, value in checkpoint['state_dict'].items()} - # Filter out unexpected keys - model_keys = set(self.mask_generation.state_dict().keys()) - filtered_state_dict = {k: v for k, v in adjusted_state_dict.items() if k in model_keys} - self.mask_generation.load_state_dict(filtered_state_dict) - - # Freeze the seg_model parameters - # for param in self.mask_generation.parameters(): - # param.requires_grad = False - - self.loss_fn = nn.MSELoss() - self.ssim_loss = SSIM(window_size=11, n_channels=28) - self.reconstruction_module.ssim_loss = SSIM(window_size=11, n_channels=28) - - self.writer = SummaryWriter(log_dir) - - # for param in self.reconstruction_model.parameters(): - # param.requires_grad = False - - def on_validation_start(self,stage=None): - print("---VALIDATION START---") - self.config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml" - config_system = load_yaml_config(self.config) - self.reconstruction_module.config_patterns = load_yaml_config("simca/configs/pattern.yml") - self.reconstruction_module.cassi_system = CassiSystemOptim(system_config=config_system) - self.reconstruction_module.cassi_system.propagate_coded_aperture_grid() - - self.first_config_patterns = load_yaml_config("simca/configs/pattern.yml") - self.first_cassi_system = CassiSystemOptim(system_config=config_system) - self.first_cassi_system.propagate_coded_aperture_grid() - - - def _normalize_data_by_itself(self, data): - # Calculate the mean and std for each batch individually - # Keep dimensions for broadcasting - mean = torch.mean(data, dim=[1, 2], keepdim=True) - std = torch.std(data, dim=[1, 2], keepdim=True) - - # Normalize each batch by its mean and std - normalized_data = (data - mean) / std - return normalized_data - - def forward(self, x): - print("---FORWARD---") - - hyperspectral_cube, wavelengths = x - hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device) - batch_size, H, W, C = hyperspectral_cube.shape - - - self.pattern = self.first_cassi_system.generate_2D_pattern(self.first_config_patterns, nb_of_patterns=batch_size) - self.pattern = self.pattern.to(self.device) - filtering_cube = self.first_cassi_system.generate_filtering_cube().to(self.device) - self.acquired_image1 = self.first_cassi_system.image_acquisition(hyperspectral_cube, self.pattern, wavelengths).to(self.device) - - self.acquired_image1 = pad_tensor(self.acquired_image1, (128, 128)) - - # flip along second and thrid axis - self.acquired_image1 = self.acquired_image1.flip(1) - self.acquired_image1 = self.acquired_image1.flip(2) - self.acquired_image1 = self.acquired_image1.unsqueeze(1).float() - #self.acquired_image1 = self._normalize_data_by_itself(self.acquired_image1) - - self.pattern = self.mask_generation(self.acquired_image1).squeeze(1) - self.pattern = BinarizeFunction.apply(self.pattern) - self.pattern = pad_tensor(self.pattern, (131, 131)) - #self.reconstruction_module.pattern = self.pattern.to(self.device) - #self.reconstruction_module.cassi_system.pattern = self.pattern.to(self.device) - - reconstructed_cube = self.reconstruction_module.forward(x, self.pattern) - - self.acquired_image2 = self.reconstruction_module.acquired_image1 - - return reconstructed_cube - - def training_step(self, batch, batch_idx): - print("Training step") - - loss, ssim_loss, reconstructed_cube, ref_cube = self._common_step(batch, batch_idx) - - - - output_images = self._convert_output_to_images(self._normalize_image_tensor(self.acquired_image2)) - patterns = self._convert_output_to_images(self._normalize_image_tensor(self.pattern)) - input_images = self._convert_output_to_images(self._normalize_image_tensor(ref_cube[:,:,:,0])) - reconstructed_image = self._convert_output_to_images(self._normalize_image_tensor(reconstructed_cube[:,:,:,0])) - - if self.global_step % 30 == 0: - self._log_images('train/acquisition2', output_images, self.global_step) - self._log_images('train/ground_truth', input_images, self.global_step) - self._log_images('train/reconstructed', reconstructed_image, self.global_step) - self._log_images('train/patterns', patterns, self.global_step) - - plt.imshow(self.pattern[0,:,:].cpu().detach().numpy()) - plt.colorbar() - plt.savefig('./pattern.png') - - spectral_filter_plot = self.plot_spectral_filter(ref_cube,reconstructed_cube) - self.log_gradients(self.global_step) - self.writer.add_image('Spectral Filter', spectral_filter_plot, self.global_step) - - self.log_dict( - { "train_loss": loss, - }, - on_step=True, - on_epoch=True, - prog_bar=True, - ) - - self.log_dict( - { "train_ssim_loss": ssim_loss, - }, - on_step=True, - on_epoch=True, - prog_bar=True, - ) - - return {"loss": loss} - - def _normalize_image_tensor(self, tensor): - # Normalize the tensor to the range [0, 1] - min_val = tensor.min() - max_val = tensor.max() - normalized_tensor = (tensor - min_val) / (max_val - min_val) - return normalized_tensor - - def validation_step(self, batch, batch_idx): - - print("Validation step") - loss, ssim_loss, reconstructed_cube, ref_cube= self._common_step(batch, batch_idx) - - self.log_dict( - { "val_loss": loss, - }, - on_step=True, - on_epoch=True, - prog_bar=True, - ) - - self.log_dict( - { "val_ssim_loss": ssim_loss, - }, - on_step=True, - on_epoch=True, - prog_bar=True, - ) - - return {"loss": loss} - - def test_step(self, batch, batch_idx): - print("Test step") - loss, ssim_loss, reconstructed_cube, ref_cube= self._common_step(batch, batch_idx) - self.log_dict( - { "test_loss": loss, - }, - on_step=True, - on_epoch=True, - prog_bar=True, - ) - - self.log_dict( - { "test_ssim_loss": ssim_loss, - }, - on_step=True, - on_epoch=True, - prog_bar=True, - ) - - return {"loss": loss} - - def predict_step(self, batch, batch_idx): - print("Predict step") - loss, ssim_loss, reconstructed_cube, ref_cube= self._common_step(batch, batch_idx) - print("Predict loss: ", loss.item()) - print("Predict ssim loss: ", ssim_loss) - #self.log('predict_step', loss,on_step=True, on_epoch=True, prog_bar=True, logger=True) - return loss - - def _common_step(self, batch, batch_idx): - - reconstructed_cube = self.forward(batch) - hyperspectral_cube, wavelengths = batch - - hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device) - reconstructed_cube = reconstructed_cube.permute(0, 2, 3, 1).to(self.device) - ref_cube = match_dataset_to_instrument(hyperspectral_cube, reconstructed_cube[0, :, :,0]) - - # fig, ax = plt.subplots(1, 2) - # plt.title(f"true cube vs reconstructed cube") - # ax[0].imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy()) - # ax[1].imshow(reconstructed_cube[0, :, :, 0].cpu().detach().numpy()) - # plt.show() - total_sum_pattern = torch.sum(self.pattern, dim=(1, 2)) - total_half_pattern_equal_1 = torch.sum(torch.ones_like(self.pattern), dim=(1, 2)) / 2 - - print(f"total_sum_pattern {total_sum_pattern}") - print(f"total_half_pattern_equal_1 {total_half_pattern_equal_1}") - - - - loss1 = torch.sqrt(self.loss_fn(reconstructed_cube, ref_cube)) - loss2 = torch.sum(torch.abs((total_sum_pattern - total_half_pattern_equal_1)/(self.pattern.shape[1]*self.pattern.shape[2]))**2) - loss = loss1 - - ssim_loss = self.ssim_loss(torch.clamp(reconstructed_cube.permute(0, 3, 1, 2), 0, 1), ref_cube.permute(0, 3, 1, 2)) - print(f"loss1 {loss1}") - print(f"loss2 {loss2}") - return loss, ssim_loss, reconstructed_cube, ref_cube - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters(), lr=1e-4) - return { "optimizer":optimizer, - "lr_scheduler":{ - "scheduler":torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 500, eta_min=1e-6), - "interval": "epoch" - } - } - - def _log_images(self, tag, images, global_step): - # Convert model output to image grid and log to TensorBoard - img_grid = torchvision.utils.make_grid(images) - self.writer.add_image(tag, img_grid, global_step) - - def _convert_output_to_images(self, acquired_images): - - acquired_images = acquired_images.unsqueeze(1) - - # Create a grid of images for visualization - img_grid = torchvision.utils.make_grid(acquired_images) - return img_grid - - def log_gradients(self, step): - for name, param in self.mask_generation.named_parameters(): - if param.grad is not None: - self.writer.add_scalar(f"Gradients/{name}", param.grad.norm(), step) - - - def plot_spectral_filter(self,ref_hyperspectral_cube,recontructed_hyperspectral_cube): - - - batch_size, y,x, lmabda_ = ref_hyperspectral_cube.shape - - # Create a figure with subplots arranged horizontally - fig, axs = plt.subplots(1, batch_size, figsize=(batch_size * 5, 4)) # Adjust figure size as needed - - # Check if batch_size is 1, axs might not be iterable - if batch_size == 1: - axs = [axs] - - # Plot each spectral filter in its own subplot - for i in range(batch_size): - colors = ['b', 'g', 'r'] - for j in range(3): - pix_j_row_value = np.random.randint(0,y) - pix_j_col_value = np.random.randint(0,x) - - pix_j_ref = ref_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy() - pix_j_reconstructed = recontructed_hyperspectral_cube[i, pix_j_row_value,pix_j_col_value,:].cpu().detach().numpy() - axs[i].plot(pix_j_reconstructed, label="pix reconstructed" + str(j),c=colors[j]) - axs[i].plot(pix_j_ref, label="pix" + str(j), linestyle='--',c=colors[j]) - - axs[i].set_title(f"Reconstruction quality") - - axs[i].set_xlabel("Wavelength index") - axs[i].set_ylabel("pix values") - axs[i].grid(True) - - plt.legend() - # Adjust layout - plt.tight_layout() - - # Create a buffer to save plot - buf = io.BytesIO() - plt.savefig(buf, format='png') - plt.close(fig) - buf.seek(0) - - # Convert PNG buffer to PIL Image - image = Image.open(buf) - - # Convert PIL Image to Tensor - image_tensor = transforms.ToTensor()(image) - return image_tensor - - -def subsample(input, origin_sampling, target_sampling): - [bs, row, col, nC] = input.shape - indices = torch.zeros(len(target_sampling), dtype=torch.int) - for i in range(len(target_sampling)): - sample = target_sampling[i] - idx = torch.abs(origin_sampling-sample).argmin() - indices[i] = idx - return input[:,:,:,indices] - -def expand_mask_3d(mask_batch): - if len(mask_batch.shape)==3: - mask3d = mask_batch.unsqueeze(-1).repeat((1, 1, 1, 28)) - else: - mask3d = mask_batch.repeat((1, 1, 1, 28)) - mask3d = torch.permute(mask3d, (0, 3, 1, 2)) - return mask3d - -class EmptyModule(nn.Module): - def __init__(self): - super().__init__() - self.useless_linear = nn.Linear(1, 1) - def forward(self, x): - return x - - -def pad_tensor(input_tensor, target_shape): - [bs, row, col] = input_tensor.shape - [target_row, target_col] = target_shape - - # Calculate padding for rows - pad_row_total = max(target_row - row, 0) - pad_row_top = pad_row_total // 2 - pad_row_bottom = pad_row_total - pad_row_top - - # Calculate padding for columns - pad_col_total = max(target_col - col, 0) - pad_col_left = pad_col_total // 2 - pad_col_right = pad_col_total - pad_col_left - - # Apply padding - padded_tensor = F.pad(input_tensor, (pad_col_left, pad_col_right, pad_row_top, pad_row_bottom), 'constant', 0) - return padded_tensor - -def crop_tensor(input_tensor, target_shape): - [bs, row, col] = input_tensor.shape - [target_row, target_col] = target_shape - crop_row = (row-target_row)//2 - crop_col = (col-target_col)//2 - return input_tensor[:,crop_row:crop_row+target_row,crop_col:crop_col+target_col] - - -class BinarizeFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input): - # Forward pass is the binary threshold operation - return (input > 0.5).float() - - @staticmethod - def backward(ctx, grad_output): - # For backward pass, just pass the gradients through unchanged - return grad_output diff --git a/playing_with_TF.py b/playing_with_TF.py deleted file mode 100755 index b873dd9..0000000 --- a/playing_with_TF.py +++ /dev/null @@ -1,34 +0,0 @@ -import matplotlib.pyplot as plt - -from simca.functions_general_purpose import load_yaml_config -from simca.CassiSystemOptim import CassiSystemOptim -import torch - -config = "simca/configs/cassi_system_optim_optics_full_triplet_dd_cassi.yml" -config_patterns = load_yaml_config("simca/configs/pattern.yml") - -config_system = load_yaml_config(config) -cassi_system = CassiSystemOptim(system_config=config_system) -cassi_system.propagate_coded_aperture_grid() - -pattern = cassi_system.generate_2D_pattern(config_patterns,nb_of_patterns=1) - - -# Compute the 2D FFT of the tensor -fft_result = torch.fft.fft2(pattern) - -# Calculate the Power Spectrum -power_spectrum = torch.abs(fft_result)**2 - -# Calculate the Geometric Mean of the power spectrum -# Use torch.log and torch.exp for differentiability, adding a small epsilon to avoid log(0) -epsilon = 1e-10 -geometric_mean = torch.exp(torch.mean(torch.log(power_spectrum + epsilon))) - -# Calculate the Arithmetic Mean of the power spectrum -arithmetic_mean = torch.mean(power_spectrum) - -# Compute the Spectral Flatness -spectral_flatness = geometric_mean / arithmetic_mean - -print("spectral_flatness: ", spectral_flatness) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt deleted file mode 100755 index 612ed2d..0000000 --- a/requirements.txt +++ /dev/null @@ -1,17 +0,0 @@ -h5py -pyqt5 -pyqtgraph -pyyaml -scipy -seaborn -tqdm -imageio -spectral -scikit-learn -matplotlib -opticalglass -pyyaml -snoop -#torch -#torch-geometric -#torch-cluster \ No newline at end of file diff --git a/simple_script_optim_testing.py b/simple_script_optim_testing.py deleted file mode 100644 index dbf8e1f..0000000 --- a/simple_script_optim_testing.py +++ /dev/null @@ -1,221 +0,0 @@ -from simca import load_yaml_config -from simca.CassiSystemOptim import CassiSystemOptim -from simca.CassiSystem import CassiSystem -import numpy as np -import snoop -import matplotlib.pyplot as plt -#import matplotlib -import torch -import time -from pprint import pprint -from simca.cost_functions import evaluate_slit_scanning_straightness, evaluate_center, evaluate_mean_lighting, evaluate_max_lighting -from simca.functions_optim import optim_smile, optim_width - -#matplotlib.use('Agg') -config_dataset = load_yaml_config("simca/configs/dataset.yml") -config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_max_center.yml") -config_patterns = load_yaml_config("simca/configs/pattern.yml") -config_acquisition = load_yaml_config("simca/configs/acquisition.yml") - -dataset_name = "indian_pines" - -test = "SMILE" - -algo = "LBFGS" - -if test=="SMILE": - config_system = load_yaml_config("simca/configs/cassi_system_simple_optim.yml") - aspect = 1 -elif test=="EQUAL_LIGHT" or test=="MAX_CENTER": - config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_max_center.yml") - aspect = 1 -elif test == "MAX_LIGHT": - config_system = load_yaml_config("simca/configs/cassi_system_simple_optim.yml") - aspect = 1 -elif test == "SMILE_mono": - config_system = load_yaml_config("simca/configs/cassi_system_simple_optim_smile_mono.yml") - aspect = 1 - - -if __name__ == '__main__': - time_start = time.time() - # Initialize the CASSI system - cassi_system = CassiSystemOptim(system_config=config_system) - - # time0 = time.time() - # DATASET : Load the hyperspectral dataset - cassi_system.load_dataset(dataset_name, config_dataset["datasets directory"]) - - # Loop beginning if optics optim. - cassi_system.update_optical_model(system_config=config_system) - X_vec_out, Y_vec_out = cassi_system.propagate_coded_aperture_grid() - sigma = 0.75 - - num_iterations = 3000 # Define num_iterations as needed - - """ convergence_counter = 0 # Counter to check convergence - max_cnt = 100 - min_cost_value = np.inf - - if test == "EQUAL_LIGHT": - cassi_system.generate_custom_pattern_parameters_slit_width(nb_slits=2, nb_rows=2, start_width = 1) - elif test == "MAX_LIGHT": - cassi_system.generate_custom_pattern_parameters_slit_width(nb_slits=1, nb_rows=cassi_system.system_config["detector"]["number of pixels along Y"], start_width = sigma) - elif (test == "SMILE") or (test == "SMILE_mono"): - cassi_system.generate_custom_pattern_parameters_slit(position=0.5) - - # Ensure array_x_positions is a tensor with gradient tracking - cassi_system.array_x_positions = cassi_system.array_x_positions.clone().detach().requires_grad_(True) - - # Define the optimizer - lr = 0.005 # default: 0.005 - optimizer = torch.optim.Adam([cassi_system.array_x_positions], lr=lr) # Adjust the learning rate as needed - - # Main optimization loop - for iteration in range(num_iterations): # Define num_iterations as needed - optimizer.zero_grad() # Clear previous gradients - if test == "EQUAL_LIGHT": - pattern = cassi_system.generate_custom_slit_pattern_width(start_pattern = "line", start_position = 0) - elif test == "MAX_LIGHT": - start_position = torch.tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, - 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.4923, - 0.4859, 0.4915, 0.4934, 0.4972, 0.4996, 0.5003, 0.5009, 0.5013, 0.5033, - 0.5041, 0.5064, 0.5068, 0.5078, 0.5070, 0.5099, 0.5103, 0.5146, 0.5145, - 0.5146, 0.5152, 0.5173, 0.5195, 0.5215, 0.5208, 0.5208, 0.5222, 0.5247, - 0.5272, 0.5287, 0.5285, 0.5240, 0.5283, 0.5282, 0.5288, 0.5288, 0.5291, - 0.5289, 0.5282, 0.5334, 0.5314, 0.5324, 0.5370, 0.5323, 0.5322, 0.5341, - 0.5329, 0.5361, 0.5364, 0.5346, 0.5333, 0.5340, 0.5333, 0.5339, 0.5345, - 0.5359, 0.5349, 0.5364, 0.5344, 0.5341, 0.5346, 0.5353, 0.5345, 0.5347, - 0.5346, 0.5362, 0.5363, 0.5330, 0.5321, 0.5323, 0.5299, 0.5315, 0.5318, - 0.5298, 0.5291, 0.5291, 0.5298, 0.5292, 0.5256, 0.5270, 0.5283, 0.5268, - 0.5255, 0.5245, 0.5200, 0.5205, 0.5207, 0.5190, 0.5188, 0.5144, 0.5122, - 0.5138, 0.5133, 0.5131, 0.5122, 0.5111, 0.5156, 0.5118, 0.5091, 0.5077, - 0.5068, 0.5030, 0.5003, 0.5000, 0.4992, 0.4970, 0.4968, 0.4947, 0.4949, - 0.4935, 0.4983, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, - 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, - 0.5000]) - pattern = cassi_system.generate_custom_slit_pattern_width(start_pattern = "corrected", start_position = start_position) - else: - pattern = cassi_system.generate_custom_slit_pattern() - #print(pattern[:, pattern.shape[1]//2-4:pattern.shape[1]//2+4]) - cassi_system.generate_filtering_cube() - pos_slit = 0.4625 - pos_slit = 0.41 - if (test == "SMILE"): - cost_value = evaluate_slit_scanning_straightness(cassi_system.filtering_cube, sigma = sigma, pos_slit = pos_slit) - elif (test == "SMILE_mono"): - cost_value = evaluate_slit_scanning_straightness(cassi_system.filtering_cube, sigma = sigma, pos_slit = pos_slit) - cassi_system.image_acquisition(use_psf = False, chunck_size = 50) - elif test == "MAX_CENTER": - cassi_system.image_acquisition(use_psf=False, chunck_size=50) - cost_value = evaluate_center(cassi_system.measurement) - elif test == "EQUAL_LIGHT": - cassi_system.image_acquisition(use_psf=False, chunck_size=50) - cost_value = evaluate_mean_lighting(cassi_system.measurement) - elif test == "MAX_LIGHT": - cassi_system.image_acquisition(use_psf=False, chunck_size=50) - cost_value = evaluate_max_lighting(cassi_system.measurement, pos_slit) - - if cost_value < min_cost_value: - min_cost_value = cost_value - convergence_counter = 0 - best_x = cassi_system.array_x_positions.clone().detach() - else: - convergence_counter+=1 - - if (iteration >= 50) and (convergence_counter >= max_cnt): # If loss didn't decrease in 25 steps, break - break - - cost_value.backward() # Perform backpropagation - # print("Gradients after backward:", cassi_system.array_x_positions.grad) - optimizer.step() # Update x positions - cassi_system.array_x_positions.data = torch.relu(cassi_system.array_x_positions.data) # Prevent the parameters to be negative - # Optional: Print cost_value every few iterations to monitor progress - if iteration % 5 == 0: # Adjust printing frequency as needed - print(f"Iteration {iteration}, Cost: {cost_value.item()}") - - if iteration % 200 == 0: - print(f"Exec time: {time.time() - time_start}s") - plt.imshow(pattern.detach().numpy(), aspect=aspect) - plt.show() - - plt.imshow(cassi_system.filtering_cube[:, :, 0].detach().numpy(), aspect=aspect) - plt.show() - - plt.plot(np.sum(cassi_system.filtering_cube[:, :, 0].detach().numpy(),axis=0)) - plt.show() - - if (test=="MAX_CENTER") or (test == "EQUAL_LIGHT") or (test == "SMILE_mono") or (test == "MAX_LIGHT"): - #plt.imshow(cassi_system.panchro.detach().numpy(), zorder=3) - plt.imshow(cassi_system.measurement.detach().numpy(), zorder=5) - plt.show() - #plt.imshow(cassi_system.panchro.detach().numpy()) - #plt.show() - cassi_system.array_x_positions.data = torch.relu(best_x.data) - - if test == "EQUAL_LIGHT": - pattern = cassi_system.generate_custom_slit_pattern_width(start_pattern = "line", start_position = 0) - elif test == "MAX_LIGHT": - pattern = cassi_system.generate_custom_slit_pattern_width(start_pattern = "corrected", start_position = start_position) - else: - pattern = cassi_system.generate_custom_slit_pattern() - cassi_system.generate_filtering_cube() - print(cassi_system.array_x_positions) - #print(torch.std(cassi_system.measurement.detach())) - - print(f"Min cost: {min_cost_value}") - print(f"Exec time: {time.time() - time_start}s") - plt.imshow(pattern.detach().numpy(), aspect=aspect) - plt.show() - - plt.imshow(cassi_system.filtering_cube[:, :, 0].detach().numpy(), aspect=aspect) - plt.show() - - plt.plot(np.sum(cassi_system.filtering_cube[:, :, 0].detach().numpy(),axis=0)) - plt.show() - - if (test=="MAX_CENTER") or (test == "EQUAL_LIGHT") or (test == "SMILE_mono") or (test == "MAX_LIGHT"): - plt.imshow(cassi_system.measurement.detach().numpy()) - plt.show()""" - algo = "ADAM" - if algo == "LBFGS": - lr = 0.002 # default: 0.05 - elif algo == "ADAM": - lr = 0.005 # default: 0.005 - - max_cnt = 25 - - #cassi_system = optim_smile(cassi_system, 0.5, 0.41, sigma, 'cpu', algo, lr, num_iterations, max_cnt, plot_frequency=200) - start_position = torch.tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, - 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.4923, - 0.4859, 0.4915, 0.4934, 0.4972, 0.4996, 0.5003, 0.5009, 0.5013, 0.5033, - 0.5041, 0.5064, 0.5068, 0.5078, 0.5070, 0.5099, 0.5103, 0.5146, 0.5145, - 0.5146, 0.5152, 0.5173, 0.5195, 0.5215, 0.5208, 0.5208, 0.5222, 0.5247, - 0.5272, 0.5287, 0.5285, 0.5240, 0.5283, 0.5282, 0.5288, 0.5288, 0.5291, - 0.5289, 0.5282, 0.5334, 0.5314, 0.5324, 0.5370, 0.5323, 0.5322, 0.5341, - 0.5329, 0.5361, 0.5364, 0.5346, 0.5333, 0.5340, 0.5333, 0.5339, 0.5345, - 0.5359, 0.5349, 0.5364, 0.5344, 0.5341, 0.5346, 0.5353, 0.5345, 0.5347, - 0.5346, 0.5362, 0.5363, 0.5330, 0.5321, 0.5323, 0.5299, 0.5315, 0.5318, - 0.5298, 0.5291, 0.5291, 0.5298, 0.5292, 0.5256, 0.5270, 0.5283, 0.5268, - 0.5255, 0.5245, 0.5200, 0.5205, 0.5207, 0.5190, 0.5188, 0.5144, 0.5122, - 0.5138, 0.5133, 0.5131, 0.5122, 0.5111, 0.5156, 0.5118, 0.5091, 0.5077, - 0.5068, 0.5030, 0.5003, 0.5000, 0.4992, 0.4970, 0.4968, 0.4947, 0.4949, - 0.4935, 0.4983, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, - 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, - 0.5000]) - cassi_system = optim_width(cassi_system, start_position, 0.41, cassi_system.system_config["detector"]["number of pixels along Y"], sigma, 'cpu', algo, lr, num_iterations, max_cnt, plot_frequency = 200) - - print(f"Exec time: {time.time() - time_start:.3f}s") - print(cassi_system.array_x_positions) - plt.imshow(cassi_system.pattern.detach().numpy(), aspect=aspect) - plt.show() - - plt.imshow(cassi_system.filtering_cube[:, :, 0].detach().numpy(), aspect=aspect) - plt.show() - - plt.plot(np.sum(cassi_system.filtering_cube[:, :, 0].detach().numpy(),axis=0)) - plt.show() - - if (test=="MAX_CENTER") or (test == "EQUAL_LIGHT") or (test == "SMILE_mono") or (test == "MAX_LIGHT"): - plt.imshow(cassi_system.measurement.detach().numpy()) - plt.show()