From df3312f341bdcae3cd4faf93a4b355ef020ffca6 Mon Sep 17 00:00:00 2001 From: Daniel Cox Date: Tue, 1 Oct 2024 14:27:29 +0200 Subject: [PATCH] Add plot functionality --- .../algorithms/custom_iter_dual_reference.py | 220 ++++++++++++++++++ openwfs/plot_utilities.py | 156 +++++++++++++ tests/test_wfs.py | 28 +-- 3 files changed, 388 insertions(+), 16 deletions(-) create mode 100644 openwfs/algorithms/custom_iter_dual_reference.py diff --git a/openwfs/algorithms/custom_iter_dual_reference.py b/openwfs/algorithms/custom_iter_dual_reference.py new file mode 100644 index 0000000..73b083b --- /dev/null +++ b/openwfs/algorithms/custom_iter_dual_reference.py @@ -0,0 +1,220 @@ +from typing import Optional + +import numpy as np +from numpy import ndarray as nd + +from .utilities import analyze_phase_stepping, WFSResult +from ..core import Detector, PhaseSLM + + +def weighted_average(a, b, wa, wb): + """ + Compute the weighted average of two values. + + Args: + a: The first value. + b: The second value. + wa: The weight of the first value. + wb: The weight of the second value. + """ + return (a * wa + b * wb) / (wa + wb) + + +class IterativeDualReference: + """ + A generic iterative dual reference WFS algorithm, which can use a custom set of basis functions. + + This algorithm is adapted from [1], with the addition of the ability to use custom basis functions and specify the number of iterations. + + In this algorithm, the SLM pixels are divided into two groups: A and B, as indicated by the boolean group_mask argument. + The algorithm first keeps the pixels in group B fixed, and displays a sequence on patterns on the pixels of group A. + It uses these measurements to construct an optimized wavefront that is displayed on the pixels of group A. + This process is then repeated for the pixels of group B, now using the *optimized* wavefront on group A as reference. + Optionally, the process can be repeated for a number of iterations, which each iteration using the current correction + pattern as a reference. This makes this algorithm suitable for non-linear feedback, such as multi-photon + excitation fluorescence [2]. + + This algorithm assumes a phase-only SLM. Hence, the input modes are defined by passing the corresponding phase + patterns (in radians) as input argument. + + [1]: X. Tao, T. Lam, B. Zhu, et al., “Three-dimensional focusing through scattering media using conjugate adaptive + optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017). + + [2]: Gerwin Osnabrugge, Lyubov V. Amitonova, and Ivo M. Vellekoop. "Blind focusing through strongly scattering media + using wavefront shaping with nonlinear feedback", Optics Express, 27(8):11673–11688, 2019. + https://opg.optica.org/oe/ abstract.cfm?uri=oe-27-8-1167 + """ + + def __init__(self, feedback: Detector, slm: PhaseSLM, phase_patterns: tuple[nd, nd], group_mask: nd, + phase_steps: int = 4, iterations: int = 4, analyzer: Optional[callable] = analyze_phase_stepping): + """ + Args: + feedback: The feedback source, usually a detector that provides measurement data. + slm: Spatial light modulator object. + phase_patterns: A tuple of two 3D arrays, containing the phase patterns for group A and group B, respectively. + The first two dimensions are the spatial dimensions, and should match the size of group_mask. + The 3rd dimension in the array is index of the phase pattern. The number of phase patterns in A and B may be different. + group_mask: A 2D bool array of that defines the pixels used by group A with False and elements used by + group B with True. + phase_steps: The number of phase steps for each mode (default is 4). Depending on the type of + non-linear feedback and the SNR, more might be required. + iterations: Number of times to measure a mode set, e.g. when iterations = 5, the measurements are + A, B, A, B, A. Should be at least 2 + analyzer: The function used to analyze the phase stepping data. Must return a WFSResult object. Defaults to `analyze_phase_stepping` + """ + if (phase_patterns[0].shape[0:2] != group_mask.shape) or (phase_patterns[1].shape[0:2] != group_mask.shape): + raise ValueError("The phase patterns and group mask must all have the same shape.") + if iterations < 2: + raise ValueError("The number of iterations must be at least 2.") + if np.prod(feedback.data_shape) != 1: + raise ValueError("The feedback detector should return a single scalar value.") + + self.slm = slm + self.feedback = feedback + self.phase_steps = phase_steps + self.iterations = iterations + self.analyzer = analyzer + self.phase_patterns = (phase_patterns[0].astype(np.float32), phase_patterns[1].astype(np.float32)) + mask = group_mask.astype(bool) + self.masks = (~mask, mask) # masks[0] is True for group A, mask[1] is True for group B + + # Pre-compute the conjugate modes for reconstruction + self.modes = [np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in + range(2)] + + def execute(self, capture_intermediate_results: bool = False, progress_bar=None) -> WFSResult: + """ + Executes the blind focusing dual reference algorithm and compute the SLM transmission matrix. + capture_intermediate_results: When True, measures the feedback from the optimized wavefront after each iteration. + This can be useful to determine how many iterations are needed to converge to an optimal pattern. + This data is stored as the 'intermediate_results' field in the results + progress_bar: Optional progress bar object. Following the convention for tqdm progress bars, + this object should have a `total` attribute and an `update()` function. + + Returns: + WFSResult: An object containing the computed SLM transmission matrix and related data. The amplitude profile + of each mode is assumed to be 1. If a different amplitude profile is desired, this can be obtained by + multiplying that amplitude profile with this transmission matrix. + """ + + # Current estimate of the transmission matrix (start with all 0) + t_full = np.zeros(shape=self.modes[0].shape[0:2]) + t_other_side = t_full + + # Initialize storage lists + t_set_all = [None] * self.iterations + results_all = [None] * self.iterations # List to store all results + results_latest = [None, None] # The two latest results. Used for computing fidelity factors. + intermediate_results = np.zeros(self.iterations) # List to store feedback from full patterns + + # Prepare progress bar + if progress_bar: + num_measurements = np.ceil(self.iterations / 2) * self.modes[0].shape[2] \ + + np.floor(self.iterations / 2) * self.modes[1].shape[2] + progress_bar.total = num_measurements + + # Switch the phase sets back and forth multiple times + for it in range(self.iterations): + side = it % 2 # pick set A or B for phase stepping + ref_phases = -np.angle(t_full) # use the best estimate so far to construct an optimized reference + side_mask = self.masks[side] + # Perform WFS experiment on one side, keeping the other side sized at the ref_phases + result = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases, + mod_mask=side_mask, progress_bar=progress_bar) + + # Compute transmission matrix for the current side and update + # estimated transmission matrix + t_this_side = self.compute_t_set(result, self.modes[side]) + t_full = t_this_side + t_other_side + t_other_side = t_this_side + + # Store results + t_set_all[it] = t_this_side # Store transmission matrix + results_all[it] = result # Store result + results_latest[side] = result # Store latest result for this set + + # Try full pattern + if capture_intermediate_results: + self.slm.set_phases(-np.angle(t_full)) + intermediate_results[it] = self.feedback.read() + + # Compute average fidelity factors + fidelity_noise = weighted_average(results_latest[0].fidelity_noise, + results_latest[1].fidelity_noise, results_latest[0].n, + results_latest[1].n) + fidelity_amplitude = weighted_average(results_latest[0].fidelity_amplitude, + results_latest[1].fidelity_amplitude, results_latest[0].n, + results_latest[1].n) + fidelity_calibration = weighted_average(results_latest[0].fidelity_calibration, + results_latest[1].fidelity_calibration, results_latest[0].n, + results_latest[1].n) + + result = WFSResult(t=t_full, + t_f=None, + n=self.modes[0].shape[2] + self.modes[1].shape[2], + axis=2, + fidelity_noise=fidelity_noise, + fidelity_amplitude=fidelity_amplitude, + fidelity_calibration=fidelity_calibration) + + # TODO: document the t_set_all and results_all attributes + result.t_set_all = t_set_all + result.results_all = results_all + result.intermediate_results = intermediate_results + return result + + def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, + progress_bar=None) -> WFSResult: + """ + Conducts experiments on one part of the SLM. + + Args: + mod_phases: 3D array containing the phase patterns of each mode. Axis 0 and 1 are used as spatial axis. + Axis 2 is used for the 'phase pattern index' or 'mode index'. + ref_phases: 2D array containing the reference phase pattern. + mod_mask: 2D array containing a boolean mask, where True indicates the modulated part of the SLM. + progress_bar: Optional progress bar object. Following the convention for tqdm progress bars, + this object should have a `total` attribute and an `update()` function. + + Returns: + WFSResult: An object containing the computed SLM transmission matrix and related data. + """ + num_modes = mod_phases.shape[2] + measurements = np.zeros((num_modes, self.phase_steps)) + + for m in range(num_modes): + phases = ref_phases.copy() + modulated = mod_phases[:, :, m] + for p in range(self.phase_steps): + phi = p * 2 * np.pi / self.phase_steps + # set the modulated pixel values to the values corresponding to mode m and phase offset phi + phases[mod_mask] = modulated[mod_mask] + phi + self.slm.set_phases(phases) + self.feedback.trigger(out=measurements[m, p, ...]) + + if progress_bar is not None: + progress_bar.update() + + self.feedback.wait() + return self.analyzer(measurements, axis=1) + + @staticmethod + def compute_t_set(wfs_result: WFSResult, mode_set: nd) -> nd: + """ + Compute the transmission matrix in SLM space from transmission matrix in input mode space. + + Note 1: This function computes the transmission matrix for one mode set, and thus returns one part of the full + transmission matrix. The elements that are not part of the mode set will be 0. The full transmission matrix can + be obtained by simply adding the parts, i.e. t_full = t_set0 + t_set1. + + Note 2: As this is a blind focusing WFS algorithm, there may be only one target or 'output mode'. + + Args: + wfs_result (WFSResult): The result of the WFS algorithm. This contains the transmission matrix in the space + of input modes. + mode_set: 3D array with set of modes. + """ + t = wfs_result.t.squeeze().reshape((1, 1, mode_set.shape[2])) + norm_factor = np.prod(mode_set.shape[0:2]) + t_set = (t * mode_set).sum(axis=2) / norm_factor + return t_set diff --git a/openwfs/plot_utilities.py b/openwfs/plot_utilities.py index 1a9a6bd..e11d1e8 100644 --- a/openwfs/plot_utilities.py +++ b/openwfs/plot_utilities.py @@ -1,5 +1,11 @@ +from typing import Tuple + +import numpy as np +from numpy import ndarray as nd from astropy import units as u from matplotlib import pyplot as plt +from matplotlib.colors import hsv_to_rgb +from matplotlib.axes import Axes from .core import Detector from .utilities import get_extent @@ -48,3 +54,153 @@ def scale_prefix(value: u.Quantity) -> u.Quantity: return value.to(u.s) else: return value + + +def slope_step(a: nd, width: nd | float) -> nd: + """ + A sloped step function from 0 to 1. + + Args: + a: Input array + width: width of the sloped step. + + Returns: + An array the size of a, with the result of the sloped step function. + """ + return (a >= width) + a/width * (0 < a) * (a < width) + + +def linear_blend(a: nd, b: nd, blend: nd | float) -> nd: + """ + Return a linear, element-wise blend between two arrays a and b. + + Args: + a: Input array a. + b: Input array b. + blend: Blend factor. Value of 1.0 -> return a. Value of 0.0 -> return b. + + Returns: + A linear combination of a and b, corresponding to the blend factor. a*blend + b*(1-blend) + """ + return a*blend + b*(1-blend) + + +def complex_to_rgb(array: nd, scale: float | nd | None = None, axis: int = 2) -> nd: + """ + Generate RGB color values to represent values of a complex array. + + The complex values are mapped to HSV colorspace and then converted to RGB. Hue represents phase and Value represents + amplitude. Saturation is set to 1. + + Args: + array: Array to create RGB values for. + scale: Scaling factor for the array values. When None, scale = 1/max(abs(array)) is used. + axis: Array axis to use for the RGB dimension. + + Returns: + An RGB array representing the complex input array. + """ + if scale is None: + scale = 1 / np.max(abs(array)) + h = np.expand_dims(np.angle(array) / (2 * np.pi) + 0.5, axis=axis) + s = np.ones_like(h) + v = np.expand_dims(np.abs(array) * scale, axis=axis).clip(min=0, max=1) + hsv = np.concatenate((h, s, v), axis=axis) + rgb = hsv_to_rgb(hsv) + return rgb + + +def plot_field(array, scale: float | nd | None = None, imshow_kwargs: dict | None = None): + """ + Plot a complex array as an RGB image. + + The phase is represented by the hue, and the magnitude by the value, i.e. black = zero, brightness shows amplitude, + and the colors represent the phase. + + Args: + array(ndarray): complex array to be plotted. + scale(float): scaling factor for the magnitude. The final value is clipped to the range [0, 1]. + imshow_kwargs: Keyword arguments for matplotlib's imshow. + """ + if imshow_kwargs is None: + imshow_kwargs = {} + rgb = complex_to_rgb(array, scale) + plt.imshow(rgb, **imshow_kwargs) + + +def plot_scatter_field(x, y, array, scale, scatter_kwargs=None): + """ + Plot complex scattered data as RGB values. + """ + if scatter_kwargs is None: + scatter_kwargs = {'s': 80} + rgb = complex_to_rgb(array, scale, axis=1) + plt.scatter(x, y, c=rgb, **scatter_kwargs) + + +def complex_colorbar(scale, width_inverse: int = 15): + """ + Create an rgb colorbar for complex numbers and return its Axes handle. + """ + amp = np.linspace(0, 1.01, 10).reshape((1, -1)) + phase = np.linspace(0, 249 / 250 * 2 * np.pi, 250).reshape(-1, 1) - np.pi + z = amp * np.exp(1j * phase) + rgb = complex_to_rgb(z, 1) + ax = plt.subplot(1, width_inverse, width_inverse) + plt.imshow(rgb, aspect='auto', extent=(0, scale, -np.pi, np.pi)) + + # Ticks and labels + ax.set_yticks((-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi), ('$-\\pi$', '$-\\pi/2$', '0', '$\\pi/2$', '$\\pi$')) + ax.set_xlabel('amp.') + ax.set_ylabel('phase (rad)') + ax.yaxis.tick_right() + ax.yaxis.set_label_position("right") + return ax + + +def complex_colorwheel(ax: Axes = None, shape: Tuple[int, int] = (100, 100), imshow_kwargs: dict = {}, + arrow_props: dict = {}, text_kwargs: dict = {}, amplitude_str: str = 'A', + phase_str: str = '$\\phi$'): + """ + Create an rgb image for a colorwheel representing the complex unit circle. + + Args: + ax: Matplotlib Axes. + shape: Number of pixels in each dimension. + imshow_kwargs: Keyword arguments for matplotlib's imshow. + arrow_props: Keyword arguments for the arrows. + text_kwargs: Keyword arguments for the text labels. + amplitude_str: Text label for the amplitude arrow. + phase_str: Text label for the phase arrow. + + Returns: + rgb_wheel: rgb image of the colorwheel. + """ + if ax is None: + ax = plt.gca() + + x = np.linspace(-1, 1, shape[1]).reshape(1, -1) + y = np.linspace(-1, 1, shape[0]).reshape(-1, 1) + z = x + 1j*y + rgb = complex_to_rgb(z, scale=1) + step_width = 1.5 / shape[1] + blend = np.expand_dims(slope_step(1 - np.abs(z) - step_width, width=step_width), axis=2) + rgba_wheel = np.concatenate((rgb, blend), axis=2) + ax.imshow(rgba_wheel, extent=(-1, 1, -1, 1), **imshow_kwargs) + + # Add arrows with annotations + ax.annotate('', xy=(-0.98/np.sqrt(2),)*2, xytext=(0, 0), arrowprops={'color': 'white', 'width': 1.8, + 'headwidth': 5.0, 'headlength': 6.0, **arrow_props}) + ax.text(**{'x': -0.4, 'y': -0.8, 's': amplitude_str, 'color': 'white', 'fontsize': 15, **text_kwargs}) + ax.annotate('', xy=(0, 0.9), xytext=(0.9, 0), + arrowprops={'connectionstyle': 'arc3,rad=0.4', 'color': 'white', 'width': 1.8, 'headwidth': 5.0, + 'headlength': 6.0, **arrow_props}) + ax.text(**{'x': 0.1, 'y': 0.5, 's': phase_str, 'color': 'white', 'fontsize': 15, **text_kwargs}) + + # Hide axes spines and ticks + ax.set_xticks([]) + ax.set_yticks([]) + ax.spines['left'].set_visible(False) + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + ax.spines['bottom'].set_visible(False) diff --git a/tests/test_wfs.py b/tests/test_wfs.py index f134287..f2ddce5 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -9,8 +9,9 @@ from ..openwfs.algorithms.troubleshoot import field_correlation from ..openwfs.algorithms.utilities import WFSController from ..openwfs.processors import SingleRoi -from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope from ..openwfs.simulation.mockdevices import GaussianNoise +from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope +from ..openwfs.plot_utilities import plot_field @pytest.mark.parametrize("shape", [(4, 7), (10, 7), (20, 31)]) @@ -436,15 +437,13 @@ def test_custom_blind_dual_reference_non_ortho(): """ Test custom blind dual reference with a non-orthogonal basis. """ - do_debug = False + do_debug = True # Create set of modes that are barely linearly independent N1 = 6 N2 = 3 M = N1 * N2 - mode_set_half = (1 / M) * ( - 1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M)) - ) + mode_set_half = (1 / M) * (1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M))) + (1/M) mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1) phases_set = np.angle(mode_set) mask = np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1) @@ -456,8 +455,8 @@ def test_custom_blind_dual_reference_non_ortho(): plt.figure(figsize=(12, 7)) for m in range(M): plt.subplot(N2, N1, m + 1) - plt.imshow(phases_set[:, :, m], vmin=-np.pi, vmax=np.pi) - plt.title(f"m={m}") + plot_field(mode_set[:, :, m]) + plt.title(f'm={m}') plt.xticks([]) plt.yticks([]) plt.pause(0.01) @@ -489,16 +488,13 @@ def test_custom_blind_dual_reference_non_ortho(): if do_debug: plt.figure() - plt.imshow( - np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap="hsv" - ) - plt.title("Aberrations") - plt.colorbar() + plt.subplot(1, 2, 1) + plot_field(np.exp(1j * aberrations)) + plt.title('Aberrations') - plt.figure() - plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") - plt.title("t") - plt.colorbar() + plt.subplot(1, 2, 2) + plot_field(result.t) + plt.title('t') plt.show() assert np.abs(field_correlation(np.exp(1j * aberrations), result.t)) > 0.999