diff --git a/configs/benchmark_diffusercam_mirflickr.yaml b/configs/benchmark_diffusercam_mirflickr.yaml new file mode 100644 index 00000000..bf07a777 --- /dev/null +++ b/configs/benchmark_diffusercam_mirflickr.yaml @@ -0,0 +1,47 @@ +# python scripts/eval/benchmark_recon.py -cn benchmark_diffusercam_mirflickr +defaults: + - benchmark + - _self_ + +dataset: HFDataset +batchsize: 4 +device: "cuda:3" + +huggingface: + repo: "bezzam/DiffuserCam-Lensless-Mirflickr-Dataset-NORM" + psf: psf.tiff + image_res: null + rotate: False # if measurement is upside-down + alignment: null + downsample: 2 + downsample_lensed: 2 + flipud: True + flip_lensed: True + single_channel_psf: True + +algorithms: [ + # "ADMM", + + # ## - -- reconstructions trained on DiffuserCam + # "hf:diffusercam:mirflickr:U5+Unet8M", + # "hf:diffusercam:mirflickr:TrainInv+Unet8M", + # "hf:diffusercam:mirflickr:MMCN4M+Unet4M", + # "hf:diffusercam:mirflickr:MWDN8M", + # "hf:diffusercam:mirflickr:Unet4M+U5+Unet4M", + # "hf:diffusercam:mirflickr:Unet4M+TrainInv+Unet4M", + # "hf:diffusercam:mirflickr:Unet2M+MMCN+Unet2M", + # "hf:diffusercam:mirflickr:Unet2M+MWDN6M", + # "hf:diffusercam:mirflickr:Unet4M+U10+Unet4M", + + # ## - -- reconstructions trained on other datasets/systems + # # "hf:tapecam:mirflickr:Unet4M+U10+Unet4M", + # # "hf:digicam:mirflickr_single_25k:Unet4M+U5+Unet4M_wave", + # # "hf:digicam:celeba_26k:Unet4M+U5+Unet4M_wave", + # "hf:tapecam:mirflickr:Unet4M+U5+Unet4M", + # "hf:digicam:mirflickr_single_25k:Unet4M+U10+Unet4M_wave", + "hf:tapecam:mirflickr:Unet4M+U5+Unet4M_flips", +] + +save_idx: [0, 1, 3, 4, 8] +n_iter_range: [100] # for ADMM + diff --git a/lensless/recon/model_dict.py b/lensless/recon/model_dict.py index 8a0a21d9..99205e37 100644 --- a/lensless/recon/model_dict.py +++ b/lensless/recon/model_dict.py @@ -142,6 +142,7 @@ "Unet2M+MMCN+Unet2M": "bezzam/tapecam-mirflickr-unet2M-mmcn-unet2M", "Unet2M+MWDN6M": "bezzam/tapecam-mirflickr-unet2M-mwdn-6M", "Unet4M+U10+Unet4M": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm10-unet4M", + "Unet4M+U5+Unet4M_flips": "bezzam/tapecam-mirflickr-unet4M-unrolled-admm5-unet4M-flips", }, }, } diff --git a/scripts/recon/train_learning_based.py b/scripts/recon/train_learning_based.py index d923b9a6..7992fbbb 100644 --- a/scripts/recon/train_learning_based.py +++ b/scripts/recon/train_learning_based.py @@ -334,7 +334,7 @@ def train_learned(config): lensed = rotate_HWC(lensed, rotate_angle) psf_recon = rotate_HWC(psf_recon, rotate_angle) - save_image(psf_recon[0].numpy(), f"psf_{_idx}.png") + save_image(psf_recon[0].cpu().numpy(), f"psf_{_idx}.png") recon = ADMM(psf_recon)