Skip to content

Commit

Permalink
fixed formatting inconsistencies
Browse files Browse the repository at this point in the history
  • Loading branch information
IvoVellekoop committed Oct 2, 2024
1 parent 9462f3e commit b7c6e73
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 29 deletions.
27 changes: 7 additions & 20 deletions openwfs/algorithms/dual_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,10 @@ def phase_patterns(self, value):
# find the modes in A and B that correspond to flat wavefronts with phase 0
try:
a0_index = next(
i
for i in range(value[0].shape[2])
if np.allclose(value[0][:, :, i], 0)
i for i in range(value[0].shape[2]) if np.allclose(value[0][:, :, i], 0)
)
b0_index = next(
i
for i in range(value[1].shape[2])
if np.allclose(value[1][:, :, i], 0)
i for i in range(value[1].shape[2]) if np.allclose(value[1][:, :, i], 0)
)
self.zero_indices = (a0_index, b0_index)
except StopIteration:
Expand All @@ -134,9 +130,7 @@ def phase_patterns(self, value):
)

if (value[0].shape[0:2] != self._shape) or (value[1].shape[0:2] != self._shape):
raise ValueError(
"The phase patterns and group mask must all have the same shape."
)
raise ValueError("The phase patterns and group mask must all have the same shape.")

self._phase_patterns = (
value[0].astype(np.float32),
Expand All @@ -162,8 +156,7 @@ def execute(

# Current estimate of the transmission matrix (start with all 0)
cobasis = [
np.exp(-1j * self.phase_patterns[side])
* np.expand_dims(self.masks[side], axis=2)
np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2)
for side in range(2)
]

Expand Down Expand Up @@ -200,9 +193,7 @@ def execute(

if self.optimized_reference:
# use the best estimate so far to construct an optimized reference
t_this_side = self.compute_t_set(
results_all[it].t, cobasis[side]
).squeeze()
t_this_side = self.compute_t_set(results_all[it].t, cobasis[side]).squeeze()
ref_phases[self.masks[side]] = -np.angle(t_this_side[self.masks[side]])

# Try full pattern
Expand All @@ -220,9 +211,7 @@ def execute(
relative = results_all[0].t[self.zero_indices[0], ...] + np.conjugate(
results_all[1].t[self.zero_indices[1], ...]
)
factor = (relative / np.abs(relative)).reshape(
(1, *self.feedback.data_shape)
)
factor = (relative / np.abs(relative)).reshape((1, *self.feedback.data_shape))

t_full = self.compute_t_set(results_all[0].t, cobasis[0]) + self.compute_t_set(
factor * results_all[1].t, cobasis[1]
Expand Down Expand Up @@ -258,9 +247,7 @@ def _single_side_experiment(
WFSResult: An object containing the computed SLM transmission matrix and related data.
"""
num_modes = mod_phases.shape[2]
measurements = np.zeros(
(num_modes, self.phase_steps, *self.feedback.data_shape)
)
measurements = np.zeros((num_modes, self.phase_steps, *self.feedback.data_shape))

for m in range(num_modes):
phases = ref_phases.copy()
Expand Down
36 changes: 27 additions & 9 deletions tests/test_wfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str):
sim.slm.set_phases(-np.angle(result.t[:, :, b]))
I_opt[b] = feedback.read()[b]
t_correlation += abs(np.vdot(result.t[:, :, b], sim.t[:, :, b])) ** 2
t_norm += abs(np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b]))
t_norm += abs(
np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b])
)
t_correlation /= t_norm

# a correlation of 1 means optimal reconstruction of the N modulated modes, which may be less than the total number of inputs in the transmission matrix
Expand All @@ -95,11 +97,15 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str):
theoretical_t_correlation = theoretical_noise_fidelity * alg_fidelity
estimated_t_correlation = result.fidelity_noise * result.fidelity_calibration * alg_fidelity
tolerance = 2.0 / np.sqrt(M)
print(f"\nenhancement: \ttheoretical= {theoretical_enhancement},\testimated={estimated_enhancement},\tactual: {enhancement}")
print(
f"\nenhancement: \ttheoretical= {theoretical_enhancement},\testimated={estimated_enhancement},\tactual: {enhancement}"
)
print(
f"t-matrix fidelity:\ttheoretical = {theoretical_t_correlation},\testimated = {estimated_t_correlation},\tactual = {t_correlation}"
)
print(f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}")
print(
f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}"
)
print(f"comparing at relative tolerance: {tolerance}")

assert np.allclose(
Expand Down Expand Up @@ -146,7 +152,9 @@ def test_fourier2():
slm_shape = (1000, 1000)
aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0)
sim = SimulatedWFS(aberrations=aberrations)
alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3)
alg = FourierDualReference(
feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3
)
controller = WFSController(alg)
controller.wavefront = WFSController.State.SHAPED_WAVEFRONT
scaled_aberration = zoom(aberrations, np.array(slm_shape) / aberrations.shape)
Expand Down Expand Up @@ -174,7 +182,9 @@ def test_fourier_microscope():
)
cam = sim.get_camera(analog_max=100)
roi_detector = SingleRoi(cam, pos=(250, 250)) # Only measure that specific point
alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, phase_steps=3)
alg = FourierDualReference(
feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, phase_steps=3
)
controller = WFSController(alg)
controller.wavefront = WFSController.State.FLAT_WAVEFRONT
before = roi_detector.read()
Expand Down Expand Up @@ -274,7 +284,9 @@ def test_flat_wf_response_fourier(optimized_reference, step):
# test the optimized wavefront by checking if it has irregularities.
measured_aberrations = np.squeeze(np.angle(t))
measured_aberrations += aberrations[0, 0] - measured_aberrations[0, 0]
assert np.allclose(measured_aberrations, aberrations, atol=0.02) # The measured wavefront is not flat.
assert np.allclose(
measured_aberrations, aberrations, atol=0.02
) # The measured wavefront is not flat.


def test_flat_wf_response_ssa():
Expand All @@ -292,7 +304,9 @@ def test_flat_wf_response_ssa():

# Assert that the standard deviation of the optimized wavefront is below the threshold,
# indicating that it is effectively flat
assert np.std(optimised_wf) < 0.001, f"Response flat wavefront not flat, std: {np.std(optimised_wf)}"
assert (
np.std(optimised_wf) < 0.001
), f"Response flat wavefront not flat, std: {np.std(optimised_wf)}"


def test_multidimensional_feedback_ssa():
Expand Down Expand Up @@ -429,7 +443,9 @@ def test_custom_blind_dual_reference_ortho_split(type: str, shape):
plt.colorbar()
plt.show()

assert np.abs(field_correlation(sim.t, result.t)) > 0.99 # todo: find out why this is not higher
assert (
np.abs(field_correlation(sim.t, result.t)) > 0.99
) # todo: find out why this is not higher


def test_custom_blind_dual_reference_non_ortho():
Expand Down Expand Up @@ -464,7 +480,9 @@ def test_custom_blind_dual_reference_non_ortho():
# Create aberrations
x = np.linspace(-1, 1, 1 * N1).reshape((1, -1))
y = np.linspace(-1, 1, 1 * N1).reshape((-1, 1))
aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi)
aberrations = (
np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)
) % (2 * np.pi)
aberrations[0:1, :] = 0
aberrations[:, 0:2] = 0

Expand Down

0 comments on commit b7c6e73

Please sign in to comment.