diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index 8191176..f2fc8f5 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -118,14 +118,10 @@ def phase_patterns(self, value): # find the modes in A and B that correspond to flat wavefronts with phase 0 try: a0_index = next( - i - for i in range(value[0].shape[2]) - if np.allclose(value[0][:, :, i], 0) + i for i in range(value[0].shape[2]) if np.allclose(value[0][:, :, i], 0) ) b0_index = next( - i - for i in range(value[1].shape[2]) - if np.allclose(value[1][:, :, i], 0) + i for i in range(value[1].shape[2]) if np.allclose(value[1][:, :, i], 0) ) self.zero_indices = (a0_index, b0_index) except StopIteration: @@ -134,9 +130,7 @@ def phase_patterns(self, value): ) if (value[0].shape[0:2] != self._shape) or (value[1].shape[0:2] != self._shape): - raise ValueError( - "The phase patterns and group mask must all have the same shape." - ) + raise ValueError("The phase patterns and group mask must all have the same shape.") self._phase_patterns = ( value[0].astype(np.float32), @@ -162,8 +156,7 @@ def execute( # Current estimate of the transmission matrix (start with all 0) cobasis = [ - np.exp(-1j * self.phase_patterns[side]) - * np.expand_dims(self.masks[side], axis=2) + np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in range(2) ] @@ -200,9 +193,7 @@ def execute( if self.optimized_reference: # use the best estimate so far to construct an optimized reference - t_this_side = self.compute_t_set( - results_all[it].t, cobasis[side] - ).squeeze() + t_this_side = self.compute_t_set(results_all[it].t, cobasis[side]).squeeze() ref_phases[self.masks[side]] = -np.angle(t_this_side[self.masks[side]]) # Try full pattern @@ -220,9 +211,7 @@ def execute( relative = results_all[0].t[self.zero_indices[0], ...] + np.conjugate( results_all[1].t[self.zero_indices[1], ...] ) - factor = (relative / np.abs(relative)).reshape( - (1, *self.feedback.data_shape) - ) + factor = (relative / np.abs(relative)).reshape((1, *self.feedback.data_shape)) t_full = self.compute_t_set(results_all[0].t, cobasis[0]) + self.compute_t_set( factor * results_all[1].t, cobasis[1] @@ -258,9 +247,7 @@ def _single_side_experiment( WFSResult: An object containing the computed SLM transmission matrix and related data. """ num_modes = mod_phases.shape[2] - measurements = np.zeros( - (num_modes, self.phase_steps, *self.feedback.data_shape) - ) + measurements = np.zeros((num_modes, self.phase_steps, *self.feedback.data_shape)) for m in range(num_modes): phases = ref_phases.copy() diff --git a/tests/test_wfs.py b/tests/test_wfs.py index f6d8326..d04992d 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -80,7 +80,9 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): sim.slm.set_phases(-np.angle(result.t[:, :, b])) I_opt[b] = feedback.read()[b] t_correlation += abs(np.vdot(result.t[:, :, b], sim.t[:, :, b])) ** 2 - t_norm += abs(np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b])) + t_norm += abs( + np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b]) + ) t_correlation /= t_norm # a correlation of 1 means optimal reconstruction of the N modulated modes, which may be less than the total number of inputs in the transmission matrix @@ -95,11 +97,15 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): theoretical_t_correlation = theoretical_noise_fidelity * alg_fidelity estimated_t_correlation = result.fidelity_noise * result.fidelity_calibration * alg_fidelity tolerance = 2.0 / np.sqrt(M) - print(f"\nenhancement: \ttheoretical= {theoretical_enhancement},\testimated={estimated_enhancement},\tactual: {enhancement}") + print( + f"\nenhancement: \ttheoretical= {theoretical_enhancement},\testimated={estimated_enhancement},\tactual: {enhancement}" + ) print( f"t-matrix fidelity:\ttheoretical = {theoretical_t_correlation},\testimated = {estimated_t_correlation},\tactual = {t_correlation}" ) - print(f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}") + print( + f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}" + ) print(f"comparing at relative tolerance: {tolerance}") assert np.allclose( @@ -146,7 +152,9 @@ def test_fourier2(): slm_shape = (1000, 1000) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3) + alg = FourierDualReference( + feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3 + ) controller = WFSController(alg) controller.wavefront = WFSController.State.SHAPED_WAVEFRONT scaled_aberration = zoom(aberrations, np.array(slm_shape) / aberrations.shape) @@ -174,7 +182,9 @@ def test_fourier_microscope(): ) cam = sim.get_camera(analog_max=100) roi_detector = SingleRoi(cam, pos=(250, 250)) # Only measure that specific point - alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, phase_steps=3) + alg = FourierDualReference( + feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, phase_steps=3 + ) controller = WFSController(alg) controller.wavefront = WFSController.State.FLAT_WAVEFRONT before = roi_detector.read() @@ -274,7 +284,9 @@ def test_flat_wf_response_fourier(optimized_reference, step): # test the optimized wavefront by checking if it has irregularities. measured_aberrations = np.squeeze(np.angle(t)) measured_aberrations += aberrations[0, 0] - measured_aberrations[0, 0] - assert np.allclose(measured_aberrations, aberrations, atol=0.02) # The measured wavefront is not flat. + assert np.allclose( + measured_aberrations, aberrations, atol=0.02 + ) # The measured wavefront is not flat. def test_flat_wf_response_ssa(): @@ -292,7 +304,9 @@ def test_flat_wf_response_ssa(): # Assert that the standard deviation of the optimized wavefront is below the threshold, # indicating that it is effectively flat - assert np.std(optimised_wf) < 0.001, f"Response flat wavefront not flat, std: {np.std(optimised_wf)}" + assert ( + np.std(optimised_wf) < 0.001 + ), f"Response flat wavefront not flat, std: {np.std(optimised_wf)}" def test_multidimensional_feedback_ssa(): @@ -429,7 +443,9 @@ def test_custom_blind_dual_reference_ortho_split(type: str, shape): plt.colorbar() plt.show() - assert np.abs(field_correlation(sim.t, result.t)) > 0.99 # todo: find out why this is not higher + assert ( + np.abs(field_correlation(sim.t, result.t)) > 0.99 + ) # todo: find out why this is not higher def test_custom_blind_dual_reference_non_ortho(): @@ -464,7 +480,9 @@ def test_custom_blind_dual_reference_non_ortho(): # Create aberrations x = np.linspace(-1, 1, 1 * N1).reshape((1, -1)) y = np.linspace(-1, 1, 1 * N1).reshape((-1, 1)) - aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi) + aberrations = ( + np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y) + ) % (2 * np.pi) aberrations[0:1, :] = 0 aberrations[:, 0:2] = 0