From 12c63f266aa6c7f6837293b8c4130045a87de1c9 Mon Sep 17 00:00:00 2001 From: Ahmed Date: Mon, 27 Nov 2023 09:46:15 +0100 Subject: [PATCH] height varying tested --- lensless/hardware/mask.py | 52 ++++++++++++++++++++------------------- test/test_masks.py | 14 ++++++++++- 2 files changed, 40 insertions(+), 26 deletions(-) diff --git a/lensless/hardware/mask.py b/lensless/hardware/mask.py index 416f71b0..4bbeef09 100644 --- a/lensless/hardware/mask.py +++ b/lensless/hardware/mask.py @@ -401,7 +401,7 @@ def phase_retrieval(target_psf, wv, d1, dz, n=1.2, n_iter=10, height_map=False): Target PSF to optimize the phase mask for. wv: float Wavelength (m). - d1: float + d1: float= Sample period on the sensor i.e. pixel size (m). dz: float Propagation distance between the mask and the sensor. @@ -473,16 +473,7 @@ def create_mask(self): class HeightVarying(Mask): - - def __init__( - self, - refractive_index = 1.2, - wavelength = 532e-9, - height_map = None, - height_range = (min, max), - seed = 0): - - """ + """ A class representing a height-varying mask for lensless imaging. Parameters @@ -512,26 +503,37 @@ def __init__( ... seed=42 ... ) """ - - super().__init__() + def __init__( + self, + refractive_index = 1.2, + wavelength = 532e-9, + height_map = None, + height_range = (1e-3, 1e-2), + seed = 0, + **kwargs): + + self.refractive_index = refractive_index self.wavelength = wavelength + self.height_range = height_range + self.seed = seed - if self.height_map is not None: + if height_map is not None: self.height_map = height_map else: + self.height_map = None np.random.seed(self.seed) - self.height_map = np.random.uniform(self.height_range[0], self.height_range[1], n) - - self.height_range = height_range - self.seed = seed - - def get_phi(self, n): - phi = self.height_map * (2*np.pi*(n-1) / self.wavelength) + + super().__init__(**kwargs) + + def get_phi(self): + phi = self.height_map * (2*np.pi*(self.refractive_index-1) / self.wavelength) phi = phi % (2*np.pi) return phi - def create_mask(self, n): - phase_mask = self.get_phi(n) - mask = np.exp(1j * self.get_phi(n)) - return mask \ No newline at end of file + def create_mask(self): + if self.height_map is None: + self.height_map = np.random.uniform(self.height_range[0], self.height_range[1], self.resolution) + assert self.height_map.shape == tuple(self.resolution) + phase_mask = self.get_phi() + self.mask = np.exp(1j * phase_mask) \ No newline at end of file diff --git a/test/test_masks.py b/test/test_masks.py index a16659d6..7c52bb46 100644 --- a/test/test_masks.py +++ b/test/test_masks.py @@ -1,7 +1,8 @@ import numpy as np -from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture +from lensless.hardware.mask import CodedAperture, PhaseContour, FresnelZoneAperture, HeightVarying from lensless.eval.metric import mse, psnr, ssim from waveprop.fresnel import fresnel_conv +from matplotlib import pyplot as plt resolution = np.array([380, 507]) @@ -90,6 +91,17 @@ def test_classmethod(): desired_psf_shape = np.array(tuple(resolution) + (len(mask3.psf_wavelength),)) assert np.all(mask3.psf.shape == desired_psf_shape) + mask5 = HeightVarying.from_sensor( + sensor_name="rpi_hq", downsample=downsample, distance_sensor=dz + ) + assert np.all(mask5.mask.shape == resolution) + desired_psf_shape = np.array(tuple(resolution) + (len(mask5.psf_wavelength),)) + assert np.all(mask5.psf.shape == desired_psf_shape) + fig, ax = plt.subplots() + im = ax.imshow(np.angle(mask5.mask), cmap="gray") + fig.colorbar(im, ax=ax, shrink=0.5, aspect=5) + plt.show() + if __name__ == "__main__": test_flatcam()