Skip to content

Commit

Permalink
Dual Reference working for non-orthonormal basis
Browse files Browse the repository at this point in the history
  • Loading branch information
dedean16 committed Oct 2, 2024
1 parent 7f5d768 commit 7727b33
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 20 deletions.
2 changes: 1 addition & 1 deletion openwfs/algorithms/dual_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _compute_cobasis(self):
B = np.asmatrix((phase_factor * amplitude_factor).reshape((p, m))) # Basis matrix
self._gram = B.H @ B
B_pinv = np.linalg.inv(self.gram) @ B.H # Moore-Penrose pseudo-inverse
cobasis[side] = np.asarray(B_pinv).reshape(self.phase_patterns[side].shape)
cobasis[side] = np.asarray(B_pinv.T).reshape(self.phase_patterns[side].shape)

self._cobasis = cobasis

Expand Down
61 changes: 42 additions & 19 deletions tests/test_wfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,14 +375,14 @@ def test_multidimensional_feedback_fourier():
def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape):
"""Test custom blind dual reference with an orthonormal phase-only basis.
Two types of bases are tested: plane waves and Hadamard"""
do_debug = True
do_debug = False
N = shape[0] * (shape[1] // 2)
modes_shape = (shape[0], shape[1] // 2, N)
if basis_str == "plane_wave":
# Create a full plane wave basis for one half of the SLM.
modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1))
modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) / np.sqrt(N)
elif basis_str == 'hadamard':
modes = hadamard(N).reshape(modes_shape)
modes = hadamard(N).reshape(modes_shape) / np.sqrt(N)
else:
raise f'Unknown type of basis "{basis_str}".'

Expand All @@ -400,11 +400,12 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape):
plt.figure(figsize=(12, 7))
for m in range(N):
plt.subplot(*modes_shape[0:2], m + 1)
plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi)
plot_field(mode_set[:, :, m])
plt.title(f"m={m}")
plt.xticks([])
plt.yticks([])
plt.pause(0.1)
plt.suptitle('Basis')
plt.pause(0.01)

# Create aberrations
sim = SimulatedWFS(t=random_transmission_matrix(shape))
Expand All @@ -418,14 +419,17 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape):
iterations=4,
)

assert np.allclose(alg.gram, np.eye(N), atol=1e-6)

for m in range(N):
alg.cobasis

result = alg.execute()

if do_debug:
plt.figure()
for m in range(N):
plt.subplot(*modes_shape[0:2], m + 1)
plot_field(alg.cobasis[0][:, :, m])
plt.title(f'{m}')
plt.suptitle('Cobasis')
plt.pause(0.01)

plt.figure()
plt.imshow(np.angle(sim.t), vmin=-np.pi, vmax=np.pi, cmap="hsv")
plt.title("Aberrations")
Expand All @@ -436,22 +440,25 @@ def test_custom_blind_dual_reference_ortho_split(basis_str: str, shape):
plt.colorbar()
plt.show()

assert (
np.abs(field_correlation(sim.t, result.t)) > 0.99
) # todo: find out why this is not higher
# Checks for orthonormal bases
assert np.allclose(alg.gram, np.eye(N), atol=1e-6) # Gram matrix must be I
assert np.allclose(alg.cobasis[0], mode_set.conj(), atol=1e-6) # Cobasis vectors are just the complex conjugates

# todo: find out why this is not higher
assert np.abs(field_correlation(sim.t, result.t)) > 0.95


def test_custom_blind_dual_reference_non_ortho():
"""
Test custom blind dual reference with a non-orthogonal basis.
"""
do_debug = True
do_debug = False

# Create set of modes that are barely linearly independent
N1 = 6
N2 = 3
M = N1 * N2
mode_set_half = (1 / M) * (1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M))) + (1/M)
mode_set_half = np.exp(2j*np.pi/3 * np.eye(M).reshape((N1, N2, M))) / np.sqrt(M)
mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1)
phases_set = np.angle(mode_set)
mask = np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1)
Expand Down Expand Up @@ -487,23 +494,39 @@ def test_custom_blind_dual_reference_non_ortho():
feedback=sim,
slm=sim.slm,
phase_patterns=(phases_set, np.flip(phases_set, axis=1)),
amplitude='ones',
amplitude='uniform',
group_mask=mask,
phase_steps=4,
iterations=4,
)

result = alg.execute()

aberration_field = np.exp(1j * aberrations)
t_field = np.exp(1j * np.angle(result.t))

if do_debug:
plt.figure()
for m in range(M):
plt.subplot(N2, N1, m + 1)
plot_field(alg.cobasis[0][:, :, m], scale=2)
plt.title(f'{m}')
plt.suptitle('Cobasis')
plt.pause(0.01)

plt.figure()
plt.imshow(abs(alg.gram), vmin=0, vmax=1)
plt.title('Gram matrix abs values')
plt.colorbar()

plt.figure()
plt.subplot(1, 2, 1)
plot_field(np.exp(1j * aberrations))
plot_field(aberration_field)
plt.title('Aberrations')

plt.subplot(1, 2, 2)
plot_field(result.t)
plot_field(t_field)
plt.title('t')
plt.show()

assert np.abs(field_correlation(np.exp(1j * aberrations), result.t)) > 0.999
assert np.abs(field_correlation(aberration_field, t_field)) > 0.999

0 comments on commit 7727b33

Please sign in to comment.