diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index 409567b..9110365 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -211,7 +211,7 @@ def _compute_cobasis(self): B = np.asmatrix((phase_factor * amplitude_factor).reshape((p, m))) # Basis matrix self._gram = B.H @ B B_pinv = np.linalg.inv(self.gram) @ B.H # Moore-Penrose pseudo-inverse - cobasis[side] = np.asarray(B_pinv).reshape(self.phase_patterns[side].shape) + cobasis[side] = np.asarray(B_pinv.T).reshape(self.phase_patterns[side].shape) self._cobasis = cobasis diff --git a/tests/test_wfs.py b/tests/test_wfs.py index d2f8627..3949b35 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -375,14 +375,14 @@ def test_multidimensional_feedback_fourier(): def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): """Test custom blind dual reference with an orthonormal phase-only basis. Two types of bases are tested: plane waves and Hadamard""" - do_debug = True + do_debug = False N = shape[0] * (shape[1] // 2) modes_shape = (shape[0], shape[1] // 2, N) if basis_str == "plane_wave": # Create a full plane wave basis for one half of the SLM. - modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) + modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) / np.sqrt(N) elif basis_str == 'hadamard': - modes = hadamard(N).reshape(modes_shape) + modes = hadamard(N).reshape(modes_shape) / np.sqrt(N) else: raise f'Unknown type of basis "{basis_str}".' @@ -400,11 +400,12 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): plt.figure(figsize=(12, 7)) for m in range(N): plt.subplot(*modes_shape[0:2], m + 1) - plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi) + plot_field(mode_set[:, :, m]) plt.title(f"m={m}") plt.xticks([]) plt.yticks([]) - plt.pause(0.1) + plt.suptitle('Basis') + plt.pause(0.01) # Create aberrations sim = SimulatedWFS(t=random_transmission_matrix(shape)) @@ -418,14 +419,17 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): iterations=4, ) - assert np.allclose(alg.gram, np.eye(N), atol=1e-6) - - for m in range(N): - alg.cobasis - result = alg.execute() if do_debug: + plt.figure() + for m in range(N): + plt.subplot(*modes_shape[0:2], m + 1) + plot_field(alg.cobasis[0][:, :, m]) + plt.title(f'{m}') + plt.suptitle('Cobasis') + plt.pause(0.01) + plt.figure() plt.imshow(np.angle(sim.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") plt.title("Aberrations") @@ -436,22 +440,25 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape): plt.colorbar() plt.show() - assert ( - np.abs(field_correlation(sim.t, result.t)) > 0.99 - ) # todo: find out why this is not higher + # Checks for orthonormal bases + assert np.allclose(alg.gram, np.eye(N), atol=1e-6) # Gram matrix must be I + assert np.allclose(alg.cobasis[0], mode_set.conj(), atol=1e-6) # Cobasis vectors are just the complex conjugates + + # todo: find out why this is not higher + assert np.abs(field_correlation(sim.t, result.t)) > 0.95 def test_custom_blind_dual_reference_non_ortho(): """ Test custom blind dual reference with a non-orthogonal basis. """ - do_debug = True + do_debug = False # Create set of modes that are barely linearly independent N1 = 6 N2 = 3 M = N1 * N2 - mode_set_half = (1 / M) * (1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M))) + (1/M) + mode_set_half = np.exp(2j*np.pi/3 * np.eye(M).reshape((N1, N2, M))) / np.sqrt(M) mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1) phases_set = np.angle(mode_set) mask = np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1) @@ -487,7 +494,7 @@ def test_custom_blind_dual_reference_non_ortho(): feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), - amplitude='ones', + amplitude='uniform', group_mask=mask, phase_steps=4, iterations=4, @@ -495,15 +502,31 @@ def test_custom_blind_dual_reference_non_ortho(): result = alg.execute() + aberration_field = np.exp(1j * aberrations) + t_field = np.exp(1j * np.angle(result.t)) + if do_debug: + plt.figure() + for m in range(M): + plt.subplot(N2, N1, m + 1) + plot_field(alg.cobasis[0][:, :, m], scale=2) + plt.title(f'{m}') + plt.suptitle('Cobasis') + plt.pause(0.01) + + plt.figure() + plt.imshow(abs(alg.gram), vmin=0, vmax=1) + plt.title('Gram matrix abs values') + plt.colorbar() + plt.figure() plt.subplot(1, 2, 1) - plot_field(np.exp(1j * aberrations)) + plot_field(aberration_field) plt.title('Aberrations') plt.subplot(1, 2, 2) - plot_field(result.t) + plot_field(t_field) plt.title('t') plt.show() - assert np.abs(field_correlation(np.exp(1j * aberrations), result.t)) > 0.999 + assert np.abs(field_correlation(aberration_field, t_field)) > 0.999