Skip to content

Commit

Permalink
Add plot functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
dedean16 committed Oct 2, 2024
1 parent 52c4334 commit df3312f
Show file tree
Hide file tree
Showing 3 changed files with 388 additions and 16 deletions.
220 changes: 220 additions & 0 deletions openwfs/algorithms/custom_iter_dual_reference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
from typing import Optional

import numpy as np
from numpy import ndarray as nd

from .utilities import analyze_phase_stepping, WFSResult
from ..core import Detector, PhaseSLM


def weighted_average(a, b, wa, wb):
"""
Compute the weighted average of two values.
Args:
a: The first value.
b: The second value.
wa: The weight of the first value.
wb: The weight of the second value.
"""
return (a * wa + b * wb) / (wa + wb)


class IterativeDualReference:
"""
A generic iterative dual reference WFS algorithm, which can use a custom set of basis functions.
This algorithm is adapted from [1], with the addition of the ability to use custom basis functions and specify the number of iterations.
In this algorithm, the SLM pixels are divided into two groups: A and B, as indicated by the boolean group_mask argument.
The algorithm first keeps the pixels in group B fixed, and displays a sequence on patterns on the pixels of group A.
It uses these measurements to construct an optimized wavefront that is displayed on the pixels of group A.
This process is then repeated for the pixels of group B, now using the *optimized* wavefront on group A as reference.
Optionally, the process can be repeated for a number of iterations, which each iteration using the current correction
pattern as a reference. This makes this algorithm suitable for non-linear feedback, such as multi-photon
excitation fluorescence [2].
This algorithm assumes a phase-only SLM. Hence, the input modes are defined by passing the corresponding phase
patterns (in radians) as input argument.
[1]: X. Tao, T. Lam, B. Zhu, et al., “Three-dimensional focusing through scattering media using conjugate adaptive
optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017).
[2]: Gerwin Osnabrugge, Lyubov V. Amitonova, and Ivo M. Vellekoop. "Blind focusing through strongly scattering media
using wavefront shaping with nonlinear feedback", Optics Express, 27(8):11673–11688, 2019.
https://opg.optica.org/oe/ abstract.cfm?uri=oe-27-8-1167
"""

def __init__(self, feedback: Detector, slm: PhaseSLM, phase_patterns: tuple[nd, nd], group_mask: nd,
phase_steps: int = 4, iterations: int = 4, analyzer: Optional[callable] = analyze_phase_stepping):
"""
Args:
feedback: The feedback source, usually a detector that provides measurement data.
slm: Spatial light modulator object.
phase_patterns: A tuple of two 3D arrays, containing the phase patterns for group A and group B, respectively.
The first two dimensions are the spatial dimensions, and should match the size of group_mask.
The 3rd dimension in the array is index of the phase pattern. The number of phase patterns in A and B may be different.
group_mask: A 2D bool array of that defines the pixels used by group A with False and elements used by
group B with True.
phase_steps: The number of phase steps for each mode (default is 4). Depending on the type of
non-linear feedback and the SNR, more might be required.
iterations: Number of times to measure a mode set, e.g. when iterations = 5, the measurements are
A, B, A, B, A. Should be at least 2
analyzer: The function used to analyze the phase stepping data. Must return a WFSResult object. Defaults to `analyze_phase_stepping`
"""
if (phase_patterns[0].shape[0:2] != group_mask.shape) or (phase_patterns[1].shape[0:2] != group_mask.shape):
raise ValueError("The phase patterns and group mask must all have the same shape.")
if iterations < 2:
raise ValueError("The number of iterations must be at least 2.")
if np.prod(feedback.data_shape) != 1:
raise ValueError("The feedback detector should return a single scalar value.")

self.slm = slm
self.feedback = feedback
self.phase_steps = phase_steps
self.iterations = iterations
self.analyzer = analyzer
self.phase_patterns = (phase_patterns[0].astype(np.float32), phase_patterns[1].astype(np.float32))
mask = group_mask.astype(bool)
self.masks = (~mask, mask) # masks[0] is True for group A, mask[1] is True for group B

# Pre-compute the conjugate modes for reconstruction
self.modes = [np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in
range(2)]

def execute(self, capture_intermediate_results: bool = False, progress_bar=None) -> WFSResult:
"""
Executes the blind focusing dual reference algorithm and compute the SLM transmission matrix.
capture_intermediate_results: When True, measures the feedback from the optimized wavefront after each iteration.
This can be useful to determine how many iterations are needed to converge to an optimal pattern.
This data is stored as the 'intermediate_results' field in the results
progress_bar: Optional progress bar object. Following the convention for tqdm progress bars,
this object should have a `total` attribute and an `update()` function.
Returns:
WFSResult: An object containing the computed SLM transmission matrix and related data. The amplitude profile
of each mode is assumed to be 1. If a different amplitude profile is desired, this can be obtained by
multiplying that amplitude profile with this transmission matrix.
"""

# Current estimate of the transmission matrix (start with all 0)
t_full = np.zeros(shape=self.modes[0].shape[0:2])
t_other_side = t_full

# Initialize storage lists
t_set_all = [None] * self.iterations
results_all = [None] * self.iterations # List to store all results
results_latest = [None, None] # The two latest results. Used for computing fidelity factors.
intermediate_results = np.zeros(self.iterations) # List to store feedback from full patterns

# Prepare progress bar
if progress_bar:
num_measurements = np.ceil(self.iterations / 2) * self.modes[0].shape[2] \
+ np.floor(self.iterations / 2) * self.modes[1].shape[2]
progress_bar.total = num_measurements

# Switch the phase sets back and forth multiple times
for it in range(self.iterations):
side = it % 2 # pick set A or B for phase stepping
ref_phases = -np.angle(t_full) # use the best estimate so far to construct an optimized reference
side_mask = self.masks[side]
# Perform WFS experiment on one side, keeping the other side sized at the ref_phases
result = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases,
mod_mask=side_mask, progress_bar=progress_bar)

# Compute transmission matrix for the current side and update
# estimated transmission matrix
t_this_side = self.compute_t_set(result, self.modes[side])
t_full = t_this_side + t_other_side
t_other_side = t_this_side

# Store results
t_set_all[it] = t_this_side # Store transmission matrix
results_all[it] = result # Store result
results_latest[side] = result # Store latest result for this set

# Try full pattern
if capture_intermediate_results:
self.slm.set_phases(-np.angle(t_full))
intermediate_results[it] = self.feedback.read()

# Compute average fidelity factors
fidelity_noise = weighted_average(results_latest[0].fidelity_noise,
results_latest[1].fidelity_noise, results_latest[0].n,
results_latest[1].n)
fidelity_amplitude = weighted_average(results_latest[0].fidelity_amplitude,
results_latest[1].fidelity_amplitude, results_latest[0].n,
results_latest[1].n)
fidelity_calibration = weighted_average(results_latest[0].fidelity_calibration,
results_latest[1].fidelity_calibration, results_latest[0].n,
results_latest[1].n)

result = WFSResult(t=t_full,
t_f=None,
n=self.modes[0].shape[2] + self.modes[1].shape[2],
axis=2,
fidelity_noise=fidelity_noise,
fidelity_amplitude=fidelity_amplitude,
fidelity_calibration=fidelity_calibration)

# TODO: document the t_set_all and results_all attributes
result.t_set_all = t_set_all
result.results_all = results_all
result.intermediate_results = intermediate_results
return result

def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd,
progress_bar=None) -> WFSResult:
"""
Conducts experiments on one part of the SLM.
Args:
mod_phases: 3D array containing the phase patterns of each mode. Axis 0 and 1 are used as spatial axis.
Axis 2 is used for the 'phase pattern index' or 'mode index'.
ref_phases: 2D array containing the reference phase pattern.
mod_mask: 2D array containing a boolean mask, where True indicates the modulated part of the SLM.
progress_bar: Optional progress bar object. Following the convention for tqdm progress bars,
this object should have a `total` attribute and an `update()` function.
Returns:
WFSResult: An object containing the computed SLM transmission matrix and related data.
"""
num_modes = mod_phases.shape[2]
measurements = np.zeros((num_modes, self.phase_steps))

for m in range(num_modes):
phases = ref_phases.copy()
modulated = mod_phases[:, :, m]
for p in range(self.phase_steps):
phi = p * 2 * np.pi / self.phase_steps
# set the modulated pixel values to the values corresponding to mode m and phase offset phi
phases[mod_mask] = modulated[mod_mask] + phi
self.slm.set_phases(phases)
self.feedback.trigger(out=measurements[m, p, ...])

if progress_bar is not None:
progress_bar.update()

self.feedback.wait()
return self.analyzer(measurements, axis=1)

@staticmethod
def compute_t_set(wfs_result: WFSResult, mode_set: nd) -> nd:
"""
Compute the transmission matrix in SLM space from transmission matrix in input mode space.
Note 1: This function computes the transmission matrix for one mode set, and thus returns one part of the full
transmission matrix. The elements that are not part of the mode set will be 0. The full transmission matrix can
be obtained by simply adding the parts, i.e. t_full = t_set0 + t_set1.
Note 2: As this is a blind focusing WFS algorithm, there may be only one target or 'output mode'.
Args:
wfs_result (WFSResult): The result of the WFS algorithm. This contains the transmission matrix in the space
of input modes.
mode_set: 3D array with set of modes.
"""
t = wfs_result.t.squeeze().reshape((1, 1, mode_set.shape[2]))
norm_factor = np.prod(mode_set.shape[0:2])
t_set = (t * mode_set).sum(axis=2) / norm_factor
return t_set
156 changes: 156 additions & 0 deletions openwfs/plot_utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from typing import Tuple

import numpy as np
from numpy import ndarray as nd
from astropy import units as u
from matplotlib import pyplot as plt
from matplotlib.colors import hsv_to_rgb
from matplotlib.axes import Axes

from .core import Detector
from .utilities import get_extent
Expand Down Expand Up @@ -48,3 +54,153 @@ def scale_prefix(value: u.Quantity) -> u.Quantity:
return value.to(u.s)
else:
return value


def slope_step(a: nd, width: nd | float) -> nd:
"""
A sloped step function from 0 to 1.
Args:
a: Input array
width: width of the sloped step.
Returns:
An array the size of a, with the result of the sloped step function.
"""
return (a >= width) + a/width * (0 < a) * (a < width)


def linear_blend(a: nd, b: nd, blend: nd | float) -> nd:
"""
Return a linear, element-wise blend between two arrays a and b.
Args:
a: Input array a.
b: Input array b.
blend: Blend factor. Value of 1.0 -> return a. Value of 0.0 -> return b.
Returns:
A linear combination of a and b, corresponding to the blend factor. a*blend + b*(1-blend)
"""
return a*blend + b*(1-blend)


def complex_to_rgb(array: nd, scale: float | nd | None = None, axis: int = 2) -> nd:
"""
Generate RGB color values to represent values of a complex array.
The complex values are mapped to HSV colorspace and then converted to RGB. Hue represents phase and Value represents
amplitude. Saturation is set to 1.
Args:
array: Array to create RGB values for.
scale: Scaling factor for the array values. When None, scale = 1/max(abs(array)) is used.
axis: Array axis to use for the RGB dimension.
Returns:
An RGB array representing the complex input array.
"""
if scale is None:
scale = 1 / np.max(abs(array))
h = np.expand_dims(np.angle(array) / (2 * np.pi) + 0.5, axis=axis)
s = np.ones_like(h)
v = np.expand_dims(np.abs(array) * scale, axis=axis).clip(min=0, max=1)
hsv = np.concatenate((h, s, v), axis=axis)
rgb = hsv_to_rgb(hsv)
return rgb


def plot_field(array, scale: float | nd | None = None, imshow_kwargs: dict | None = None):
"""
Plot a complex array as an RGB image.
The phase is represented by the hue, and the magnitude by the value, i.e. black = zero, brightness shows amplitude,
and the colors represent the phase.
Args:
array(ndarray): complex array to be plotted.
scale(float): scaling factor for the magnitude. The final value is clipped to the range [0, 1].
imshow_kwargs: Keyword arguments for matplotlib's imshow.
"""
if imshow_kwargs is None:
imshow_kwargs = {}
rgb = complex_to_rgb(array, scale)
plt.imshow(rgb, **imshow_kwargs)


def plot_scatter_field(x, y, array, scale, scatter_kwargs=None):
"""
Plot complex scattered data as RGB values.
"""
if scatter_kwargs is None:
scatter_kwargs = {'s': 80}
rgb = complex_to_rgb(array, scale, axis=1)
plt.scatter(x, y, c=rgb, **scatter_kwargs)


def complex_colorbar(scale, width_inverse: int = 15):
"""
Create an rgb colorbar for complex numbers and return its Axes handle.
"""
amp = np.linspace(0, 1.01, 10).reshape((1, -1))
phase = np.linspace(0, 249 / 250 * 2 * np.pi, 250).reshape(-1, 1) - np.pi
z = amp * np.exp(1j * phase)
rgb = complex_to_rgb(z, 1)
ax = plt.subplot(1, width_inverse, width_inverse)
plt.imshow(rgb, aspect='auto', extent=(0, scale, -np.pi, np.pi))

# Ticks and labels
ax.set_yticks((-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi), ('$-\\pi$', '$-\\pi/2$', '0', '$\\pi/2$', '$\\pi$'))
ax.set_xlabel('amp.')
ax.set_ylabel('phase (rad)')
ax.yaxis.tick_right()
ax.yaxis.set_label_position("right")
return ax


def complex_colorwheel(ax: Axes = None, shape: Tuple[int, int] = (100, 100), imshow_kwargs: dict = {},
arrow_props: dict = {}, text_kwargs: dict = {}, amplitude_str: str = 'A',
phase_str: str = '$\\phi$'):
"""
Create an rgb image for a colorwheel representing the complex unit circle.
Args:
ax: Matplotlib Axes.
shape: Number of pixels in each dimension.
imshow_kwargs: Keyword arguments for matplotlib's imshow.
arrow_props: Keyword arguments for the arrows.
text_kwargs: Keyword arguments for the text labels.
amplitude_str: Text label for the amplitude arrow.
phase_str: Text label for the phase arrow.
Returns:
rgb_wheel: rgb image of the colorwheel.
"""
if ax is None:
ax = plt.gca()

x = np.linspace(-1, 1, shape[1]).reshape(1, -1)
y = np.linspace(-1, 1, shape[0]).reshape(-1, 1)
z = x + 1j*y
rgb = complex_to_rgb(z, scale=1)
step_width = 1.5 / shape[1]
blend = np.expand_dims(slope_step(1 - np.abs(z) - step_width, width=step_width), axis=2)
rgba_wheel = np.concatenate((rgb, blend), axis=2)
ax.imshow(rgba_wheel, extent=(-1, 1, -1, 1), **imshow_kwargs)

# Add arrows with annotations
ax.annotate('', xy=(-0.98/np.sqrt(2),)*2, xytext=(0, 0), arrowprops={'color': 'white', 'width': 1.8,
'headwidth': 5.0, 'headlength': 6.0, **arrow_props})
ax.text(**{'x': -0.4, 'y': -0.8, 's': amplitude_str, 'color': 'white', 'fontsize': 15, **text_kwargs})
ax.annotate('', xy=(0, 0.9), xytext=(0.9, 0),
arrowprops={'connectionstyle': 'arc3,rad=0.4', 'color': 'white', 'width': 1.8, 'headwidth': 5.0,
'headlength': 6.0, **arrow_props})
ax.text(**{'x': 0.1, 'y': 0.5, 's': phase_str, 'color': 'white', 'fontsize': 15, **text_kwargs})

# Hide axes spines and ticks
ax.set_xticks([])
ax.set_yticks([])
ax.spines['left'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
ax.spines['bottom'].set_visible(False)
Loading

0 comments on commit df3312f

Please sign in to comment.