diff --git a/optimization_modules.py b/optimization_modules.py index 3dbb27a..9a4d291 100644 --- a/optimization_modules.py +++ b/optimization_modules.py @@ -25,7 +25,13 @@ def __init__(self, model_name): def on_validation_start(self,stage=None): print("---VALIDATION START---") self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml" - # self.config = "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi_shifted.yml" + self.shift_bool = True + if self.shift_bool: + self.crop_value_left = 8 + self.crop_value_right = 8 + else: + self.crop_value_left = 8 + self.crop_value_right = 8 config_system = load_yaml_config(self.config) self.config_patterns = load_yaml_config("simca/configs/pattern.yml") self.cassi_system = CassiSystemOptim(system_config=config_system) @@ -35,8 +41,12 @@ def forward(self, x): print("---FORWARD---") hyperspectral_cube, wavelengths = x - hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1).to(self.device) + hyperspectral_cube = hyperspectral_cube.permute(0, 2, 3, 1).to(self.device) batch_size, H, W, C = hyperspectral_cube.shape + fig, ax = plt.subplots(1, 1) + plt.title(f"entry cube") + ax.imshow(hyperspectral_cube[0, :, :, 0].cpu().detach().numpy()) + plt.show() # print(f"batch size:{batch_size}") # generate pattern pattern = self.cassi_system.generate_2D_pattern(self.config_patterns,nb_of_patterns=batch_size) @@ -61,23 +71,25 @@ def forward(self, x): # process first acquisition with reconstruction model # TODO : replace by the real reconstruction model - if self.config == "simca/configs/cassi_system_optim_optics_full_triplet_sd_cassi.yml": + if not self.shift_bool: acquired_cubes = acquired_image1.unsqueeze(1).repeat((1, 28, 1, 1)).float().to(self.device) # b x W x R x C acquired_cubes = torch.flip(acquired_cubes, dims=(2,3)) # -1 magnification + fig, ax = plt.subplots(1, 2) + plt.title(f"true cube cropped vs measurement") + ax[0].imshow(hyperspectral_cube[0, self.crop_value_left:-self.crop_value_right, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy()) + ax[1].imshow(acquired_cubes[0, 0, :, :].cpu().detach().numpy()) + plt.show() reconstructed_cube = self.reconstruction_model(acquired_cubes, filtering_cubes) else: mask_3d = expand_mask_3d(pattern.flip(dims=(1, 2))).float().to(self.device) - shifted_image = shift_back(acquired_image1.flip(dims=(1, 2)), displacement_in_pix).float().to(self.device) - - - for i in range(10): + shifted_image = self.shift_back(acquired_image1.flip(dims=(1, 2)), displacement_in_pix).float().to(self.device) - fig,ax = plt.subplots(1,2) - # plt.title(f"acquired_cubes {i}") - ax[0].imshow(shifted_image[0, i, :, :].cpu().detach().numpy()) - ax[1].imshow(filtering_cubes[0, i, :, :].cpu().detach().numpy()) - plt.show() + fig,ax = plt.subplots(1,2) + plt.title(f"true cube cropped vs measurement") + ax[0].imshow(hyperspectral_cube[0, 8:-8, self.crop_value_left:-self.crop_value_right, 0].cpu().detach().numpy()) + ax[1].imshow(shifted_image[0, 0, :, :].cpu().detach().numpy()) + plt.show() reconstructed_cube = self.reconstruction_model(shifted_image, mask_3d) @@ -144,11 +156,12 @@ def _common_step(self, batch, batch_idx): hyperspectral_cube, wavelengths = batch #hyperspectral_cube = hyperspectral_cube.permute(0, 3, 2, 1) - hyperspectral_cube = hyperspectral_cube[:,:, 8:-8, 8:-8] + hyperspectral_cube = hyperspectral_cube[:,:, self.crop_value_left:-self.crop_value_right, self.crop_value_left:-self.crop_value_right] - fig, ax = plt.subplots(1, 1) - # plt.title(f"acquired_cubes {i}") - ax.imshow(hyperspectral_cube[0, 0, :, :].cpu().detach().numpy()) + fig, ax = plt.subplots(1, 2) + plt.title(f"cubes") + ax[0].imshow(hyperspectral_cube[0, 0, :, :].cpu().detach().numpy()) + ax[1].imshow(y_hat[0, 0, :, :].cpu().detach().numpy()) plt.show() #print("y_hat shape", y_hat.shape) @@ -161,6 +174,23 @@ def _common_step(self, batch, batch_idx): def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=4e-4) return optimizer + + def shift_back(self, inputs, d): # input [bs,256,310], [bs, 28] output [bs, 28, 256, 256] + [bs, row, col] = inputs.shape + nC = 28 + d = d[0] + d -= d.min() + self.crop_value_right += int(np.round(d.max())) + output = torch.zeros(bs, nC, row, col - int(np.round(d.max()))).float().to(self.device) + for i in range(nC): + shift = int(np.round(d[i])) + #output[:, i, :, :] = inputs[:, :, step * i:step * i + col - 27 * step] step = 2 + # if shift >=0: + # output[:, i, :, :] = inputs[:, :, shift:row+shift] + # else: + # output[:, i, :, :] = inputs[:, :, shift-row:shift] + output[:, i, :, :] = inputs[:, :, shift:shift + col - int(np.round(d.max()))] + return output def subsample(input, origin_sampling, target_sampling): [bs, row, col, nC] = input.shape @@ -178,22 +208,6 @@ def expand_mask_3d(mask_batch): mask3d = mask_batch.repeat((1, 1, 1, 28)) mask3d = torch.permute(mask3d, (0, 3, 1, 2)) return mask3d - -def shift_back(inputs, d): # input [bs,256,310], [bs, 28] output [bs, 28, 256, 256] - [bs, row, col] = inputs.shape - nC = 28 - output = torch.zeros(bs, nC, row, row).cuda().float() - d = d[0] - d -= d.min() - for i in range(nC): - shift = int(np.round(d[i])) - #output[:, i, :, :] = inputs[:, :, step * i:step * i + col - 27 * step] step = 2 - # if shift >=0: - # output[:, i, :, :] = inputs[:, :, shift:row+shift] - # else: - # output[:, i, :, :] = inputs[:, :, shift-row:shift] - output[:, i, :, :] = inputs[:, :, shift:shift + row] - return output class EmptyModule(nn.Module): def __init__(self): diff --git a/simca/CassiSystem_lightning.py b/simca/CassiSystem_lightning.py index 81d76ee..c6eb4b2 100755 --- a/simca/CassiSystem_lightning.py +++ b/simca/CassiSystem_lightning.py @@ -308,10 +308,6 @@ def image_acquisition(self, hyperspectral_cube, pattern,wavelengths,use_psf=Fals self.last_filtered_interpolated_scene = sd_measurement self.interpolated_scene = scene - for i in range(5): - plt.imshow(sd_measurement[0,:,:,i*8].cpu().numpy()) - plt.title("SD measurement") - plt.show() if dataset_labels is not None: scene_labels = torch.from_numpy(match_dataset_labels_to_instrument(dataset_labels, self.last_filtered_interpolated_scene)) diff --git a/training_simca_reconstruction.py b/training_simca_reconstruction.py index 26a295f..07f3344 100644 --- a/training_simca_reconstruction.py +++ b/training_simca_reconstruction.py @@ -6,7 +6,7 @@ data_dir = "./datasets_reconstruction/mst_datasets/cave_1024_28" #data_dir = "/local/users/ademaio/lpaillet/mst_datasets/cave_1024_28" -datamodule = CubesDataModule(data_dir, batch_size=5, num_workers=11) +datamodule = CubesDataModule(data_dir, batch_size=2, num_workers=5) name = "testing_simca_reconstruction"