From 066f78e31531ee011694996a5a47b59e4738d6f7 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Thu, 26 Sep 2024 18:06:34 +0200 Subject: [PATCH 01/15] reviewing IterativeDualReference algorithm --- openwfs/algorithms/__init__.py | 6 +- .../algorithms/custom_iter_dual_reference.py | 160 ++++++++---------- pyproject.toml | 2 +- tests/test_wfs.py | 18 +- 4 files changed, 86 insertions(+), 100 deletions(-) diff --git a/openwfs/algorithms/__init__.py b/openwfs/algorithms/__init__.py index dd7e333..066a232 100644 --- a/openwfs/algorithms/__init__.py +++ b/openwfs/algorithms/__init__.py @@ -1,5 +1,5 @@ -from .ssa import StepwiseSequential -from .fourier import FourierBase from .basic_fourier import FourierDualReference, FourierDualReferenceCircle +from .custom_iter_dual_reference import IterativeDualReference +from .fourier import FourierBase +from .ssa import StepwiseSequential from .troubleshoot import troubleshoot -from .custom_iter_dual_reference import CustomIterativeDualReference diff --git a/openwfs/algorithms/custom_iter_dual_reference.py b/openwfs/algorithms/custom_iter_dual_reference.py index 5519595..521441a 100644 --- a/openwfs/algorithms/custom_iter_dual_reference.py +++ b/openwfs/algorithms/custom_iter_dual_reference.py @@ -1,9 +1,7 @@ from typing import Optional -import matplotlib.pyplot as plt import numpy as np from numpy import ndarray as nd -from tqdm import tqdm from .utilities import analyze_phase_stepping, WFSResult from ..core import Detector, PhaseSLM @@ -22,83 +20,76 @@ def weighted_average(a, b, wa, wb): return (a * wa + b * wb) / (wa + wb) -class CustomIterativeDualReference: +class IterativeDualReference: """ A generic iterative dual reference WFS algorithm, which can use a custom set of basis functions. - Similar to the Fourier Dual Reference algorithm [1], the SLM is divided in two large segments (e.g. both halves, - split in the middle). The blind focusing dual reference algorithm switches back and forth multiple times between two - large segments of the SLM (A and B). The segment shape is defined with a binary mask. Each segment has a - corresponding set of phase patterns to measure. With these measurements, a correction pattern for one segment can - be computed. To achieve convergence or 'blind focusing' [2], in each iteration we use the previously constructed - correction pattern as reference. This makes this algorithm suitable for non-linear feedback, such as multi-photon - excitation fluorescence, and unsuitable for multi-target optimization. + This algorithm is adapted from [1], with the addition of the ability to use custom basis functions and specify the number of iterations. + + In this algorithm, the SLM pixels are divided into two groups: A and B, as indicated by the boolean group_mask argument. + The algorithm first keeps the pixels in group B fixed, and displays a sequence on patterns on the pixels of group A. + It uses these measurements to construct an optimized wavefront that is displayed on the pixels of group A. + This process is then repeated for the pixels of group B, now using the *optimized* wavefront on group A as reference. + Optionally, the process can be repeated for a number of iterations, which each iteration using the current correction + pattern as a reference. This makes this algorithm suitable for non-linear feedback, such as multi-photon + excitation fluorescence [2]. This algorithm assumes a phase-only SLM. Hence, the input modes are defined by passing the corresponding phase - patterns as input argument. + patterns (in radians) as input argument. + + [1]: X. Tao, T. Lam, B. Zhu, et al., “Three-dimensional focusing through scattering media using conjugate adaptive + optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017). - [1]: Bahareh Mastiani, Gerwin Osnabrugge, and Ivo M. Vellekoop, - "Wavefront shaping for forward scattering", Optics Express 30, 37436-37445 (2022) [2]: Gerwin Osnabrugge, Lyubov V. Amitonova, and Ivo M. Vellekoop. "Blind focusing through strongly scattering media using wavefront shaping with nonlinear feedback", Optics Express, 27(8):11673–11688, 2019. https://opg.optica.org/oe/ abstract.cfm?uri=oe-27-8-1167 """ - def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape: tuple[int, int], phases: tuple[nd, nd], set1_mask: - nd, phase_steps: int = 4, iterations: int = 4, analyzer: Optional[callable] = analyze_phase_stepping, - do_try_full_patterns=False, do_progress_bar=True, progress_bar_kwargs={}): + def __init__(self, feedback: Detector, slm: PhaseSLM, phase_patterns: tuple[nd, nd], group_mask: nd, + phase_steps: int = 4, iterations: int = 4, analyzer: Optional[callable] = analyze_phase_stepping): """ Args: - feedback (Detector): The feedback source, usually a detector that provides measurement data. - slm (PhaseSLM): slm object. - The slm may have the `extent` property set to indicate the extent of the back pupil of the microscope - objective in slm coordinates. By default, a value of 2.0, 2.0 is used (indicating that the pupil - corresponds to a circle of radius 1.0 on the SLM). However, to prevent artefacts at the edges of the - SLM,it may be overfilled, such that the `phases` image is mapped to an extent of e. g. (2.2, 2.2), - i. e. 10% larger than the back pupil. - slm_shape (tuple[int, int]): The shape of the SLM patterns and transmission matrices. - phases (tuple): A tuple of two 3D arrays. We will refer to these as set A and B (phases[0] and - phases[1] respectively). The 3D arrays contain the set of phases to measure set A and B. With both of - these 3D arrays, axis 0 and 1 are used as spatial axes. Axis 2 is used as phase pattern index. - E.g. phases[1][:, :, 4] is the 4th phase pattern of set B. - set1_mask: A 2D array of that defines the elements used by set A (= modes[0]) with 0s and elements used by - set B (= modes[1]) with 1s. - phase_steps (int): The number of phase steps for each mode (default is 4). Depending on the type of + feedback: The feedback source, usually a detector that provides measurement data. + slm: Spatial light modulator object. + phase_patterns: A tuple of two 3D arrays, containing the phase patterns for group A and group B, respectively. + The first two dimensions are the spatial dimensions, and should match the size of group_mask. + The 3rd dimension in the array is index of the phase pattern. The number of phase patterns in A and B may be different. + group_mask: A 2D bool array of that defines the pixels used by group A with False and elements used by + group B with True. + phase_steps: The number of phase steps for each mode (default is 4). Depending on the type of non-linear feedback and the SNR, more might be required. - iterations (int): Number of times to measure a mode set, e.g. when iterations = 5, the measurements are - A, B, A, B, A. - analyzer (callable): The function used to analyze the phase stepping data. Must return a WFSResult object. - do_try_full_patterns (bool): Whether to measure feedback from the full patterns each iteration. This can - be useful to determine how many iterations are needed to converge to an optimal pattern. - do_progress_bar (bool): Whether to print a tqdm progress bar during algorithm execution. - progress_bar_kwargs (dict): Dictionary containing keyword arguments for the tqdm progress bar. + iterations: Number of times to measure a mode set, e.g. when iterations = 5, the measurements are + A, B, A, B, A. Should be at least 2 + analyzer: The function used to analyze the phase stepping data. Must return a WFSResult object. Defaults to `analyze_phase_stepping` """ + if (phase_patterns[0].shape[0:2] != group_mask.shape) or (phase_patterns[1].shape[0:2] != group_mask.shape): + raise ValueError("The phase patterns and group mask must all have the same shape.") + if iterations < 2: + raise ValueError("The number of iterations must be at least 2.") + if np.prod(feedback.data_shape) != 1: + raise ValueError("The feedback detector should return a single scalar value.") + self.slm = slm self.feedback = feedback self.phase_steps = phase_steps self.iterations = iterations - self.slm_shape = slm_shape self.analyzer = analyzer - self.do_try_full_patterns = do_try_full_patterns - self.do_progress_bar = do_progress_bar - self.progress_bar_kwargs = progress_bar_kwargs - - assert (phases[0].shape[0] == phases[1].shape[0]) and (phases[0].shape[1] == phases[1].shape[1]) - self.phases = (phases[0].astype(np.float32), phases[1].astype(np.float32)) - - # Pre-compute set0 mask - mask1 = set1_mask.astype(dtype=np.float32) - mask0 = 1.0 - mask1 - self.set_masks = (mask0, mask1) + self.phase_patterns = (phase_patterns[0].astype(np.float32), phase_patterns[1].astype(np.float32)) + mask = group_mask.astype(bool) + self.masks = (~mask, mask) # mask[0] is True for group A, mask[1] is True for group B # Pre-compute the conjugate modes for reconstruction - modes0 = np.exp(-1j * self.phases[0]) * np.expand_dims(mask0, axis=2) - modes1 = np.exp(-1j * self.phases[1]) * np.expand_dims(mask1, axis=2) - self.modes = (modes0, modes1) + self.modes = [np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in + range(2)] - def execute(self) -> WFSResult: + def execute(self, capture_intermediate_results: bool = False, progress_bar=None) -> WFSResult: """ Executes the blind focusing dual reference algorithm and compute the SLM transmission matrix. + capture_intermediate_results: When True, measures the feedback from the optimized wavefront after each iteration. + This can be useful to determine how many iterations are needed to converge to an optimal pattern. + This data is stored as the 'intermediate_results' field in the results + progress_bar: Optional progress bar object. Following the convention for tqdm progress bars, + this object should have a `total` attribute and an `update()` function. Returns: WFSResult: An object containing the computed SLM transmission matrix and related data. The amplitude profile @@ -114,38 +105,35 @@ def execute(self) -> WFSResult: t_set_all = [None] * self.iterations results_all = [None] * self.iterations # List to store all results results_latest = [None, None] # The two latest results. Used for computing fidelity factors. - full_pattern_feedback = np.zeros(self.iterations) # List to store feedback from full patterns + intermediate_results = np.zeros(self.iterations) # List to store feedback from full patterns # Prepare progress bar - if self.do_progress_bar: + if progress_bar: num_measurements = np.ceil(self.iterations / 2) * self.modes[0].shape[2] \ - + np.floor(self.iterations / 2) * self.modes[1].shape[2] - progress_bar = tqdm(total=num_measurements, **self.progress_bar_kwargs) - else: - progress_bar = None + + np.floor(self.iterations / 2) * self.modes[1].shape[2] + progress_bar.total = num_measurements # Switch the phase sets back and forth multiple times for it in range(self.iterations): - s = it % 2 # Set id: 0 or 1. Used to pick set A or B for phase stepping - mod_mask = self.set_masks[s] + side = it % 2 # pick set A or B for phase stepping t_prev = t_set ref_phases = -np.angle(t_prev) # Shaped reference phase pattern from transmission matrix # Measure and compute - result = self._single_side_experiment(mod_phases=self.phases[s], ref_phases=ref_phases, - mod_mask=mod_mask, progress_bar=progress_bar) - t_set = self.compute_t_set(result, self.modes[s]) # Compute transmission matrix from measurements + result = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases, + mod_mask=self.masks[side], progress_bar=progress_bar) + t_set = self.compute_t_set(result, self.modes[side]) # Compute transmission matrix from measurements # Store results t_full = t_prev + t_set t_set_all[it] = t_set # Store transmission matrix results_all[it] = result # Store result - results_latest[s] = result # Store latest result for this set + results_latest[side] = result # Store latest result for this set # Try full pattern - if self.do_try_full_patterns: + if capture_intermediate_results: self.slm.set_phases(-np.angle(t_full)) - full_pattern_feedback[it] = self.feedback.read() + intermediate_results[it] = self.feedback.read() # Compute average fidelity factors fidelity_noise = weighted_average(results_latest[0].fidelity_noise, @@ -160,20 +148,20 @@ def execute(self) -> WFSResult: result = WFSResult(t=t_full, t_f=None, - n=self.modes[0].shape[2]+self.modes[1].shape[2], + n=self.modes[0].shape[2] + self.modes[1].shape[2], axis=2, fidelity_noise=fidelity_noise, fidelity_amplitude=fidelity_amplitude, fidelity_calibration=fidelity_calibration) - # TODO: This is a dirty way to add attributes. Find better way. + # TODO: document the t_set_all and results_all attributes result.t_set_all = t_set_all result.results_all = results_all - result.full_pattern_feedback = full_pattern_feedback + result.intermediate_results = intermediate_results return result def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, - progress_bar: Optional[tqdm] = None) -> WFSResult: + progress_bar=None) -> WFSResult: """ Conducts experiments on one part of the SLM. @@ -181,28 +169,24 @@ def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, mod_phases: 3D array containing the phase patterns of each mode. Axis 0 and 1 are used as spatial axis. Axis 2 is used for the 'phase pattern index' or 'mode index'. ref_phases: 2D array containing the reference phase pattern. - mod_mask: 2D array containing a mask of 1s and 0s, where 1s indicate the modulated part of the SLM. - progress_bar: An optional tqdm progress bar. + mod_mask: 2D array containing a boolean mask, where True indicates the modulated part of the SLM. + progress_bar: Optional progress bar object. Following the convention for tqdm progress bars, + this object should have a `total` attribute and an `update()` function. Returns: WFSResult: An object containing the computed SLM transmission matrix and related data. - - Note: In order to speed up calculations, I used np.float32 phases with a mask, instead of adding complex128 - fields and taking the np.angle. I did a quick test for this (on AMD Ryzen 7 3700X) for a 1000x1000 array. This - brings down the phase pattern computation time from ~26ms to ~2ms. - Code comparison: - With complex128: phase_pattern = np.angle(field_B + field_A * np.exp(step)) ~26ms per phase pattern - With float32: phase_pattern = phases_B + (phases_A + step) * mask ~2ms per phase pattern """ - num_of_modes = mod_phases.shape[2] - measurements = np.zeros((num_of_modes, self.phase_steps, *self.feedback.data_shape)) - ref_phases_masked = (1.0 - mod_mask) * ref_phases # Pre-compute masked reference phase pattern + num_modes = mod_phases.shape[2] + measurements = np.zeros((num_modes, self.phase_steps)) - for m in range(num_of_modes): + for m in range(num_modes): + phases = ref_phases.copy() + modulated = mod_phases[:, :, m] for p in range(self.phase_steps): - phase_step = p * 2 * np.pi / self.phase_steps - phase_pattern = ref_phases_masked + mod_mask * (mod_phases[:, :, m] + phase_step) - self.slm.set_phases(phase_pattern) + phi = p * 2 * np.pi / self.phase_steps + # set the modulated pixel values to the values corresponding to mode m and phase offset phi + phases[mod_mask] = modulated[mod_mask] + phi + self.slm.set_phases(phases) self.feedback.trigger(out=measurements[m, p, ...]) if progress_bar is not None: diff --git a/pyproject.toml b/pyproject.toml index ee43077..7b125f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ numpy = ">=1.25.2" astropy = ">=5.3.4" glfw = ">=2.5.9" opencv-python = ">=4.9.0.80" -matplotlib = ">=3.7.3" +matplotlib = ">=3.7.3" # TODO: remove dependency? scipy = ">=1.11.3" annotated-types = "~0.7.0" tqdm = "^4.66.2" # TODO: remove dependency diff --git a/tests/test_wfs.py b/tests/test_wfs.py index 7a4febf..a08cff2 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -8,7 +8,7 @@ from skimage.transform import resize from ..openwfs.algorithms import StepwiseSequential, FourierDualReference, FourierDualReferenceCircle, \ - CustomIterativeDualReference, troubleshoot + IterativeDualReference, troubleshoot from ..openwfs.algorithms.troubleshoot import field_correlation from ..openwfs.algorithms.utilities import WFSController from ..openwfs.processors import SingleRoi @@ -476,9 +476,10 @@ def test_custom_blind_dual_reference_ortho_split(construct_basis: callable): sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) - alg = CustomIterativeDualReference(feedback=sim, slm=sim.slm, slm_shape=aberrations.shape, - phases=(phases_set, np.flip(phases_set, axis=1)), set1_mask=mask, phase_steps=4, - iterations=4) + alg = IterativeDualReference(feedback=sim, slm=sim.slm, + phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, + phase_steps=4, + iterations=4) result = alg.execute() @@ -500,7 +501,7 @@ def test_custom_blind_dual_reference_non_ortho(): """ Test custom blind dual reference with a non-orthogonal basis. """ - do_debug = True + do_debug = False # Create set of modes that are barely linearly independent N1 = 6 @@ -533,9 +534,10 @@ def test_custom_blind_dual_reference_non_ortho(): sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) - alg = CustomIterativeDualReference(feedback=sim, slm=sim.slm, slm_shape=aberrations.shape, - phases=(phases_set, np.flip(phases_set, axis=1)), set1_mask=mask, phase_steps=4, - iterations=4) + alg = IterativeDualReference(feedback=sim, slm=sim.slm, + phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, + phase_steps=4, + iterations=4) result = alg.execute() From 56de490ae2704f8d3a615cf1faac0a9b4966dc94 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Thu, 26 Sep 2024 20:36:07 +0200 Subject: [PATCH 02/15] reviewing IterativeDualReference algorithm --- .../algorithms/custom_iter_dual_reference.py | 23 +++++++++++-------- tests/test_wfs.py | 4 ++-- 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/openwfs/algorithms/custom_iter_dual_reference.py b/openwfs/algorithms/custom_iter_dual_reference.py index 521441a..9873fd2 100644 --- a/openwfs/algorithms/custom_iter_dual_reference.py +++ b/openwfs/algorithms/custom_iter_dual_reference.py @@ -97,11 +97,11 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) multiplying that amplitude profile with this transmission matrix. """ - # Initial transmission matrix for reference is constant phase + # Current estimate of the transmission matrix (start with all 0) t_full = np.zeros(shape=self.modes[0].shape[0:2]) + t_other_side = t_full # Initialize storage lists - t_set = t_full t_set_all = [None] * self.iterations results_all = [None] * self.iterations # List to store all results results_latest = [None, None] # The two latest results. Used for computing fidelity factors. @@ -116,17 +116,20 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) # Switch the phase sets back and forth multiple times for it in range(self.iterations): side = it % 2 # pick set A or B for phase stepping - t_prev = t_set - ref_phases = -np.angle(t_prev) # Shaped reference phase pattern from transmission matrix - - # Measure and compute + ref_phases = -np.angle(t_full) # use the best estimate so far to construct an optimized reference + side_mask = self.masks[side] + # Perform WFS experiment on one side, keeping the other side sized at the ref_phases result = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases, - mod_mask=self.masks[side], progress_bar=progress_bar) - t_set = self.compute_t_set(result, self.modes[side]) # Compute transmission matrix from measurements + mod_mask=side_mask, progress_bar=progress_bar) + + # Compute transmission matrix for the current side and update + # estimated transmission matrix + t_this_side = self.compute_t_set(result, self.modes[side]) + t_full = t_this_side + t_other_side + t_other_side = t_this_side # Store results - t_full = t_prev + t_set - t_set_all[it] = t_set # Store transmission matrix + t_set_all[it] = t_this_side # Store transmission matrix results_all[it] = result # Store result results_latest[side] = result # Store latest result for this set diff --git a/tests/test_wfs.py b/tests/test_wfs.py index a08cff2..464616f 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -474,7 +474,7 @@ def test_custom_blind_dual_reference_ortho_split(construct_basis: callable): aberrations[0:2, :] = 0 aberrations[:, 0:2] = 0 - sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) + sim = SimulatedWFS(aberrations=aberrations) alg = IterativeDualReference(feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, @@ -532,7 +532,7 @@ def test_custom_blind_dual_reference_non_ortho(): aberrations[0:1, :] = 0 aberrations[:, 0:2] = 0 - sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) + sim = SimulatedWFS(aberrations=aberrations) alg = IterativeDualReference(feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, From b056b66334fd9c7af5c107c8be74acf7edd4881f Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Fri, 27 Sep 2024 14:48:34 +0200 Subject: [PATCH 03/15] refactored WFSResult.combine --- .../algorithms/custom_iter_dual_reference.py | 32 ++----------------- openwfs/algorithms/utilities.py | 31 +++++++++++++++++- 2 files changed, 32 insertions(+), 31 deletions(-) diff --git a/openwfs/algorithms/custom_iter_dual_reference.py b/openwfs/algorithms/custom_iter_dual_reference.py index 9873fd2..5ec94da 100644 --- a/openwfs/algorithms/custom_iter_dual_reference.py +++ b/openwfs/algorithms/custom_iter_dual_reference.py @@ -7,19 +7,6 @@ from ..core import Detector, PhaseSLM -def weighted_average(a, b, wa, wb): - """ - Compute the weighted average of two values. - - Args: - a: The first value. - b: The second value. - wa: The weight of the first value. - wb: The weight of the second value. - """ - return (a * wa + b * wb) / (wa + wb) - - class IterativeDualReference: """ A generic iterative dual reference WFS algorithm, which can use a custom set of basis functions. @@ -139,23 +126,8 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) intermediate_results[it] = self.feedback.read() # Compute average fidelity factors - fidelity_noise = weighted_average(results_latest[0].fidelity_noise, - results_latest[1].fidelity_noise, results_latest[0].n, - results_latest[1].n) - fidelity_amplitude = weighted_average(results_latest[0].fidelity_amplitude, - results_latest[1].fidelity_amplitude, results_latest[0].n, - results_latest[1].n) - fidelity_calibration = weighted_average(results_latest[0].fidelity_calibration, - results_latest[1].fidelity_calibration, results_latest[0].n, - results_latest[1].n) - - result = WFSResult(t=t_full, - t_f=None, - n=self.modes[0].shape[2] + self.modes[1].shape[2], - axis=2, - fidelity_noise=fidelity_noise, - fidelity_amplitude=fidelity_amplitude, - fidelity_calibration=fidelity_calibration) + result = WFSResult.combine(results_latest) + result.t = t_full # TODO: document the t_set_all and results_all attributes result.t_set_all = t_set_all diff --git a/openwfs/algorithms/utilities.py b/openwfs/algorithms/utilities.py index 2b9c9e3..95f54b8 100644 --- a/openwfs/algorithms/utilities.py +++ b/openwfs/algorithms/utilities.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Optional +from typing import Optional, Sequence import numpy as np from numpy.typing import ArrayLike @@ -108,6 +108,35 @@ def select_target(self, b) -> 'WFSResult': n=self.n, ) + @staticmethod + def combine(results: Sequence['WFSResult']): + """Merges the results for several sub-experiments. + + Currently, this just computes the average of the fidelities, weighted + by the number of segments used in each sub-experiment. + + Note: the matrix t is also averaged, but this is not always meaningful. + The caller can replace the `.t` attribute of the result with a more meaningful value. + """ + n = sum(r.n for r in results) + axis = results[0].axis + if any(r.axis != axis for r in results): + raise ValueError("All results must have the same axis") + + def weighted_average(attribute): + data = getattr(results[0], attribute) * results[0].n + for r in results[1:]: + data += getattr(r, attribute) * r.n / n + return data + + return WFSResult(t=weighted_average('t'), + t_f=None, + n=n, + axis=axis, + fidelity_noise=weighted_average('fidelity_noise'), + fidelity_amplitude=weighted_average('fidelity_amplitude'), + fidelity_calibration=weighted_average('fidelity_calibration')) + def analyze_phase_stepping(measurements: np.ndarray, axis: int, A: Optional[float] = None): """Analyzes the result of phase stepping measurements, returning matrix `t` and noise statistics From c5b8f192d9787c1a90fddb21254d7704788d98f2 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Mon, 30 Sep 2024 12:28:14 +0200 Subject: [PATCH 04/15] implemented FourierDualReference as specialization of DualReference --- openwfs/algorithms/__init__.py | 5 +- openwfs/algorithms/basic_fourier.py | 195 +++++------------ ...er_dual_reference.py => dual_reference.py} | 146 +++++++++---- openwfs/algorithms/fourier.py | 203 ------------------ openwfs/algorithms/utilities.py | 4 +- tests/test_wfs.py | 87 +++++--- 6 files changed, 212 insertions(+), 428 deletions(-) rename openwfs/algorithms/{custom_iter_dual_reference.py => dual_reference.py} (55%) delete mode 100644 openwfs/algorithms/fourier.py diff --git a/openwfs/algorithms/__init__.py b/openwfs/algorithms/__init__.py index 066a232..00bee96 100644 --- a/openwfs/algorithms/__init__.py +++ b/openwfs/algorithms/__init__.py @@ -1,5 +1,4 @@ -from .basic_fourier import FourierDualReference, FourierDualReferenceCircle -from .custom_iter_dual_reference import IterativeDualReference -from .fourier import FourierBase +from .basic_fourier import FourierDualReference +from .dual_reference import DualReference from .ssa import StepwiseSequential from .troubleshoot import troubleshoot diff --git a/openwfs/algorithms/basic_fourier.py b/openwfs/algorithms/basic_fourier.py index 20d5dd6..52860f5 100644 --- a/openwfs/algorithms/basic_fourier.py +++ b/openwfs/algorithms/basic_fourier.py @@ -1,119 +1,33 @@ from typing import Optional import numpy as np -import matplotlib.pyplot as plt +from .dual_reference import DualReference from .utilities import analyze_phase_stepping -from .fourier import FourierBase from ..core import Detector, PhaseSLM +from ..utilities import tilt -def build_square_k_space(k_min, k_max, k_step=1.0): +class FourierDualReference(DualReference): """ - Constructs the k-space by creating a set of (k_x, k_y) coordinates. - Fills the k_left and k_right matrices with the same k-space. (k_x, k_y) denote the k-space coordinates of the whole - pupil. Only half SLM (and thus pupil) is modulated at a time, hence k_y (axis=1) must make steps of 2. + Fourier double reference algorithm, based on Mastiani et al. [1]. - Args: - k_min: Minimum value for k_x and k_y, without k_step scaling applied. - k_max: Maximum value for k_x and k_y, without k_step scaling applied. - k_step: Scaling factor for the steps in k-space. - - Returns: - k_space (np.ndarray): A 2xN array of k-space coordinates. - """ - # Generate kx and ky coordinates - kx_angles = np.arange(k_min, k_max + 1, 1) - k_angles_min_even = (k_min if k_min % 2 == 0 else k_min + 1) # Must be even - ky_angles = np.arange(k_angles_min_even, k_max + 1, 2) # Steps of 2 - - # Combine kx and ky coordinates into pairs - k_x = np.repeat(np.array(kx_angles)[np.newaxis, :], len(ky_angles), axis=0).flatten() - k_y = np.repeat(np.array(ky_angles)[:, np.newaxis], len(kx_angles), axis=1).flatten() - k_space = np.vstack((k_x, k_y)) * k_step - return k_space - - -class FourierDualReference(FourierBase): - """Fourier double reference algorithm, based on Mastiani et al. [1]. - - It constructs a square k-space coordinate grid for the algorithm. For custom k-spaces, you should use the - FourierBase class. The k-space coordinates denote the entire pupil, not just one half. The k-space is normalized - such that (1, 0) yields a -π to π gradient over the entire pupil. - diffraction limit. + Improvements over [1]: + * The set of plane waves is taken from a disk in k-space instead of a square. + * No overlap between the two halves is needed, instead the final stitching step is done using measurements already in the data set. + * When only a single target is optimized, the algorithm can be used in an iterative version to increase SNR during the measurument, + similar to [2]. [1]: Bahareh Mastiani, Gerwin Osnabrugge, and Ivo M. Vellekoop, "Wavefront shaping for forward scattering," Opt. Express 30, 37436-37445 (2022) - """ - - def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phase_steps=4, k_angles_min: int = -3, - k_angles_max: int = 3, analyzer: Optional[callable] = analyze_phase_stepping): - """ - Args: - feedback (Detector): Source of feedback - slm (PhaseSLM): The spatial light modulator - slm_shape (tuple of two ints): The shape that the SLM patterns & transmission matrices are calculated for, - does not necessarily have to be the actual pixel dimensions as the SLM. - phase_steps (int): The number of phase steps. - k_angles_min (int): The minimum k-angle. - k_angles_max (int): The maximum k-angle. - """ - super().__init__(feedback, slm, slm_shape, np.array((0, 0)), np.array((0, 0)), phase_steps=phase_steps, - analyzer=analyzer) - self._k_angles_min = k_angles_min - self._k_angles_max = k_angles_max - - self._build_k_space() - - def _build_k_space(self): - """ - Constructs the k-space by creating Cartesian products of k_x and k_y angles. - Fills the k_left and k_right matrices with the same k-space. (k_x, k_y) denote the k-space coords of the whole - SLM. Only half the SLM is modulated at a time, hence ky must make steps of 2. - - Returns: - None: The function updates the instance attributes. - """ - k_space = build_square_k_space(self.k_angles_min, self.k_angles_max) - self.k_left = k_space - self.k_right = k_space - @property - def k_angles_min(self) -> int: - """The lower bound of the range of angles in x and y direction""" - return self._k_angles_min - - @k_angles_min.setter - def k_angles_min(self, value): - """Sets the lower bound of the range of angles in x and y direction, triggers the building of the internal - k-space properties. - """ - self._k_angles_min = value - self._build_k_space() - - @property - def k_angles_max(self) -> int: - """The higher bound of the range of angles in x and y direction""" - return self._k_angles_max - - @k_angles_max.setter - def k_angles_max(self, value): - """Sets the higher bound of the range of angles in x and y direction, triggers the building of the internal - k-space properties.""" - self._k_angles_max = value - self._build_k_space() - - -class FourierDualReferenceCircle(FourierBase): + [2]: X. Tao, T. Lam, B. Zhu, et al., “Three-dimensional focusing through scattering media using conjugate adaptive + optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017). """ - Slightly altered version of Fourier double reference algorithm, based on Mastiani et al. [1]. - In this version, the k-space coordinates are restricted to lie within a circle of chosen radius. - [1]: Bahareh Mastiani, Gerwin Osnabrugge, and Ivo M. Vellekoop, - "Wavefront shaping for forward scattering," Opt. Express 30, 37436-37445 (2022) - """ - def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phase_steps=4, k_radius: float = 3.2, - k_step: float = 1.0, analyzer: Optional[callable] = analyze_phase_stepping): + def __init__(self, *, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phase_steps=4, k_radius: float = 3.2, + k_step: float = 1.0, iterations: int = 2, analyzer: Optional[callable] = analyze_phase_stepping, + optimized_reference: Optional[bool] = None): """ Args: feedback (Detector): Source of feedback @@ -124,37 +38,46 @@ def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phas k_step (float): Make steps in k-space of this value. 1 corresponds to diffraction limited tilt. phase_steps (int): The number of phase steps. """ - # TODO: Could be rewritten more compactly if we ditch the settable properties: - # first build the k_space, then call super().__init__ with k_left=k_space, k_right=k_space. - # TODO: Add custom grid spacing - - super().__init__(feedback=feedback, slm=slm, slm_shape=slm_shape, k_left=np.array((0, 0)), - k_right=np.array((0, 0)), phase_steps=phase_steps, analyzer=analyzer) - self._k_radius = k_radius self.k_step = k_step - self._build_k_space() - - def _build_k_space(self): - """ - Constructs the k-space by creating Cartesian products of k_x and k_y angles. - Fills the k_left and k_right matrices with the same k-space. (k_x, k_y) denote the k-space coordinates of the - whole SLM. Only half the SLM is modulated at a time, hence k_y must make steps of 2. - - Returns: - None: The function updates the instance attributes k_left and k_right. - """ - k_radius = self.k_radius - k_step = self.k_step - k_max = int(np.floor(k_radius)) - k_space_square = build_square_k_space(-k_max, k_max, k_step=k_step) - - # Filter out k-space coordinates that are outside the circle of radius k_radius - k_mask = (np.linalg.norm(k_space_square, axis=0) <= k_radius) - k_space = k_space_square[:, k_mask] - - self.k_left = k_space - self.k_right = k_space + self._slm_shape = slm_shape + group_mask = np.zeros(slm_shape, dtype=bool) + group_mask[:, slm_shape[1] // 2:] = True + super().__init__(feedback=feedback, slm=slm, + phase_patterns=None, group_mask=group_mask, + phase_steps=phase_steps, + iterations=iterations, + optimized_reference=optimized_reference, + analyzer=analyzer) + self._update_modes() + + def _update_modes(self): + """Constructs the set of plane wave modes.""" + + # start with a grid of k-values + # then filter out the ones that are outside the circle + # in the grid, the spacing in the kx direction is twice the spacing in the ky direction + # because we subdivide the SLM into two halves along the x direction, + # which effectively doubles the number of kx values + int_radius_x = np.ceil(self.k_radius / (self.k_step * 2)) + int_radius_y = np.ceil(self.k_radius / self.k_step) + kx, ky = np.meshgrid( + np.arange(-int_radius_x, int_radius_x + 1) * (self.k_step * 2), + np.arange(-int_radius_y, int_radius_y + 1) * self.k_step) + + # only keep the points within the circle + mask = kx ** 2 + ky ** 2 <= self.k_radius ** 2 + k = np.stack((ky[mask], kx[mask])).T + + # construct the modes for these kx ky values + modes = np.zeros((*self._slm_shape, len(k)), dtype=np.float32) + for i, k_i in enumerate(k): + # tilt generates a pattern from -2.0 to 2.0 (The convention for Zernike modes normalized to an RMS of 1). + # The natural step to take is the Abbe diffraction limit of the modulated part, which corresponds to a gradient + # from -π to π over the modulated part. + modes[..., i] = tilt(self._slm_shape, g=k_i * 0.5 * np.pi) + + self.phase_patterns = (modes, modes) @property def k_radius(self) -> float: @@ -165,16 +88,4 @@ def k_radius(self) -> float: def k_radius(self, value): """Sets the maximum radius of the k-space circle, triggers the building of the internal k-space properties.""" self._k_radius = value - self._build_k_space() - - def plot_k_space(self): - """Plots the k-space coordinates. Useful for debugging.""" - phi = np.linspace(0, 2 * np.pi, 200) - x = self.k_radius * np.cos(phi) - y = self.k_radius * np.sin(phi) - plt.plot(x, y, 'k') - plt.plot(self.k_left[0, :], self.k_left[1, :], 'ob', label='k_left') - plt.plot(self.k_right[0, :], self.k_right[1, :], '.r', label='k_right') - plt.xlabel('k_x') - plt.ylabel('k_y') - plt.gca().set_aspect('equal') + self._update_modes() diff --git a/openwfs/algorithms/custom_iter_dual_reference.py b/openwfs/algorithms/dual_reference.py similarity index 55% rename from openwfs/algorithms/custom_iter_dual_reference.py rename to openwfs/algorithms/dual_reference.py index 5ec94da..a1e554c 100644 --- a/openwfs/algorithms/custom_iter_dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -7,7 +7,7 @@ from ..core import Detector, PhaseSLM -class IterativeDualReference: +class DualReference: """ A generic iterative dual reference WFS algorithm, which can use a custom set of basis functions. @@ -32,8 +32,9 @@ class IterativeDualReference: https://opg.optica.org/oe/ abstract.cfm?uri=oe-27-8-1167 """ - def __init__(self, feedback: Detector, slm: PhaseSLM, phase_patterns: tuple[nd, nd], group_mask: nd, - phase_steps: int = 4, iterations: int = 4, analyzer: Optional[callable] = analyze_phase_stepping): + def __init__(self, *, feedback: Detector, slm: PhaseSLM, phase_patterns: Optional[tuple[nd, nd]], group_mask: nd, + phase_steps: int = 4, iterations: int = 2, + analyzer: Optional[callable] = analyze_phase_stepping, optimized_reference: Optional[bool] = None): """ Args: feedback: The feedback source, usually a detector that provides measurement data. @@ -41,33 +42,76 @@ def __init__(self, feedback: Detector, slm: PhaseSLM, phase_patterns: tuple[nd, phase_patterns: A tuple of two 3D arrays, containing the phase patterns for group A and group B, respectively. The first two dimensions are the spatial dimensions, and should match the size of group_mask. The 3rd dimension in the array is index of the phase pattern. The number of phase patterns in A and B may be different. + When None, the phase_patterns attribute must be set before executing the algorithm. group_mask: A 2D bool array of that defines the pixels used by group A with False and elements used by group B with True. phase_steps: The number of phase steps for each mode (default is 4). Depending on the type of non-linear feedback and the SNR, more might be required. - iterations: Number of times to measure a mode set, e.g. when iterations = 5, the measurements are - A, B, A, B, A. Should be at least 2 - analyzer: The function used to analyze the phase stepping data. Must return a WFSResult object. Defaults to `analyze_phase_stepping` + iterations: Number of times to optimize a mode set, e.g. when iterations = 5, the measurements are + A, B, A, B, A. + optimized_reference: When `True`, during each iteration the other half of the SLM displays the optimized pattern so far (as in [1]). + When `False`, the algorithm optimizes A with a flat wavefront on B, and then optimizes B with a flat wavefront on A. + This mode also allows for multi-target optimization, where the algorithm optimizes multiple targets in parallel. + The two halves are then combined (stitched) to form the full transmission matrix. + In this mode, it is essential that both A and B include a flat wavefront as mode 0. The measurement for + mode A0 and for B0 both give contain relative phase between group A and B, so there is a slight redundancy. + These two measurements are combined to find the final phase for stitching. + When set to `None` (default), the algorithm uses True if there is a single target, and False if there are multiple targets. + + analyzer: The function used to analyze the phase stepping data. + Must return a WFSResult object. Defaults to `analyze_phase_stepping` + + [1]: X. Tao, T. Lam, B. Zhu, et al., “Three-dimensional focusing through scattering media using conjugate adaptive + optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017). + """ - if (phase_patterns[0].shape[0:2] != group_mask.shape) or (phase_patterns[1].shape[0:2] != group_mask.shape): - raise ValueError("The phase patterns and group mask must all have the same shape.") + if optimized_reference is None: # 'auto' mode + optimized_reference = np.prod(feedback.data_shape) == 1 + elif optimized_reference and np.prod(feedback.data_shape) != 1: + raise ValueError( + "When using an optimized reference, the feedback detector should return a single scalar value.") + if iterations < 2: raise ValueError("The number of iterations must be at least 2.") - if np.prod(feedback.data_shape) != 1: - raise ValueError("The feedback detector should return a single scalar value.") + if not optimized_reference and iterations != 2: + raise ValueError("When not using an optimized reference, the number of iterations must be 2.") self.slm = slm self.feedback = feedback self.phase_steps = phase_steps + self.optimized_reference = optimized_reference self.iterations = iterations - self.analyzer = analyzer - self.phase_patterns = (phase_patterns[0].astype(np.float32), phase_patterns[1].astype(np.float32)) + self._analyzer = analyzer + self._phase_patterns = None + self._shape = group_mask.shape mask = group_mask.astype(bool) self.masks = (~mask, mask) # mask[0] is True for group A, mask[1] is True for group B + self.phase_patterns = phase_patterns + + @property + def phase_patterns(self) -> tuple[nd, nd]: + return self._phase_patterns + + @phase_patterns.setter + def phase_patterns(self, value): + """Sets the phase patterns for group A and group B. This also updates the conjugate modes.""" + if value is None: + self._phase_patterns = None + return + + if not self.optimized_reference: + # find the modes in A and B that correspond to flat wavefronts with phase 0 + try: + a0_index = next(i for i in range(value[0].shape[2]) if np.allclose(value[0][:, :, i], 0)) + b0_index = next(i for i in range(value[1].shape[2]) if np.allclose(value[1][:, :, i], 0)) + self.zero_indices = (a0_index, b0_index) + except StopIteration: + raise ("For multi-target optimization, the both sets must contain a flat wavefront with phase 0.") + + if (value[0].shape[0:2] != self._shape) or (value[1].shape[0:2] != self._shape): + raise ValueError("The phase patterns and group mask must all have the same shape.") - # Pre-compute the conjugate modes for reconstruction - self.modes = [np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in - range(2)] + self._phase_patterns = (value[0].astype(np.float32), value[1].astype(np.float32)) def execute(self, capture_intermediate_results: bool = False, progress_bar=None) -> WFSResult: """ @@ -85,52 +129,65 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) """ # Current estimate of the transmission matrix (start with all 0) - t_full = np.zeros(shape=self.modes[0].shape[0:2]) - t_other_side = t_full + cobasis = [np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in + range(2)] + + ref_phases = np.zeros(self._shape) # Initialize storage lists - t_set_all = [None] * self.iterations results_all = [None] * self.iterations # List to store all results - results_latest = [None, None] # The two latest results. Used for computing fidelity factors. intermediate_results = np.zeros(self.iterations) # List to store feedback from full patterns # Prepare progress bar if progress_bar: - num_measurements = np.ceil(self.iterations / 2) * self.modes[0].shape[2] \ - + np.floor(self.iterations / 2) * self.modes[1].shape[2] + num_measurements = np.ceil(self.iterations / 2) * self.phase_patterns[0].shape[2] \ + + np.floor(self.iterations / 2) * self.phase_patterns[1].shape[2] progress_bar.total = num_measurements # Switch the phase sets back and forth multiple times for it in range(self.iterations): side = it % 2 # pick set A or B for phase stepping - ref_phases = -np.angle(t_full) # use the best estimate so far to construct an optimized reference side_mask = self.masks[side] # Perform WFS experiment on one side, keeping the other side sized at the ref_phases - result = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases, - mod_mask=side_mask, progress_bar=progress_bar) + results_all[it] = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases, + mod_mask=side_mask, progress_bar=progress_bar) # Compute transmission matrix for the current side and update # estimated transmission matrix - t_this_side = self.compute_t_set(result, self.modes[side]) - t_full = t_this_side + t_other_side - t_other_side = t_this_side - # Store results - t_set_all[it] = t_this_side # Store transmission matrix - results_all[it] = result # Store result - results_latest[side] = result # Store latest result for this set + if self.optimized_reference: + # use the best estimate so far to construct an optimized reference + t_this_side = self.compute_t_set(results_all[it].t, cobasis[side]).squeeze() + ref_phases[self.masks[side]] = -np.angle(t_this_side[self.masks[side]]) # Try full pattern if capture_intermediate_results: - self.slm.set_phases(-np.angle(t_full)) + self.slm.set_phases(ref_phases) intermediate_results[it] = self.feedback.read() + if self.optimized_reference: + factor = 1.0 + else: + # when not using optimized reference, we need to stitch the + # two halves of the wavefront together. For that, we need the + # relative phase between the two sides, which we extract from + # the measurements of the flat wavefronts. + relative = results_all[0].t[self.zero_indices[0], ...] + np.conjugate( + results_all[1].t[self.zero_indices[1], ...]) + factor = (relative / np.abs(relative)).reshape((1, *self.feedback.data_shape)) + print(np.angle(factor)) + + t_full = (self.compute_t_set(results_all[0].t, cobasis[0]) + + self.compute_t_set(factor * results_all[1].t, cobasis[1])) + # Compute average fidelity factors - result = WFSResult.combine(results_latest) + # subtract 1 from n, because both sets (usually) contain a flat wavefront, + # so there is one redundant measurement + result = WFSResult.combine(results_all[-2:]) + result.n = result.n - 1 result.t = t_full - # TODO: document the t_set_all and results_all attributes - result.t_set_all = t_set_all + # TODO: document the results_all attribute result.results_all = results_all result.intermediate_results = intermediate_results return result @@ -152,7 +209,7 @@ def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, WFSResult: An object containing the computed SLM transmission matrix and related data. """ num_modes = mod_phases.shape[2] - measurements = np.zeros((num_modes, self.phase_steps)) + measurements = np.zeros((num_modes, self.phase_steps, *self.feedback.data_shape)) for m in range(num_modes): phases = ref_phases.copy() @@ -168,10 +225,10 @@ def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, progress_bar.update() self.feedback.wait() - return self.analyzer(measurements, axis=1) + return self._analyzer(measurements, axis=1) @staticmethod - def compute_t_set(wfs_result: WFSResult, mode_set: nd) -> nd: + def compute_t_set(t, cobasis: nd) -> nd: """ Compute the transmission matrix in SLM space from transmission matrix in input mode space. @@ -182,11 +239,10 @@ def compute_t_set(wfs_result: WFSResult, mode_set: nd) -> nd: Note 2: As this is a blind focusing WFS algorithm, there may be only one target or 'output mode'. Args: - wfs_result (WFSResult): The result of the WFS algorithm. This contains the transmission matrix in the space - of input modes. - mode_set: 3D array with set of modes. + t: transmission matrix in mode-index space. The first axis corresponds to the input modes. + cobasis: 3D array with set of modes (conjugated) + Returns: + nd: The transmission matrix in SLM space. The last two axes correspond to SLM coordinates """ - t = wfs_result.t.squeeze().reshape((1, 1, mode_set.shape[2])) - norm_factor = np.prod(mode_set.shape[0:2]) - t_set = (t * mode_set).sum(axis=2) / norm_factor - return t_set + norm_factor = np.prod(cobasis.shape[0:2]) + return np.tensordot(cobasis, t, 1) / norm_factor diff --git a/openwfs/algorithms/fourier.py b/openwfs/algorithms/fourier.py deleted file mode 100644 index 71f52c8..0000000 --- a/openwfs/algorithms/fourier.py +++ /dev/null @@ -1,203 +0,0 @@ -from typing import Optional - -import numpy as np - -from .utilities import analyze_phase_stepping, WFSResult -from ..core import Detector, PhaseSLM -from ..utilities.patterns import tilt - - -class FourierBase: - """Base class definition for the Fourier algorithms as described in [1]. - - This algorithm optimises the wavefront in a Fourier-basis. The modes that are tested are provided into a 'k-space' - of which each 'k-vector' represents a certain angled wavefront that will be tested. (more detailed explanation is - found in _get_phase_pattern). - - As described in [1], these modes are measured by interfering a certain mode on one half of the SLM with a - 'reference beam'. This is done by not modulating the other half of the SLM. In order to find a full corrective - wavefront therefore, the experiment has to be repeated twice for each side of the SLM. Finally, the two wavefronts - are combined. - - [1]: Bahareh Mastiani, Gerwin Osnabrugge, and Ivo M. Vellekoop, - "Wavefront shaping for forward scattering," Opt. Express 30, 37436-37445 (2022) - """ - - def __init__(self, feedback: Detector, slm: PhaseSLM, slm_shape: tuple[int, int], k_left: np.ndarray, - k_right: np.ndarray, phase_steps: int = 4, analyzer: Optional[callable] = analyze_phase_stepping): - """ - - Args: - feedback (Detector): The feedback source, usually a detector that provides measurement data. - slm (PhaseSLM): slm object. - The slm may have the `extent` property set to indicate the extent of the back pupil of the microscope - objective in slm coordinates. By default, a value of 2.0, 2.0 is used (indicating that the pupil - corresponds to a circle of radius 1.0 on the SLM). However, to prevent artefacts at the edges of the - SLM,it may be overfilled, such that the `phases` image is mapped to an extent of e. g. (2.2, 2.2), i. e. - 10% larger than the back pupil. - slm_shape (tuple[int, int]): The shape of the SLM patterns and transmission matrices. - k_left (numpy.ndarray): 2-row matrix containing the y, and x components of the spatial frequencies - used as a basis for the left-hand side of the SLM. - The frequencies are defined such that a frequency of (1,0) or (0,1) corresponds to - a phase gradient of -π to π over the back pupil of the microscope objective, which results in - a displacement in the focal plane of exactly a distance corresponding to the Abbe diffraction limit. - k_right (numpy.ndarray): 2-row matrix containing the y and x components of the spatial frequencies - for the right-hand side of the SLM. - The number of frequencies need not be equal for k_left and k_right. - phase_steps (int): The number of phase steps for each mode (default is 4). - analyzer (callable): The function used to analyze the phase stepping data. Must return a WFSResult object. - """ - self._execute_button = False - self.slm = slm - self.feedback = feedback - self.phase_steps = phase_steps - self.k_left = k_left - self.k_right = k_right - self.slm_shape = slm_shape - self.analyzer = analyzer - - def execute(self) -> WFSResult: - """ - Executes the FourierDualRef algorithm, computing the SLM transmission matrix. - - Returns: - WFSResult: An object containing the computed SLM transmission matrix and related data. - """ - # left side experiment - data_left = self._single_side_experiment(self.k_left, 0) - - # right side experiment - data_right = self._single_side_experiment(self.k_right, 1) - - # Compute transmission matrix (=field at SLM), as well as noise statistics - results = self.compute_t(data_left, data_right, self.k_left, self.k_right) - results.left = data_left - results.right = data_right - results.k_left = self.k_left - results.k_right = self.k_right - return results - - def _single_side_experiment(self, k_set: np.ndarray, side: int) -> WFSResult: - """ - Conducts experiments on one side of the SLM, generating measurements for each spatial frequency and phase step. - - Args: - k_set (np.ndarray): An array of spatial frequencies to use in the experiment. - side (int): Indicates which side of the SLM to use (0 for the left hand side, 1 for right hand side). - - Returns: - WFSResult: An object containing the computed SLM transmission matrix and related data. - """ - measurements = np.zeros((k_set.shape[1], self.phase_steps, *self.feedback.data_shape)) - - for i in range(k_set.shape[1]): - for p in range(self.phase_steps): - phase_offset = p * 2 * np.pi / self.phase_steps - phase_pattern = self._get_phase_pattern(k_set[:, i], phase_offset, side) - self.slm.set_phases(phase_pattern) - self.feedback.trigger(out=measurements[i, p, ...]) - - self.feedback.wait() - return self.analyzer(measurements, axis=1) - - def _get_phase_pattern(self, - k: np.ndarray, - phase_offset: float, - side: int) -> np.ndarray: - """ - Generates a phase pattern for the SLM based on the given spatial frequency, phase offset, and side. - - Args: - k (np.ndarray): A 2-element array representing the spatial frequency over the whole pupil plane. - phase_offset (float): The phase offset to apply to the pattern. - side (int): Indicates the side of the SLM for the pattern (0 for left, 1 for right). - - Returns: - np.ndarray: The generated phase pattern. - """ - # tilt generates a pattern from -2.0 to 2.0 (The convention for Zernike modes normalized to an RMS of 1). - # The natural step to take is the Abbe diffraction limit of the modulated part, which corresponds to a gradient - # from -π to π over the modulated part. - num_columns = self.slm_shape[1] // 2 - tilted_front = tilt([self.slm_shape[0], num_columns], k * (0.5 * np.pi), extent=(2.0, 1.0), - phase_offset=phase_offset) - - # Handle side-dependent pattern - - empty_part = np.zeros((self.slm_shape[0], self.slm_shape[1] - num_columns)) - - # Concatenate based on the side - if side == 0: - # Place the pattern on the left - result = np.concatenate((tilted_front, empty_part), axis=1) - else: - # Place the pattern on the right - result = np.concatenate((empty_part, tilted_front), axis=1) - - return result - - def compute_t(self, left: WFSResult, right: WFSResult, k_left, k_right) -> WFSResult: - """ - Computes the SLM transmission matrix by combining the Fourier transmission matrices from both sides of the SLM. - - Args: - left (WFSResult): The wavefront shaping result for the left side. - right (WFSResult): The wavefront shaping result for the right side. - k_left (np.ndarray): The spatial frequency matrix for the left side. - k_right (np.ndarray): The spatial frequency matrix for the right side. - - Returns: - WFSResult: An object containing the combined transmission matrix and related statistics. - """ - - # TODO: determine noise - # Initialize transmission matrices - t1 = np.zeros((*self.slm_shape, *self.feedback.data_shape), dtype='complex128') - t2 = np.zeros((*self.slm_shape, *self.feedback.data_shape), dtype='complex128') - - # Calculate phase difference between the two halves - # We have two phase stepping measurements where both halves are flat (k=0) - # Locate these measurements, and use the phase difference between them to connect the two halves - # of the corrected wavefront. - - # Find the index of the (0,0) mode in k_left and k_right - index_0_left = np.argmin(k_left[0] ** 2 + k_left[1] ** 2) - index_0_right = np.argmin(k_right[0] ** 2 + k_left[1] ** 2) - if not np.all(k_left[:, index_0_left] == 0.0) or not np.all(k_right[:, index_0_right] == 0.0): - raise Exception("k=(0,0) component missing from the measurement set, cannot determine relative phase.") - - # average the measurements for better accuracy - # TODO: absolute values are not the same in simulation, 'A' scaling is off? - relative = 0.5 * (left.t[index_0_left, ...] + np.conjugate(right.t[index_0_right, ...])) - - # Apply phase correction to the right side - phase_correction = relative / np.abs(relative) - - # Construct the transmission matrices - normalisation = 1.0 / (0.5 * self.slm_shape[0] * self.slm_shape[1]) - for n, t in enumerate(left.t): - phi = self._get_phase_pattern(k_left[:, n], 0, 0) - t1 += np.tensordot(np.exp(-1j * phi), t * normalisation, 0) - - for n, t in enumerate(right.t): - phi = self._get_phase_pattern(k_right[:, n], 0, 1) - t2 += np.tensordot(np.exp(-1j * phi), t * (normalisation * phase_correction), 0) - - # Combine the left and right sides - t_full = np.concatenate([t1[:, :self.slm_shape[0] // 2, ...], t2[:, self.slm_shape[0] // 2:, ...]], axis=1) - t_f_full = np.concatenate([left.t_f, right.t_f], - axis=1) # also store raw data (not normalized or corrected yet!) - - # return combined result, along with a course estimate of the snr and expected enhancement - # TODO: not accurate yet - # for the estimated_improvement, first convert to field improvement, then back to intensity improvement - def weighted_average(x_left, x_right): - return (left.n * x_left + right.n * x_right) / (left.n + right.n) - - return WFSResult(t=t_full, - t_f=t_f_full, - n=left.n + right.n, - axis=2, - fidelity_noise=weighted_average(left.fidelity_noise, right.fidelity_noise), - fidelity_amplitude=weighted_average(left.fidelity_amplitude, right.fidelity_amplitude), - fidelity_calibration=weighted_average(left.fidelity_calibration, right.fidelity_calibration)) diff --git a/openwfs/algorithms/utilities.py b/openwfs/algorithms/utilities.py index 95f54b8..e565165 100644 --- a/openwfs/algorithms/utilities.py +++ b/openwfs/algorithms/utilities.py @@ -124,13 +124,13 @@ def combine(results: Sequence['WFSResult']): raise ValueError("All results must have the same axis") def weighted_average(attribute): - data = getattr(results[0], attribute) * results[0].n + data = getattr(results[0], attribute) * results[0].n / n for r in results[1:]: data += getattr(r, attribute) * r.n / n return data return WFSResult(t=weighted_average('t'), - t_f=None, + t_f=weighted_average('t_f'), n=n, axis=axis, fidelity_noise=weighted_average('fidelity_noise'), diff --git a/tests/test_wfs.py b/tests/test_wfs.py index 464616f..26b45dc 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -7,8 +7,8 @@ from scipy.ndimage import zoom from skimage.transform import resize -from ..openwfs.algorithms import StepwiseSequential, FourierDualReference, FourierDualReferenceCircle, \ - IterativeDualReference, troubleshoot +from ..openwfs.algorithms import StepwiseSequential, FourierDualReference, \ + DualReference, troubleshoot from ..openwfs.algorithms.troubleshoot import field_correlation from ..openwfs.algorithms.utilities import WFSController from ..openwfs.processors import SingleRoi @@ -24,7 +24,7 @@ def assert_enhancement(slm, feedback, wfs_results, t_correct=None): slm.set_phases(optimised_wf) after = feedback.read() ratio = after / before - estimated_ratio = wfs_results.estimated_optimized_intensity / before + estimated_ratio = wfs_results.estimated_enhancement # wfs_results.estimated_optimized_intensity / before print(f"expected: {estimated_ratio}, actual: {ratio}") assert estimated_ratio * 0.5 <= ratio <= estimated_ratio * 2.0, f""" The SSA algorithm did not enhance the focus as much as expected. @@ -153,8 +153,7 @@ def test_fourier(n_x): """ aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_angles_min=-n_x, - k_angles_max=n_x, + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=n_x, phase_steps=4) results = alg.execute() assert_enhancement(sim.slm, sim, results, np.exp(1j * aberrations)) @@ -165,9 +164,7 @@ def test_fourier2(): slm_shape = (1000, 1000) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_angles_min=-5, - k_angles_max=5, - phase_steps=3) + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3) controller = WFSController(alg) controller.wavefront = WFSController.State.SHAPED_WAVEFRONT scaled_aberration = zoom(aberrations, np.array(slm_shape) / aberrations.shape) @@ -197,8 +194,27 @@ def test_fourier_circle(k_radius, g): """ aberrations = tilt(shape=(100, 100), extent=(2, 2), g=g, phase_offset=0.5) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReferenceCircle(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=k_radius, - phase_steps=4) + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=k_radius, + phase_steps=4) + + do_debug = False + if do_debug: + # Plot the modes + import matplotlib.pyplot as plt + plt.figure(figsize=(12, 7)) + patterns = alg.phase_patterns[0] * np.expand_dims(alg.masks[0], axis=-1) + N = patterns.shape[2] + Nsqrt = int(np.ceil(np.sqrt(N))) + for m in range(N): + plt.subplot(Nsqrt, Nsqrt, m + 1) + plt.imshow(np.cos(patterns[:, :, m]), vmin=-1.0, vmax=1.0) + plt.title(f'm={m}') + plt.xticks([]) + plt.yticks([]) + plt.colorbar() + plt.pause(0.01) + plt.suptitle('Phase of basis functions for one half') + results = alg.execute() assert_enhancement(sim.slm, sim, results, np.exp(1j * aberrations)) @@ -218,7 +234,7 @@ def test_fourier_microscope(): wavelength=800 * u.nm) cam = sim.get_camera(analog_max=100) roi_detector = SingleRoi(cam, pos=(250, 250)) # Only measure that specific point - alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_angles_min=-1, k_angles_max=1, + alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, phase_steps=3) controller = WFSController(alg) controller.wavefront = WFSController.State.FLAT_WAVEFRONT @@ -237,8 +253,7 @@ def test_fourier_correction_field(): """ aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_angles_min=-2, - k_angles_max=2, + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=3.0, phase_steps=3) t = alg.execute().t @@ -253,11 +268,11 @@ def test_phase_shift_correction(): """ Test the effect of shifting the found correction of the Fourier-based algorithm. Without the bug, a phase shift of the entire correction should not influence the measurement. + TODO: move to test of SimulatedWFS, since it is not testing the WFS algorithm itself """ aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_angles_min=-1, - k_angles_max=1, + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=1.5, phase_steps=3) t = alg.execute().t @@ -275,28 +290,34 @@ def test_phase_shift_correction(): signal = sim.read() signals.append(signal) - assert np.std(signals) < 0.0001 * before, f"""The simulated response of the Fourier algorithm is sensitive to a - flat - phase-shift. This is incorrect behaviour""" + assert np.std(signals) / np.mean(signals) < 0.001, f"""The response of SimulatedWFS is sensitive to a flat + phase shift. This is incorrect behaviour""" -def test_flat_wf_response_fourier(): +@pytest.mark.parametrize("optimized_reference", [True, False]) +@pytest.mark.parametrize("step", [True, False]) +def test_flat_wf_response_fourier(optimized_reference, step): """ Test the response of the Fourier-based WFS method when the solution is flat A flat solution means that the optimal correction is no correction. + Also tests if stitching is done correctly by having an aberration pattern which is flat (but different) on the two halves. test the optimized wavefront by checking if it has irregularities. """ - aberrations = np.zeros(shape=(512, 512)) + aberrations = np.ones(shape=(4, 4)) + if step: + aberrations[:, 2:] = 2.0 sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_angles_min=-1, - k_angles_max=1, phase_steps=3) + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=1.5, phase_steps=3, + optimized_reference=optimized_reference) t = alg.execute().t # test the optimized wavefront by checking if it has irregularities. - assert np.std(t) < 0.001 # The measured wavefront is not flat. + measured_aberrations = np.squeeze(np.angle(t)) + measured_aberrations += aberrations[0, 0] - measured_aberrations[0, 0] + assert np.allclose(measured_aberrations, aberrations, atol=0.02) # The measured wavefront is not flat. def test_flat_wf_response_ssa(): @@ -345,7 +366,7 @@ def test_multidimensional_feedback_fourier(): sim = SimulatedWFS(aberrations=aberrations) # input the camera as a feedback object, such that it is multidimensional - alg = FourierDualReference(feedback=sim, slm=sim.slm, k_angles_min=-1, k_angles_max=1, phase_steps=3) + alg = FourierDualReference(feedback=sim, slm=sim.slm, k_radius=3.5, phase_steps=3) t = alg.execute().t # compute the phase pattern to optimize the intensity in target 0 @@ -476,10 +497,10 @@ def test_custom_blind_dual_reference_ortho_split(construct_basis: callable): sim = SimulatedWFS(aberrations=aberrations) - alg = IterativeDualReference(feedback=sim, slm=sim.slm, - phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, - phase_steps=4, - iterations=4) + alg = DualReference(feedback=sim, slm=sim.slm, + phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, + phase_steps=4, + iterations=4) result = alg.execute() @@ -518,7 +539,7 @@ def test_custom_blind_dual_reference_non_ortho(): plt.figure(figsize=(12, 7)) for m in range(M): plt.subplot(N2, N1, m + 1) - plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi) + plt.imshow(phases_set[:, :, m], vmin=-np.pi, vmax=np.pi) plt.title(f'm={m}') plt.xticks([]) plt.yticks([]) @@ -534,10 +555,10 @@ def test_custom_blind_dual_reference_non_ortho(): sim = SimulatedWFS(aberrations=aberrations) - alg = IterativeDualReference(feedback=sim, slm=sim.slm, - phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, - phase_steps=4, - iterations=4) + alg = DualReference(feedback=sim, slm=sim.slm, + phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, + phase_steps=4, + iterations=4) result = alg.execute() From 679bddb7149b5343d07b5357b0ce9b50d4e23083 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Mon, 30 Sep 2024 16:14:21 +0200 Subject: [PATCH 05/15] translating last uses of old code --- examples/wfs_demonstration_experimental.py | 2 +- tests/test_wfs.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/wfs_demonstration_experimental.py b/examples/wfs_demonstration_experimental.py index ef845c3..44eb784 100644 --- a/examples/wfs_demonstration_experimental.py +++ b/examples/wfs_demonstration_experimental.py @@ -24,7 +24,7 @@ # we are using a setup with an SLM that produces 2pi phase shift # at a gray value of 142 slm.lookup_table = range(142) -alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=[800, 800], k_angles_min=-5, k_angles_max=5) +alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=[800, 800], k_radius=7) result = alg.execute() print(result) diff --git a/tests/test_wfs.py b/tests/test_wfs.py index 26b45dc..cc27d87 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -178,8 +178,7 @@ def test_fourier3(): slm_shape = (32, 32) aberrations = np.random.uniform(0.0, 2 * np.pi, slm_shape) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_angles_min=-32, - k_angles_max=32, + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=45, phase_steps=3) controller = WFSController(alg) controller.wavefront = WFSController.State.SHAPED_WAVEFRONT From 20d903a89da74f48f5ca0e46b698ea26009180a9 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 13:23:59 +0200 Subject: [PATCH 06/15] added conversion factor read-only attribute to ADC converter for convenience --- openwfs/simulation/mockdevices.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/openwfs/simulation/mockdevices.py b/openwfs/simulation/mockdevices.py index 1f002a7..32f7b13 100644 --- a/openwfs/simulation/mockdevices.py +++ b/openwfs/simulation/mockdevices.py @@ -176,6 +176,11 @@ def digital_max(self) -> int: """ return self._digital_max + @property + def conversion_factor(self) -> float: + """Conversion factor between analog and digital values.""" + return self.digital_max / self.analog_max + @digital_max.setter def digital_max(self, value): if value < 0 or value > 0xFFFF: From 1afb8a549daf24e5f492ce90d2a787f2b75a4fb9 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 13:27:34 +0200 Subject: [PATCH 07/15] made t attribute of SimulatedWFS public --- openwfs/simulation/transmission.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/openwfs/simulation/transmission.py b/openwfs/simulation/transmission.py index 88e5a85..aed1ecb 100644 --- a/openwfs/simulation/transmission.py +++ b/openwfs/simulation/transmission.py @@ -43,8 +43,8 @@ def __init__(self, *, t: Optional[np.ndarray] = None, aberrations: Optional[np.n """ # transmission matrix (normalized so that the maximum transmission is 1) - self._t = t if t is not None else np.exp(1.0j * aberrations) / (aberrations.shape[0] * aberrations.shape[1]) - self.slm = slm if slm is not None else SLM(self._t.shape[0:2]) + self.t = t if t is not None else np.exp(1.0j * aberrations) / (aberrations.shape[0] * aberrations.shape[1]) + self.slm = slm if slm is not None else SLM(self.t.shape[0:2]) super().__init__(self.slm.field, multi_threaded=multi_threaded) self.beam_amplitude = beam_amplitude @@ -64,9 +64,9 @@ def _fetch(self, incident_field): # noqa np.ndarray: A numpy array containing the calculated intensity in the focus. """ - field = np.tensordot(incident_field * self.beam_amplitude, self._t, 2) + field = np.tensordot(incident_field * self.beam_amplitude, self.t, 2) return np.abs(field) ** 2 @property def data_shape(self): - return self._t.shape[2:] + return self.t.shape[2:] From 284c83165cb753c90878ce7b77cfd28f4e2318dd Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 13:28:23 +0200 Subject: [PATCH 08/15] implemented GaussianNoise processor for adding noise to a signal --- openwfs/simulation/mockdevices.py | 33 +++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/openwfs/simulation/mockdevices.py b/openwfs/simulation/mockdevices.py index 32f7b13..3f1f593 100644 --- a/openwfs/simulation/mockdevices.py +++ b/openwfs/simulation/mockdevices.py @@ -344,3 +344,36 @@ def open(self, value: bool): def _fetch(self, source: np.ndarray) -> np.ndarray: # noqa return source if self._open else 0.0 * source + + +class GaussianNoise(Processor): + """Adds gaussian noise of a specified standard deviation to the signal + Args: + source (Detector): The source detector object to process the data from. + std (float): The standard deviation of the gaussian noise. + multi_threaded: Whether to perform processing in a worker thread. + """ + + def __init__(self, source: Detector, std: float, multi_threaded: bool = True): + super().__init__(source, multi_threaded=multi_threaded) + self._std = std + + @property + def std(self) -> float: + return self._std + + @std.setter + def std(self, value: float): + if value < 0.0: + raise ValueError("Standard deviation must be non-negative") + self._std = float(value) + + def _fetch(self, data: np.ndarray) -> np.ndarray: # noqa + """ + Args: + data (ndarray): source data + + Returns: the out array containing the image with added noise. + + """ + return data + np.random.normal(0.0, self.std, data.shape) From 4a26f5f93f3de75c7404a42d22da8c99564f2883 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 13:28:37 +0200 Subject: [PATCH 09/15] added missing documentation line --- openwfs/core.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/openwfs/core.py b/openwfs/core.py index 8657b37..3aba36a 100644 --- a/openwfs/core.py +++ b/openwfs/core.py @@ -507,6 +507,11 @@ class Processor(Detector, ABC): The `latency` and `duration` properties are computed from the latency and duration of the inputs and cannot be set. By default, the `pixel_size` and `data_shape` are the same as the `pixel_size` and `data_shape` of the first input. To override this behavior, override the `pixel_size` and `data_shape` properties. + + Args: + multi_threaded: If True, `_fetch` is called from a worker thread. Otherwise, `_fetch` is called + directly from `trigger`. If the device is not thread-safe, or threading provides no benefit, + or for easy debugging, set this to False. """ def __init__(self, *args, multi_threaded: bool): From 935f7440484882bfe9aae6af0b051e17baa7fccf Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 13:33:20 +0200 Subject: [PATCH 10/15] fixed noise fidelity estimation and implemented proper test. Other tests need fixing. --- openwfs/algorithms/utilities.py | 13 +- tests/test_wfs.py | 268 +++++++++++++------------------- 2 files changed, 118 insertions(+), 163 deletions(-) diff --git a/openwfs/algorithms/utilities.py b/openwfs/algorithms/utilities.py index e565165..abd0ea4 100644 --- a/openwfs/algorithms/utilities.py +++ b/openwfs/algorithms/utilities.py @@ -215,12 +215,17 @@ def analyze_phase_stepping(measurements: np.ndarray, axis: int, A: Optional[floa # (which occurs twice, ideally in the +1 and -1 components of the Fourier transform), # but this factor of two is already included in the 'signal_energy' calculation. # an offset, and the rest is noise. - offset_energy = np.sum(np.take(t_f, 0, axis=axis) ** 2) - total_energy = np.sum(np.abs(t_f) ** 2) - + # average over all targets to get the most accurate result (assuming all targets are similar) + axes = tuple([i for i in range(t_f.ndim) if i != axis]) + energies = np.sum(np.abs(t_f) ** 2, axis=axes) + offset_energy = energies[0] + total_energy = np.sum(energies) + signal_energy = energies[1] + energies[-1] if phase_steps > 3: + # estimate the noise energy as the energy that is not explained + # by the signal or the offset. noise_energy = (total_energy - signal_energy - offset_energy) / (phase_steps - 3) - noise_factor = np.abs(np.maximum(signal_energy - noise_energy, 0.0) / signal_energy) + noise_factor = np.abs(np.maximum(signal_energy - 2 * noise_energy, 0.0) / signal_energy) else: noise_factor = 1.0 # cannot estimate reliably diff --git a/tests/test_wfs.py b/tests/test_wfs.py index cc27d87..c328905 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -2,21 +2,20 @@ import numpy as np import pytest import skimage -from numpy import ndarray as nd from scipy.linalg import hadamard from scipy.ndimage import zoom -from skimage.transform import resize +from openwfs.simulation.mockdevices import GaussianNoise from ..openwfs.algorithms import StepwiseSequential, FourierDualReference, \ DualReference, troubleshoot from ..openwfs.algorithms.troubleshoot import field_correlation from ..openwfs.algorithms.utilities import WFSController from ..openwfs.processors import SingleRoi -from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope, ADCProcessor, Shutter +from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope, Shutter from ..openwfs.utilities import set_pixel_size, tilt -def assert_enhancement(slm, feedback, wfs_results, t_correct=None): +def assert_enhancement(slm, feedback, wfs_results): """Helper function to check if the intensity in the target focus increases as much as expected""" optimised_wf = -np.angle(wfs_results.t) slm.set_phases(0.0) @@ -38,125 +37,115 @@ def assert_enhancement(slm, feedback, wfs_results, t_correct=None): assert corr > 1.0 - 2.0 / np.sqrt(wfs_results.n) -def half_plane_wave_basis(N1: int, N2: int) -> nd: +@pytest.mark.parametrize("shape", [(4, 7), (10, 7), (20, 31)]) +@pytest.mark.parametrize("noise", [0.0, 0.1]) +def test_ssa(shape, noise: float): """ - Create a plane wave basis for one half of the SLM. + Test the SSA algorithm. - N1: shape[0] of SLM pattern half - N2: shape[1] of SLM pattern half + This tests checks if the algorithm achieves the theoretical enhancement, + and it also verifies that the enhancement and noise fidelity + are estimated correctly by the algorithm. """ - M = N1 * N2 - return np.fft.fft2(np.eye(M).reshape((N1, N2, M)), axes=(0, 1)) - - -def half_hadamard_basis(N1: int, N2: int) -> nd: - """ - Create a Hadamard basis for one half of the SLM. N1 and N2 must be powers of 2. - - N1: shape[0] of SLM pattern half - N2: shape[1] of SLM pattern half - """ - M = N1 * N2 - return hadamard(M, dtype=np.complex128).reshape((N1, N2, M)) + np.random.seed(42) # for reproducibility + M = 100 # number of targets + phase_steps = 6 -def half_split_mask(N1: int, N2: int) -> nd: - """ - Create a mask that splits the slm in the center in a left and right half. - - N1: shape[0] of slm pattern half - N2: shape[1] of slm pattern half - """ - return np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1) + N = np.prod(shape) # number of input modes + sim = SimulatedWFS(t=random_transmission_matrix((*shape, M))) + I_0 = np.mean(sim.read()) + # create feedback object, with noise if needed + if noise > 0.0: + sim.slm.set_phases(0.0) + feedback = GaussianNoise(sim, std=I_0 * noise) + signal = (N - 1) / N ** 2 + theoretical_noise_fidelity = signal / (signal + noise ** 2 / phase_steps) + else: + feedback = sim + theoretical_noise_fidelity = 1.0 -@pytest.mark.parametrize("n_y, n_x", [(5, 5), (7, 11), (6, 4), (30, 20)]) -def test_ssa(n_y, n_x): - """ - Test the enhancement performance of the SSA algorithm. - Note, for low N, the improvement estimate is not accurate, - and the test may sometimes fail due to statistical fluctuations. - """ - aberrations = np.random.uniform(0.0, 2 * np.pi, (n_y, n_x)) - sim = SimulatedWFS(aberrations=aberrations) - alg = StepwiseSequential(feedback=sim, slm=sim.slm, n_x=n_x, n_y=n_y, phase_steps=4) - result = alg.execute() - print(np.mean(np.abs(result.t))) - assert_enhancement(sim.slm, sim, result, np.exp(1j * aberrations)) - - -@pytest.mark.parametrize("n_y, n_x", [(5, 5), (7, 11), (6, 4)]) -def test_ssa_noise(n_y, n_x): - """ - Test the enhancement prediction with noisy SSA. - - Note: this test fails if a smooth image is shown, indicating that the estimators - only work well for strong scattering at the moment. - """ - generator = np.random.default_rng(seed=12345) - aberrations = generator.uniform(0.0, 2 * np.pi, (n_y, n_x)) - sim_no_noise = SimulatedWFS(aberrations=aberrations) - slm = sim_no_noise.slm - scale = np.max(sim_no_noise.read()) - sim = ADCProcessor(sim_no_noise, analog_max=scale * 200.0, digital_max=10000, shot_noise=True, generator=generator) - alg = StepwiseSequential(feedback=sim, slm=slm, n_x=n_x, n_y=n_y, phase_steps=10) + # Execute the SSA algorithm to get the optimized wavefront + # for all targets simultaneously + alg = StepwiseSequential(feedback=feedback, slm=sim.slm, n_x=shape[1], n_y=shape[0], phase_steps=phase_steps) + alg_fidelity = (N - 1) / N # SSA is inaccurate if N is low result = alg.execute() - print(result.fidelity_noise) - - assert_enhancement(slm, sim, result) - -def test_ssa_enhancement(): - input_shape = (40, 40) - output_shape = (200, 200) # todo: resize - rng = np.random.default_rng(seed=12345) - - def get_random_aberrations(): - return resize(rng.uniform(size=input_shape) * 2 * np.pi, output_shape, order=0) + # Determine the optimized intensities in each of the targets individually + # Also estimate the fidelity of the transmission matrix reconstruction + # This fidelity is determined row by row, since we need to compensate + # the unknown phases. The normalization of the correlation function + # is performed on all rows together, not per row, to increase + # the accuracy of the estimate. + I_opt = np.zeros((M,)) + t_correlation = 0.0 + t_norm = 0.0 + for b in range(M): + sim.slm.set_phases(-np.angle(result.t[:, :, b])) + I_opt[b] = feedback.read()[b] + t_correlation += abs(np.vdot(result.t[:, :, b], sim.t[:, :, b])) ** 2 + t_norm += np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b]) + t_correlation /= t_norm + + # Check the enhancement, noise fidelity and + # the fidelity of the transmission matrix reconstruction + enhancement = I_opt.mean() / I_0 + theoretical_enhancement = np.pi / 4 * theoretical_noise_fidelity * alg_fidelity * (N - 1) + 1 + estimated_enhancement = result.estimated_enhancement.mean() * alg_fidelity + theoretical_t_correlation = theoretical_noise_fidelity * alg_fidelity + estimated_t_correlation = result.fidelity_noise * result.fidelity_calibration * alg_fidelity + tolerance = 2.0 / np.sqrt(M) + print( + f"\nenhancement: \ttheoretical= {theoretical_enhancement},\testimated={estimated_enhancement},\tactual: {enhancement}") + print( + f"t-matrix fidelity:\ttheoretical = {theoretical_t_correlation},\testimated = {estimated_t_correlation},\tactual = {t_correlation}") + print(f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}") + print(f"comparing at relative tolerance: {tolerance}") - # Define mock hardware and algorithm - slm = SLM(shape=output_shape) + assert np.allclose(enhancement, theoretical_enhancement, rtol=tolerance), f""" + The SSA algorithm did not enhance the focus as much as expected. + Theoretical {theoretical_enhancement}, got {enhancement}""" - # Find average background intensity - unshaped_intensities = np.zeros((30,)) - for n in range(len(unshaped_intensities)): - signal = SimulatedWFS(aberrations=get_random_aberrations(), slm=slm) - unshaped_intensities[n] = signal.read() + assert np.allclose(estimated_enhancement, enhancement, rtol=tolerance), f""" + The SSA algorithm did not estimate the enhancement correctly. + Estimated {estimated_enhancement}, got {enhancement}""" - num_runs = 10 - shaped_intensities_ssa = np.zeros(num_runs) - for r in range(num_runs): - sim = SimulatedWFS(aberrations=get_random_aberrations(), slm=slm) + assert np.allclose(t_correlation, theoretical_t_correlation, rtol=tolerance), f""" + The SSA algorithm did not measure the transmission matrix correctly. + Expected {theoretical_t_correlation}, got {t_correlation}""" - # SSA - print(f'SSA run {r + 1}/{num_runs}') - alg_ssa = StepwiseSequential(feedback=sim, slm=sim.slm, n_x=13, n_y=13, phase_steps=6) - wfs_result_ssa = alg_ssa.execute() - sim.slm.set_phases(-np.angle(wfs_result_ssa.t)) - shaped_intensities_ssa[r] = sim.read() + assert np.allclose(estimated_t_correlation, theoretical_t_correlation, rtol=tolerance), f""" + The SSA algorithm did not estimate the fidelity of the transmission matrix correctly. + Expected {theoretical_t_correlation}, got {estimated_t_correlation}""" - # Compute enhancements and error margins - enhancement_ssa = shaped_intensities_ssa.mean() / unshaped_intensities.mean() - enhancement_ssa_std = shaped_intensities_ssa.std() / unshaped_intensities.mean() + assert np.allclose(result.fidelity_noise, theoretical_noise_fidelity, rtol=tolerance), f""" + The SSA algorithm did not estimate the noise correctly. + Expected {theoretical_noise_fidelity}, got {result.fidelity_noise}""" - print( - f'SSA enhancement (squared signal): {enhancement_ssa:.2f}, std={enhancement_ssa_std:.2f}, with {wfs_result_ssa.n} modes') - assert enhancement_ssa > 100.0 +def random_transmission_matrix(shape): + """ + Create a random transmission matrix with the given shape. + """ + return np.random.normal(size=shape) + 1j * np.random.normal(size=shape) -@pytest.mark.parametrize("n_x", [2, 3]) -def test_fourier(n_x): +@pytest.mark.parametrize("k_radius", [2, 3]) +def test_fourier(k_radius): """ Test the enhancement performance of the Fourier-based algorithm. - Use the 'cameraman' test image since it is relatively smooth. + Check if the estimated enhancement is close to the actual enhancement. + Check if the measured transmission matrix is close to the actual transmission matrix. + For this check, compare two situations: one with a completely random aberration pattern, + and one with a smooth aberration pattern. In the latter case, the measured transmission matrix + should match the actual transmission matrix better than for the completely random one """ - aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) - sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=n_x, - phase_steps=4) + shape = (16, 15) + sim = SimulatedWFS(t=random_transmission_matrix(shape)) + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=shape, k_radius=k_radius) results = alg.execute() - assert_enhancement(sim.slm, sim, results, np.exp(1j * aberrations)) + assert_enhancement(sim, results) def test_fourier2(): @@ -425,62 +414,30 @@ def test_ssa_fidelity(gaussian_noise_std): assert np.isclose(trouble.measured_enhancement, trouble.expected_enhancement, rtol=0.2) -def test_ssa_aberration_reconstruction(): - """Test if SSA can closely reconstruct the ground truth aberrations.""" +@pytest.mark.parametrize("type", ('plane_wave', 'hadamard')) +@pytest.mark.parametrize("shape", ((8, 8), (6, 4))) +def test_custom_blind_dual_reference_ortho_split(type: str, shape): + """Test custom blind dual reference with an orthonormal phase-only basis. + Two types of bases are tested: plane waves and Hadamard""" do_debug = False - - n_x = 6 - n_y = 5 - - # Create aberrations - x = np.linspace(-1, 1, n_x).reshape((1, -1)) - y = np.linspace(-1, 1, n_y).reshape((-1, 1)) - aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi) - aberrations[0:1, :] = 0 - aberrations[:, 0:1] = 0 - - # Initialize simulation and algorithm - sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) - - alg = StepwiseSequential(feedback=sim, slm=sim.slm, n_x=n_x, n_y=n_y, phase_steps=4) - - result = alg.execute() - - if do_debug: - import matplotlib.pyplot as plt - plt.figure() - plt.imshow(np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('Aberrations') - - plt.figure() - plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('t') - plt.colorbar() - plt.show() - - assert np.abs(field_correlation(np.exp(1j * aberrations), result.t)) > 0.99 - - -@pytest.mark.parametrize("construct_basis", (half_plane_wave_basis, half_hadamard_basis)) -def test_custom_blind_dual_reference_ortho_split(construct_basis: callable): - """Test custom blind dual reference with an orthonormal phase-only basis.""" - do_debug = False - - # Create set of phase-only orthonormal modes - N1 = 8 - N2 = 4 - M = N1 * N2 - mode_set_half = construct_basis(N1, N2) - mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1) + N = shape[0] * (shape[1] // 2) + modes_shape = (shape[0], shape[1] // 2, N) + if type == 'plane_wave': + # Create a full plane wave basis for one half of the SLM. + modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) + else: # type == 'hadamard': + modes = hadamard(N).reshape(modes_shape) + + mask = np.concatenate((np.zeros(modes_shape[0:1], dtype=bool), np.ones(modes_shape[0:1], dtype=bool)), axis=1) + mode_set = np.concatenate((modes, np.zeros(shape=modes_shape)), axis=1) phases_set = np.angle(mode_set) - mask = half_split_mask(N1, N2) if do_debug: # Plot the modes import matplotlib.pyplot as plt plt.figure(figsize=(12, 7)) - for m in range(M): - plt.subplot(N2, N1, m + 1) + for m in range(N): + plt.subplot(*modes_shape[0:1], m + 1) plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi) plt.title(f'm={m}') plt.xticks([]) @@ -488,24 +445,17 @@ def test_custom_blind_dual_reference_ortho_split(construct_basis: callable): plt.pause(0.1) # Create aberrations - x = np.linspace(-1, 1, 1 * N1).reshape((1, -1)) - y = np.linspace(-1, 1, 1 * N1).reshape((-1, 1)) - aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi) - aberrations[0:2, :] = 0 - aberrations[:, 0:2] = 0 - - sim = SimulatedWFS(aberrations=aberrations) + sim = SimulatedWFS(t=random_transmission_matrix(shape)) alg = DualReference(feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, - phase_steps=4, iterations=4) result = alg.execute() if do_debug: plt.figure() - plt.imshow(np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap='hsv') + plt.imshow(np.angle(sim.t), vmin=-np.pi, vmax=np.pi, cmap='hsv') plt.title('Aberrations') plt.figure() From e8fa459513c8d928ac4f0cc4e60257d3428e07e7 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 14:02:40 +0200 Subject: [PATCH 11/15] cleanup --- openwfs/algorithms/dual_reference.py | 1 - 1 file changed, 1 deletion(-) diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index a1e554c..f04df3c 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -175,7 +175,6 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) relative = results_all[0].t[self.zero_indices[0], ...] + np.conjugate( results_all[1].t[self.zero_indices[1], ...]) factor = (relative / np.abs(relative)).reshape((1, *self.feedback.data_shape)) - print(np.angle(factor)) t_full = (self.compute_t_set(results_all[0].t, cobasis[0]) + self.compute_t_set(factor * results_all[1].t, cobasis[1])) From 6771d873fe92cc6070a3f3b520b17bc109e450b2 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 14:20:27 +0200 Subject: [PATCH 12/15] cleanup of tests --- tests/test_wfs.py | 186 ++++++++-------------------------------------- 1 file changed, 33 insertions(+), 153 deletions(-) diff --git a/tests/test_wfs.py b/tests/test_wfs.py index c328905..5e422d7 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -5,43 +5,21 @@ from scipy.linalg import hadamard from scipy.ndimage import zoom -from openwfs.simulation.mockdevices import GaussianNoise from ..openwfs.algorithms import StepwiseSequential, FourierDualReference, \ - DualReference, troubleshoot + DualReference from ..openwfs.algorithms.troubleshoot import field_correlation from ..openwfs.algorithms.utilities import WFSController from ..openwfs.processors import SingleRoi -from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope, Shutter -from ..openwfs.utilities import set_pixel_size, tilt - - -def assert_enhancement(slm, feedback, wfs_results): - """Helper function to check if the intensity in the target focus increases as much as expected""" - optimised_wf = -np.angle(wfs_results.t) - slm.set_phases(0.0) - before = feedback.read() - slm.set_phases(optimised_wf) - after = feedback.read() - ratio = after / before - estimated_ratio = wfs_results.estimated_enhancement # wfs_results.estimated_optimized_intensity / before - print(f"expected: {estimated_ratio}, actual: {ratio}") - assert estimated_ratio * 0.5 <= ratio <= estimated_ratio * 2.0, f""" - The SSA algorithm did not enhance the focus as much as expected. - Expected at least 0.5 * {estimated_ratio}, got {ratio}""" - - if t_correct is not None: - # Check if we correctly measured the transmission matrix. - # The correlation will be less for fewer segments, hence an (ad hoc) factor of 2/sqrt(n) - t = wfs_results.t[:] - corr = np.abs(np.vdot(t_correct, t) / np.sqrt(np.vdot(t_correct, t_correct) * np.vdot(t, t))) - assert corr > 1.0 - 2.0 / np.sqrt(wfs_results.n) +from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope +from ..openwfs.simulation.mockdevices import GaussianNoise @pytest.mark.parametrize("shape", [(4, 7), (10, 7), (20, 31)]) @pytest.mark.parametrize("noise", [0.0, 0.1]) -def test_ssa(shape, noise: float): +@pytest.mark.parametrize("algorithm", ['ssa', 'fourier']) +def test_multi_target_algorithms(shape, noise: float, algorithm: str): """ - Test the SSA algorithm. + Test the multi-target capable algorithms (SSA and Fourier dual ref). This tests checks if the algorithm achieves the theoretical enhancement, and it also verifies that the enhancement and noise fidelity @@ -52,24 +30,26 @@ def test_ssa(shape, noise: float): M = 100 # number of targets phase_steps = 6 - N = np.prod(shape) # number of input modes + # create feedback object, with noise if needed sim = SimulatedWFS(t=random_transmission_matrix((*shape, M))) + sim.slm.set_phases(0.0) I_0 = np.mean(sim.read()) - - # create feedback object, with noise if needed - if noise > 0.0: - sim.slm.set_phases(0.0) - feedback = GaussianNoise(sim, std=I_0 * noise) - signal = (N - 1) / N ** 2 - theoretical_noise_fidelity = signal / (signal + noise ** 2 / phase_steps) - else: - feedback = sim - theoretical_noise_fidelity = 1.0 - - # Execute the SSA algorithm to get the optimized wavefront + feedback = GaussianNoise(sim, std=I_0 * noise) + + if algorithm == 'ssa': + alg = StepwiseSequential(feedback=feedback, slm=sim.slm, n_x=shape[1], n_y=shape[0], phase_steps=phase_steps) + N = np.prod(shape) # number of input modes + alg_fidelity = (N - 1) / N # SSA is inaccurate if N is low + signal = (N - 1) / N ** 2 # for estimating SNR + else: # 'fourier': + alg = FourierDualReference(feedback=feedback, slm=sim.slm, slm_shape=shape, k_radius=(np.min(shape) - 1) // 2, + phase_steps=phase_steps) + N = alg.phase_patterns[0].shape[2] + alg.phase_patterns[1].shape[2] # number of input modes + alg_fidelity = 1.0 # Fourier is accurate for any N + signal = 1 / 2 # for estimating SNR. + + # Execute the algorithm to get the optimized wavefront # for all targets simultaneously - alg = StepwiseSequential(feedback=feedback, slm=sim.slm, n_x=shape[1], n_y=shape[0], phase_steps=phase_steps) - alg_fidelity = (N - 1) / N # SSA is inaccurate if N is low result = alg.execute() # Determine the optimized intensities in each of the targets individually @@ -85,11 +65,15 @@ def test_ssa(shape, noise: float): sim.slm.set_phases(-np.angle(result.t[:, :, b])) I_opt[b] = feedback.read()[b] t_correlation += abs(np.vdot(result.t[:, :, b], sim.t[:, :, b])) ** 2 - t_norm += np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b]) + t_norm += abs(np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b])) t_correlation /= t_norm + # a correlation of 1 means optimal reconstruction of the N modulated modes, which may be less than the total number of inputs in the transmission matrix + t_correlation *= np.prod(shape) / N + # Check the enhancement, noise fidelity and # the fidelity of the transmission matrix reconstruction + theoretical_noise_fidelity = signal / (signal + noise ** 2 / phase_steps) enhancement = I_opt.mean() / I_0 theoretical_enhancement = np.pi / 4 * theoretical_noise_fidelity * alg_fidelity * (N - 1) + 1 estimated_enhancement = result.estimated_enhancement.mean() * alg_fidelity @@ -131,23 +115,7 @@ def random_transmission_matrix(shape): return np.random.normal(size=shape) + 1j * np.random.normal(size=shape) -@pytest.mark.parametrize("k_radius", [2, 3]) -def test_fourier(k_radius): - """ - Test the enhancement performance of the Fourier-based algorithm. - Check if the estimated enhancement is close to the actual enhancement. - Check if the measured transmission matrix is close to the actual transmission matrix. - For this check, compare two situations: one with a completely random aberration pattern, - and one with a smooth aberration pattern. In the latter case, the measured transmission matrix - should match the actual transmission matrix better than for the completely random one - """ - shape = (16, 15) - sim = SimulatedWFS(t=random_transmission_matrix(shape)) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=shape, k_radius=k_radius) - results = alg.execute() - assert_enhancement(sim, results) - - +@pytest.mark.skip("Not implemented") def test_fourier2(): """Test the Fourier dual reference algorithm using WFSController.""" slm_shape = (1000, 1000) @@ -160,53 +128,7 @@ def test_fourier2(): assert_enhancement(sim.slm, sim, controller._result, np.exp(1j * scaled_aberration)) -@pytest.mark.skip(reason="This test is is not passing yet and needs further inspection to see if the test itself is " - "correct.") -def test_fourier3(): - """Test the Fourier dual reference algorithm using WFSController.""" - slm_shape = (32, 32) - aberrations = np.random.uniform(0.0, 2 * np.pi, slm_shape) - sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=45, - phase_steps=3) - controller = WFSController(alg) - controller.wavefront = WFSController.State.SHAPED_WAVEFRONT - scaled_aberration = zoom(aberrations, np.array(slm_shape) / aberrations.shape) - assert_enhancement(sim.slm, sim, controller._result, np.exp(1j * scaled_aberration)) - - -@pytest.mark.parametrize("k_radius, g", [[2.5, (1.0, 0.0)], [2.5, (0.0, 2.0)]], ) -def test_fourier_circle(k_radius, g): - """ - Test Fourier dual reference algorithm with a circular k-space, with a tilt 'aberration'. - """ - aberrations = tilt(shape=(100, 100), extent=(2, 2), g=g, phase_offset=0.5) - sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=k_radius, - phase_steps=4) - - do_debug = False - if do_debug: - # Plot the modes - import matplotlib.pyplot as plt - plt.figure(figsize=(12, 7)) - patterns = alg.phase_patterns[0] * np.expand_dims(alg.masks[0], axis=-1) - N = patterns.shape[2] - Nsqrt = int(np.ceil(np.sqrt(N))) - for m in range(N): - plt.subplot(Nsqrt, Nsqrt, m + 1) - plt.imshow(np.cos(patterns[:, :, m]), vmin=-1.0, vmax=1.0) - plt.title(f'm={m}') - plt.xticks([]) - plt.yticks([]) - plt.colorbar() - plt.pause(0.01) - plt.suptitle('Phase of basis functions for one half') - - results = alg.execute() - assert_enhancement(sim.slm, sim, results, np.exp(1j * aberrations)) - - +@pytest.mark.skip("Not implemented") def test_fourier_microscope(): aberration_phase = skimage.data.camera() * ((2 * np.pi) / 255.0) + np.pi aberration = StaticSource(aberration_phase, pixel_size=2.0 / np.array(aberration_phase.shape)) @@ -372,50 +294,8 @@ def test_multidimensional_feedback_fourier(): Expected at least 3.0, got {enhancement}""" -@pytest.mark.parametrize("gaussian_noise_std", (0.0, 0.1, 0.5, 3.0)) -def test_ssa_fidelity(gaussian_noise_std): - """Test fidelity prediction for WFS simulation with various noise levels.""" - # === Define virtual devices for a WFS simulation === - # Define aberration as a pattern of random phases at the pupil plane - aberrations = np.random.uniform(size=(80, 80)) * 2 * np.pi - - # Define specimen as an image with several bright pixels - specimen_img = np.zeros((240, 240)) - specimen_img[120, 120] = 2e5 - specimen = set_pixel_size(specimen_img, pixel_size=100 * u.nm) - - # The SLM is conjugated to the back pupil plane - slm = SLM(shape=(80, 80)) - # Also simulate a shutter that can turn off the light - shutter = Shutter(slm.field) - - # Simulate a WFS microscope looking at the specimen - sim = Microscope(source=specimen, incident_field=shutter, aberrations=aberrations, wavelength=800 * u.nm) - - # Simulate a camera device with gaussian noise and shot noise - cam = sim.get_camera(analog_max=1e4, shot_noise=False, gaussian_noise_std=gaussian_noise_std) - - # Define feedback as circular region of interest in the center of the frame - roi_detector = SingleRoi(cam, radius=1) - - # === Run wavefront shaping experiment === - # Use the stepwise sequential (SSA) WFS algorithm - n_x = 10 - n_y = 10 - alg = StepwiseSequential(feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=8) - - # Define a region of interest to determine average speckle intensity - roi_background = SingleRoi(cam, radius=50) - - # Run WFS troubleshooter and output a report to the console - trouble = troubleshoot(algorithm=alg, background_feedback=roi_background, - frame_source=cam, shutter=shutter) - - assert np.isclose(trouble.measured_enhancement, trouble.expected_enhancement, rtol=0.2) - - @pytest.mark.parametrize("type", ('plane_wave', 'hadamard')) -@pytest.mark.parametrize("shape", ((8, 8), (6, 4))) +@pytest.mark.parametrize("shape", ((8, 8), (16, 4))) def test_custom_blind_dual_reference_ortho_split(type: str, shape): """Test custom blind dual reference with an orthonormal phase-only basis. Two types of bases are tested: plane waves and Hadamard""" @@ -428,7 +308,7 @@ def test_custom_blind_dual_reference_ortho_split(type: str, shape): else: # type == 'hadamard': modes = hadamard(N).reshape(modes_shape) - mask = np.concatenate((np.zeros(modes_shape[0:1], dtype=bool), np.ones(modes_shape[0:1], dtype=bool)), axis=1) + mask = np.concatenate((np.zeros(modes_shape[0:2], dtype=bool), np.ones(modes_shape[0:2], dtype=bool)), axis=1) mode_set = np.concatenate((modes, np.zeros(shape=modes_shape)), axis=1) phases_set = np.angle(mode_set) @@ -464,7 +344,7 @@ def test_custom_blind_dual_reference_ortho_split(type: str, shape): plt.colorbar() plt.show() - assert np.abs(field_correlation(np.exp(1j * aberrations), result.t)) > 0.999 + assert np.abs(field_correlation(sim.t, result.t)) > 0.99 # todo: find out why this is not higher def test_custom_blind_dual_reference_non_ortho(): From 0f326ad3660552f7d2dfd62c86a2cc70ddbe7799 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 15:30:21 +0200 Subject: [PATCH 13/15] added subpackages to pyproject.toml --- pyproject.toml | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index ee43077..7457674 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,14 @@ classifiers = [ 'License :: OSI Approved :: BSD License', 'Operating System :: OS Independent', ] +packages = [ + { include = "openwfs" }, + { include = "openwfs.algorithms" }, + { include = "openwfs.devices" }, + { include = "openwfs.processors" }, + { include = "openwfs.simulation" }, + { include = "openwfs.utilities" } +] [build-system] requires = ["poetry-core"] From 32c46de5bba5087360aca5a31260b3f2cce37e94 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 15:36:24 +0200 Subject: [PATCH 14/15] auto formatting with black --- docs/source/conf.py | 128 +++++--- examples/hello_simulation.py | 1 + examples/hello_wfs.py | 3 +- examples/mm_scanning_microscope.py | 48 ++- examples/sample_microscope.py | 33 +- examples/slm_demo.py | 4 +- examples/slm_disk.py | 2 +- examples/troubleshooter_demo.py | 15 +- examples/wfs_demonstration_experimental.py | 8 +- openwfs/algorithms/basic_fourier.py | 39 ++- openwfs/algorithms/dual_reference.py | 112 +++++-- openwfs/algorithms/ssa.py | 18 +- openwfs/algorithms/troubleshoot.py | 331 +++++++++++++-------- openwfs/algorithms/utilities.py | 137 ++++++--- openwfs/core.py | 132 ++++++-- openwfs/devices/camera.py | 73 +++-- openwfs/devices/galvo_scanner.py | 284 ++++++++++++------ openwfs/devices/nidaq_gain.py | 11 +- openwfs/devices/slm/context.py | 3 +- openwfs/devices/slm/geometry.py | 37 ++- openwfs/devices/slm/patch.py | 113 +++++-- openwfs/devices/slm/slm.py | 250 +++++++++++----- openwfs/devices/slm/texture.py | 88 +++++- openwfs/plot_utilities.py | 8 +- openwfs/processors/__init__.py | 9 +- openwfs/processors/processors.py | 135 ++++++--- openwfs/simulation/__init__.py | 9 +- openwfs/simulation/microscope.py | 105 +++++-- openwfs/simulation/mockdevices.py | 90 ++++-- openwfs/simulation/slm.py | 104 ++++--- openwfs/simulation/transmission.py | 18 +- openwfs/utilities/__init__.py | 14 +- openwfs/utilities/patterns.py | 48 ++- openwfs/utilities/utilities.py | 179 +++++++---- tests/test_algorithms_troubleshoot.py | 151 +++++++--- tests/test_camera.py | 18 +- tests/test_core.py | 52 ++-- tests/test_processors.py | 63 ++-- tests/test_scanning_microscope.py | 118 +++++--- tests/test_simulation.py | 120 ++++++-- tests/test_slm.py | 46 ++- tests/test_utilities.py | 68 +++-- tests/test_wfs.py | 230 ++++++++++---- 43 files changed, 2435 insertions(+), 1020 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index ba1c421..060e523 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -19,24 +19,36 @@ # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration -extensions = ['sphinx.ext.napoleon', 'sphinx.ext.autodoc', 'sphinx.ext.mathjax', - 'sphinx.ext.viewcode', 'sphinx_autodoc_typehints', 'sphinxcontrib.bibtex', 'sphinx.ext.autosectionlabel', - 'sphinx_markdown_builder', 'sphinx_gallery.gen_gallery'] +extensions = [ + "sphinx.ext.napoleon", + "sphinx.ext.autodoc", + "sphinx.ext.mathjax", + "sphinx.ext.viewcode", + "sphinx_autodoc_typehints", + "sphinxcontrib.bibtex", + "sphinx.ext.autosectionlabel", + "sphinx_markdown_builder", + "sphinx_gallery.gen_gallery", +] # basic project information -project = 'OpenWFS' -copyright = '2023-, Ivo Vellekoop, Daniël W. S. Cox, and Jeroen H. Doornbos, University of Twente' -author = 'Jeroen H. Doornbos, Daniël W. S. Cox, Tom Knop, Harish Sasikumar, Ivo M. Vellekoop' -release = '0.1.0rc2' -html_title = "OpenWFS - a library for conducting and simulating wavefront shaping experiments" +project = "OpenWFS" +copyright = "2023-, Ivo Vellekoop, Daniël W. S. Cox, and Jeroen H. Doornbos, University of Twente" +author = ( + "Jeroen H. Doornbos, Daniël W. S. Cox, Tom Knop, Harish Sasikumar, Ivo M. Vellekoop" +) +release = "0.1.0rc2" +html_title = ( + "OpenWFS - a library for conducting and simulating wavefront shaping experiments" +) # \renewenvironment{sphinxtheindex}{\setbox0\vbox\bgroup\begin{theindex}}{\end{theindex}} # latex configuration latex_elements = { - 'preamble': r""" + "preamble": r""" \usepackage{authblk} """, - 'maketitle': r""" + "maketitle": r""" \author[1]{Daniël~W.~S.~Cox} \author[1]{Tom~Knop} \author[1,2]{Harish~Sasikumar} @@ -67,40 +79,52 @@ } \maketitle """, - 'tableofcontents': "", - 'makeindex': "", - 'printindex': "", - 'figure_align': "", - 'extraclassoptions': 'notitlepage', + "tableofcontents": "", + "makeindex": "", + "printindex": "", + "figure_align": "", + "extraclassoptions": "notitlepage", } latex_docclass = { - 'manual': 'scrartcl', - 'howto': 'scrartcl', + "manual": "scrartcl", + "howto": "scrartcl", } -latex_documents = [('index_latex', 'OpenWFS.tex', - 'OpenWFS - a library for conducting and simulating wavefront shaping experiments', - 'Jeroen H. Doornbos', 'howto')] -latex_toplevel_sectioning = 'section' -bibtex_default_style = 'unsrt' -bibtex_bibfiles = ['references.bib'] +latex_documents = [ + ( + "index_latex", + "OpenWFS.tex", + "OpenWFS - a library for conducting and simulating wavefront shaping experiments", + "Jeroen H. Doornbos", + "howto", + ) +] +latex_toplevel_sectioning = "section" +bibtex_default_style = "unsrt" +bibtex_bibfiles = ["references.bib"] numfig = True -templates_path = ['_templates'] -exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', 'acknowledgements.rst', 'sg_execution_times.rst'] -master_doc = '' -include_patterns = ['**'] +templates_path = ["_templates"] +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", + "acknowledgements.rst", + "sg_execution_times.rst", +] +master_doc = "" +include_patterns = ["**"] napoleon_use_rtype = False napoleon_use_param = True typehints_document_rtype = False -latex_engine = 'xelatex' -html_theme = 'sphinx_rtd_theme' +latex_engine = "xelatex" +html_theme = "sphinx_rtd_theme" add_module_names = False autodoc_preserve_defaults = True sphinx_gallery_conf = { - 'examples_dirs': '../../examples', # path to your example scripts - 'ignore_pattern': 'set_path.py', - 'gallery_dirs': 'auto_examples', # path to where to save gallery generated output + "examples_dirs": "../../examples", # path to your example scripts + "ignore_pattern": "set_path.py", + "gallery_dirs": "auto_examples", # path to where to save gallery generated output } # importing this module without OpenGL installed will fail, @@ -117,7 +141,7 @@ def skip(app, what, name, obj, skip, options): def visit_citation(self, node): """Patch-in function for markdown builder to support citations.""" - id = node['ids'][0] + id = node["ids"][0] self.add(f'') @@ -141,29 +165,35 @@ def setup(app): def source_read(app, docname, source): - if docname == 'readme' or docname == 'conclusion': - if (app.builder.name == 'latex') == (docname == 'conclusion'): - source[0] = source[0].replace('%endmatter%', '.. include:: acknowledgements.rst') + if docname == "readme" or docname == "conclusion": + if (app.builder.name == "latex") == (docname == "conclusion"): + source[0] = source[0].replace( + "%endmatter%", ".. include:: acknowledgements.rst" + ) else: - source[0] = source[0].replace('%endmatter%', '') + source[0] = source[0].replace("%endmatter%", "") def builder_inited(app): - if app.builder.name == 'html': - exclude_patterns.extend(['conclusion.rst', 'index_latex.rst', 'index_markdown.rst']) - app.config.master_doc = 'index' - elif app.builder.name == 'latex': - exclude_patterns.extend(['auto_examples/*', 'index_markdown.rst', 'index.rst', 'api*']) - app.config.master_doc = 'index_latex' - elif app.builder.name == 'markdown': + if app.builder.name == "html": + exclude_patterns.extend( + ["conclusion.rst", "index_latex.rst", "index_markdown.rst"] + ) + app.config.master_doc = "index" + elif app.builder.name == "latex": + exclude_patterns.extend( + ["auto_examples/*", "index_markdown.rst", "index.rst", "api*"] + ) + app.config.master_doc = "index_latex" + elif app.builder.name == "markdown": include_patterns.clear() - include_patterns.extend(['readme.rst', 'index_markdown.rst']) - app.config.master_doc = 'index_markdown' + include_patterns.extend(["readme.rst", "index_markdown.rst"]) + app.config.master_doc = "index_markdown" def copy_readme(app, exception): """Copy the readme file to the root of the documentation directory.""" - if exception is None and app.builder.name == 'markdown': - source_file = Path(app.outdir) / 'readme.md' - destination_dir = Path(app.confdir).parents[1] / 'README.md' + if exception is None and app.builder.name == "markdown": + source_file = Path(app.outdir) / "readme.md" + destination_dir = Path(app.confdir).parents[1] / "README.md" shutil.copy(source_file, destination_dir) diff --git a/examples/hello_simulation.py b/examples/hello_simulation.py index 57926a7..722e097 100644 --- a/examples/hello_simulation.py +++ b/examples/hello_simulation.py @@ -2,6 +2,7 @@ =============================================== Simulates a wavefront shaping experiment using a SimulatedWFS object, which acts both as a spatial light modulator (SLM) and a detector.""" + import numpy as np from openwfs.algorithms import StepwiseSequential diff --git a/examples/hello_wfs.py b/examples/hello_wfs.py index 28912d5..be7015f 100644 --- a/examples/hello_wfs.py +++ b/examples/hello_wfs.py @@ -7,6 +7,7 @@ and a spatial light modulator (SLM) connected to the secondary video output. """ + import numpy as np from openwfs.algorithms import StepwiseSequential @@ -18,7 +19,7 @@ # Connect to a GenICam camera, average pixels to get feedback signal camera = Camera(R"C:\Program Files\Basler\pylon 7\Runtime\x64\ProducerU3V.cti") -feedback = SingleRoi(camera, pos=(320, 320), mask_type='disk', radius=2.5) +feedback = SingleRoi(camera, pos=(320, 320), mask_type="disk", radius=2.5) # Run the algorithm alg = StepwiseSequential(feedback=feedback, slm=slm, n_x=10, n_y=10, phase_steps=4) diff --git a/examples/mm_scanning_microscope.py b/examples/mm_scanning_microscope.py index ec32c9d..3bb96d1 100644 --- a/examples/mm_scanning_microscope.py +++ b/examples/mm_scanning_microscope.py @@ -20,31 +20,51 @@ optical_deflection=1.0 / (0.22 * u.V / u.deg), galvo_to_pupil_magnification=2, objective_magnification=16, - reference_tube_lens=200 * u.mm) + reference_tube_lens=200 * u.mm, +) acceleration = Axis.compute_acceleration( optical_deflection=1.0 / (0.22 * u.V / u.deg), - torque_constant=2.8E5 * u.dyne * u.cm / u.A, - rotor_inertia=8.25 * u.g * u.cm ** 2, - maximum_current=4 * u.A) + torque_constant=2.8e5 * u.dyne * u.cm / u.A, + rotor_inertia=8.25 * u.g * u.cm**2, + maximum_current=4 * u.A, +) # scale = 440 * u.um / u.V (calibrated) sample_rate = 0.5 * u.MHz reference_zoom = 1.2 -y_axis = Axis(channel='Dev4/ao0', v_min=-2.0 * u.V, v_max=2.0 * u.V, maximum_acceleration=acceleration, scale=scale) -x_axis = Axis(channel='Dev4/ao1', v_min=-2.0 * u.V, v_max=2.0 * u.V, maximum_acceleration=acceleration, scale=scale) -input_channel = InputChannel('Dev4/ai0', -1.0 * u.V, 1.0 * u.V) +y_axis = Axis( + channel="Dev4/ao0", + v_min=-2.0 * u.V, + v_max=2.0 * u.V, + maximum_acceleration=acceleration, + scale=scale, +) +x_axis = Axis( + channel="Dev4/ao1", + v_min=-2.0 * u.V, + v_max=2.0 * u.V, + maximum_acceleration=acceleration, + scale=scale, +) +input_channel = InputChannel("Dev4/ai0", -1.0 * u.V, 1.0 * u.V) test_image = skimage.data.hubble_deep_field() * 256 -scanner = ScanningMicroscope(sample_rate=sample_rate, - input=input_channel, y_axis=y_axis, x_axis=x_axis, - test_pattern='image', reference_zoom=reference_zoom, - resolution=1024, test_image=test_image) +scanner = ScanningMicroscope( + sample_rate=sample_rate, + input=input_channel, + y_axis=y_axis, + x_axis=x_axis, + test_pattern="image", + reference_zoom=reference_zoom, + resolution=1024, + test_image=test_image, +) -if __name__ == '__main__': +if __name__ == "__main__": scanner.binning = 4 - plt.imshow(scanner.read(), cmap='gray') + plt.imshow(scanner.read(), cmap="gray") plt.colorbar() plt.show() else: - devices = {'microscope': scanner} + devices = {"microscope": scanner} diff --git a/examples/sample_microscope.py b/examples/sample_microscope.py index bdc2057..9259594 100644 --- a/examples/sample_microscope.py +++ b/examples/sample_microscope.py @@ -41,27 +41,42 @@ # Code img = set_pixel_size( - np.maximum(np.random.randint(-10000, 100, (img_size_y, img_size_x), dtype=np.int16), 0), - 60 * u.nm) + np.maximum( + np.random.randint(-10000, 100, (img_size_y, img_size_x), dtype=np.int16), 0 + ), + 60 * u.nm, +) src = StaticSource(img) -mic = Microscope(src, magnification=magnification, numerical_aperture=numerical_aperture, wavelength=wavelength) +mic = Microscope( + src, + magnification=magnification, + numerical_aperture=numerical_aperture, + wavelength=wavelength, +) # simulate shot noise in an 8-bit camera with auto-exposure: -cam = mic.get_camera(shot_noise=True, digital_max=255, data_shape=camera_resolution, pixel_size=pixel_size) -devices = {'camera': cam, 'stage': mic.xy_stage} +cam = mic.get_camera( + shot_noise=True, + digital_max=255, + data_shape=camera_resolution, + pixel_size=pixel_size, +) +devices = {"camera": cam, "stage": mic.xy_stage} -if __name__ == '__main__': +if __name__ == "__main__": import matplotlib.pyplot as plt plt.subplot(1, 2, 1) imshow(img) - plt.title('Original image') + plt.title("Original image") plt.subplot(1, 2, 2) - plt.title('Scanned image') + plt.title("Scanned image") ax = None for p in range(p_limit): mic.xy_stage.x = p * 1 * u.um mic.numerical_aperture = 1.0 * (p + 1) / p_limit # NA increases to 1.0 ax = grab_and_show(cam, ax) - plt.title(f"NA: {mic.numerical_aperture}, δ: {mic.abbe_limit.to_value(u.um):2.2} μm") + plt.title( + f"NA: {mic.numerical_aperture}, δ: {mic.abbe_limit.to_value(u.um):2.2} μm" + ) plt.pause(0.2) diff --git a/examples/slm_demo.py b/examples/slm_demo.py index bbb64c9..aba50bf 100644 --- a/examples/slm_demo.py +++ b/examples/slm_demo.py @@ -34,7 +34,9 @@ p4.phases = 1 p4.additive_blend = False -pf.phases = patterns.lens(100, f=1 * u.m, wavelength=0.8 * u.um, extent=(10 * u.mm, 10 * u.mm)) +pf.phases = patterns.lens( + 100, f=1 * u.m, wavelength=0.8 * u.um, extent=(10 * u.mm, 10 * u.mm) +) rng = np.random.default_rng() for n in range(200): random_data = rng.random([10, 10], np.float32) * 2.0 * np.pi diff --git a/examples/slm_disk.py b/examples/slm_disk.py index cbe79c0..87d8260 100644 --- a/examples/slm_disk.py +++ b/examples/slm_disk.py @@ -29,4 +29,4 @@ # read back the pixels and store in a file pixels = slm.pixels.read() -cv2.imwrite('slm_disk.png', pixels) +cv2.imwrite("slm_disk.png", pixels) diff --git a/examples/troubleshooter_demo.py b/examples/troubleshooter_demo.py index 706e1ee..abf08d6 100644 --- a/examples/troubleshooter_demo.py +++ b/examples/troubleshooter_demo.py @@ -31,12 +31,16 @@ # Simulate an SLM with incorrect phase response # Also simulate a shutter that can turn off the light # The SLM is conjugated to the back pupil plane -slm = SLM(shape=(100, 100), - phase_response=(np.arange(256) / 128 * np.pi) * 1.2) +slm = SLM(shape=(100, 100), phase_response=(np.arange(256) / 128 * np.pi) * 1.2) shutter = Shutter(slm.field) # Simulate a WFS microscope looking at the specimen -sim = Microscope(source=specimen, incident_field=shutter, aberrations=aberrations, wavelength=800 * u.nm) +sim = Microscope( + source=specimen, + incident_field=shutter, + aberrations=aberrations, + wavelength=800 * u.nm, +) # Simulate a camera device with gaussian noise and shot noise cam = sim.get_camera(analog_max=1e4, shot_noise=True, gaussian_noise_std=4.0) @@ -52,6 +56,7 @@ roi_background = SingleRoi(cam, radius=10) # Run WFS troubleshooter and output a report to the console -trouble = troubleshoot(algorithm=alg, background_feedback=roi_background, - frame_source=cam, shutter=shutter) +trouble = troubleshoot( + algorithm=alg, background_feedback=roi_background, frame_source=cam, shutter=shutter +) trouble.report() diff --git a/examples/wfs_demonstration_experimental.py b/examples/wfs_demonstration_experimental.py index 44eb784..b2a16f2 100644 --- a/examples/wfs_demonstration_experimental.py +++ b/examples/wfs_demonstration_experimental.py @@ -19,12 +19,16 @@ # constructs the actual slm for wavefront shaping, and a monitor window to display the current phase pattern slm = SLM(monitor_id=2, duration=2) -monitor = slm.clone(monitor_id=0, pos=(0, 0), shape=(slm.shape[0] // 4, slm.shape[1] // 4)) +monitor = slm.clone( + monitor_id=0, pos=(0, 0), shape=(slm.shape[0] // 4, slm.shape[1] // 4) +) # we are using a setup with an SLM that produces 2pi phase shift # at a gray value of 142 slm.lookup_table = range(142) -alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=[800, 800], k_radius=7) +alg = FourierDualReference( + feedback=roi_detector, slm=slm, slm_shape=[800, 800], k_radius=7 +) result = alg.execute() print(result) diff --git a/openwfs/algorithms/basic_fourier.py b/openwfs/algorithms/basic_fourier.py index 52860f5..2b1f32b 100644 --- a/openwfs/algorithms/basic_fourier.py +++ b/openwfs/algorithms/basic_fourier.py @@ -25,9 +25,19 @@ class FourierDualReference(DualReference): optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017). """ - def __init__(self, *, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), phase_steps=4, k_radius: float = 3.2, - k_step: float = 1.0, iterations: int = 2, analyzer: Optional[callable] = analyze_phase_stepping, - optimized_reference: Optional[bool] = None): + def __init__( + self, + *, + feedback: Detector, + slm: PhaseSLM, + slm_shape=(500, 500), + phase_steps=4, + k_radius: float = 3.2, + k_step: float = 1.0, + iterations: int = 2, + analyzer: Optional[callable] = analyze_phase_stepping, + optimized_reference: Optional[bool] = None + ): """ Args: feedback (Detector): Source of feedback @@ -42,13 +52,17 @@ def __init__(self, *, feedback: Detector, slm: PhaseSLM, slm_shape=(500, 500), p self.k_step = k_step self._slm_shape = slm_shape group_mask = np.zeros(slm_shape, dtype=bool) - group_mask[:, slm_shape[1] // 2:] = True - super().__init__(feedback=feedback, slm=slm, - phase_patterns=None, group_mask=group_mask, - phase_steps=phase_steps, - iterations=iterations, - optimized_reference=optimized_reference, - analyzer=analyzer) + group_mask[:, slm_shape[1] // 2 :] = True + super().__init__( + feedback=feedback, + slm=slm, + phase_patterns=None, + group_mask=group_mask, + phase_steps=phase_steps, + iterations=iterations, + optimized_reference=optimized_reference, + analyzer=analyzer, + ) self._update_modes() def _update_modes(self): @@ -63,10 +77,11 @@ def _update_modes(self): int_radius_y = np.ceil(self.k_radius / self.k_step) kx, ky = np.meshgrid( np.arange(-int_radius_x, int_radius_x + 1) * (self.k_step * 2), - np.arange(-int_radius_y, int_radius_y + 1) * self.k_step) + np.arange(-int_radius_y, int_radius_y + 1) * self.k_step, + ) # only keep the points within the circle - mask = kx ** 2 + ky ** 2 <= self.k_radius ** 2 + mask = kx**2 + ky**2 <= self.k_radius**2 k = np.stack((ky[mask], kx[mask])).T # construct the modes for these kx ky values diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index f04df3c..1f6fd92 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -32,9 +32,18 @@ class DualReference: https://opg.optica.org/oe/ abstract.cfm?uri=oe-27-8-1167 """ - def __init__(self, *, feedback: Detector, slm: PhaseSLM, phase_patterns: Optional[tuple[nd, nd]], group_mask: nd, - phase_steps: int = 4, iterations: int = 2, - analyzer: Optional[callable] = analyze_phase_stepping, optimized_reference: Optional[bool] = None): + def __init__( + self, + *, + feedback: Detector, + slm: PhaseSLM, + phase_patterns: Optional[tuple[nd, nd]], + group_mask: nd, + phase_steps: int = 4, + iterations: int = 2, + analyzer: Optional[callable] = analyze_phase_stepping, + optimized_reference: Optional[bool] = None + ): """ Args: feedback: The feedback source, usually a detector that provides measurement data. @@ -58,9 +67,9 @@ def __init__(self, *, feedback: Detector, slm: PhaseSLM, phase_patterns: Optiona These two measurements are combined to find the final phase for stitching. When set to `None` (default), the algorithm uses True if there is a single target, and False if there are multiple targets. - analyzer: The function used to analyze the phase stepping data. - Must return a WFSResult object. Defaults to `analyze_phase_stepping` - + analyzer: The function used to analyze the phase stepping data. + Must return a WFSResult object. Defaults to `analyze_phase_stepping` + [1]: X. Tao, T. Lam, B. Zhu, et al., “Three-dimensional focusing through scattering media using conjugate adaptive optics with remote focusing (CAORF),” Opt. Express 25, 10368–10383 (2017). @@ -69,12 +78,15 @@ def __init__(self, *, feedback: Detector, slm: PhaseSLM, phase_patterns: Optiona optimized_reference = np.prod(feedback.data_shape) == 1 elif optimized_reference and np.prod(feedback.data_shape) != 1: raise ValueError( - "When using an optimized reference, the feedback detector should return a single scalar value.") + "When using an optimized reference, the feedback detector should return a single scalar value." + ) if iterations < 2: raise ValueError("The number of iterations must be at least 2.") if not optimized_reference and iterations != 2: - raise ValueError("When not using an optimized reference, the number of iterations must be 2.") + raise ValueError( + "When not using an optimized reference, the number of iterations must be 2." + ) self.slm = slm self.feedback = feedback @@ -85,7 +97,10 @@ def __init__(self, *, feedback: Detector, slm: PhaseSLM, phase_patterns: Optiona self._phase_patterns = None self._shape = group_mask.shape mask = group_mask.astype(bool) - self.masks = (~mask, mask) # mask[0] is True for group A, mask[1] is True for group B + self.masks = ( + ~mask, + mask, + ) # mask[0] is True for group A, mask[1] is True for group B self.phase_patterns = phase_patterns @property @@ -102,18 +117,35 @@ def phase_patterns(self, value): if not self.optimized_reference: # find the modes in A and B that correspond to flat wavefronts with phase 0 try: - a0_index = next(i for i in range(value[0].shape[2]) if np.allclose(value[0][:, :, i], 0)) - b0_index = next(i for i in range(value[1].shape[2]) if np.allclose(value[1][:, :, i], 0)) + a0_index = next( + i + for i in range(value[0].shape[2]) + if np.allclose(value[0][:, :, i], 0) + ) + b0_index = next( + i + for i in range(value[1].shape[2]) + if np.allclose(value[1][:, :, i], 0) + ) self.zero_indices = (a0_index, b0_index) except StopIteration: - raise ("For multi-target optimization, the both sets must contain a flat wavefront with phase 0.") + raise ( + "For multi-target optimization, the both sets must contain a flat wavefront with phase 0." + ) if (value[0].shape[0:2] != self._shape) or (value[1].shape[0:2] != self._shape): - raise ValueError("The phase patterns and group mask must all have the same shape.") + raise ValueError( + "The phase patterns and group mask must all have the same shape." + ) - self._phase_patterns = (value[0].astype(np.float32), value[1].astype(np.float32)) + self._phase_patterns = ( + value[0].astype(np.float32), + value[1].astype(np.float32), + ) - def execute(self, capture_intermediate_results: bool = False, progress_bar=None) -> WFSResult: + def execute( + self, capture_intermediate_results: bool = False, progress_bar=None + ) -> WFSResult: """ Executes the blind focusing dual reference algorithm and compute the SLM transmission matrix. capture_intermediate_results: When True, measures the feedback from the optimized wavefront after each iteration. @@ -129,19 +161,26 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) """ # Current estimate of the transmission matrix (start with all 0) - cobasis = [np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in - range(2)] + cobasis = [ + np.exp(-1j * self.phase_patterns[side]) + * np.expand_dims(self.masks[side], axis=2) + for side in range(2) + ] ref_phases = np.zeros(self._shape) # Initialize storage lists results_all = [None] * self.iterations # List to store all results - intermediate_results = np.zeros(self.iterations) # List to store feedback from full patterns + intermediate_results = np.zeros( + self.iterations + ) # List to store feedback from full patterns # Prepare progress bar if progress_bar: - num_measurements = np.ceil(self.iterations / 2) * self.phase_patterns[0].shape[2] \ - + np.floor(self.iterations / 2) * self.phase_patterns[1].shape[2] + num_measurements = ( + np.ceil(self.iterations / 2) * self.phase_patterns[0].shape[2] + + np.floor(self.iterations / 2) * self.phase_patterns[1].shape[2] + ) progress_bar.total = num_measurements # Switch the phase sets back and forth multiple times @@ -149,15 +188,21 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) side = it % 2 # pick set A or B for phase stepping side_mask = self.masks[side] # Perform WFS experiment on one side, keeping the other side sized at the ref_phases - results_all[it] = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases, - mod_mask=side_mask, progress_bar=progress_bar) + results_all[it] = self._single_side_experiment( + mod_phases=self.phase_patterns[side], + ref_phases=ref_phases, + mod_mask=side_mask, + progress_bar=progress_bar, + ) # Compute transmission matrix for the current side and update # estimated transmission matrix if self.optimized_reference: # use the best estimate so far to construct an optimized reference - t_this_side = self.compute_t_set(results_all[it].t, cobasis[side]).squeeze() + t_this_side = self.compute_t_set( + results_all[it].t, cobasis[side] + ).squeeze() ref_phases[self.masks[side]] = -np.angle(t_this_side[self.masks[side]]) # Try full pattern @@ -173,11 +218,15 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) # relative phase between the two sides, which we extract from # the measurements of the flat wavefronts. relative = results_all[0].t[self.zero_indices[0], ...] + np.conjugate( - results_all[1].t[self.zero_indices[1], ...]) - factor = (relative / np.abs(relative)).reshape((1, *self.feedback.data_shape)) + results_all[1].t[self.zero_indices[1], ...] + ) + factor = (relative / np.abs(relative)).reshape( + (1, *self.feedback.data_shape) + ) - t_full = (self.compute_t_set(results_all[0].t, cobasis[0]) + - self.compute_t_set(factor * results_all[1].t, cobasis[1])) + t_full = self.compute_t_set(results_all[0].t, cobasis[0]) + self.compute_t_set( + factor * results_all[1].t, cobasis[1] + ) # Compute average fidelity factors # subtract 1 from n, because both sets (usually) contain a flat wavefront, @@ -191,8 +240,9 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) result.intermediate_results = intermediate_results return result - def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, - progress_bar=None) -> WFSResult: + def _single_side_experiment( + self, mod_phases: nd, ref_phases: nd, mod_mask: nd, progress_bar=None + ) -> WFSResult: """ Conducts experiments on one part of the SLM. @@ -208,7 +258,9 @@ def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, WFSResult: An object containing the computed SLM transmission matrix and related data. """ num_modes = mod_phases.shape[2] - measurements = np.zeros((num_modes, self.phase_steps, *self.feedback.data_shape)) + measurements = np.zeros( + (num_modes, self.phase_steps, *self.feedback.data_shape) + ) for m in range(num_modes): phases = ref_phases.copy() diff --git a/openwfs/algorithms/ssa.py b/openwfs/algorithms/ssa.py index c86c36a..4e3eeee 100644 --- a/openwfs/algorithms/ssa.py +++ b/openwfs/algorithms/ssa.py @@ -1,6 +1,7 @@ import numpy as np -from ..core import Detector, PhaseSLM + from .utilities import analyze_phase_stepping, WFSResult +from ..core import Detector, PhaseSLM class StepwiseSequential: @@ -15,7 +16,14 @@ class StepwiseSequential: [2]: Ivo M. Vellekoop, "Feedback-based wavefront shaping," Opt. Express 23, 12189-12206 (2015) """ - def __init__(self, feedback: Detector, slm: PhaseSLM, phase_steps: int = 4, n_x: int = 4, n_y: int = None): + def __init__( + self, + feedback: Detector, + slm: PhaseSLM, + phase_steps: int = 4, + n_x: int = 4, + n_y: int = None, + ): """ This algorithm systematically modifies the phase pattern of each SLM element and measures the resulting feedback. @@ -39,8 +47,10 @@ def execute(self) -> WFSResult: Returns: WFSResult: An object containing the computed transmission matrix and statistics. """ - phase_pattern = np.zeros((self.n_y, self.n_x), 'float32') - measurements = np.zeros((self.n_y, self.n_x, self.phase_steps, *self.feedback.data_shape)) + phase_pattern = np.zeros((self.n_y, self.n_x), "float32") + measurements = np.zeros( + (self.n_y, self.n_x, self.phase_steps, *self.feedback.data_shape) + ) for y in range(self.n_y): for x in range(self.n_x): diff --git a/openwfs/algorithms/troubleshoot.py b/openwfs/algorithms/troubleshoot.py index 662f74d..4ae3cd4 100644 --- a/openwfs/algorithms/troubleshoot.py +++ b/openwfs/algorithms/troubleshoot.py @@ -47,8 +47,9 @@ def cnr(signal_with_noise: np.ndarray, noise: np.ndarray) -> np.float64: return signal_std(signal_with_noise, noise) / noise.std() -def contrast_enhancement(signal_with_noise: np.ndarray, reference_with_noise: np.ndarray, - noise: np.ndarray) -> float: +def contrast_enhancement( + signal_with_noise: np.ndarray, reference_with_noise: np.ndarray, noise: np.ndarray +) -> float: """ Compute noise corrected contrast enhancement. The noise is assumed to be uncorrelated with the signal, such that var(measured) = var(signal) + var(noise). @@ -124,7 +125,9 @@ def frame_correlation(a: np.ndarray, b: np.ndarray) -> float: return np.mean(a * b) / (np.mean(a) * np.mean(b)) - 1 -def pearson_correlation(a: np.ndarray, b: np.ndarray, noise_var: np.ndarray = 0.0) -> float: +def pearson_correlation( + a: np.ndarray, b: np.ndarray, noise_var: np.ndarray = 0.0 +) -> float: """ Compute Pearson correlation. @@ -160,9 +163,19 @@ class StabilityResult: framestack: 3D array containing all recorded frames. Is None unless saving frames was requested. """ - def __init__(self, pixel_shifts_first, correlations_first, correlations_disattenuated_first, contrast_ratios_first, - pixel_shifts_prev, correlations_prev, correlations_disattenuated_prev, contrast_ratios_prev, - abs_timestamps, framestack): + def __init__( + self, + pixel_shifts_first, + correlations_first, + correlations_disattenuated_first, + contrast_ratios_first, + pixel_shifts_prev, + correlations_prev, + correlations_disattenuated_prev, + contrast_ratios_prev, + abs_timestamps, + framestack, + ): # Comparison with first frame self.pixel_shifts_first = pixel_shifts_first self.correlations_first = correlations_first @@ -186,47 +199,66 @@ def plot(self): """ # Comparisons with first frame plt.figure() - plt.plot(self.timestamps, self.pixel_shifts_first, '.-', label='image-shift (pix)') - plt.title('Stability - Image shift w.r.t. first frame') - plt.ylabel('Image shift (pix)') - plt.xlabel('time (s)') + plt.plot( + self.timestamps, self.pixel_shifts_first, ".-", label="image-shift (pix)" + ) + plt.title("Stability - Image shift w.r.t. first frame") + plt.ylabel("Image shift (pix)") + plt.xlabel("time (s)") plt.figure() - plt.plot(self.timestamps, self.correlations_first, '.-', label='correlation') - plt.plot(self.timestamps, self.correlations_disattenuated_first, '.-', label='correlation disattenuated') - plt.title('Stability - Correlation with first frame') - plt.xlabel('time (s)') + plt.plot(self.timestamps, self.correlations_first, ".-", label="correlation") + plt.plot( + self.timestamps, + self.correlations_disattenuated_first, + ".-", + label="correlation disattenuated", + ) + plt.title("Stability - Correlation with first frame") + plt.xlabel("time (s)") plt.legend() plt.figure() - plt.plot(self.timestamps, self.contrast_ratios_first, '.-', label='contrast ratio') - plt.title('Stability - Contrast ratio with first frame') - plt.xlabel('time (s)') + plt.plot( + self.timestamps, self.contrast_ratios_first, ".-", label="contrast ratio" + ) + plt.title("Stability - Contrast ratio with first frame") + plt.xlabel("time (s)") # Comparisons with previous frame plt.figure() - plt.plot(self.timestamps, self.pixel_shifts_prev, '.-', label='image-shift (pix)') - plt.title('Stability - Image shift w.r.t. previous frame') - plt.ylabel('Image shift (pix)') - plt.xlabel('time (s)') + plt.plot( + self.timestamps, self.pixel_shifts_prev, ".-", label="image-shift (pix)" + ) + plt.title("Stability - Image shift w.r.t. previous frame") + plt.ylabel("Image shift (pix)") + plt.xlabel("time (s)") plt.figure() - plt.plot(self.timestamps, self.correlations_prev, '.-', label='correlation') - plt.plot(self.timestamps, self.correlations_disattenuated_prev, '.-', label='correlation disattenuated') - plt.title('Stability - Correlation with previous frame') - plt.xlabel('time (s)') + plt.plot(self.timestamps, self.correlations_prev, ".-", label="correlation") + plt.plot( + self.timestamps, + self.correlations_disattenuated_prev, + ".-", + label="correlation disattenuated", + ) + plt.title("Stability - Correlation with previous frame") + plt.xlabel("time (s)") plt.legend() plt.figure() - plt.plot(self.timestamps, self.contrast_ratios_prev, '.-', label='contrast ratio') - plt.title('Stability - Contrast ratio with previous frame') - plt.xlabel('time (s)') + plt.plot( + self.timestamps, self.contrast_ratios_prev, ".-", label="contrast ratio" + ) + plt.title("Stability - Contrast ratio with previous frame") + plt.xlabel("time (s)") plt.show() -def measure_setup_stability(frame_source, sleep_time_s, num_of_frames, dark_frame, - do_save_frames=False) -> StabilityResult: +def measure_setup_stability( + frame_source, sleep_time_s, num_of_frames, dark_frame, do_save_frames=False +) -> StabilityResult: """Test the setup stability by repeatedly reading frames.""" first_frame = frame_source.read() prev_frame = first_frame @@ -260,14 +292,22 @@ def measure_setup_stability(frame_source, sleep_time_s, num_of_frames, dark_fram # Compare with first frame pixel_shifts_first[n, :] = find_pixel_shift(first_frame, new_frame) correlations_first[n] = pearson_correlation(first_frame, new_frame) - correlations_disattenuated_first[n] = pearson_correlation(first_frame, new_frame, noise_var=dark_var) - contrast_ratios_first[n] = contrast_enhancement(new_frame, first_frame, dark_frame) + correlations_disattenuated_first[n] = pearson_correlation( + first_frame, new_frame, noise_var=dark_var + ) + contrast_ratios_first[n] = contrast_enhancement( + new_frame, first_frame, dark_frame + ) # Compare with previous frame pixel_shifts_prev[n, :] = find_pixel_shift(prev_frame, new_frame) correlations_prev[n] = pearson_correlation(prev_frame, new_frame) - correlations_disattenuated_prev[n] = pearson_correlation(prev_frame, new_frame, noise_var=dark_var) - contrast_ratios_prev[n] = contrast_enhancement(new_frame, prev_frame, dark_frame) + correlations_disattenuated_prev[n] = pearson_correlation( + prev_frame, new_frame, noise_var=dark_var + ) + contrast_ratios_prev[n] = contrast_enhancement( + new_frame, prev_frame, dark_frame + ) abs_timestamps[n] = time.perf_counter() # Save frame if requested @@ -276,19 +316,23 @@ def measure_setup_stability(frame_source, sleep_time_s, num_of_frames, dark_fram prev_frame = new_frame - return StabilityResult(pixel_shifts_first=pixel_shifts_first, - correlations_first=correlations_first, - correlations_disattenuated_first=correlations_disattenuated_first, - contrast_ratios_first=contrast_ratios_first, - pixel_shifts_prev=pixel_shifts_prev, - correlations_prev=correlations_prev, - correlations_disattenuated_prev=correlations_disattenuated_prev, - contrast_ratios_prev=contrast_ratios_prev, - abs_timestamps=abs_timestamps, - framestack=framestack) - - -def measure_modulated_light_dual_phase_stepping(slm: PhaseSLM, feedback: Detector, phase_steps: int, num_blocks: int): + return StabilityResult( + pixel_shifts_first=pixel_shifts_first, + correlations_first=correlations_first, + correlations_disattenuated_first=correlations_disattenuated_first, + contrast_ratios_first=contrast_ratios_first, + pixel_shifts_prev=pixel_shifts_prev, + correlations_prev=correlations_prev, + correlations_disattenuated_prev=correlations_disattenuated_prev, + contrast_ratios_prev=contrast_ratios_prev, + abs_timestamps=abs_timestamps, + framestack=framestack, + ) + + +def measure_modulated_light_dual_phase_stepping( + slm: PhaseSLM, feedback: Detector, phase_steps: int, num_blocks: int +): """ Measure the ratio of modulated light with the dual phase stepping method. @@ -323,12 +367,16 @@ def measure_modulated_light_dual_phase_stepping(slm: PhaseSLM, feedback: Detecto measurements[p, q] = feedback.read() # 2D Fourier transform the modulation measurements - f = np.fft.fft2(measurements) / phase_steps ** 2 + f = np.fft.fft2(measurements) / phase_steps**2 # Compute fidelity factor due to modulated light eps = 1e-6 # Epsilon term to prevent division by zero - m1_m2_ratio = (np.abs(f[0, 1]) ** 2 + eps) / (np.abs(f[1, 0]) ** 2 + eps) # Ratio of modulated intensities - fidelity_modulated = (1 + m1_m2_ratio) / (1 + m1_m2_ratio + np.abs(f[0, 1]) ** 2 / np.abs(f[1, -1]) ** 2) + m1_m2_ratio = (np.abs(f[0, 1]) ** 2 + eps) / ( + np.abs(f[1, 0]) ** 2 + eps + ) # Ratio of modulated intensities + fidelity_modulated = (1 + m1_m2_ratio) / ( + 1 + m1_m2_ratio + np.abs(f[0, 1]) ** 2 / np.abs(f[1, -1]) ** 2 + ) return fidelity_modulated @@ -362,7 +410,9 @@ def measure_modulated_light(slm: PhaseSLM, feedback: Detector, phase_steps: int) f = np.fft.fft(measurements) # Compute ratio of modulated light over total - fidelity_modulated = 0.5 * (1.0 + np.sqrt(np.clip(1.0 - 4.0 * np.abs(f[1] / f[0]) ** 2, 0, None))) + fidelity_modulated = 0.5 * ( + 1.0 + np.sqrt(np.clip(1.0 - 4.0 * np.abs(f[1] / f[0]) ** 2, 0, None)) + ) return fidelity_modulated @@ -438,72 +488,88 @@ def report(self, do_plots=True): Args: do_plots (bool): Plot some results as graphs. """ - print(f'\n===========================') - print(f'{time.ctime(self.timestamp)}\n') - print(f'=== Feedback metrics ===') - print(f'number of modes (N): {self.wfs_result.n:.3f}') - print(f'fidelity_amplitude: {self.wfs_result.fidelity_amplitude.squeeze():.3f}') - print(f'fidelity_noise: {self.wfs_result.fidelity_noise.squeeze():.3f}') - print(f'fidelity_non_modulated: {self.fidelity_non_modulated:.3f}') - print(f'fidelity_phase_calibration: {self.wfs_result.fidelity_calibration.squeeze():.3f}') - print(f'fidelity_decorrelation: {self.fidelity_decorrelation:.3f}') - print(f'expected enhancement: {self.expected_enhancement:.3f}') - print(f'measured enhancement: {self.measured_enhancement:.3f}') - print(f'') - print(f'=== Frame metrics ===') - print(f'signal std, before: {self.frame_signal_std_before:.2f}') - print(f'signal std, after: {self.frame_signal_std_after:.2f}') - print(f'signal std, with shaped wavefront: {self.frame_signal_std_shaped_wf:.2f}') + print(f"\n===========================") + print(f"{time.ctime(self.timestamp)}\n") + print(f"=== Feedback metrics ===") + print(f"number of modes (N): {self.wfs_result.n:.3f}") + print(f"fidelity_amplitude: {self.wfs_result.fidelity_amplitude.squeeze():.3f}") + print(f"fidelity_noise: {self.wfs_result.fidelity_noise.squeeze():.3f}") + print(f"fidelity_non_modulated: {self.fidelity_non_modulated:.3f}") + print( + f"fidelity_phase_calibration: {self.wfs_result.fidelity_calibration.squeeze():.3f}" + ) + print(f"fidelity_decorrelation: {self.fidelity_decorrelation:.3f}") + print(f"expected enhancement: {self.expected_enhancement:.3f}") + print(f"measured enhancement: {self.measured_enhancement:.3f}") + print(f"") + print(f"=== Frame metrics ===") + print(f"signal std, before: {self.frame_signal_std_before:.2f}") + print(f"signal std, after: {self.frame_signal_std_after:.2f}") + print( + f"signal std, with shaped wavefront: {self.frame_signal_std_shaped_wf:.2f}" + ) if self.dark_frame is not None: - print(f'average offset (dark frame): {self.dark_frame.mean():.2f}') - print(f'median offset (dark frame): {np.median(self.dark_frame):.2f}') - print(f'noise std (dark frame): {np.std(self.dark_frame):.2f}') - print(f'frame repeatability: {self.frame_repeatability:.3f}') - print(f'contrast to noise ratio before: {self.frame_cnr_before:.3f}') - print(f'contrast to noise ratio after: {self.frame_cnr_after:.3f}') - print(f'contrast to noise ratio with shaped wavefront: {self.frame_cnr_shaped_wf:.3f}') - print(f'contrast enhancement: {self.frame_contrast_enhancement:.3f}') - print(f'photobleaching ratio: {self.frame_photobleaching_ratio:.3f}') + print(f"average offset (dark frame): {self.dark_frame.mean():.2f}") + print(f"median offset (dark frame): {np.median(self.dark_frame):.2f}") + print(f"noise std (dark frame): {np.std(self.dark_frame):.2f}") + print(f"frame repeatability: {self.frame_repeatability:.3f}") + print(f"contrast to noise ratio before: {self.frame_cnr_before:.3f}") + print(f"contrast to noise ratio after: {self.frame_cnr_after:.3f}") + print( + f"contrast to noise ratio with shaped wavefront: {self.frame_cnr_shaped_wf:.3f}" + ) + print(f"contrast enhancement: {self.frame_contrast_enhancement:.3f}") + print(f"photobleaching ratio: {self.frame_photobleaching_ratio:.3f}") if do_plots and self.stability is not None: self.stability.plot() - if (do_plots and self.dark_frame is not None and self.after_frame is not None and - self.shaped_wf_frame is not None): + if ( + do_plots + and self.dark_frame is not None + and self.after_frame is not None + and self.shaped_wf_frame is not None + ): max_value = max(self.after_frame.max(), self.shaped_wf_frame.max()) # Plot dark frame plt.figure() plt.imshow(self.dark_frame, vmin=0, vmax=max_value) - plt.title('Dark frame') + plt.title("Dark frame") plt.colorbar() - plt.xlabel('x (pix)') - plt.ylabel('y (pix)') + plt.xlabel("x (pix)") + plt.ylabel("y (pix)") plt.figure() # Plot after frame with flat wf plt.imshow(self.after_frame, vmin=0, vmax=max_value) - plt.title('Frame with flat wavefront') + plt.title("Frame with flat wavefront") plt.colorbar() - plt.xlabel('x (pix)') - plt.ylabel('y (pix)') + plt.xlabel("x (pix)") + plt.ylabel("y (pix)") # Plot shaped wf frame plt.figure() plt.imshow(self.shaped_wf_frame, vmin=0, vmax=max_value) - plt.title('Frame with shaped wavefront') + plt.title("Frame with shaped wavefront") plt.colorbar() - plt.xlabel('x (pix)') - plt.ylabel('y (pix)') + plt.xlabel("x (pix)") + plt.ylabel("y (pix)") plt.show() -def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detector, shutter, - do_frame_capture=True, do_long_stability_test=False, - stability_sleep_time_s=0.5, - stability_num_of_frames=500, - stability_do_save_frames=False, - measure_non_modulated_phase_steps=16) -> WFSTroubleshootResult: +def troubleshoot( + algorithm, + background_feedback: Detector, + frame_source: Detector, + shutter, + do_frame_capture=True, + do_long_stability_test=False, + stability_sleep_time_s=0.5, + stability_num_of_frames=500, + stability_do_save_frames=False, + measure_non_modulated_phase_steps=16, +) -> WFSTroubleshootResult: """ Run a series of basic checks to find common sources of error in a WFS experiment. Quantifies several types of fidelity reduction. @@ -532,7 +598,7 @@ def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detecto trouble = WFSTroubleshootResult() if do_frame_capture: - logging.info('Capturing frames before WFS...') + logging.info("Capturing frames before WFS...") # Capture frames before WFS algorithm.slm.set_phases(0.0) # Flat wavefront @@ -543,12 +609,16 @@ def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detecto before_frame_2 = frame_source.read() # Frame metrics - trouble.frame_signal_std_before = signal_std(trouble.before_frame, trouble.dark_frame) + trouble.frame_signal_std_before = signal_std( + trouble.before_frame, trouble.dark_frame + ) trouble.frame_cnr_before = cnr(trouble.before_frame, trouble.dark_frame) - trouble.frame_repeatability = pearson_correlation(trouble.before_frame, before_frame_2) + trouble.frame_repeatability = pearson_correlation( + trouble.before_frame, before_frame_2 + ) if do_long_stability_test and do_frame_capture: - logging.info('Run long stability test...') + logging.info("Run long stability test...") # Test setup stability trouble.stability = measure_setup_stability( @@ -556,12 +626,13 @@ def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detecto sleep_time_s=stability_sleep_time_s, num_of_frames=stability_num_of_frames, dark_frame=trouble.dark_frame, - do_save_frames=stability_do_save_frames) + do_save_frames=stability_do_save_frames, + ) trouble.feedback_before = algorithm.feedback.read() # WFS experiment - logging.info('Run WFS algorithm...') + logging.info("Run WFS algorithm...") trouble.wfs_result = algorithm.execute() # Execute WFS algorithm # Flat wavefront @@ -570,39 +641,61 @@ def troubleshoot(algorithm, background_feedback: Detector, frame_source: Detecto trouble.feedback_after = algorithm.feedback.read() if do_frame_capture: - logging.info('Capturing frames after WFS...') + logging.info("Capturing frames after WFS...") trouble.after_frame = frame_source.read() # After frame (flat wf) # Shaped wavefront algorithm.slm.set_phases(-np.angle(trouble.wfs_result.t)) trouble.feedback_shaped_wf = algorithm.feedback.read() - trouble.measured_enhancement = trouble.feedback_shaped_wf / trouble.average_background + trouble.measured_enhancement = ( + trouble.feedback_shaped_wf / trouble.average_background + ) if do_frame_capture: trouble.shaped_wf_frame = frame_source.read() # Shaped wavefront frame # Frame metrics - logging.info('Compute frame metrics...') - trouble.frame_signal_std_after = signal_std(trouble.after_frame, trouble.dark_frame) - trouble.frame_signal_std_shaped_wf = signal_std(trouble.shaped_wf_frame, trouble.dark_frame) - trouble.frame_cnr_after = cnr(trouble.after_frame, trouble.dark_frame) # Frame CNR after - trouble.frame_cnr_shaped_wf = cnr(trouble.shaped_wf_frame, trouble.dark_frame) # Frame CNR shaped wf - trouble.frame_contrast_enhancement = \ - contrast_enhancement(trouble.shaped_wf_frame, trouble.after_frame, trouble.dark_frame) - trouble.frame_photobleaching_ratio = \ - contrast_enhancement(trouble.after_frame, trouble.before_frame, trouble.dark_frame) - trouble.fidelity_decorrelation = \ - pearson_correlation(trouble.before_frame, trouble.after_frame, noise_var=trouble.dark_frame.var()) - - trouble.fidelity_non_modulated = \ - measure_modulated_light(slm=algorithm.slm, feedback=algorithm.feedback, - phase_steps=measure_non_modulated_phase_steps) + logging.info("Compute frame metrics...") + trouble.frame_signal_std_after = signal_std( + trouble.after_frame, trouble.dark_frame + ) + trouble.frame_signal_std_shaped_wf = signal_std( + trouble.shaped_wf_frame, trouble.dark_frame + ) + trouble.frame_cnr_after = cnr( + trouble.after_frame, trouble.dark_frame + ) # Frame CNR after + trouble.frame_cnr_shaped_wf = cnr( + trouble.shaped_wf_frame, trouble.dark_frame + ) # Frame CNR shaped wf + trouble.frame_contrast_enhancement = contrast_enhancement( + trouble.shaped_wf_frame, trouble.after_frame, trouble.dark_frame + ) + trouble.frame_photobleaching_ratio = contrast_enhancement( + trouble.after_frame, trouble.before_frame, trouble.dark_frame + ) + trouble.fidelity_decorrelation = pearson_correlation( + trouble.before_frame, + trouble.after_frame, + noise_var=trouble.dark_frame.var(), + ) + + trouble.fidelity_non_modulated = measure_modulated_light( + slm=algorithm.slm, + feedback=algorithm.feedback, + phase_steps=measure_non_modulated_phase_steps, + ) trouble.expected_enhancement = np.squeeze( - trouble.wfs_result.n * trouble.wfs_result.fidelity_amplitude * trouble.wfs_result.fidelity_noise - * trouble.fidelity_non_modulated * trouble.wfs_result.fidelity_calibration * trouble.fidelity_decorrelation) + trouble.wfs_result.n + * trouble.wfs_result.fidelity_amplitude + * trouble.wfs_result.fidelity_noise + * trouble.fidelity_non_modulated + * trouble.wfs_result.fidelity_calibration + * trouble.fidelity_decorrelation + ) # Analyze the WFS result - logging.info('Analyze WFS result...') + logging.info("Analyze WFS result...") return trouble diff --git a/openwfs/algorithms/utilities.py b/openwfs/algorithms/utilities.py index abd0ea4..f79b7f8 100644 --- a/openwfs/algorithms/utilities.py +++ b/openwfs/algorithms/utilities.py @@ -27,15 +27,17 @@ class WFSResult: This is the offset that is caused by a bias in the detector signal, stray light, etc. Default value: 0.0. """ - def __init__(self, - t: np.ndarray, - t_f: np.ndarray, - axis: int, - fidelity_noise: ArrayLike, - fidelity_amplitude: ArrayLike, - fidelity_calibration: ArrayLike, - n: Optional[int] = None, - intensity_offset: Optional[ArrayLike] = 0.0): + def __init__( + self, + t: np.ndarray, + t_f: np.ndarray, + axis: int, + fidelity_noise: ArrayLike, + fidelity_amplitude: ArrayLike, + fidelity_calibration: ArrayLike, + n: Optional[int] = None, + intensity_offset: Optional[ArrayLike] = 0.0, + ): """ Args: t(ndarray): measured transmission matrix. @@ -65,20 +67,39 @@ def __init__(self, self.fidelity_amplitude = np.atleast_1d(fidelity_amplitude) self.fidelity_calibration = np.atleast_1d(fidelity_calibration) self.estimated_enhancement = np.atleast_1d( - 1.0 + (self.n - 1) * self.fidelity_amplitude * self.fidelity_noise * self.fidelity_calibration) - self.intensity_offset = intensity_offset * np.ones(self.fidelity_calibration.shape) if np.isscalar( - intensity_offset) \ + 1.0 + + (self.n - 1) + * self.fidelity_amplitude + * self.fidelity_noise + * self.fidelity_calibration + ) + self.intensity_offset = ( + intensity_offset * np.ones(self.fidelity_calibration.shape) + if np.isscalar(intensity_offset) else intensity_offset - after = np.sum(np.abs(t), tuple( - range(self.axis))) ** 2 * self.fidelity_noise * self.fidelity_calibration + intensity_offset + ) + after = ( + np.sum(np.abs(t), tuple(range(self.axis))) ** 2 + * self.fidelity_noise + * self.fidelity_calibration + + intensity_offset + ) self.estimated_optimized_intensity = np.atleast_1d(after) def __str__(self) -> str: - noise_warning = "OK" if self.fidelity_noise > 0.5 else "WARNING low signal quality." - amplitude_warning = "OK" if self.fidelity_amplitude > 0.5 else "WARNING uneven contribution of optical modes." - calibration_fidelity_warning = "OK" if self.fidelity_calibration > 0.5 else ( - "WARNING non-linear phase response, check " - "lookup table.") + noise_warning = ( + "OK" if self.fidelity_noise > 0.5 else "WARNING low signal quality." + ) + amplitude_warning = ( + "OK" + if self.fidelity_amplitude > 0.5 + else "WARNING uneven contribution of optical modes." + ) + calibration_fidelity_warning = ( + "OK" + if self.fidelity_calibration > 0.5 + else ("WARNING non-linear phase response, check " "lookup table.") + ) return f""" Wavefront shaping results: fidelity_noise: {self.fidelity_noise} {noise_warning} @@ -88,7 +109,7 @@ def __str__(self) -> str: estimated_optimized_intensity: {self.estimated_optimized_intensity} """ - def select_target(self, b) -> 'WFSResult': + def select_target(self, b) -> "WFSResult": """ Returns the wavefront shaping results for a single target @@ -98,18 +119,19 @@ def select_target(self, b) -> 'WFSResult': Returns: WFSResults data for the specified target """ - return WFSResult(t=self.t.reshape((*self.t.shape[0:2], -1))[:, :, b], - t_f=self.t_f.reshape((*self.t_f.shape[0:2], -1))[:, :, b], - axis=self.axis, - intensity_offset=self.intensity_offset[:][b], - fidelity_noise=self.fidelity_noise[:][b], - fidelity_amplitude=self.fidelity_amplitude[:][b], - fidelity_calibration=self.fidelity_calibration[:][b], - n=self.n, - ) + return WFSResult( + t=self.t.reshape((*self.t.shape[0:2], -1))[:, :, b], + t_f=self.t_f.reshape((*self.t_f.shape[0:2], -1))[:, :, b], + axis=self.axis, + intensity_offset=self.intensity_offset[:][b], + fidelity_noise=self.fidelity_noise[:][b], + fidelity_amplitude=self.fidelity_amplitude[:][b], + fidelity_calibration=self.fidelity_calibration[:][b], + n=self.n, + ) @staticmethod - def combine(results: Sequence['WFSResult']): + def combine(results: Sequence["WFSResult"]): """Merges the results for several sub-experiments. Currently, this just computes the average of the fidelities, weighted @@ -129,16 +151,20 @@ def weighted_average(attribute): data += getattr(r, attribute) * r.n / n return data - return WFSResult(t=weighted_average('t'), - t_f=weighted_average('t_f'), - n=n, - axis=axis, - fidelity_noise=weighted_average('fidelity_noise'), - fidelity_amplitude=weighted_average('fidelity_amplitude'), - fidelity_calibration=weighted_average('fidelity_calibration')) + return WFSResult( + t=weighted_average("t"), + t_f=weighted_average("t_f"), + n=n, + axis=axis, + fidelity_noise=weighted_average("fidelity_noise"), + fidelity_amplitude=weighted_average("fidelity_amplitude"), + fidelity_calibration=weighted_average("fidelity_calibration"), + ) -def analyze_phase_stepping(measurements: np.ndarray, axis: int, A: Optional[float] = None): +def analyze_phase_stepping( + measurements: np.ndarray, axis: int, A: Optional[float] = None +): """Analyzes the result of phase stepping measurements, returning matrix `t` and noise statistics This function assumes that all measurements were made using the same reference field `A` @@ -195,7 +221,9 @@ def analyze_phase_stepping(measurements: np.ndarray, axis: int, A: Optional[floa # compute the effect of amplitude variations. # for perfectly developed speckle, and homogeneous illumination, this factor will be pi/4 - amplitude_factor = np.mean(np.abs(t), segments) ** 2 / np.mean(np.abs(t) ** 2, segments) + amplitude_factor = np.mean(np.abs(t), segments) ** 2 / np.mean( + np.abs(t) ** 2, segments + ) # estimate the calibration error # we first construct a matrix that can be used to fit @@ -224,15 +252,26 @@ def analyze_phase_stepping(measurements: np.ndarray, axis: int, A: Optional[floa if phase_steps > 3: # estimate the noise energy as the energy that is not explained # by the signal or the offset. - noise_energy = (total_energy - signal_energy - offset_energy) / (phase_steps - 3) - noise_factor = np.abs(np.maximum(signal_energy - 2 * noise_energy, 0.0) / signal_energy) + noise_energy = (total_energy - signal_energy - offset_energy) / ( + phase_steps - 3 + ) + noise_factor = np.abs( + np.maximum(signal_energy - 2 * noise_energy, 0.0) / signal_energy + ) else: noise_factor = 1.0 # cannot estimate reliably calibration_fidelity = np.abs(c[1]) ** 2 / np.sum(np.abs(c[1:]) ** 2) - return WFSResult(t, t_f=t_f, axis=axis, fidelity_amplitude=amplitude_factor, fidelity_noise=noise_factor, - fidelity_calibration=calibration_fidelity, n=n) + return WFSResult( + t, + t_f=t_f, + axis=axis, + fidelity_amplitude=amplitude_factor, + fidelity_noise=noise_factor, + fidelity_calibration=calibration_fidelity, + n=n, + ) class WFSController: @@ -298,7 +337,9 @@ def wavefront(self, value): self._amplitude_factor = result.fidelity_amplitude self._estimated_enhancement = result.estimated_enhancement self._calibration_fidelity = result.fidelity_calibration - self._estimated_optimized_intensity = result.estimated_optimized_intensity + self._estimated_optimized_intensity = ( + result.estimated_optimized_intensity + ) self._snr = 1.0 / (1.0 / result.fidelity_noise - 1.0) self._result = result self.algorithm.slm.set_phases(self._optimized_wavefront) @@ -357,12 +398,12 @@ def snr(self) -> float: @property def recompute_wavefront(self) -> bool: - """Returns: bool that indicates whether the wavefront needs to be recomputed. """ + """Returns: bool that indicates whether the wavefront needs to be recomputed.""" return self._recompute_wavefront @recompute_wavefront.setter def recompute_wavefront(self, value): - """Sets the bool that indicates whether the wavefront needs to be recomputed. """ + """Sets the bool that indicates whether the wavefront needs to be recomputed.""" self._recompute_wavefront = value @property @@ -389,6 +430,8 @@ def test_wavefront(self, value): feedback_flat = self.algorithm.feedback.read().copy() self.wavefront = WFSController.State.SHAPED_WAVEFRONT feedback_shaped = self.algorithm.feedback.read().copy() - self._feedback_enhancement = float(feedback_shaped.sum() / feedback_flat.sum()) + self._feedback_enhancement = float( + feedback_shaped.sum() / feedback_flat.sum() + ) self._test_wavefront = value diff --git a/openwfs/core.py b/openwfs/core.py index 3aba36a..3151216 100644 --- a/openwfs/core.py +++ b/openwfs/core.py @@ -17,12 +17,21 @@ class Device(ABC): """Base class for detectors and actuators - See :ref:`key_concepts` for more information. + See :ref:`key_concepts` for more information. """ - __slots__ = ('_end_time_ns', '_timeout_margin', '_locking_thread', '_error', - '__weakref__', '_latency', '_duration', '_multi_threaded') - _workers = ThreadPoolExecutor(thread_name_prefix='Device._workers') + + __slots__ = ( + "_end_time_ns", + "_timeout_margin", + "_locking_thread", + "_error", + "__weakref__", + "_latency", + "_duration", + "_multi_threaded", + ) + _workers = ThreadPoolExecutor(thread_name_prefix="Device._workers") _moving = False _state_lock = threading.Lock() _devices: "Set[Device]" = WeakSet() @@ -62,15 +71,25 @@ def _start(self): else: logging.debug("switch to MOVING requested by %s.", self) - same_type = [device for device in Device._devices if device._is_actuator == self._is_actuator] - other_type = [device for device in Device._devices if device._is_actuator != self._is_actuator] + same_type = [ + device + for device in Device._devices + if device._is_actuator == self._is_actuator + ] + other_type = [ + device + for device in Device._devices + if device._is_actuator != self._is_actuator + ] # compute the minimum latency of same_type # for instance, when switching to 'measuring', this number tells us how long it takes before any of the # detectors actually starts a measurement. # If this is a positive number, we can make the switch to 'measuring' slightly _before_ # all actuators have stabilized. - latency = min([device.latency for device in same_type], default=0.0 * u.ns) # noqa - incorrect warning + latency = min( + [device.latency for device in same_type], default=0.0 * u.ns + ) # noqa - incorrect warning # wait until all devices of the other type have (almost) finished for device in other_type: @@ -85,7 +104,11 @@ def _start(self): # also store the time we expect the operation to finish # note: it may finish slightly earlier since (latency + duration) is a maximum value - self._end_time_ns = time.time_ns() + self.latency.to_value(u.ns) + self.duration.to_value(u.ns) + self._end_time_ns = ( + time.time_ns() + + self.latency.to_value(u.ns) + + self.duration.to_value(u.ns) + ) @property def latency(self) -> Quantity[u.ms]: @@ -173,13 +196,15 @@ def wait(self, up_to: Optional[Quantity[u.ms]] = None) -> None: while self.busy(): time.sleep(0.01) if time.time_ns() - start > timeout: - raise TimeoutError("Timeout in %s (tid %i)", self, threading.get_ident()) + raise TimeoutError( + "Timeout in %s (tid %i)", self, threading.get_ident() + ) else: time_to_wait = self._end_time_ns - time.time_ns() if up_to is not None: time_to_wait -= up_to.to_value(u.ns) if time_to_wait > 0: - time.sleep(time_to_wait / 1.0E9) + time.sleep(time_to_wait / 1.0e9) def busy(self) -> bool: """Returns true if the device is measuring or moving (see `wait()`). @@ -210,8 +235,8 @@ def timeout(self, value): class Actuator(Device, ABC): - """Base class for all actuators - """ + """Base class for all actuators""" + __slots__ = () @final @@ -224,10 +249,23 @@ class Detector(Device, ABC): See :numref:`Detectors` in the documentation for more information. """ - __slots__ = ('_measurements_pending', '_lock_condition', '_pixel_size', '_data_shape') - def __init__(self, *, data_shape: Optional[tuple[int, ...]], pixel_size: Optional[Quantity], - duration: Optional[Quantity[u.ms]], latency: Optional[Quantity[u.ms]], multi_threaded: bool = True): + __slots__ = ( + "_measurements_pending", + "_lock_condition", + "_pixel_size", + "_data_shape", + ) + + def __init__( + self, + *, + data_shape: Optional[tuple[int, ...]], + pixel_size: Optional[Quantity], + duration: Optional[Quantity[u.ms]], + latency: Optional[Quantity[u.ms]], + multi_threaded: bool = True + ): """ Constructor for the Detector class. @@ -362,11 +400,19 @@ def __do_fetch(self, out_, *args_, **kwargs_): """Helper function that awaits all futures in the keyword argument list, and then calls _fetch""" try: if len(args_) > 0 or len(kwargs_) > 0: - logging.debug("awaiting inputs for %s (tid: %i).", self, threading.get_ident()) - awaited_args = [(arg.result() if isinstance(arg, Future) else arg) for arg in args_] - awaited_kwargs = {key: (arg.result() if isinstance(arg, Future) else arg) for (key, arg) in - kwargs_.items()} - logging.debug("fetching data of %s ((tid: %i)).", self, threading.get_ident()) + logging.debug( + "awaiting inputs for %s (tid: %i).", self, threading.get_ident() + ) + awaited_args = [ + (arg.result() if isinstance(arg, Future) else arg) for arg in args_ + ] + awaited_kwargs = { + key: (arg.result() if isinstance(arg, Future) else arg) + for (key, arg) in kwargs_.items() + } + logging.debug( + "fetching data of %s ((tid: %i)).", self, threading.get_ident() + ) data = self._fetch(*awaited_args, **awaited_kwargs) data = set_pixel_size(data, self.pixel_size) assert data.shape == self.data_shape @@ -404,7 +450,7 @@ def __setattr__(self, key, value): """ # note: the check needs to be in this order, otherwise we cannot initialize set _multi_threaded - if not key.startswith('_') and self._multi_threaded: + if not key.startswith("_") and self._multi_threaded: with self._lock_condition: while self._measurements_pending > 0: self._lock_condition.wait() @@ -477,10 +523,16 @@ def coordinates(self, dimension: int) -> Quantity: Args: dimension: Dimension for which to return the coordinates. """ - unit = u.dimensionless_unscaled if self.pixel_size is None else self.pixel_size[dimension] + unit = ( + u.dimensionless_unscaled + if self.pixel_size is None + else self.pixel_size[dimension] + ) shape = np.ones_like(self.data_shape) shape[dimension] = self.data_shape[dimension] - return np.arange(0.5, 0.5 + self.data_shape[dimension], 1.0).reshape(shape) * unit + return ( + np.arange(0.5, 0.5 + self.data_shape[dimension], 1.0).reshape(shape) * unit + ) @final @property @@ -520,19 +572,30 @@ def __init__(self, *args, multi_threaded: bool): # when the settings of one of the source detectors is changed. # Therefore, we pass 'None' for all parameters, and override # data_shape, pixel_size, duration and latency in the properties. - super().__init__(data_shape=None, pixel_size=None, duration=None, latency=None, multi_threaded=multi_threaded) + super().__init__( + data_shape=None, + pixel_size=None, + duration=None, + latency=None, + multi_threaded=multi_threaded, + ) def trigger(self, *args, immediate=False, **kwargs): """Triggers all sources at the same time (regardless of latency), and schedules a call to `_fetch()`""" - future_data = [(source.trigger(immediate=immediate) if source is not None else None) for source in - self._sources] + future_data = [ + (source.trigger(immediate=immediate) if source is not None else None) + for source in self._sources + ] return super().trigger(*future_data, *args, **kwargs) @final @property def latency(self) -> Quantity[u.ms]: """Returns the shortest latency for all detectors.""" - return min((source.latency for source in self._sources if source is not None), default=0.0 * u.ms) + return min( + (source.latency for source in self._sources if source is not None), + default=0.0 * u.ms, + ) @final @property @@ -543,11 +606,16 @@ def duration(self) -> Quantity[u.ms]: Note that `latency` is allowed to vary over time for devices that can only be triggered periodically, so this `duration` may also vary over time. """ - times = [(source.duration, source.latency) for source in self._sources if source is not None] + times = [ + (source.duration, source.latency) + for source in self._sources + if source is not None + ] if len(times) == 0: return 0.0 * u.ms - return (max([duration + latency for (duration, latency) in times]) - - min([latency for (duration, latency) in times])) + return max([duration + latency for (duration, latency) in times]) - min( + [latency for (duration, latency) in times] + ) @property def data_shape(self): @@ -561,8 +629,8 @@ def pixel_size(self) -> Optional[Quantity]: class PhaseSLM(ABC): - """Base class for phase-only SLMs - """ + """Base class for phase-only SLMs""" + __slots__ = () @abstractmethod diff --git a/openwfs/devices/camera.py b/openwfs/devices/camera.py index c28a7d9..3e35550 100644 --- a/openwfs/devices/camera.py +++ b/openwfs/devices/camera.py @@ -13,7 +13,8 @@ ```pip install harvesters``` Alternatively, specify the genicam dependency when installing openwfs: ```pip install openwfs[genicam]``` - """) + """ + ) from ..core import Detector @@ -42,29 +43,37 @@ class Camera(Detector): >>> camera = Camera(cti_file=R"C:\\Program Files\\Basler\\pylon 7\\Runtime\\x64\\ProducerU3V.cti") >>> camera.exposure_time = 10 * u.ms >>> frame = camera.read() + """ + + def __init__( + self, + cti_file: str, + serial_number: Optional[str] = None, + multi_threaded=True, + **kwargs, + ): """ - - def __init__(self, cti_file: str, serial_number: Optional[str] = None, multi_threaded=True, **kwargs): - """ - Initialize the Camera object. - - Args: - cti_file: The path to the GenTL producer file. - This path depends on where the driver for the camera is installed. - For Basler cameras, this is typically located in - R"C:\\Program Files\\Basler\\pylon 7\\Runtime\\x64\\ProducerU3V.cti". - - serial_number: The serial number of the camera. - When omitted, the first camera found is selected. - **kwargs: Additional keyword arguments. - These arguments are transferred to the node map of the camera. + Initialize the Camera object. + + Args: + cti_file: The path to the GenTL producer file. + This path depends on where the driver for the camera is installed. + For Basler cameras, this is typically located in + R"C:\\Program Files\\Basler\\pylon 7\\Runtime\\x64\\ProducerU3V.cti". + + serial_number: The serial number of the camera. + When omitted, the first camera found is selected. + **kwargs: Additional keyword arguments. + These arguments are transferred to the node map of the camera. """ self._harvester = Harvester() self._harvester.add_file(cti_file, check_validity=True) self._harvester.update() # open the camera, use the serial_number to select the camera if it is specified. - search_key = {'serial_number': serial_number} if serial_number is not None else None + search_key = ( + {"serial_number": serial_number} if serial_number is not None else None + ) self._camera = self._harvester.create(search_key=search_key) nodes = self._camera.remote_device.node_map @@ -72,10 +81,10 @@ def __init__(self, cti_file: str, serial_number: Optional[str] = None, multi_thr # set triggering to 'Software', so that we can trigger the camera by calling `trigger`. # turn off auto exposure so that `duration` accurately reflects the required measurement time. - nodes.TriggerMode.value = 'On' - nodes.TriggerSource.value = 'Software' - nodes.ExposureMode.value = 'Timed' - nodes.ExposureAuto.value = 'Off' + nodes.TriggerMode.value = "On" + nodes.TriggerSource.value = "Software" + nodes.ExposureMode.value = "Timed" + nodes.ExposureAuto.value = "Off" nodes.BinningHorizontal.value = 1 nodes.BinningVertical.value = 1 nodes.OffsetX.value = 0 @@ -104,22 +113,30 @@ def __init__(self, cti_file: str, serial_number: Optional[str] = None, multi_thr try: setattr(nodes, key, value) except AttributeError: - print(f'Warning: could not set camera property {key} to {value}') + print(f"Warning: could not set camera property {key} to {value}") try: - pixel_size = [nodes.SensorPixelHeight.value, nodes.SensorPixelWidth.value] * u.um + pixel_size = [ + nodes.SensorPixelHeight.value, + nodes.SensorPixelWidth.value, + ] * u.um except AttributeError: # the SensorPixelWidth feature is optional pixel_size = None - super().__init__(multi_threaded=multi_threaded, data_shape=None, pixel_size=pixel_size, duration=None, - latency=0.0 * u.ms) + super().__init__( + multi_threaded=multi_threaded, + data_shape=None, + pixel_size=pixel_size, + duration=None, + latency=0.0 * u.ms, + ) self._camera.start() def __del__(self): - if hasattr(self, '_camera'): + if hasattr(self, "_camera"): self._camera.stop() self._camera.destroy() - if hasattr(self, '_harvester'): + if hasattr(self, "_harvester"): self._harvester.reset() def _do_trigger(self): @@ -138,7 +155,7 @@ def _fetch(self, *args, **kwargs) -> np.ndarray: buffer = self._camera.fetch() frame = buffer.payload.components[0].data.reshape(self.data_shape) if frame.size == 0: - raise Exception('Camera returned an empty frame') + raise Exception("Camera returned an empty frame") data = frame.copy() buffer.queue() # give back buffer to the camera driver return data diff --git a/openwfs/devices/galvo_scanner.py b/openwfs/devices/galvo_scanner.py index 8fd2e7e..0b5db7c 100644 --- a/openwfs/devices/galvo_scanner.py +++ b/openwfs/devices/galvo_scanner.py @@ -19,7 +19,8 @@ ```pip install nidaqmx``` Alternatively, specify the genicam dependency when installing openwfs: ```pip install openwfs[nidaq]``` - """) + """ + ) from ..core import Detector from ..utilities import unitless @@ -36,6 +37,7 @@ class InputChannel: terminal_configuration: The terminal configuration of the channel, defaults to `TerminalConfiguration.DEFAULT` """ + channel: str v_min: Quantity[u.V] v_max: Quantity[u.V] @@ -65,11 +67,12 @@ class Axis: terminal_configuration: The terminal configuration of the channel, defaults to `TerminalConfiguration.DEFAULT` """ + channel: str v_min: Quantity[u.V] v_max: Quantity[u.V] scale: Quantity[u.um / u.V] - maximum_acceleration: Quantity[u.V / u.s ** 2] + maximum_acceleration: Quantity[u.V / u.s**2] terminal_configuration: TerminalConfiguration = TerminalConfiguration.DEFAULT def to_volt(self, pos: Union[np.ndarray, float]) -> Quantity[u.V]: @@ -103,16 +106,22 @@ def maximum_scan_speed(self, linear_range: float): Quantity[u.V / u.s]: maximum scan speed """ # x = 0.5 · a · t² = 0.5 (v_max - v_min) · (1 - linear_range) - t_accel = np.sqrt((self.v_max - self.v_min) * (1 - linear_range) / self.maximum_acceleration) + t_accel = np.sqrt( + (self.v_max - self.v_min) * (1 - linear_range) / self.maximum_acceleration + ) hardware_limit = t_accel * self.maximum_acceleration # t_linear = linear_range · (v_max - v_min) / maximum_speed # t_accel = maximum_speed / maximum_acceleration # 0.5·t_linear == t_accel => 0.5·linear_range · (v_max-v_min) · maximum_acceleration = maximum_speed² - practical_limit = np.sqrt(0.5 * linear_range * (self.v_max - self.v_min) * self.maximum_acceleration) + practical_limit = np.sqrt( + 0.5 * linear_range * (self.v_max - self.v_min) * self.maximum_acceleration + ) return np.minimum(hardware_limit, practical_limit) - def step(self, start: float, stop: float, sample_rate: Quantity[u.Hz]) -> Quantity[u.V]: + def step( + self, start: float, stop: float, sample_rate: Quantity[u.Hz] + ) -> Quantity[u.V]: """ Generate a voltage sequence to move from `start` to `stop` in the fastest way possible. @@ -138,15 +147,23 @@ def step(self, start: float, stop: float, sample_rate: Quantity[u.Hz]) -> Quanti # `t` is measured in samples # `a` is measured in volt/sample² - a = self.maximum_acceleration / sample_rate ** 2 * np.sign(v_end - v_start) + a = self.maximum_acceleration / sample_rate**2 * np.sign(v_end - v_start) t_total = unitless(2.0 * np.sqrt((v_end - v_start) / a)) - t = np.arange(np.ceil(t_total + 1E-6)) # add a small number to deal with case t=0 (start=end) - v_accel = v_start + 0.5 * a * t[:len(t) // 2] ** 2 # acceleration part - v_decel = v_end - 0.5 * a * (t_total - t[len(t) // 2:]) ** 2 # deceleration part + t = np.arange( + np.ceil(t_total + 1e-6) + ) # add a small number to deal with case t=0 (start=end) + v_accel = v_start + 0.5 * a * t[: len(t) // 2] ** 2 # acceleration part + v_decel = ( + v_end - 0.5 * a * (t_total - t[len(t) // 2 :]) ** 2 + ) # deceleration part v_decel[-1] = v_end # fix last point because t may be > t_total due to rounding - return np.clip(np.concatenate((v_accel, v_decel)), self.v_min, self.v_max) # noqa ignore incorrect type warning + return np.clip( + np.concatenate((v_accel, v_decel)), self.v_min, self.v_max + ) # noqa ignore incorrect type warning - def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quantity[u.Hz]): + def scan( + self, start: float, stop: float, sample_count: int, sample_rate: Quantity[u.Hz] + ): """ Generate a voltage sequence to scan with a constant velocity from start to stop, including acceleration and deceleration. @@ -172,7 +189,12 @@ def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quanti """ v_start = self.to_volt(start) if start == stop: # todo: tolerance? - return np.ones((sample_count,)) * v_start, start, start, slice(0, sample_count) + return ( + np.ones((sample_count,)) * v_start, + start, + start, + slice(0, sample_count), + ) v_end = self.to_volt(stop) scan_speed = (v_end - v_start) / sample_count # V per sample @@ -181,9 +203,13 @@ def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quanti # we start by constructing a sequence with a maximum acceleration. # This sequence may be up to 1 sample longer than needed to reach the scan speed. # This last sample is replaced by movement at a linear scan speed - a = self.maximum_acceleration / sample_rate ** 2 * np.sign(scan_speed) # V per sample² + a = ( + self.maximum_acceleration / sample_rate**2 * np.sign(scan_speed) + ) # V per sample² t_launch = np.arange(np.ceil(unitless(scan_speed / a))) # in samples - v_accel = 0.5 * a * t_launch ** 2 # last sample may have faster scan speed than needed + v_accel = ( + 0.5 * a * t_launch**2 + ) # last sample may have faster scan speed than needed if len(v_accel) > 1 and np.abs(v_accel[-1] - v_accel[-2]) > np.abs(scan_speed): v_accel[-1] = v_accel[-2] + scan_speed v_launch = v_start - v_accel[-1] - 0.5 * scan_speed # launch point @@ -200,8 +226,13 @@ def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quanti return v, launch, land, slice(len(v_accel), len(v_accel) + sample_count) @staticmethod - def compute_scale(*, optical_deflection: Quantity[u.deg / u.V], galvo_to_pupil_magnification: float, - objective_magnification: float, reference_tube_lens: Quantity[u.mm]) -> Quantity[u.um / u.V]: + def compute_scale( + *, + optical_deflection: Quantity[u.deg / u.V], + galvo_to_pupil_magnification: float, + objective_magnification: float, + reference_tube_lens: Quantity[u.mm], + ) -> Quantity[u.um / u.V]: """Computes the conversion factor between voltage and displacement in the object plane. Args: @@ -224,12 +255,18 @@ def compute_scale(*, optical_deflection: Quantity[u.deg / u.V], galvo_to_pupil_m """ f_objective = reference_tube_lens / objective_magnification angle_to_displacement = f_objective / u.rad - return ((optical_deflection / galvo_to_pupil_magnification) * angle_to_displacement).to(u.um / u.V) + return ( + (optical_deflection / galvo_to_pupil_magnification) * angle_to_displacement + ).to(u.um / u.V) @staticmethod - def compute_acceleration(*, optical_deflection: Quantity[u.deg / u.V], torque_constant: Quantity[u.N * u.m / u.A], - rotor_inertia: Quantity[u.kg * u.m ** 2], - maximum_current: Quantity[u.A]) -> Quantity[u.V / u.s ** 2]: + def compute_acceleration( + *, + optical_deflection: Quantity[u.deg / u.V], + torque_constant: Quantity[u.N * u.m / u.A], + rotor_inertia: Quantity[u.kg * u.m**2], + maximum_current: Quantity[u.A], + ) -> Quantity[u.V / u.s**2]: """Computes the angular acceleration of the focus of the galvo mirror. The result is returned in the unit V / second², @@ -247,16 +284,19 @@ def compute_acceleration(*, optical_deflection: Quantity[u.deg / u.V], torque_co maximum_current (Quantity[u.A]): The maximum current that can be applied to the galvo mirror. """ - angular_acceleration = (torque_constant * maximum_current / rotor_inertia).to(u.s ** -2) * u.rad - return (angular_acceleration / optical_deflection).to(u.V / u.s ** 2) + angular_acceleration = (torque_constant * maximum_current / rotor_inertia).to( + u.s**-2 + ) * u.rad + return (angular_acceleration / optical_deflection).to(u.V / u.s**2) class TestPatternType(Enum): """Type of test pattern to use for simulation.""" - NONE = 'none' - HORIZONTAL = 'horizontal' - VERTICAL = 'vertical' - IMAGE = 'image' + + NONE = "none" + HORIZONTAL = "horizontal" + VERTICAL = "vertical" + IMAGE = "image" class ScanningMicroscope(Detector): @@ -308,19 +348,22 @@ class ScanningMicroscope(Detector): parameter can be used. """ - def __init__(self, - input: InputChannel, - y_axis: Axis, - x_axis: Axis, - sample_rate: Quantity[u.MHz], - resolution: int, - reference_zoom: float, *, - delay: Quantity[u.us] = 0.0 * u.us, - bidirectional: bool = True, - multi_threaded: bool = True, - preprocessor: Optional[callable] = None, - test_pattern: Union[TestPatternType, str] = TestPatternType.NONE, - test_image=None): + def __init__( + self, + input: InputChannel, + y_axis: Axis, + x_axis: Axis, + sample_rate: Quantity[u.MHz], + resolution: int, + reference_zoom: float, + *, + delay: Quantity[u.us] = 0.0 * u.us, + bidirectional: bool = True, + multi_threaded: bool = True, + preprocessor: Optional[callable] = None, + test_pattern: Union[TestPatternType, str] = TestPatternType.NONE, + test_image=None, + ): """ Args: resolution: number of pixels (height and width) in the full field of view. @@ -353,8 +396,12 @@ def __init__(self, self._resolution = int(resolution) self._roi_top = 0 # in pixels self._roi_left = 0 # in pixels - self._center_x = 0.5 # in relative coordinates (relative to the full field of view) - self._center_y = 0.5 # in relative coordinates (relative to the full field of view) + self._center_x = ( + 0.5 # in relative coordinates (relative to the full field of view) + ) + self._center_y = ( + 0.5 # in relative coordinates (relative to the full field of view) + ) self._delay = delay.to(u.us) self._reference_zoom = float(reference_zoom) self._zoom = 1.0 @@ -365,9 +412,9 @@ def __init__(self, self._test_pattern = TestPatternType(test_pattern) self._test_image = None if test_image is not None: - self._test_image = np.array(test_image, dtype='uint16') + self._test_image = np.array(test_image, dtype="uint16") while self._test_image.ndim > 2: - self._test_image = np.mean(self._test_image, 2).astype('uint16') + self._test_image = np.mean(self._test_image, 2).astype("uint16") self._preprocessor = preprocessor @@ -379,9 +426,13 @@ def __init__(self, # the pixel size and duration are computed dynamically # data_shape just returns self._data shape, and latency = 0.0 ms - super().__init__(data_shape=(resolution, resolution), pixel_size=None, duration=None, - latency=0.0 * u.ms, - multi_threaded=multi_threaded) + super().__init__( + data_shape=(resolution, resolution), + pixel_size=None, + duration=None, + latency=0.0 * u.ms, + multi_threaded=multi_threaded, + ) self._update() def _update(self): @@ -411,7 +462,9 @@ def _update(self): # Compute the retrace pattern for the slow axis # The scan starts at half a pixel after roi_bottom and ends half a pixel before roi_top - v_yr = self._y_axis.step(roi_bottom - 0.5 * roi_scale, roi_top + 0.5 * roi_scale, self._sample_rate) + v_yr = self._y_axis.step( + roi_bottom - 0.5 * roi_scale, roi_top + 0.5 * roi_scale, self._sample_rate + ) # Compute the scan pattern for the fast axis # The naive speed is the scan speed assuming one pixel per sample @@ -419,22 +472,33 @@ def _update(self): # (at least, without spending more time on accelerating and decelerating than the scan itself) # The user can set the scan speed relative to the maximum speed. # If this set speed is lower than naive scan speed, multiple samples are taken per pixel. - naive_speed = (self._x_axis.v_max - self._x_axis.v_min) * roi_scale * self._sample_rate - max_speed = self._x_axis.maximum_scan_speed(1.0 / actual_zoom) * self._scan_speed_factor + naive_speed = ( + (self._x_axis.v_max - self._x_axis.v_min) * roi_scale * self._sample_rate + ) + max_speed = ( + self._x_axis.maximum_scan_speed(1.0 / actual_zoom) * self._scan_speed_factor + ) if max_speed == 0.0: # this may happen if the ROI reaches to or beyond [0,1]. In this case, the mirror has no time to accelerate # TODO: implement an auto-adjust option instead of raising an error - raise ValueError("Maximum scan speed is zero. " - "This may be because the region of interest exceeds the maximum voltage range") + raise ValueError( + "Maximum scan speed is zero. " + "This may be because the region of interest exceeds the maximum voltage range" + ) self._oversampling = int(np.ceil(unitless(naive_speed / max_speed))) oversampled_width = width * self._oversampling - v_x_even, x_launch, x_land, self._mask = self._x_axis.scan(roi_left, roi_right, oversampled_width, - self._sample_rate) + v_x_even, x_launch, x_land, self._mask = self._x_axis.scan( + roi_left, roi_right, oversampled_width, self._sample_rate + ) if self._bidirectional: - v_x_odd, _, _, _ = self._x_axis.scan(roi_right, roi_left, oversampled_width, self._sample_rate) + v_x_odd, _, _, _ = self._x_axis.scan( + roi_right, roi_left, oversampled_width, self._sample_rate + ) else: - v_xr = self._x_axis.step(x_land, x_launch, self._sample_rate) # horizontal retrace + v_xr = self._x_axis.step( + x_land, x_launch, self._sample_rate + ) # horizontal retrace v_x_even = np.concatenate((v_x_even, v_xr)) v_x_odd = v_x_even @@ -444,7 +508,7 @@ def _update(self): # For bidirectional mode, the scan pattern is padded to always have an even number of scan lines # The horizontal pattern is repeated continuously, so even during the # vertical retrace. In bidirectional scan mode, th - n_rows = self._data_shape[0] + np.ceil(len(v_yr) / len(v_x_odd)).astype('int32') + n_rows = self._data_shape[0] + np.ceil(len(v_yr) / len(v_x_odd)).astype("int32") self._n_cols = len(v_x_odd) if self._bidirectional and n_rows % 2 == 1: n_rows += 1 @@ -464,8 +528,8 @@ def _update(self): # which is essential for resonant scanning. if len(v_yr) > 0: retrace = scan_pattern[0, height:, :].reshape(-1) - retrace[0:len(v_yr)] = v_yr - retrace[len(v_yr):] = v_yr[-1] + retrace[0 : len(v_yr)] = v_yr + retrace[len(v_yr) :] = v_yr[-1] self._scan_pattern = scan_pattern.reshape(2, -1) if self._test_pattern != TestPatternType.NONE: @@ -489,25 +553,39 @@ def _update(self): sample_count = self._scan_pattern.shape[1] # Configure the analog output task (two channels) - self._write_task.ao_channels.add_ao_voltage_chan(self._x_axis.channel, - min_val=self._x_axis.v_min.to_value(u.V), - max_val=self._x_axis.v_max.to_value(u.V)) - self._write_task.ao_channels.add_ao_voltage_chan(self._y_axis.channel, - min_val=self._y_axis.v_min.to_value(u.V), - max_val=self._y_axis.v_max.to_value(u.V)) - self._write_task.timing.cfg_samp_clk_timing(sample_rate, samps_per_chan=sample_count) + self._write_task.ao_channels.add_ao_voltage_chan( + self._x_axis.channel, + min_val=self._x_axis.v_min.to_value(u.V), + max_val=self._x_axis.v_max.to_value(u.V), + ) + self._write_task.ao_channels.add_ao_voltage_chan( + self._y_axis.channel, + min_val=self._y_axis.v_min.to_value(u.V), + max_val=self._y_axis.v_max.to_value(u.V), + ) + self._write_task.timing.cfg_samp_clk_timing( + sample_rate, samps_per_chan=sample_count + ) # Configure the analog input task (one channel) - self._read_task.ai_channels.add_ai_voltage_chan(self._input_channel.channel, - min_val=self._input_channel.v_min.to_value(u.V), - max_val=self._input_channel.v_max.to_value(u.V), - terminal_config=self._input_channel.terminal_configuration) - self._read_task.timing.cfg_samp_clk_timing(sample_rate, samps_per_chan=sample_count) - self._read_task.triggers.start_trigger.cfg_dig_edge_start_trig(self._write_task.triggers.start_trigger.term) + self._read_task.ai_channels.add_ai_voltage_chan( + self._input_channel.channel, + min_val=self._input_channel.v_min.to_value(u.V), + max_val=self._input_channel.v_max.to_value(u.V), + terminal_config=self._input_channel.terminal_configuration, + ) + self._read_task.timing.cfg_samp_clk_timing( + sample_rate, samps_per_chan=sample_count + ) + self._read_task.triggers.start_trigger.cfg_dig_edge_start_trig( + self._write_task.triggers.start_trigger.term + ) delay = self._delay.to_value(u.s) if delay > 0.0: self._read_task.triggers.start_trigger.delay = delay - self._read_task.triggers.start_trigger.delay_units = DigitalWidthUnits.SECONDS + self._read_task.triggers.start_trigger.delay_units = ( + DigitalWidthUnits.SECONDS + ) self._writer = AnalogMultiChannelWriter(self._write_task.out_stream) self._valid = True @@ -541,21 +619,27 @@ def _raw_to_cropped(self, raw: np.ndarray) -> np.ndarray: flips the even rows back if scanned in bidirectional mode. """ # convert data to 2-d, discard padding - cropped = raw.reshape(-1, self._n_cols)[:self._data_shape[0], self._mask] + cropped = raw.reshape(-1, self._n_cols)[: self._data_shape[0], self._mask] # down sample along fast axis if needed if self._oversampling > 1: # remove samples if not divisible by oversampling factor - cropped = cropped[:, :(cropped.shape[1] // self._oversampling) * self._oversampling] + cropped = cropped[ + :, : (cropped.shape[1] // self._oversampling) * self._oversampling + ] cropped = cropped.reshape(cropped.shape[0], -1, self._oversampling) - cropped = np.round(np.mean(cropped, 2)).astype(cropped.dtype) # todo: faster alternative? + cropped = np.round(np.mean(cropped, 2)).astype( + cropped.dtype + ) # todo: faster alternative? # Change the data type into uint16 if necessary if cropped.dtype == np.int16: # add 32768 to go from -32768-32767 to 0-65535 - cropped = cropped.view('uint16') + 0x8000 + cropped = cropped.view("uint16") + 0x8000 elif cropped.dtype != np.uint16: - raise ValueError(f'Only int16 and uint16 data types are supported at the moment, got type {cropped.dtype}.') + raise ValueError( + f"Only int16 and uint16 data types are supported at the moment, got type {cropped.dtype}." + ) if self._bidirectional: # note: requires the mask to be symmetrical cropped[1::2, :] = cropped[1::2, ::-1] @@ -569,31 +653,43 @@ def _fetch(self) -> np.ndarray: # noqa self._read_task.stop() self._write_task.stop() elif self._test_pattern == TestPatternType.HORIZONTAL: - raw = np.round(self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * 10000).astype('int16') + raw = np.round( + self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * 10000 + ).astype("int16") elif self._test_pattern == TestPatternType.VERTICAL: - raw = np.round(self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * 10000).astype('int16') + raw = np.round( + self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * 10000 + ).astype("int16") elif self._test_pattern == TestPatternType.IMAGE: if self._test_image is None: - raise ValueError('No test image was provided for the image simulation.') + raise ValueError("No test image was provided for the image simulation.") # todo: cache the test image row = np.floor( - self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * (self._test_image.shape[0] - 1)).astype( - 'int32') + self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) + * (self._test_image.shape[0] - 1) + ).astype("int32") column = np.floor( - self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * (self._test_image.shape[1] - 1)).astype( - 'int32') + self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) + * (self._test_image.shape[1] - 1) + ).astype("int32") raw = self._test_image[row, column] else: - raise ValueError(f"Invalid simulation option {self._test_pattern}. " - "Should be 'horizontal', 'vertical', 'image', or 'None'") + raise ValueError( + f"Invalid simulation option {self._test_pattern}. " + "Should be 'horizontal', 'vertical', 'image', or 'None'" + ) # Preprocess raw data if a preprocess function is set if self._preprocessor is None: preprocessed_raw = raw elif callable(self._preprocessor): - preprocessed_raw = self._preprocessor(data=raw, sample_rate=self._sample_rate) + preprocessed_raw = self._preprocessor( + data=raw, sample_rate=self._sample_rate + ) else: - raise TypeError(f"Invalid type for {self._preprocessor}. Should be callable or None.") + raise TypeError( + f"Invalid type for {self._preprocessor}. Should be callable or None." + ) return self._raw_to_cropped(preprocessed_raw) def close(self): @@ -622,7 +718,9 @@ def preprocessor(self): @preprocessor.setter def preprocessor(self, value: Optional[callable]): if not callable(value) and value is not None: - raise TypeError(f"Invalid type for {self._preprocessor}. Should be callable or None.") + raise TypeError( + f"Invalid type for {self._preprocessor}. Should be callable or None." + ) self._preprocessor = value @property @@ -631,8 +729,10 @@ def pixel_size(self) -> Quantity: # TODO: make extent a read-only attribute of Axis extent_y = (self._y_axis.v_max - self._y_axis.v_min) * self._y_axis.scale extent_x = (self._x_axis.v_max - self._x_axis.v_min) * self._x_axis.scale - return (Quantity(extent_y, extent_x) / ( - self._reference_zoom * self._zoom * self._resolution)).to(u.um) + return ( + Quantity(extent_y, extent_x) + / (self._reference_zoom * self._zoom * self._resolution) + ).to(u.um) @property def duration(self) -> Quantity[u.ms]: @@ -832,7 +932,7 @@ def binning(self) -> int: @binning.setter def binning(self, value: int): if value < 1: - raise ValueError('Binning value should be a positive integer') + raise ValueError("Binning value should be a positive integer") self._scale_roi(self._binning / int(value)) self._binning = int(value) diff --git a/openwfs/devices/nidaq_gain.py b/openwfs/devices/nidaq_gain.py index e2971e3..b0cb23c 100644 --- a/openwfs/devices/nidaq_gain.py +++ b/openwfs/devices/nidaq_gain.py @@ -1,9 +1,9 @@ -import nidaqmx as ni -from nidaqmx.constants import LineGrouping -from typing import Annotated import time + import astropy.units as u +import nidaqmx as ni from astropy.units import Quantity +from nidaqmx.constants import LineGrouping class Gain: @@ -55,7 +55,9 @@ def check_overload(self): def on_reset(self, value): if value: with ni.Task() as task: - task.do_channels.add_do_chan(self.port_do, line_grouping=LineGrouping.CHAN_FOR_ALL_LINES) + task.do_channels.add_do_chan( + self.port_do, line_grouping=LineGrouping.CHAN_FOR_ALL_LINES + ) task.write([True]) time.sleep(1) task.write([False]) @@ -84,4 +86,3 @@ def gain(self, value: Quantity[u.V]): channel.ao_min = 0 channel.ao_max = 0.9 write_task.write(self._gain.to_value(u.V)) - diff --git a/openwfs/devices/slm/context.py b/openwfs/devices/slm/context.py index b36764b..0e33cf7 100644 --- a/openwfs/devices/slm/context.py +++ b/openwfs/devices/slm/context.py @@ -3,7 +3,7 @@ import glfw -SLM = 'slm.SLM' +SLM = "slm.SLM" class Context: @@ -15,6 +15,7 @@ class Context: one thread can use OpenGL at the same time. This class holds a weak ref to the SLM object, so that the SLM object can be garbage collected. """ + _lock = threading.RLock() def __init__(self, slm): diff --git a/openwfs/devices/slm/geometry.py b/openwfs/devices/slm/geometry.py index 0506cbe..b6f73a1 100644 --- a/openwfs/devices/slm/geometry.py +++ b/openwfs/devices/slm/geometry.py @@ -110,15 +110,26 @@ def rectangle(extent: ExtentType, center: CoordinateType = (0, 0)) -> Geometry: right = center[1] + 0.5 * extent[1] bottom = center[0] + 0.5 * extent[0] - vertices = np.array(([left, top, 0.0, 0.0], [right, top, 1.0, 0.0], - [left, bottom, 0.0, 1.0], [right, bottom, 1.0, 1.0]), dtype=np.float32) + vertices = np.array( + ( + [left, top, 0.0, 0.0], + [right, top, 1.0, 0.0], + [left, bottom, 0.0, 1.0], + [right, bottom, 1.0, 1.0], + ), + dtype=np.float32, + ) indices = Geometry.compute_indices_for_grid((1, 1)) return Geometry(vertices, indices) -def circular(radii: Sequence[float], segments_per_ring: Sequence[int], edge_count: int = 256, - center: CoordinateType = (0, 0)) -> Geometry: +def circular( + radii: Sequence[float], + segments_per_ring: Sequence[int], + edge_count: int = 256, + center: CoordinateType = (0, 0), +) -> Geometry: """Creates a circular geometry with the specified extent. This geometry maps a texture to a disk or a ring. @@ -153,7 +164,8 @@ def circular(radii: Sequence[float], segments_per_ring: Sequence[int], edge_coun if len(segments_per_ring) != ring_count: raise ValueError( "The length of `radii` and `segments_per_ring` should both equal the number of rings (counting " - "the inner disk as the first ring).") + "the inner disk as the first ring)." + ) # construct coordinates of points on a circle of radius 1.0 # the start and end point coincide @@ -172,15 +184,19 @@ def circular(radii: Sequence[float], segments_per_ring: Sequence[int], edge_coun segments_inside = 0 total_segments = np.sum(segments_per_ring) for r in range(ring_count): - x_outside = x * radii[r + 1] # coordinates of the vertices at the outside of the ring + x_outside = ( + x * radii[r + 1] + ) # coordinates of the vertices at the outside of the ring y_outside = y * radii[r + 1] segments = segments_inside + segments_per_ring[r] vertices[r, 0, :, 0] = x_inside + center[1] vertices[r, 0, :, 1] = y_inside + center[0] vertices[r, 1, :, 0] = x_outside + center[1] vertices[r, 1, :, 1] = y_outside + center[0] - vertices[r, :, :, 2] = np.linspace(segments_inside, segments, edge_count + 1).reshape( - (1, -1)) / total_segments # tx + vertices[r, :, :, 2] = ( + np.linspace(segments_inside, segments, edge_count + 1).reshape((1, -1)) + / total_segments + ) # tx x_inside = x_outside y_inside = y_outside segments_inside = segments @@ -190,6 +206,9 @@ def circular(radii: Sequence[float], segments_per_ring: Sequence[int], edge_coun # construct indices for a single ring, and repeat for all rings with the appropriate offset indices = Geometry.compute_indices_for_grid((1, edge_count)).reshape((1, -1)) - indices = indices + np.arange(ring_count).reshape((-1, 1)) * vertices.shape[1] * vertices.shape[2] + indices = ( + indices + + np.arange(ring_count).reshape((-1, 1)) * vertices.shape[1] * vertices.shape[2] + ) indices[:, -1] = 0xFFFF return Geometry(vertices.reshape((-1, 4)), indices.reshape(-1)) diff --git a/openwfs/devices/slm/patch.py b/openwfs/devices/slm/patch.py index a32c14c..7c55840 100644 --- a/openwfs/devices/slm/patch.py +++ b/openwfs/devices/slm/patch.py @@ -8,18 +8,44 @@ try: import OpenGL.GL as GL - from OpenGL.GL import glGenBuffers, glBindBuffer, glBufferData, glDeleteBuffers, glEnable, glBlendFunc, \ - glBlendEquation, glDisable, glUseProgram, glBindVertexBuffer, glDrawElements, glGenFramebuffers, \ - glBindFramebuffer, glFramebufferTexture2D, glCheckFramebufferStatus, glDeleteFramebuffers, \ - glEnableVertexAttribArray, glVertexAttribFormat, glVertexAttribBinding, glEnableVertexAttribArray, \ - glPrimitiveRestartIndex, glActiveTexture, glBindTexture, glGenVertexArrays, glBindVertexArray + from OpenGL.GL import ( + glGenBuffers, + glBindBuffer, + glBufferData, + glDeleteBuffers, + glEnable, + glBlendFunc, + glBlendEquation, + glDisable, + glUseProgram, + glBindVertexBuffer, + glDrawElements, + glGenFramebuffers, + glBindFramebuffer, + glFramebufferTexture2D, + glCheckFramebufferStatus, + glDeleteFramebuffers, + glEnableVertexAttribArray, + glVertexAttribFormat, + glVertexAttribBinding, + glEnableVertexAttribArray, + glPrimitiveRestartIndex, + glActiveTexture, + glBindTexture, + glGenVertexArrays, + glBindVertexArray, + ) from OpenGL.GL import shaders except AttributeError: warnings.warn("OpenGL not found, SLM will not work") from .geometry import rectangle, Geometry -from .shaders import default_vertex_shader, default_fragment_shader, \ - post_process_fragment_shader, post_process_vertex_shader +from .shaders import ( + default_vertex_shader, + default_fragment_shader, + post_process_fragment_shader, + post_process_vertex_shader, +) from .texture import Texture from ...core import PhaseSLM @@ -27,8 +53,13 @@ class Patch(PhaseSLM): _PHASES_TEXTURE = 0 # indices of the phases texture in the _texture array - def __init__(self, slm, geometry=None, vertex_shader=default_vertex_shader, - fragment_shader=default_fragment_shader): + def __init__( + self, + slm, + geometry=None, + vertex_shader=default_vertex_shader, + fragment_shader=default_fragment_shader, + ): """ Constructs a new patch (a shape) that can be drawn on the screen. By default, the patch is a square with 'radius' 1.0 (width and height 2.0) centered at 0.0, 0.0 @@ -81,7 +112,9 @@ def _draw(self): # perform the actual drawing glBindBuffer(GL.GL_ELEMENT_ARRAY_BUFFER, self._indices) glBindVertexBuffer(0, self._vertices, 0, 16) - glDrawElements(GL.GL_TRIANGLE_STRIP, self._index_count, GL.GL_UNSIGNED_SHORT, None) + glDrawElements( + GL.GL_TRIANGLE_STRIP, self._index_count, GL.GL_UNSIGNED_SHORT, None + ) def set_phases(self, values: ArrayLike, update=True): """ @@ -123,9 +156,19 @@ def geometry(self, value: Geometry): (self._vertices, self._indices) = glGenBuffers(2) self._index_count = value.indices.size glBindBuffer(GL.GL_ARRAY_BUFFER, self._vertices) - glBufferData(GL.GL_ARRAY_BUFFER, value.vertices.size * 4, value.vertices, GL.GL_DYNAMIC_DRAW) + glBufferData( + GL.GL_ARRAY_BUFFER, + value.vertices.size * 4, + value.vertices, + GL.GL_DYNAMIC_DRAW, + ) glBindBuffer(GL.GL_ELEMENT_ARRAY_BUFFER, self._indices) - glBufferData(GL.GL_ELEMENT_ARRAY_BUFFER, value.indices.size * 2, value.indices, GL.GL_DYNAMIC_DRAW) + glBufferData( + GL.GL_ELEMENT_ARRAY_BUFFER, + value.indices.size * 2, + value.indices, + GL.GL_DYNAMIC_DRAW, + ) class FrameBufferPatch(Patch): @@ -137,22 +180,34 @@ class FrameBufferPatch(Patch): _textures: list[Texture] def __init__(self, slm, lookup_table: Sequence[int]): - super().__init__(slm, fragment_shader=post_process_fragment_shader, - vertex_shader=post_process_vertex_shader) + super().__init__( + slm, + fragment_shader=post_process_fragment_shader, + vertex_shader=post_process_vertex_shader, + ) # Create a frame buffer object to render to. The frame buffer holds a texture that is the same size as the # window. All patches are first rendered to this texture. The texture # is then processed as a whole (applying the software lookup table) and displayed on the screen. self._frame_buffer = glGenFramebuffers(1) - self.set_phases(np.zeros(self.context.slm.shape, dtype=np.float32), update=False) + self.set_phases( + np.zeros(self.context.slm.shape, dtype=np.float32), update=False + ) glBindFramebuffer(GL.GL_FRAMEBUFFER, self._frame_buffer) - glFramebufferTexture2D(GL.GL_FRAMEBUFFER, GL.GL_COLOR_ATTACHMENT0, GL.GL_TEXTURE_2D, - self._textures[Patch._PHASES_TEXTURE].handle, 0) + glFramebufferTexture2D( + GL.GL_FRAMEBUFFER, + GL.GL_COLOR_ATTACHMENT0, + GL.GL_TEXTURE_2D, + self._textures[Patch._PHASES_TEXTURE].handle, + 0, + ) if glCheckFramebufferStatus(GL.GL_FRAMEBUFFER) != GL.GL_FRAMEBUFFER_COMPLETE: raise Exception("Could not construct frame buffer") glBindFramebuffer(GL.GL_FRAMEBUFFER, 0) - self._textures.append(Texture(self.context, GL.GL_TEXTURE_1D)) # create texture for lookup table + self._textures.append( + Texture(self.context, GL.GL_TEXTURE_1D) + ) # create texture for lookup table self._lookup_table = None self.lookup_table = lookup_table self.additive_blend = False @@ -164,7 +219,7 @@ def __del__(self): @property def lookup_table(self): - """1-D array """ + """1-D array""" return self._lookup_table @lookup_table.setter @@ -195,15 +250,25 @@ class VertexArray: # Since we have a fixed vertex format, we only need to bind the VertexArray once, and not bother with # updating, binding, or even deleting it def __init__(self): - self._vertex_array = glGenVertexArrays(1) # no need to destroy explicitly, destroyed when window is destroyed + self._vertex_array = glGenVertexArrays( + 1 + ) # no need to destroy explicitly, destroyed when window is destroyed glBindVertexArray(self._vertex_array) glEnableVertexAttribArray(0) glEnableVertexAttribArray(1) - glVertexAttribFormat(0, 2, GL.GL_FLOAT, GL.GL_FALSE, 0) # first two float32 are screen coordinates - glVertexAttribFormat(1, 2, GL.GL_FLOAT, GL.GL_FALSE, 8) # second two are texture coordinates + glVertexAttribFormat( + 0, 2, GL.GL_FLOAT, GL.GL_FALSE, 0 + ) # first two float32 are screen coordinates + glVertexAttribFormat( + 1, 2, GL.GL_FLOAT, GL.GL_FALSE, 8 + ) # second two are texture coordinates glVertexAttribBinding(0, 0) # use binding index 0 for both attributes - glVertexAttribBinding(1, 0) # the attribute format can now be used with glBindVertexBuffer + glVertexAttribBinding( + 1, 0 + ) # the attribute format can now be used with glBindVertexBuffer # enable primitive restart, so that we can draw multiple triangle strips with a single draw call glEnable(GL.GL_PRIMITIVE_RESTART) - glPrimitiveRestartIndex(0xFFFF) # this is the index we use to separate individual triangle strips + glPrimitiveRestartIndex( + 0xFFFF + ) # this is the index we use to separate individual triangle strips diff --git a/openwfs/devices/slm/slm.py b/openwfs/devices/slm/slm.py index e177069..d1a4750 100644 --- a/openwfs/devices/slm/slm.py +++ b/openwfs/devices/slm/slm.py @@ -13,8 +13,19 @@ try: import OpenGL.GL as GL - from OpenGL.GL import glViewport, glClearColor, glClear, glGenBuffers, glReadBuffer, glReadPixels, glFinish, \ - glBindBuffer, glBufferData, glBindBufferBase, glBindFramebuffer + from OpenGL.GL import ( + glViewport, + glClearColor, + glClear, + glGenBuffers, + glReadBuffer, + glReadPixels, + glFinish, + glBindBuffer, + glBufferData, + glBindBufferBase, + glBindFramebuffer, + ) except AttributeError: warnings.warn("OpenGL not found, SLM will not work") from .patch import FrameBufferPatch, Patch, VertexArray @@ -31,9 +42,27 @@ class SLM(Actuator, PhaseSLM): See :numref:`section-slms` for more information. """ - __slots__ = ['_vertex_array', '_frame_buffer', '_monitor_id', '_position', '_refresh_rate', - '_transform', '_shape', '_window', '_globals', '_frame_buffer', 'patches', 'primary_patch', - '_coordinate_system', '_pixel_reader', '_phase_reader', '_field_reader', '_context', '_clones'] + + __slots__ = [ + "_vertex_array", + "_frame_buffer", + "_monitor_id", + "_position", + "_refresh_rate", + "_transform", + "_shape", + "_window", + "_globals", + "_frame_buffer", + "patches", + "primary_patch", + "_coordinate_system", + "_pixel_reader", + "_phase_reader", + "_field_reader", + "_context", + "_clones", + ] _active_slms = WeakSet() """Keep track of all active SLMs. This is done for two reasons. First, to check if we are not putting two @@ -43,10 +72,17 @@ class SLM(Actuator, PhaseSLM): WINDOWED = 0 patches: list[Patch] - def __init__(self, monitor_id: int = WINDOWED, shape: Optional[tuple[int, int]] = None, - pos: tuple[int, int] = (0, 0), refresh_rate: Optional[Quantity[u.Hz]] = None, - latency: TimeType = 2, duration: TimeType = 1, coordinate_system: str = 'short', - transform: Optional[Transform] = None): + def __init__( + self, + monitor_id: int = WINDOWED, + shape: Optional[tuple[int, int]] = None, + pos: tuple[int, int] = (0, 0), + refresh_rate: Optional[Quantity[u.Hz]] = None, + latency: TimeType = 2, + duration: TimeType = 1, + coordinate_system: str = "short", + transform: Optional[Transform] = None, + ): """ Constructs a new SLM window. @@ -86,7 +122,9 @@ def __init__(self, monitor_id: int = WINDOWED, shape: Optional[tuple[int, int]] self._position = pos (default_shape, default_rate, _) = SLM._current_mode(monitor_id) self._shape = default_shape if shape is None else shape - self._refresh_rate = default_rate if refresh_rate is None else refresh_rate.to_value(u.Hz) + self._refresh_rate = ( + default_rate if refresh_rate is None else refresh_rate.to_value(u.Hz) + ) self._frame_buffer = None self._window = None self._globals = -1 @@ -121,20 +159,32 @@ def _assert_window_available(self, monitor_id) -> None: Exception: If a full screen SLM is already present on the target monitor. """ if monitor_id == SLM.WINDOWED: - if any([slm.monitor_id == 1 for slm in SLM._active_slms if slm is not self]): + if any( + [slm.monitor_id == 1 for slm in SLM._active_slms if slm is not self] + ): raise RuntimeError( - f"Cannot create an SLM window because a full-screen SLM is already active on monitor 1") + f"Cannot create an SLM window because a full-screen SLM is already active on monitor 1" + ) else: # we cannot have multiple full screen windows on the same monitor. Also, we cannot have # a full screen window on monitor 1 if there are already windowed SLMs. - if any([slm.monitor_id == monitor_id or - (monitor_id == 1 and slm.monitor_id == SLM.WINDOWED) - for slm in SLM._active_slms if slm is not self]): - raise RuntimeError(f"Cannot create a full-screen SLM window on monitor {monitor_id} because a " - f"window is already displayed on that monitor") + if any( + [ + slm.monitor_id == monitor_id + or (monitor_id == 1 and slm.monitor_id == SLM.WINDOWED) + for slm in SLM._active_slms + if slm is not self + ] + ): + raise RuntimeError( + f"Cannot create a full-screen SLM window on monitor {monitor_id} because a " + f"window is already displayed on that monitor" + ) if monitor_id > len(glfw.get_monitors()): - raise IndexError(f"Monitor {monitor_id} not found, only {len(glfw.get_monitors())} monitor(s) " - f"are connected.") + raise IndexError( + f"Monitor {monitor_id} not found, only {len(glfw.get_monitors())} monitor(s) " + f"are connected." + ) @staticmethod def _current_mode(monitor_id: int): @@ -152,7 +202,11 @@ def _current_mode(monitor_id: int): mode = glfw.get_video_mode(monitor) shape = (mode.size[1], mode.size[0]) - return shape, mode.refresh_rate, min([mode.bits.red, mode.bits.green, mode.bits.blue]) + return ( + shape, + mode.refresh_rate, + min([mode.bits.red, mode.bits.green, mode.bits.blue]), + ) def _on_resize(self): """Updates shape and refresh rate to the actual values of the window. @@ -169,7 +223,11 @@ def _on_resize(self): """ # create a new frame buffer, re-use the old one if one was present, otherwise use a default of range(256) # re-use the lookup table if possible, otherwise create a default one ranging from 0 to 255. - old_lut = self._frame_buffer.lookup_table if self._frame_buffer is not None else range(256) + old_lut = ( + self._frame_buffer.lookup_table + if self._frame_buffer is not None + else range(256) + ) self._frame_buffer = FrameBufferPatch(self, old_lut) glViewport(0, 0, self._shape[1], self._shape[0]) # tell openGL to wait for the vertical retrace when swapping buffers (it appears need to do this @@ -180,22 +238,29 @@ def _on_resize(self): (fb_width, fb_height) = glfw.get_framebuffer_size(self._window) fb_shape = (fb_height, fb_width) if self._shape != fb_shape: - warnings.warn(f"Actual resolution {fb_shape} does not match requested resolution {self._shape}.") + warnings.warn( + f"Actual resolution {fb_shape} does not match requested resolution {self._shape}." + ) self._shape = fb_shape - (current_size, current_rate, current_bit_depth) = SLM._current_mode(self._monitor_id) + (current_size, current_rate, current_bit_depth) = SLM._current_mode( + self._monitor_id + ) # verify that the bit depth is at least 8 bit if current_bit_depth < 8: warnings.warn( f"Bit depth is less than 8 bits " - f"You may not be able to use the full phase resolution of your SLM.") + f"You may not be able to use the full phase resolution of your SLM." + ) # verify the refresh rate is correct # Then update the refresh rate to the actual value if int(self._refresh_rate) != current_rate: - warnings.warn(f"Actual refresh rate of {current_rate} Hz does not match set rate " - f"of {self._refresh_rate} Hz") + warnings.warn( + f"Actual refresh rate of {current_rate} Hz does not match set rate " + f"of {self._refresh_rate} Hz" + ) self._refresh_rate = current_rate @staticmethod @@ -208,19 +273,29 @@ def _init_glfw(): trouble if the user of our library also uses glfw for something else. """ glfw.init() - glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) # Required on Mac. Doesn't hurt on Windows - glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, glfw.TRUE) # Required on Mac. Useless on Windows + glfw.window_hint( + glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE + ) # Required on Mac. Doesn't hurt on Windows + glfw.window_hint( + glfw.OPENGL_FORWARD_COMPAT, glfw.TRUE + ) # Required on Mac. Useless on Windows glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 4) # request at least opengl 4.2 glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 2) glfw.window_hint(glfw.FLOATING, glfw.TRUE) # Keep window on top glfw.window_hint(glfw.DECORATED, glfw.FALSE) # Disable window border - glfw.window_hint(glfw.AUTO_ICONIFY, glfw.FALSE) # Prevent window minimization during task switch + glfw.window_hint( + glfw.AUTO_ICONIFY, glfw.FALSE + ) # Prevent window minimization during task switch glfw.window_hint(glfw.FOCUSED, glfw.FALSE) glfw.window_hint(glfw.DOUBLEBUFFER, glfw.TRUE) - glfw.window_hint(glfw.RED_BITS, 8) # require at least 8 bits per color channel (256 gray values) + glfw.window_hint( + glfw.RED_BITS, 8 + ) # require at least 8 bits per color channel (256 gray values) glfw.window_hint(glfw.GREEN_BITS, 8) glfw.window_hint(glfw.BLUE_BITS, 8) - glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, glfw.FALSE) # disable retina multisampling on Mac (untested) + glfw.window_hint( + glfw.COCOA_RETINA_FRAMEBUFFER, glfw.FALSE + ) # disable retina multisampling on Mac (untested) glfw.window_hint(glfw.SAMPLES, 0) # disable multisampling def _create_window(self): @@ -232,10 +307,18 @@ def _create_window(self): shared = other._window if other is not None else None # noqa: ok to use _window SLM._active_slms.add(self) - monitor = glfw.get_monitors()[self._monitor_id - 1] if self._monitor_id != SLM.WINDOWED else None + monitor = ( + glfw.get_monitors()[self._monitor_id - 1] + if self._monitor_id != SLM.WINDOWED + else None + ) glfw.window_hint(glfw.REFRESH_RATE, int(self._refresh_rate)) - self._window = glfw.create_window(self._shape[1], self._shape[0], "OpenWFS SLM", monitor, shared) - glfw.set_input_mode(self._window, glfw.CURSOR, glfw.CURSOR_HIDDEN) # disable cursor + self._window = glfw.create_window( + self._shape[1], self._shape[0], "OpenWFS SLM", monitor, shared + ) + glfw.set_input_mode( + self._window, glfw.CURSOR, glfw.CURSOR_HIDDEN + ) # disable cursor if monitor: # full screen mode glfw.set_gamma(monitor, 1.0) else: # windowed mode @@ -300,8 +383,7 @@ def refresh_rate(self) -> Quantity[u.Hz]: @property def period(self) -> Quantity[u.ms]: - """The period of the refresh rate in milliseconds (read only). - """ + """The period of the refresh rate in milliseconds (read only).""" return (1000 / self._refresh_rate) * u.ms @property @@ -331,9 +413,15 @@ def monitor_id(self, value): monitor = glfw.get_monitors()[value - 1] if value != SLM.WINDOWED else None # move window to new monitor - glfw.set_window_monitor(self._window, monitor, self._position[1], self._position[0], self._shape[1], - self._shape[0], - int(self._refresh_rate)) + glfw.set_window_monitor( + self._window, + monitor, + self._position[1], + self._position[0], + self._shape[1], + self._shape[0], + int(self._refresh_rate), + ) self._on_resize() def __del__(self): @@ -361,7 +449,9 @@ def update(self): """ with self._context: # first draw all patches into the frame buffer - glBindFramebuffer(GL.GL_FRAMEBUFFER, self._frame_buffer._frame_buffer) # noqa - ok to access 'friend class' + glBindFramebuffer( + GL.GL_FRAMEBUFFER, self._frame_buffer._frame_buffer + ) # noqa - ok to access 'friend class' glClear(GL.GL_COLOR_BUFFER_BIT) for patch in self.patches: patch._draw() # noqa - ok to access 'friend class' @@ -373,7 +463,9 @@ def update(self): glfw.poll_events() # process window messages if len(self._clones) > 0: - self._context.__exit__(None, None, None) # release context before updating clones + self._context.__exit__( + None, None, None + ) # release context before updating clones for clone in self._clones: with clone.slm._context: # noqa self._frame_buffer._draw() # noqa - ok to access 'friend class' @@ -413,31 +505,31 @@ def duration(self, value: Quantity[u.ms]): def coordinate_system(self) -> str: """Specifies the base coordinate system that is used to map vertex coordinates to the SLM window. - Possible values are 'full', 'short' and 'long'. + Possible values are 'full', 'short' and 'long'. - 'full' means that the coordinate range (-1,-1) to (1,1) is mapped to the entire SLM window. - If the window is not square, this means that the coordinates are anisotropic. + 'full' means that the coordinate range (-1,-1) to (1,1) is mapped to the entire SLM window. + If the window is not square, this means that the coordinates are anisotropic. - 'short' and 'long' map the coordinate range (-1,-1) to (1,1) to a square. - 'short' means that the square is scaled to fill the short side of the SLM (introducing zero-padding at the - edges). + 'short' and 'long' map the coordinate range (-1,-1) to (1,1) to a square. + 'short' means that the square is scaled to fill the short side of the SLM (introducing zero-padding at the + edges). - 'long' means that the square is scaled to fill the long side of the SLM - (causing part of the coordinate range to be cropped because these coordinates correspond to points outside - the SLM window). + 'long' means that the square is scaled to fill the long side of the SLM + (causing part of the coordinate range to be cropped because these coordinates correspond to points outside + the SLM window). - For a square SLM, 'full', 'short' and 'long' are all equivalent. + For a square SLM, 'full', 'short' and 'long' are all equivalent. - In all three cases, (-1,-1) corresponds to the top-left corner of the screen, and (1,-1) to the - bottom-left corner. This convention is consistent with that used in numpy/matplotlib + In all three cases, (-1,-1) corresponds to the top-left corner of the screen, and (1,-1) to the + bottom-left corner. This convention is consistent with that used in numpy/matplotlib - To further modify the mapping system, use the `transform` property. + To further modify the mapping system, use the `transform` property. """ return self._coordinate_system @coordinate_system.setter def coordinate_system(self, value: str): - if value not in ['full', 'short', 'long']: + if value not in ["full", "short", "long"]: raise ValueError(f"Unsupported coordinate system {value}") self._coordinate_system = value self.transform = self._transform # trigger update of transform matrix on gpu @@ -466,21 +558,29 @@ def transform(self, value: Transform): self._field_reader = None # update matrix stored on the gpu - if self._coordinate_system == 'full': + if self._coordinate_system == "full": transform = self._transform else: - scale_width = (width > height) == (self._coordinate_system == 'short') + scale_width = (width > height) == (self._coordinate_system == "short") if scale_width: - root_transform = Transform(np.array(((1.0, 0.0), (0.0, height / width)))) + root_transform = Transform( + np.array(((1.0, 0.0), (0.0, height / width))) + ) else: - root_transform = Transform(np.array(((width / height, 0.0), (0.0, 1.0)))) + root_transform = Transform( + np.array(((width / height, 0.0), (0.0, 1.0))) + ) transform = self._transform @ root_transform padded = transform.opencl_matrix() with self._context: glBindBuffer(GL.GL_UNIFORM_BUFFER, self._globals) - glBufferData(GL.GL_UNIFORM_BUFFER, padded.size * 4, padded, GL.GL_STATIC_DRAW) - glBindBufferBase(GL.GL_UNIFORM_BUFFER, 1, self._globals) # connect buffer to binding point 1 + glBufferData( + GL.GL_UNIFORM_BUFFER, padded.size * 4, padded, GL.GL_STATIC_DRAW + ) + glBindBufferBase( + GL.GL_UNIFORM_BUFFER, 1, self._globals + ) # connect buffer to binding point 1 @property def lookup_table(self) -> Sequence[int]: @@ -527,8 +627,12 @@ def phases(self) -> Detector: self._phase_reader = FrameBufferReader(self) return self._phase_reader - def clone(self, monitor_id: int = WINDOWED, shape: Optional[tuple[int, int]] = None, - pos: tuple[int, int] = (0, 0)): + def clone( + self, + monitor_id: int = WINDOWED, + shape: Optional[tuple[int, int]] = None, + pos: tuple[int, int] = (0, 0), + ): """Creates a new SLM window that mirrors the content of this SLM window. This is useful for demonstration and debugging purposes. @@ -560,8 +664,13 @@ def __init__(self, slm: SLM): class FrontBufferReader(Detector): def __init__(self, slm): self._context = Context(slm) - super().__init__(data_shape=None, pixel_size=None, duration=0.0 * u.ms, latency=0.0 * u.ms, - multi_threaded=False) + super().__init__( + data_shape=None, + pixel_size=None, + duration=0.0 * u.ms, + latency=0.0 * u.ms, + multi_threaded=False, + ) @property def data_shape(self): @@ -571,7 +680,7 @@ def _fetch(self, *args, **kwargs) -> np.ndarray: with self._context: glReadBuffer(GL.GL_FRONT) shape = self.data_shape - data = np.empty(shape, dtype='uint8') + data = np.empty(shape, dtype="uint8") glReadPixels(0, 0, shape[1], shape[0], GL.GL_RED, GL.GL_UNSIGNED_BYTE, data) # flip data upside down, because the OpenGL convention is to have the origin at the bottom left, # but we want it at the top left (like in numpy) @@ -581,8 +690,13 @@ def _fetch(self, *args, **kwargs) -> np.ndarray: class FrameBufferReader(Detector): def __init__(self, slm): self._context = Context(slm) - super().__init__(data_shape=None, pixel_size=None, duration=0.0 * u.ms, latency=0.0 * u.ms, - multi_threaded=False) + super().__init__( + data_shape=None, + pixel_size=None, + duration=0.0 * u.ms, + latency=0.0 * u.ms, + multi_threaded=False, + ) @property def data_shape(self): diff --git a/openwfs/devices/slm/texture.py b/openwfs/devices/slm/texture.py index 6909abe..97c83f1 100644 --- a/openwfs/devices/slm/texture.py +++ b/openwfs/devices/slm/texture.py @@ -6,8 +6,19 @@ try: import OpenGL.GL as GL - from OpenGL.GL import glGenTextures, glBindTexture, glTexImage2D, glTexSubImage2D, glTexImage1D, glTexSubImage1D, \ - glTexParameteri, glActiveTexture, glDeleteTextures, glGetTextureImage, glPixelStorei + from OpenGL.GL import ( + glGenTextures, + glBindTexture, + glTexImage2D, + glTexSubImage2D, + glTexImage1D, + glTexSubImage1D, + glTexParameteri, + glActiveTexture, + glDeleteTextures, + glGetTextureImage, + glPixelStorei, + ) except AttributeError: warnings.warn("OpenGL not found, SLM will not work"), @@ -17,7 +28,9 @@ def __init__(self, slm, texture_type=GL.GL_TEXTURE_2D): self.context = Context(slm) self.handle = glGenTextures(1) self.type = texture_type - self.synchronized = False # self.data is not yet synchronized with texture in GPU memory + self.synchronized = ( + False # self.data is not yet synchronized with texture in GPU memory + ) self._data_shape = None # current size of the texture, to see if we need to make a new texture or # overwrite the exiting one @@ -37,22 +50,28 @@ def __del__(self): glDeleteTextures(1, [self.handle]) def _bind(self, idx): - """ Bind texture to texture unit idx. Assumes that the OpenGL context is already active.""" + """Bind texture to texture unit idx. Assumes that the OpenGL context is already active.""" glActiveTexture(GL.GL_TEXTURE0 + idx) glBindTexture(self.type, self.handle) def set_data(self, value): - """ Set texture data. + """Set texture data. The texture data is directly copied to the GPU memory, so the original data array can be modified or deleted. """ - value = np.array(value, dtype=np.float32, order='C', copy=False) + value = np.array(value, dtype=np.float32, order="C", copy=False) with self.context: glBindTexture(self.type, self.handle) - glPixelStorei(GL.GL_UNPACK_ALIGNMENT, 4) # alignment is at least four bytes since we use float32 - (internal_format, data_format, data_type) = (GL.GL_R32F, GL.GL_RED, GL.GL_FLOAT) + glPixelStorei( + GL.GL_UNPACK_ALIGNMENT, 4 + ) # alignment is at least four bytes since we use float32 + (internal_format, data_format, data_type) = ( + GL.GL_R32F, + GL.GL_RED, + GL.GL_FLOAT, + ) if self.type == GL.GL_TEXTURE_1D: # check if data has the correct dimension, convert scalars to arrays of correct dimension @@ -62,11 +81,28 @@ def set_data(self, value): raise ValueError("Data should be a 1-d array or a scalar") if value.shape != self._data_shape: # create a new texture - glTexImage1D(GL.GL_TEXTURE_1D, 0, internal_format, value.shape[0], 0, data_format, data_type, value) + glTexImage1D( + GL.GL_TEXTURE_1D, + 0, + internal_format, + value.shape[0], + 0, + data_format, + data_type, + value, + ) self._data_shape = value.shape else: # overwrite existing texture - glTexSubImage1D(GL.GL_TEXTURE_1D, 0, 0, value.shape[0], data_format, data_type, value) + glTexSubImage1D( + GL.GL_TEXTURE_1D, + 0, + 0, + value.shape[0], + data_format, + data_type, + value, + ) elif self.type == GL.GL_TEXTURE_2D: if value.ndim == 0: @@ -74,17 +110,37 @@ def set_data(self, value): elif value.ndim != 2: raise ValueError("Data should be a 2-D array or a scalar") if value.shape != self._data_shape: - glTexImage2D(GL.GL_TEXTURE_2D, 0, internal_format, value.shape[1], value.shape[0], 0, - data_format, data_type, value) + glTexImage2D( + GL.GL_TEXTURE_2D, + 0, + internal_format, + value.shape[1], + value.shape[0], + 0, + data_format, + data_type, + value, + ) self._data_shape = value.shape else: - glTexSubImage2D(GL.GL_TEXTURE_2D, 0, 0, 0, value.shape[1], value.shape[0], data_format, - data_type, value) + glTexSubImage2D( + GL.GL_TEXTURE_2D, + 0, + 0, + 0, + value.shape[1], + value.shape[0], + data_format, + data_type, + value, + ) else: raise ValueError("Texture type not supported") def get_data(self): with self.context: - data = np.empty(self._data_shape, dtype='float32') - glGetTextureImage(self.handle, 0, GL.GL_RED, GL.GL_FLOAT, data.size * 4, data) + data = np.empty(self._data_shape, dtype="float32") + glGetTextureImage( + self.handle, 0, GL.GL_RED, GL.GL_FLOAT, data.size * 4, data + ) return data diff --git a/openwfs/plot_utilities.py b/openwfs/plot_utilities.py index 27d7516..1a9a6bd 100644 --- a/openwfs/plot_utilities.py +++ b/openwfs/plot_utilities.py @@ -14,11 +14,11 @@ def imshow(data, axis=None): e0 = scale_prefix(extent[0]) e1 = scale_prefix(extent[1]) if axis is None: - plt.imshow(data, extent=(0.0, e1.value, 0.0, e0.value), cmap='gray') + plt.imshow(data, extent=(0.0, e1.value, 0.0, e0.value), cmap="gray") plt.colorbar() axis = plt.gca() else: - axis.imshow(data, extent=(0.0, e1.value, 0.0, e0.value), cmap='gray') + axis.imshow(data, extent=(0.0, e1.value, 0.0, e0.value), cmap="gray") plt.ylabel(e0.unit.to_string()) plt.xlabel(e1.unit.to_string()) plt.show(block=False) @@ -28,7 +28,7 @@ def imshow(data, axis=None): def scale_prefix(value: u.Quantity) -> u.Quantity: """Scale a quantity to the most appropriate prefix unit.""" - if value.unit.physical_type == 'length': + if value.unit.physical_type == "length": if value < 100 * u.nm: return value.to(u.nm) if value < 100 * u.um: @@ -37,7 +37,7 @@ def scale_prefix(value: u.Quantity) -> u.Quantity: return value.to(u.mm) else: return value.to(u.m) - elif value.unit.physical_type == 'time': + elif value.unit.physical_type == "time": if value < 100 * u.ns: return value.to(u.ns) if value < 100 * u.us: diff --git a/openwfs/processors/__init__.py b/openwfs/processors/__init__.py index 1c2a16a..ca18915 100644 --- a/openwfs/processors/__init__.py +++ b/openwfs/processors/__init__.py @@ -1,2 +1,9 @@ from . import processors -from .processors import CropProcessor, SingleRoi, MultipleRoi, TransformProcessor, Roi, select_roi +from .processors import ( + CropProcessor, + SingleRoi, + MultipleRoi, + TransformProcessor, + Roi, + select_roi, +) diff --git a/openwfs/processors/processors.py b/openwfs/processors/processors.py index d6babae..d270a77 100644 --- a/openwfs/processors/processors.py +++ b/openwfs/processors/processors.py @@ -17,8 +17,9 @@ class Roi: radius, mask type, and parameters specific to the mask type. """ - def __init__(self, pos, radius=0.1, mask_type: str = 'disk', waist=None, - source_shape=None): + def __init__( + self, pos, radius=0.1, mask_type: str = "disk", waist=None, source_shape=None + ): """ Initialize the Roi object. @@ -36,10 +37,16 @@ def __init__(self, pos, radius=0.1, mask_type: str = 'disk', waist=None, """ if pos is None: pos = (source_shape[0] // 2, source_shape[1] // 2) - if round(pos[0] - radius) < 0 or round(pos[1] - radius) < 0 or source_shape is not None and ( - round(pos[0] + radius) >= source_shape[0] or - round(pos[1] + radius) >= source_shape[1]): - raise ValueError('ROI does not fit inside source image') + if ( + round(pos[0] - radius) < 0 + or round(pos[1] - radius) < 0 + or source_shape is not None + and ( + round(pos[0] + radius) >= source_shape[0] + or round(pos[1] + radius) >= source_shape[1] + ) + ): + raise ValueError("ROI does not fit inside source image") self._pos = pos self._radius = radius @@ -99,7 +106,7 @@ def mask_type(self) -> str: @waist.setter def waist(self, value: str): - if value not in ['disk', 'gaussian', 'square']: + if value not in ["disk", "gaussian", "square"]: raise ValueError("mask_type must be 'disk', 'gaussian', or 'square'") self._mask_type = value self._mask = None # need to re-compute mask @@ -124,10 +131,10 @@ def apply(self, image: np.ndarray, order: float = 1.0): # for circular masks, always use an odd number of pixels so that we have a clearly # defined center. # for square masks, instead use the actual size - if self.mask_type == 'disk': + if self.mask_type == "disk": d = round(self._radius) * 2 + 1 self._mask = disk(d, r) - elif self.mask_type == 'gaussian': + elif self.mask_type == "gaussian": d = round(self._radius) * 2 + 1 self._mask = gaussian(d, self._waist) else: # square @@ -138,13 +145,15 @@ def apply(self, image: np.ndarray, order: float = 1.0): image_start = np.array(self.pos) - int(0.5 * self._mask.shape[0] - 0.5) image_cropped = image[ - image_start[0]:image_start[0] + self._mask.shape[0], - image_start[1]:image_start[1] + self._mask.shape[1]] + image_start[0] : image_start[0] + self._mask.shape[0], + image_start[1] : image_start[1] + self._mask.shape[1], + ] if image_cropped.shape != self._mask.shape: raise ValueError( f"ROI is larger than the possible area. ROI shape: {self._mask.shape}, " - + f"Cropped image shape: {image_cropped.shape}") + + f"Cropped image shape: {image_cropped.shape}" + ) if order != 1.0: image_cropped = np.power(image_cropped, order) @@ -200,8 +209,15 @@ def pixel_size(self) -> None: class SingleRoi(MultipleRoi): - def __init__(self, source, pos=None, radius=0.1, mask_type: str = 'disk', waist=0.5, - multi_threaded: bool = True): + def __init__( + self, + source, + pos=None, + radius=0.1, + mask_type: str = "disk", + waist=0.5, + multi_threaded: bool = True, + ): """ Processor that averages a signal over a single region of interest (ROI). @@ -228,8 +244,14 @@ class CropProcessor(Processor): the data is padded with 'padding_value' """ - def __init__(self, source: Detector, shape: Optional[Sequence[int]] = None, - pos: Optional[Sequence[int]] = None, padding_value=0.0, multi_threaded: bool = False): + def __init__( + self, + source: Detector, + shape: Optional[Sequence[int]] = None, + pos: Optional[Sequence[int]] = None, + padding_value=0.0, + multi_threaded: bool = False, + ): """ Args: @@ -244,7 +266,11 @@ def __init__(self, source: Detector, shape: Optional[Sequence[int]] = None, """ super().__init__(source, multi_threaded=multi_threaded) self._data_shape = tuple(shape) if shape is not None else source.data_shape - self._pos = np.array(pos) if pos is not None else np.zeros((len(self.data_shape),), dtype=int) + self._pos = ( + np.array(pos) + if pos is not None + else np.zeros((len(self.data_shape),), dtype=int) + ) self._padding_value = padding_value @property @@ -272,16 +298,19 @@ def _fetch(self, image: np.ndarray) -> np.ndarray: # noqa Returns: the out array containing the cropped image. """ - src_start = np.maximum(self._pos, 0).astype('int32') - src_end = np.minimum(self._pos + self._data_shape, image.shape).astype('int32') - dst_start = np.maximum(-self._pos, 0).astype('int32') + src_start = np.maximum(self._pos, 0).astype("int32") + src_end = np.minimum(self._pos + self._data_shape, image.shape).astype("int32") + dst_start = np.maximum(-self._pos, 0).astype("int32") dst_end = dst_start + src_end - src_start src_select = tuple( - slice(start, end) for (start, end) in zip(src_start, src_end)) + slice(start, end) for (start, end) in zip(src_start, src_end) + ) src = image.__getitem__(src_select) if any(dst_start != 0) or any(dst_end != self._data_shape): dst = np.zeros(self._data_shape) + self._padding_value - dst_select = tuple(slice(start, end) for (start, end) in zip(dst_start, dst_end)) + dst_select = tuple( + slice(start, end) for (start, end) in zip(dst_start, dst_end) + ) dst.__setitem__(dst_select, src) else: dst = src @@ -293,10 +322,17 @@ def select_roi(source: Detector, mask_type: str): """ Opens a window that allows the user to select a region of interest. """ - if mask_type not in ['disk', 'gaussian', 'square']: + if mask_type not in ["disk", "gaussian", "square"]: raise ValueError("mask_type must be 'disk', 'gaussian', or 'square'") - image = cv2.normalize(source.read(), None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U) + image = cv2.normalize( + source.read(), + None, + alpha=0, + beta=255, + norm_type=cv2.NORM_MINMAX, + dtype=cv2.CV_8U, + ) title = "Select ROI and press c to continue or ESC to cancel" cv2.namedWindow(title) cv2.imshow(title, image) @@ -312,10 +348,18 @@ def mouse_callback(event, x, y, flags, _param): elif event == cv2.EVENT_MOUSEMOVE and cv2.EVENT_FLAG_LBUTTON & flags: roi_size = np.minimum(x - roi_start[0], y - roi_start[1]) rect_image = image.copy() - if mask_type == 'square': - cv2.rectangle(rect_image, roi_start, roi_start + roi_size, (0.0, 0.0, 255.0), 2) + if mask_type == "square": + cv2.rectangle( + rect_image, roi_start, roi_start + roi_size, (0.0, 0.0, 255.0), 2 + ) else: - cv2.circle(rect_image, roi_start + roi_size // 2, abs(roi_size) // 2, (0.0, 0.0, 255.0), 2) + cv2.circle( + rect_image, + roi_start + roi_size // 2, + abs(roi_size) // 2, + (0.0, 0.0, 255.0), + 2, + ) cv2.imshow(title, rect_image) cv2.setMouseCallback(title, mouse_callback) @@ -344,19 +388,24 @@ class TransformProcessor(Processor): should match the unit of the input data after applying the transform. """ - def __init__(self, source: Detector, - transform: Transform = None, - data_shape: Optional[Sequence[int]] = None, - pixel_size: Optional[Quantity] = None, - multi_threaded: bool = True): + def __init__( + self, + source: Detector, + transform: Transform = None, + data_shape: Optional[Sequence[int]] = None, + pixel_size: Optional[Quantity] = None, + multi_threaded: bool = True, + ): """ Args: transform: Transform object that describes the transformation from the source to the target image data_shape: Shape of the output. If omitted, the shape of the input data is used. multi_threaded: Whether to perform processing in a worker thread. - """ - if (data_shape is not None and len(data_shape) != 2) or len(source.data_shape) != 2: + """ + if (data_shape is not None and len(data_shape) != 2) or len( + source.data_shape + ) != 2: raise ValueError("TransformProcessor only supports 2-D data") if transform is None: transform = Transform() @@ -364,10 +413,14 @@ def __init__(self, source: Detector, # check if input and output pixel sizes are compatible dst_unit = transform.destination_unit(source.pixel_size.unit) if pixel_size is not None and not pixel_size.unit.is_equivalent(dst_unit): - raise ValueError("Pixel size unit does not match the unit of the transformed data") + raise ValueError( + "Pixel size unit does not match the unit of the transformed data" + ) if pixel_size is None and not source.pixel_size.unit.is_equivalent(dst_unit): - raise ValueError("The transform changes the unit of the coordinates." - " An output pixel_size must be provided.") + raise ValueError( + "The transform changes the unit of the coordinates." + " An output pixel_size must be provided." + ) self.transform = transform super().__init__(source, multi_threaded=multi_threaded) @@ -390,5 +443,9 @@ def _fetch(self, source: np.ndarray) -> np.ndarray: # noqa Returns: ndarray that has been transformed TODO: Fix and add test, or remove """ - return project(source, transform=self.transform, out_shape=self.data_shape, - out_extent=self.extent) + return project( + source, + transform=self.transform, + out_shape=self.data_shape, + out_extent=self.extent, + ) diff --git a/openwfs/simulation/__init__.py b/openwfs/simulation/__init__.py index d4e122f..5f1807f 100644 --- a/openwfs/simulation/__init__.py +++ b/openwfs/simulation/__init__.py @@ -4,6 +4,13 @@ from . import transmission from .microscope import Microscope -from .mockdevices import XYStage, StaticSource, Camera, ADCProcessor, Shutter, NoiseSource +from .mockdevices import ( + XYStage, + StaticSource, + Camera, + ADCProcessor, + Shutter, + NoiseSource, +) from .slm import SLM, PhaseToField from .transmission import SimulatedWFS diff --git a/openwfs/simulation/microscope.py b/openwfs/simulation/microscope.py index c1cdfe0..93cae59 100644 --- a/openwfs/simulation/microscope.py +++ b/openwfs/simulation/microscope.py @@ -11,7 +11,14 @@ from ..plot_utilities import imshow # noqa - for debugging from ..processors import TransformProcessor from ..simulation.mockdevices import XYStage, Camera, StaticSource -from ..utilities import project, place, Transform, get_pixel_size, patterns, CoordinateType +from ..utilities import ( + project, + place, + Transform, + get_pixel_size, + patterns, + CoordinateType, +) class Microscope(Processor): @@ -31,15 +38,22 @@ class Microscope(Processor): that has the same total intensity as the source image. """ - def __init__(self, source: Union[Detector, np.ndarray], *, data_shape=None, - numerical_aperture: float = 1.0, - wavelength: Quantity[u.nm], - magnification: float = 1.0, xy_stage=None, z_stage=None, - incident_field: Union[Detector, ArrayLike, None] = None, - incident_transform: Optional[Transform] = None, - aberrations: Union[Detector, np.ndarray, None] = None, - aberration_transform: Optional[Transform] = None, - multi_threaded: bool = True): + def __init__( + self, + source: Union[Detector, np.ndarray], + *, + data_shape=None, + numerical_aperture: float = 1.0, + wavelength: Quantity[u.nm], + magnification: float = 1.0, + xy_stage=None, + z_stage=None, + incident_field: Union[Detector, ArrayLike, None] = None, + incident_transform: Optional[Transform] = None, + aberrations: Union[Detector, np.ndarray, None] = None, + aberration_transform: Optional[Transform] = None, + multi_threaded: bool = True + ): """ Args: source: 2-D image (must have `pixel_size` metadata), or @@ -93,7 +107,9 @@ def __init__(self, source: Union[Detector, np.ndarray], *, data_shape=None, if get_pixel_size(aberrations) is None: aberrations = StaticSource(aberrations) - super().__init__(source, aberrations, incident_field, multi_threaded=multi_threaded) + super().__init__( + source, aberrations, incident_field, multi_threaded=multi_threaded + ) self._magnification = magnification self._data_shape = data_shape if data_shape is not None else source.data_shape self.numerical_aperture = numerical_aperture @@ -105,8 +121,12 @@ def __init__(self, source: Union[Detector, np.ndarray], *, data_shape=None, self.z_stage = z_stage # or MockStage() self._psf = None - def _fetch(self, source: np.ndarray, aberrations: np.ndarray, # noqa - incident_field: np.ndarray) -> np.ndarray: + def _fetch( + self, + source: np.ndarray, + aberrations: np.ndarray, # noqa + incident_field: np.ndarray, + ) -> np.ndarray: """ Updates the image on the camera sensor @@ -133,7 +153,9 @@ def _fetch(self, source: np.ndarray, aberrations: np.ndarray, # noqa source_pixel_size = get_pixel_size(source) target_pixel_size = self.pixel_size / self.magnification if np.any(source_pixel_size > target_pixel_size): - warnings.warn("The resolution of the specimen image is worse than that of the output.") + warnings.warn( + "The resolution of the specimen image is worse than that of the output." + ) # Note: there seems to be a bug (feature?) in `fftconvolve` that shifts the image by one pixel # when the 'same' option is used. To compensate for this feature, @@ -166,23 +188,40 @@ def _fetch(self, source: np.ndarray, aberrations: np.ndarray, # noqa # Compute the field in the pupil plane # The aberrations and the SLM phase pattern are both mapped to the pupil plane coordinates - pupil_field = patterns.disk(pupil_shape, radius=self.numerical_aperture, extent=pupil_extent) - pupil_area = np.sum(pupil_field) # TODO (efficiency): compute area directly from radius + pupil_field = patterns.disk( + pupil_shape, radius=self.numerical_aperture, extent=pupil_extent + ) + pupil_area = np.sum( + pupil_field + ) # TODO (efficiency): compute area directly from radius # Project aberrations if aberrations is not None: # use default of 2.0 * NA for the extent of the aberration map if no pixel size is provided - aberration_extent = (2.0 * self.numerical_aperture,) * 2 if get_pixel_size(aberrations) is None else None - pupil_field = pupil_field * np.exp(1.0j * project(aberrations, - source_extent=aberration_extent, - out_extent=pupil_extent, - out_shape=pupil_shape, - transform=self.aberration_transform)) + aberration_extent = ( + (2.0 * self.numerical_aperture,) * 2 + if get_pixel_size(aberrations) is None + else None + ) + pupil_field = pupil_field * np.exp( + 1.0j + * project( + aberrations, + source_extent=aberration_extent, + out_extent=pupil_extent, + out_shape=pupil_shape, + transform=self.aberration_transform, + ) + ) # Project SLM fields if incident_field is not None: - pupil_field = pupil_field * project(incident_field, out_extent=pupil_extent, out_shape=pupil_shape, - transform=self.slm_transform) + pupil_field = pupil_field * project( + incident_field, + out_extent=pupil_extent, + out_shape=pupil_shape, + transform=self.slm_transform, + ) # Compute the point spread function # This is done by Fourier transforming the pupil field and taking the absolute value squared @@ -193,7 +232,7 @@ def _fetch(self, source: np.ndarray, aberrations: np.ndarray, # noqa psf = np.fft.ifftshift(psf) * (psf.size / pupil_area) self._psf = psf # store psf for later inspection - return fftconvolve(source, psf, 'same') + return fftconvolve(source, psf, "same") @property def magnification(self) -> float: @@ -225,10 +264,14 @@ def data_shape(self): """Returns the shape of the image in the image plane""" return self._data_shape - def get_camera(self, *, transform: Optional[Transform] = None, - data_shape: Optional[tuple[int, int]] = None, - pixel_size: Optional[CoordinateType] = None, - **kwargs) -> Detector: + def get_camera( + self, + *, + transform: Optional[Transform] = None, + data_shape: Optional[tuple[int, int]] = None, + pixel_size: Optional[CoordinateType] = None, + **kwargs + ) -> Detector: """ Returns a simulated camera that observes the microscope image. @@ -247,6 +290,8 @@ def get_camera(self, *, transform: Optional[Transform] = None, if transform is None and data_shape is None and pixel_size is None: src = self else: - src = TransformProcessor(self, data_shape=data_shape, pixel_size=pixel_size, transform=transform) + src = TransformProcessor( + self, data_shape=data_shape, pixel_size=pixel_size, transform=transform + ) return Camera(src, **kwargs) diff --git a/openwfs/simulation/mockdevices.py b/openwfs/simulation/mockdevices.py index 3f1f593..1117e61 100644 --- a/openwfs/simulation/mockdevices.py +++ b/openwfs/simulation/mockdevices.py @@ -15,8 +15,15 @@ class StaticSource(Detector): Detector that returns pre-set data. Also simulates latency and measurement duration. """ - def __init__(self, data: np.ndarray, pixel_size: Optional[ExtentType] = None, extent: Optional[ExtentType] = None, - latency: Quantity[u.ms] = 0 * u.ms, duration: Quantity[u.ms] = 0 * u.ms, multi_threaded: bool = None): + def __init__( + self, + data: np.ndarray, + pixel_size: Optional[ExtentType] = None, + extent: Optional[ExtentType] = None, + latency: Quantity[u.ms] = 0 * u.ms, + duration: Quantity[u.ms] = 0 * u.ms, + multi_threaded: bool = None, + ): """ Initializes the MockSource TODO: factor out the latency and duration into a separate class? @@ -35,15 +42,24 @@ def __init__(self, data: np.ndarray, pixel_size: Optional[ExtentType] = None, ex else: pixel_size = get_pixel_size(data) - if pixel_size is not None and (np.isscalar(pixel_size) or pixel_size.size == 1) and data.ndim > 1: + if ( + pixel_size is not None + and (np.isscalar(pixel_size) or pixel_size.size == 1) + and data.ndim > 1 + ): pixel_size = pixel_size.repeat(data.ndim) if multi_threaded is None: multi_threaded = latency > 0 * u.ms or duration > 0 * u.ms self._data = data - super().__init__(data_shape=data.shape, pixel_size=pixel_size, latency=latency, duration=duration, - multi_threaded=multi_threaded) + super().__init__( + data_shape=data.shape, + pixel_size=pixel_size, + latency=latency, + duration=duration, + multi_threaded=multi_threaded, + ) def _fetch(self) -> np.ndarray: # noqa total_time_s = self.latency.to_value(u.s) + self.duration.to_value(u.s) @@ -72,22 +88,34 @@ def data(self, value): class NoiseSource(Detector): - def __init__(self, noise_type: str, *, data_shape: tuple[int, ...], pixel_size: Quantity, multi_threaded=True, - generator=None, - **kwargs): + def __init__( + self, + noise_type: str, + *, + data_shape: tuple[int, ...], + pixel_size: Quantity, + multi_threaded=True, + generator=None, + **kwargs, + ): self._noise_type = noise_type self._noise_arguments = kwargs self._rng = generator if generator is not None else np.random.default_rng() - super().__init__(data_shape=data_shape, pixel_size=pixel_size, latency=0 * u.ms, duration=0 * u.ms, - multi_threaded=multi_threaded) + super().__init__( + data_shape=data_shape, + pixel_size=pixel_size, + latency=0 * u.ms, + duration=0 * u.ms, + multi_threaded=multi_threaded, + ) def _fetch(self) -> np.ndarray: # noqa - if self._noise_type == 'uniform': + if self._noise_type == "uniform": return self._rng.uniform(**self._noise_arguments, size=self.data_shape) - elif self._noise_type == 'gaussian': + elif self._noise_type == "gaussian": return self._rng.normal(**self._noise_arguments, size=self.data_shape) else: - raise ValueError(f'Unknown noise type: {self._noise_type}') + raise ValueError(f"Unknown noise type: {self._noise_type}") @Detector.data_shape.setter def data_shape(self, value): @@ -100,9 +128,16 @@ class ADCProcessor(Processor): At the moment, only positive input and output values are supported. """ - def __init__(self, source: Detector, analog_max: float = 0.0, digital_max: int = 0xFFFF, - shot_noise: bool = False, gaussian_noise_std: float = 0.0, multi_threaded: bool = True, - generator=None): + def __init__( + self, + source: Detector, + analog_max: float = 0.0, + digital_max: int = 0xFFFF, + shot_noise: bool = False, + gaussian_noise_std: float = 0.0, + multi_threaded: bool = True, + generator=None, + ): """ Initializes the ADCProcessor class, which mimics an analog-digital converter. @@ -140,7 +175,9 @@ def _fetch(self, data) -> np.ndarray: # noqa if self.analog_max == 0.0: # auto scaling max_value = np.max(data) if max_value > 0.0: - data = data * (self.digital_max / max_value) # auto-scale to maximum value + data = data * ( + self.digital_max / max_value + ) # auto-scale to maximum value else: data = data * (self.digital_max / self.analog_max) @@ -148,9 +185,11 @@ def _fetch(self, data) -> np.ndarray: # noqa data = self._rng.poisson(data) if self._gaussian_noise_std > 0.0: - data = data + self._rng.normal(scale=self._gaussian_noise_std, size=data.shape) + data = data + self._rng.normal( + scale=self._gaussian_noise_std, size=data.shape + ) - return np.clip(np.rint(data), 0, self.digital_max).astype('uint16') + return np.clip(np.rint(data), 0, self.digital_max).astype("uint16") @property def analog_max(self) -> Optional[float]: @@ -165,7 +204,7 @@ def analog_max(self) -> Optional[float]: @analog_max.setter def analog_max(self, value): if value < 0.0: - raise ValueError('analog_max cannot be negative') + raise ValueError("analog_max cannot be negative") self._analog_max = value @property @@ -184,7 +223,7 @@ def conversion_factor(self) -> float: @digital_max.setter def digital_max(self, value): if value < 0 or value > 0xFFFF: - raise ValueError('digital_max must be between 0 and 0xFFFF') + raise ValueError("digital_max must be between 0 and 0xFFFF") self._digital_max = int(value) @property @@ -214,8 +253,13 @@ class Camera(ADCProcessor): Conversion to uint16 is implemented in the ADCProcessor base class. """ - def __init__(self, source: Detector, shape: Optional[Sequence[int]] = None, - pos: Optional[Sequence[int]] = None, **kwargs): + def __init__( + self, + source: Detector, + shape: Optional[Sequence[int]] = None, + pos: Optional[Sequence[int]] = None, + **kwargs, + ): """ Args: source (Detector): The source detector to be wrapped. diff --git a/openwfs/simulation/slm.py b/openwfs/simulation/slm.py index a5905a5..eaa04b2 100644 --- a/openwfs/simulation/slm.py +++ b/openwfs/simulation/slm.py @@ -17,8 +17,12 @@ class PhaseToField(Processor): Computes `amplitude * (exp(1j * phase) + non_modulated_field_fraction)` """ - def __init__(self, slm_phases: Detector, field_amplitude: ArrayLike = 1.0, - non_modulated_field_fraction: float = 0.0): + def __init__( + self, + slm_phases: Detector, + field_amplitude: ArrayLike = 1.0, + non_modulated_field_fraction: float = 0.0, + ): """ Args: slm_phases: The `Detector` that returns the phases of the slm pixels. @@ -34,7 +38,9 @@ def _fetch(self, slm_phases: np.ndarray) -> np.ndarray: # noqa Updates the complex field output of the SLM. The output field is the sum of the modulated field and the non-modulated field. """ - return self.modulated_field_amplitude * (np.exp(1j * slm_phases) + self.non_modulated_field) + return self.modulated_field_amplitude * ( + np.exp(1j * slm_phases) + self.non_modulated_field + ) class _SLMTiming(Detector): @@ -45,15 +51,22 @@ class _SLMTiming(Detector): the refresh rate, or the conversion of gray values to phases. """ - def __init__(self, - shape: tuple[int, ...], - update_latency: Quantity[u.ms] = 0.0 * u.ms, - update_duration: Quantity[u.ms] = 0.0 * u.ms): + def __init__( + self, + shape: tuple[int, ...], + update_latency: Quantity[u.ms] = 0.0 * u.ms, + update_duration: Quantity[u.ms] = 0.0 * u.ms, + ): if len(shape) != 2: raise ValueError("Shape of the SLM should be 2-dimensional.") - super().__init__(data_shape=shape, pixel_size=Quantity(2.0 / np.min(shape)), latency=0 * u.ms, - duration=0 * u.ms, multi_threaded=False) + super().__init__( + data_shape=shape, + pixel_size=Quantity(2.0 / np.min(shape)), + latency=0 * u.ms, + duration=0 * u.ms, + multi_threaded=False, + ) self.update_latency = update_latency self.update_duration = update_duration @@ -139,20 +152,29 @@ class SLM(PhaseSLM, Actuator): A mock version of a phase-only spatial light modulator. Some properties are available to simulate physical phenomena such as imperfect phase response, and front reflections (which cause non-modulated light). """ - __slots__ = ('_hardware_fields', '_hardware_phases', '_hardware_timing', '_back_buffer', - 'refresh_rate', '_first_update_ns', '_lookup_table') - - def __init__(self, - shape: tuple[int, ...], - latency: Quantity[u.ms] = 0.0 * u.ms, - duration: Quantity[u.ms] = 0.0 * u.ms, - update_latency: Quantity[u.ms] = 0.0 * u.ms, - update_duration: Quantity[u.ms] = 0.0 * u.ms, - refresh_rate: Quantity[u.Hz] = 0 * u.Hz, - field_amplitude: Union[np.ndarray, float, None] = 1.0, - non_modulated_field_fraction: float = 0.0, - phase_response: Optional[np.ndarray] = None, - ): + + __slots__ = ( + "_hardware_fields", + "_hardware_phases", + "_hardware_timing", + "_back_buffer", + "refresh_rate", + "_first_update_ns", + "_lookup_table", + ) + + def __init__( + self, + shape: tuple[int, ...], + latency: Quantity[u.ms] = 0.0 * u.ms, + duration: Quantity[u.ms] = 0.0 * u.ms, + update_latency: Quantity[u.ms] = 0.0 * u.ms, + update_duration: Quantity[u.ms] = 0.0 * u.ms, + refresh_rate: Quantity[u.Hz] = 0 * u.Hz, + field_amplitude: Union[np.ndarray, float, None] = 1.0, + non_modulated_field_fraction: float = 0.0, + phase_response: Optional[np.ndarray] = None, + ): """ Args: @@ -169,16 +191,20 @@ def __init__(self, Choose a value different from `duration` to simulate incorrect timing. refresh_rate: Simulated refresh rate. Affects the timing of the `update` method, since this will wait until the next vertical retrace. Keep at 0 to disable this feature. - """ + """ super().__init__(latency=latency, duration=duration) self.refresh_rate = refresh_rate # Simulates transferring frames to the SLM self._hardware_timing = _SLMTiming(shape, update_latency, update_duration) - self._hardware_phases = _SLMPhaseResponse(self._hardware_timing, - phase_response) # Simulates reading the phase from the SLM - self._hardware_fields = PhaseToField(self._hardware_phases, field_amplitude, - non_modulated_field_fraction) # Simulates reading the field from the SLM - self._lookup_table = None # index = input phase (scaled to -> [0, 255]), value = grey value + self._hardware_phases = _SLMPhaseResponse( + self._hardware_timing, phase_response + ) # Simulates reading the phase from the SLM + self._hardware_fields = PhaseToField( + self._hardware_phases, field_amplitude, non_modulated_field_fraction + ) # Simulates reading the field from the SLM + self._lookup_table = ( + None # index = input phase (scaled to -> [0, 255]), value = grey value + ) self._first_update_ns = time.time_ns() self._back_buffer = np.zeros(shape, dtype=np.float32) @@ -188,8 +214,12 @@ def update(self): self._start() # wait for detectors to finish if self.refresh_rate > 0: # wait for the vertical retrace - time_in_frames = unitless((time.time_ns() - self._first_update_ns) * u.ns * self.refresh_rate) - time_to_next_frame = (np.ceil(time_in_frames) - time_in_frames) / self.refresh_rate + time_in_frames = unitless( + (time.time_ns() - self._first_update_ns) * u.ns * self.refresh_rate + ) + time_to_next_frame = ( + np.ceil(time_in_frames) - time_in_frames + ) / self.refresh_rate time.sleep(time_to_next_frame.tovalue(u.s)) # update the start time (this is also done in the actual SLM) self._start() @@ -205,7 +235,9 @@ def update(self): if self._lookup_table is None: grey_values = (256 * tx).astype(np.uint8) else: - lookup_index = (self._lookup_table.shape[0] * tx).astype(np.uint8) # index into lookup table + lookup_index = (self._lookup_table.shape[0] * tx).astype( + np.uint8 + ) # index into lookup table grey_values = self._lookup_table[lookup_index] self._hardware_timing.send(grey_values) @@ -242,8 +274,12 @@ def set_phases(self, values: ArrayLike, update=True): # no docstring, use documentation from base class # Copy the phase image to the back buffer, scaling it as necessary - project(np.atleast_2d(values).astype('float32'), out=self._back_buffer, source_extent=(2.0, 2.0), - out_extent=(2.0, 2.0)) + project( + np.atleast_2d(values).astype("float32"), + out=self._back_buffer, + source_extent=(2.0, 2.0), + out_extent=(2.0, 2.0), + ) if update: self.update() diff --git a/openwfs/simulation/transmission.py b/openwfs/simulation/transmission.py index aed1ecb..d3b3ec4 100644 --- a/openwfs/simulation/transmission.py +++ b/openwfs/simulation/transmission.py @@ -18,8 +18,15 @@ class SimulatedWFS(Processor): For a more advanced (but slower) simulation, use `Microscope` """ - def __init__(self, *, t: Optional[np.ndarray] = None, aberrations: Optional[np.ndarray] = None, slm=None, - multi_threaded=True, beam_amplitude: ScalarType = 1.0): + def __init__( + self, + *, + t: Optional[np.ndarray] = None, + aberrations: Optional[np.ndarray] = None, + slm=None, + multi_threaded=True, + beam_amplitude: ScalarType = 1.0 + ): """ Initializes the optical system with specified aberrations and optionally a Gaussian beam profile. @@ -43,7 +50,12 @@ def __init__(self, *, t: Optional[np.ndarray] = None, aberrations: Optional[np.n """ # transmission matrix (normalized so that the maximum transmission is 1) - self.t = t if t is not None else np.exp(1.0j * aberrations) / (aberrations.shape[0] * aberrations.shape[1]) + self.t = ( + t + if t is not None + else np.exp(1.0j * aberrations) + / (aberrations.shape[0] * aberrations.shape[1]) + ) self.slm = slm if slm is not None else SLM(self.t.shape[0:2]) super().__init__(self.slm.field, multi_threaded=multi_threaded) diff --git a/openwfs/utilities/__init__.py b/openwfs/utilities/__init__.py index 3765cba..3136f24 100644 --- a/openwfs/utilities/__init__.py +++ b/openwfs/utilities/__init__.py @@ -1,5 +1,15 @@ from . import patterns from . import utilities from .patterns import coordinate_range, disk, gaussian, tilt -from .utilities import ExtentType, CoordinateType, unitless, get_pixel_size, \ - set_pixel_size, Transform, project, place, set_extent, get_extent +from .utilities import ( + ExtentType, + CoordinateType, + unitless, + get_pixel_size, + set_pixel_size, + Transform, + project, + place, + set_extent, + get_extent, +) diff --git a/openwfs/utilities/patterns.py b/openwfs/utilities/patterns.py index 1dc9d47..818c352 100644 --- a/openwfs/utilities/patterns.py +++ b/openwfs/utilities/patterns.py @@ -42,8 +42,9 @@ """ -def coordinate_range(shape: ShapeType, extent: ExtentType, offset: Optional[CoordinateType] = None) -> (Quantity, - Quantity): +def coordinate_range( + shape: ShapeType, extent: ExtentType, offset: Optional[CoordinateType] = None +) -> (Quantity, Quantity): """ Returns coordinate vectors for the two coordinates (y and x). @@ -72,8 +73,10 @@ def c_range(res, ex, cx): dx = ex / res return np.arange(res) * dx + (0.5 * dx - 0.5 * ex + cx) - return (c_range(shape[0], extent[0], offset[0]).reshape((-1, 1)), - c_range(shape[1], extent[1], offset[1]).reshape((1, -1))) + return ( + c_range(shape[0], extent[0], offset[0]).reshape((-1, 1)), + c_range(shape[1], extent[1], offset[1]).reshape((1, -1)), + ) def r2_range(shape: ShapeType, extent: ExtentType): @@ -81,10 +84,15 @@ def r2_range(shape: ShapeType, extent: ExtentType): Equivalent to computing cx^2 + cy^2 """ c0, c1 = coordinate_range(shape, extent) - return c0 ** 2 + c1 ** 2 + return c0**2 + c1**2 -def tilt(shape: ShapeType, g: ExtentType, extent: ExtentType = (2.0, 2.0), phase_offset: float = 0.0): +def tilt( + shape: ShapeType, + g: ExtentType, + extent: ExtentType = (2.0, 2.0), + phase_offset: float = 0.0, +): """Constructs a linear gradient pattern φ=2 g·r Args: @@ -115,11 +123,17 @@ def lens(shape: ShapeType, f: ScalarType, wavelength: ScalarType, extent: Extent extent(ExtentType): physical extent of the SLM, same units as `f` and `wavelength` """ r_sqr = r2_range(shape, extent) - return unitless((f - np.sqrt(f ** 2 + r_sqr)) * (2 * np.pi / wavelength)) + return unitless((f - np.sqrt(f**2 + r_sqr)) * (2 * np.pi / wavelength)) -def propagation(shape: ShapeType, distance: ScalarType, numerical_aperture: ScalarType, - refractive_index: ScalarType, wavelength: ScalarType, extent: ExtentType = (2.0, 2.0)): +def propagation( + shape: ShapeType, + distance: ScalarType, + numerical_aperture: ScalarType, + refractive_index: ScalarType, + wavelength: ScalarType, + extent: ExtentType = (2.0, 2.0), +): """Computes a wavefront that corresponds to digitally propagating the field in the object plane. k_z = sqrt(n² k_0²-k_x²-k_y²) @@ -139,7 +153,9 @@ def propagation(shape: ShapeType, distance: ScalarType, numerical_aperture: Scal # convert pupil coordinates to absolute k_x, k_y coordinates k_0 = 2.0 * np.pi / wavelength extent_k = Quantity(extent) * numerical_aperture * k_0 - k_z = np.sqrt(np.maximum((refractive_index * k_0) ** 2 - r2_range(shape, extent_k), 0.0)) + k_z = np.sqrt( + np.maximum((refractive_index * k_0) ** 2 - r2_range(shape, extent_k), 0.0) + ) return unitless(k_z * distance) @@ -153,11 +169,15 @@ def disk(shape: ShapeType, radius: ScalarType = 1.0, extent: ExtentType = (2.0, radius (ScalarType): radius of the disk, should have the same unit as `extent`. extent: see module documentation """ - return 1.0 * (r2_range(shape, extent) < radius ** 2) + return 1.0 * (r2_range(shape, extent) < radius**2) -def gaussian(shape: ShapeType, waist: ScalarType, - truncation_radius: ScalarType = None, extent: ExtentType = (2.0, 2.0)): +def gaussian( + shape: ShapeType, + waist: ScalarType, + truncation_radius: ScalarType = None, + extent: ExtentType = (2.0, 2.0), +): """Constructs an image of a centered Gaussian `waist`, `extent` and the optional `truncation_radius` should all have the same unit. @@ -172,7 +192,7 @@ def gaussian(shape: ShapeType, waist: ScalarType, """ r_sqr = r2_range(shape, extent) - w2inv = -1.0 / waist ** 2 + w2inv = -1.0 / waist**2 gauss = np.exp(unitless(r_sqr * w2inv)) if truncation_radius is not None: gauss = gauss * disk(shape, truncation_radius, extent=extent) diff --git a/openwfs/utilities/utilities.py b/openwfs/utilities/utilities.py index 1b7b5c9..3572e46 100644 --- a/openwfs/utilities/utilities.py +++ b/openwfs/utilities/utilities.py @@ -89,16 +89,25 @@ class Transform: """ - def __init__(self, transform: Optional[TransformType] = None, - source_origin: Optional[CoordinateType] = None, - destination_origin: Optional[CoordinateType] = None): + def __init__( + self, + transform: Optional[TransformType] = None, + source_origin: Optional[CoordinateType] = None, + destination_origin: Optional[CoordinateType] = None, + ): self.transform = Quantity(transform if transform is not None else np.eye(2)) - self.source_origin = Quantity(source_origin) if source_origin is not None else None - self.destination_origin = Quantity(destination_origin) if destination_origin is not None else None + self.source_origin = ( + Quantity(source_origin) if source_origin is not None else None + ) + self.destination_origin = ( + Quantity(destination_origin) if destination_origin is not None else None + ) if source_origin is not None: - self.destination_unit(self.source_origin.unit) # check if the units are consistent + self.destination_unit( + self.source_origin.unit + ) # check if the units are consistent def destination_unit(self, src_unit: u.Unit) -> u.Unit: """Computes the unit of the output of the transformation, given the unit of the input. @@ -107,20 +116,28 @@ def destination_unit(self, src_unit: u.Unit) -> u.Unit: ValueError: If src_unit does not match the unit of the source_origin (if specified) or if dst_unit does not match the unit of the destination_origin (if specified). """ - if self.source_origin is not None and not self.source_origin.unit.is_equivalent(src_unit): + if ( + self.source_origin is not None + and not self.source_origin.unit.is_equivalent(src_unit) + ): raise ValueError("src_unit must match the units of source_origin.") dst_unit = (self.transform[0, 0] * src_unit).unit - if self.destination_origin is not None and not self.destination_origin.unit.is_equivalent(dst_unit): + if ( + self.destination_origin is not None + and not self.destination_origin.unit.is_equivalent(dst_unit) + ): raise ValueError("dst_unit must match the units of destination_origin.") return dst_unit - def cv2_matrix(self, - source_shape: Sequence[int], - source_pixel_size: CoordinateType, - destination_shape: Sequence[int], - destination_pixel_size: CoordinateType) -> np.ndarray: + def cv2_matrix( + self, + source_shape: Sequence[int], + source_pixel_size: CoordinateType, + destination_shape: Sequence[int], + destination_pixel_size: CoordinateType, + ) -> np.ndarray: """Returns the transformation matrix in the format used by cv2.warpAffine.""" # correct the origin. OpenCV uses the _center_ of the top-left corner as the origin @@ -133,30 +150,42 @@ def cv2_matrix(self, if self.source_origin is not None: source_origin += self.source_origin - destination_origin = 0.5 * (np.array(destination_shape) - 1.0) * destination_pixel_size + destination_origin = ( + 0.5 * (np.array(destination_shape) - 1.0) * destination_pixel_size + ) if self.destination_origin is not None: destination_origin += self.destination_origin - centered_transform = Transform(transform=self.transform, - source_origin=source_origin, - destination_origin=destination_origin) + centered_transform = Transform( + transform=self.transform, + source_origin=source_origin, + destination_origin=destination_origin, + ) # then convert the transform to a matrix, using the specified pixel sizes - transform_matrix = centered_transform.to_matrix(source_pixel_size=source_pixel_size, - destination_pixel_size=destination_pixel_size) + transform_matrix = centered_transform.to_matrix( + source_pixel_size=source_pixel_size, + destination_pixel_size=destination_pixel_size, + ) # finally, convert the matrix to the format used by cv2.warpAffine by swapping x and y columns and rows transform_matrix = transform_matrix[[1, 0], :] transform_matrix = transform_matrix[:, [1, 0, 2]] return transform_matrix - def to_matrix(self, source_pixel_size: CoordinateType, destination_pixel_size: CoordinateType) -> np.ndarray: + def to_matrix( + self, source_pixel_size: CoordinateType, destination_pixel_size: CoordinateType + ) -> np.ndarray: matrix = np.zeros((2, 3)) - matrix[0:2, 0:2] = unitless(self.transform * source_pixel_size / destination_pixel_size) + matrix[0:2, 0:2] = unitless( + self.transform * source_pixel_size / destination_pixel_size + ) if self.destination_origin is not None: matrix[0:2, 2] = unitless(self.destination_origin / destination_pixel_size) if self.source_origin is not None: - matrix[0:2, 2] -= unitless((self.transform @ self.source_origin) / destination_pixel_size) + matrix[0:2, 2] -= unitless( + (self.transform @ self.source_origin) / destination_pixel_size + ) return matrix def opencl_matrix(self) -> np.ndarray: @@ -167,9 +196,15 @@ def opencl_matrix(self) -> np.ndarray: # to construct the homogeneous transformation matrix # convert to opencl format: swap x and y columns (note: the rows were # already swapped in the construction of t2), and flip the sign of the y-axis. - transform = np.eye(3, 4, dtype='float32', order='C') - transform[0, 0:3] = matrix[1, [1, 0, 2],] - transform[1, 0:3] = -matrix[0, [1, 0, 2],] + transform = np.eye(3, 4, dtype="float32", order="C") + transform[0, 0:3] = matrix[ + 1, + [1, 0, 2], + ] + transform[1, 0:3] = -matrix[ + 0, + [1, 0, 2], + ] return transform @staticmethod @@ -188,7 +223,8 @@ def __matmul__(self, other): def apply(self, vector: CoordinateType) -> CoordinateType: """Applies the transformation to a column vector. - If `vector` is a 2-D array, applies the transformation to each column of `vector` individually.""" + If `vector` is a 2-D array, applies the transformation to each column of `vector` individually. + """ if self.source_origin is not None: vector = vector - self.source_origin vector = self.transform @ vector @@ -198,7 +234,8 @@ def apply(self, vector: CoordinateType) -> CoordinateType: def inverse(self): """Compute the inverse transformation, - such that the composition of the transformation and its inverse is the identity.""" + such that the composition of the transformation and its inverse is the identity. + """ # invert the transform matrix if self.transform is not None: @@ -207,9 +244,13 @@ def inverse(self): transform = None # swap source and destination origins - return Transform(transform, source_origin=self.destination_origin, destination_origin=self.source_origin) + return Transform( + transform, + source_origin=self.destination_origin, + destination_origin=self.source_origin, + ) - def compose(self, other: 'Transform'): + def compose(self, other: "Transform"): """Compose two transformations. Args: @@ -220,7 +261,11 @@ def compose(self, other: 'Transform'): """ transform = self.transform @ other.transform source_origin = other.source_origin - destination_origin = self.apply(other.destination_origin) if other.destination_origin is not None else None + destination_origin = ( + self.apply(other.destination_origin) + if other.destination_origin is not None + else None + ) return Transform(transform, source_origin, destination_origin) def _standard_input(self) -> Quantity: @@ -232,8 +277,13 @@ def identity(cls): return Transform() -def place(out_shape: tuple[int, ...], out_pixel_size: Quantity, source: np.ndarray, offset: Optional[Quantity] = None, - out: Optional[np.ndarray] = None): +def place( + out_shape: tuple[int, ...], + out_pixel_size: Quantity, + source: np.ndarray, + offset: Optional[Quantity] = None, + out: Optional[np.ndarray] = None, +): """Takes a source array and places it in an otherwise empty array of specified shape and pixel size. The source array must have a pixel_size property (see set_pixel_size). @@ -251,16 +301,20 @@ def place(out_shape: tuple[int, ...], out_pixel_size: Quantity, source: np.ndarr """ out_extent = out_pixel_size * np.array(out_shape) transform = Transform(destination_origin=offset) - return project(source, out_extent=out_extent, out_shape=out_shape, transform=transform, out=out) + return project( + source, out_extent=out_extent, out_shape=out_shape, transform=transform, out=out + ) def project( - source: np.ndarray, *, - source_extent: Optional[ExtentType] = None, - transform: Optional[Transform] = None, - out: Optional[np.ndarray] = None, - out_extent: Optional[ExtentType] = None, - out_shape: Optional[tuple[int, ...]] = None) -> np.ndarray: + source: np.ndarray, + *, + source_extent: Optional[ExtentType] = None, + transform: Optional[Transform] = None, + out: Optional[np.ndarray] = None, + out_extent: Optional[ExtentType] = None, + out_shape: Optional[tuple[int, ...]] = None +) -> np.ndarray: """Projects the input image onto an array with specified shape and resolution. The input image is scaled so that the pixel sizes match those of the output, @@ -281,7 +335,9 @@ def project( transform = transform if transform is not None else Transform() if out is not None: if out_shape is not None and out_shape != out.shape: - raise ValueError("out_shape and out.shape must match. Note that out_shape may be omitted") + raise ValueError( + "out_shape and out.shape must match. Note that out_shape may be omitted" + ) if out.dtype != source.dtype: raise ValueError("out and source must have the same dtype") out_shape = out.shape @@ -289,7 +345,9 @@ def project( if out_shape is None: raise ValueError("Either out_shape or out must be specified") if out_extent is None: - raise ValueError("Either out_extent or the pixel_size metadata of out must be specified") + raise ValueError( + "Either out_extent or the pixel_size metadata of out must be specified" + ) source_extent = source_extent if source_extent is not None else get_extent(source) source_ps = source_extent / np.array(source.shape) out_ps = out_extent / np.array(out_shape) @@ -301,17 +359,38 @@ def project( if out is None: out = np.zeros(out_shape, dtype=source.dtype) # real part - out.real = cv2.warpAffine(source.real, t, out_size, flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, borderValue=(0.0,)) + out.real = cv2.warpAffine( + source.real, + t, + out_size, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0.0,), + ) # imaginary part - out.imag = cv2.warpAffine(source.imag, t, out_size, flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, borderValue=(0.0,)) + out.imag = cv2.warpAffine( + source.imag, + t, + out_size, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0.0,), + ) else: - dst = cv2.warpAffine(source, t, out_size, dst=out, flags=cv2.INTER_NEAREST, - borderMode=cv2.BORDER_CONSTANT, borderValue=(0.0,)) + dst = cv2.warpAffine( + source, + t, + out_size, + dst=out, + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=(0.0,), + ) if out is not None and out is not dst: - raise ValueError("OpenCV did not use the specified output array. This should not happen.") + raise ValueError( + "OpenCV did not use the specified output array. This should not happen." + ) out = dst return set_pixel_size(out, out_ps) @@ -339,7 +418,7 @@ def set_pixel_size(data: ArrayLike, pixel_size: Optional[Quantity]) -> np.ndarra if pixel_size is not None and pixel_size.size == 1: pixel_size = pixel_size * np.ones(data.ndim) - data.dtype = np.dtype(data.dtype, metadata={'pixel_size': pixel_size}) + data.dtype = np.dtype(data.dtype, metadata={"pixel_size": pixel_size}) return data @@ -364,7 +443,7 @@ def get_pixel_size(data: np.ndarray) -> Optional[Quantity]: metadata = data.dtype.metadata if metadata is None: return None - return data.dtype.metadata.get('pixel_size', None) + return data.dtype.metadata.get("pixel_size", None) def get_extent(data: np.ndarray) -> Quantity: diff --git a/tests/test_algorithms_troubleshoot.py b/tests/test_algorithms_troubleshoot.py index 0a76681..8822a1f 100644 --- a/tests/test_algorithms_troubleshoot.py +++ b/tests/test_algorithms_troubleshoot.py @@ -4,9 +4,16 @@ from .test_simulation import phase_response_test_function, lookup_table_test_function from ..openwfs.algorithms import StepwiseSequential -from ..openwfs.algorithms.troubleshoot import cnr, signal_std, find_pixel_shift, \ - field_correlation, frame_correlation, pearson_correlation, \ - measure_modulated_light, measure_modulated_light_dual_phase_stepping +from ..openwfs.algorithms.troubleshoot import ( + cnr, + signal_std, + find_pixel_shift, + field_correlation, + frame_correlation, + pearson_correlation, + measure_modulated_light, + measure_modulated_light_dual_phase_stepping, +) from ..openwfs.processors import SingleRoi from ..openwfs.simulation import SimulatedWFS, StaticSource, SLM, Microscope @@ -18,8 +25,12 @@ def test_signal_std(): a = np.random.rand(400, 400) b = np.random.rand(400, 400) assert signal_std(a, a) < 1e-6 # Test noise only - assert np.abs(signal_std(a + b, b) - a.std()) < 0.005 # Test signal+uncorrelated noise - assert np.abs(signal_std(a + a, a) - np.sqrt(3) * a.std()) < 0.005 # Test signal+correlated noise + assert ( + np.abs(signal_std(a + b, b) - a.std()) < 0.005 + ) # Test signal+uncorrelated noise + assert ( + np.abs(signal_std(a + a, a) - np.sqrt(3) * a.std()) < 0.005 + ) # Test signal+correlated noise def test_cnr(): @@ -30,8 +41,12 @@ def test_cnr(): b = np.random.randn(800, 800) cnr_gt = 3.0 # Ground Truth assert cnr(a, a) < 1e-6 # Test noise only - assert np.abs(cnr(cnr_gt * a + b, b) - cnr_gt) < 0.01 # Test signal+uncorrelated noise - assert np.abs(cnr(cnr_gt * a + a, a) - np.sqrt((cnr_gt + 1) ** 2 - 1)) < 0.01 # Test signal+correlated noise + assert ( + np.abs(cnr(cnr_gt * a + b, b) - cnr_gt) < 0.01 + ) # Test signal+uncorrelated noise + assert ( + np.abs(cnr(cnr_gt * a + a, a) - np.sqrt((cnr_gt + 1) ** 2 - 1)) < 0.01 + ) # Test signal+correlated noise def test_find_pixel_shift(): @@ -80,8 +95,12 @@ def test_field_correlation(): assert field_correlation(a, a) == 1.0 # Self-correlation assert field_correlation(2 * a, a) == 1.0 # Invariant under scalar-multiplication assert field_correlation(a, b) == 0.0 # Orthogonal arrays - assert np.abs(field_correlation(a + b, b) - np.sqrt(0.5)) < 1e-10 # Self+orthogonal array - assert np.abs(field_correlation(b, c) - np.conj(field_correlation(c, b))) < 1e-10 # Arguments swapped + assert ( + np.abs(field_correlation(a + b, b) - np.sqrt(0.5)) < 1e-10 + ) # Self+orthogonal array + assert ( + np.abs(field_correlation(b, c) - np.conj(field_correlation(c, b))) < 1e-10 + ) # Arguments swapped def test_frame_correlation(): @@ -152,12 +171,16 @@ def test_pearson_correlation_noise_compensated(): assert np.isclose(noise1.var(), noise2.var(), atol=2e-3) assert np.isclose(corr_AA, 1, atol=2e-3) assert np.isclose(corr_AB, 0, atol=2e-3) - A_spearman = 1 / np.sqrt((1 + noise1.var() / A1.var()) * (1 + noise2.var() / A2.var())) + A_spearman = 1 / np.sqrt( + (1 + noise1.var() / A1.var()) * (1 + noise2.var() / A2.var()) + ) assert np.isclose(corr_AA_with_noise, A_spearman, atol=2e-3) -@pytest.mark.parametrize("n_y, n_x, phase_steps, b, c, gamma", - [(11, 9, 8, -0.05, 1.5, 0.8), (4, 4, 10, -0.05, 1.5, 0.8)]) +@pytest.mark.parametrize( + "n_y, n_x, phase_steps, b, c, gamma", + [(11, 9, 8, -0.05, 1.5, 0.8), (4, 4, 10, -0.05, 1.5, 0.8)], +) def test_fidelity_phase_calibration_ssa_noise_free(n_y, n_x, phase_steps, b, c, gamma): """ Test computing phase calibration fidelity factor, with the SSA algorithm. Noise-free scenarios. @@ -165,7 +188,9 @@ def test_fidelity_phase_calibration_ssa_noise_free(n_y, n_x, phase_steps, b, c, # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (n_y, n_x)) sim = SimulatedWFS(aberrations=aberrations) - alg = StepwiseSequential(feedback=sim, slm=sim.slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps) + alg = StepwiseSequential( + feedback=sim, slm=sim.slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps + ) result = alg.execute() assert result.fidelity_calibration > 0.99 @@ -181,8 +206,12 @@ def test_fidelity_phase_calibration_ssa_noise_free(n_y, n_x, phase_steps, b, c, assert result.fidelity_calibration > 0.99 -@pytest.mark.parametrize("n_y, n_x, phase_steps, gaussian_noise_std", [(4, 4, 10, 0.2), (6, 6, 12, 1.0)]) -def test_fidelity_phase_calibration_ssa_with_noise(n_y, n_x, phase_steps, gaussian_noise_std): +@pytest.mark.parametrize( + "n_y, n_x, phase_steps, gaussian_noise_std", [(4, 4, 10, 0.2), (6, 6, 12, 1.0)] +) +def test_fidelity_phase_calibration_ssa_with_noise( + n_y, n_x, phase_steps, gaussian_noise_std +): """ Test estimation of phase calibration fidelity factor, with the SSA algorithm. With noise. """ @@ -197,26 +226,39 @@ def test_fidelity_phase_calibration_ssa_with_noise(n_y, n_x, phase_steps, gaussi # SLM, simulation, camera, ROI detector slm = SLM(shape=(80, 80)) - sim = Microscope(source=src, incident_field=slm.field, magnification=1, - numerical_aperture=numerical_aperture, - aberrations=aberration, wavelength=800 * u.nm) + sim = Microscope( + source=src, + incident_field=slm.field, + magnification=1, + numerical_aperture=numerical_aperture, + aberrations=aberration, + wavelength=800 * u.nm, + ) cam = sim.get_camera(analog_max=1e4, gaussian_noise_std=gaussian_noise_std) roi_detector = SingleRoi(cam, radius=0) # Only measure that specific point # Define and run WFS algorithm - alg = StepwiseSequential(feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps) + alg = StepwiseSequential( + feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps + ) result_good = alg.execute() assert result_good.fidelity_calibration > 0.9 # SLM with incorrect phase response linear_phase = np.arange(0, 2 * np.pi, 2 * np.pi / 256) - slm.phase_response = phase_response_test_function(linear_phase, b=0.05, c=0.6, gamma=1.5) + slm.phase_response = phase_response_test_function( + linear_phase, b=0.05, c=0.6, gamma=1.5 + ) result_good = alg.execute() assert result_good.fidelity_calibration < 0.9 -@pytest.mark.parametrize("num_blocks, phase_steps, expected_fid, atol", [(10, 8, 1, 1e-6)]) -def test_measure_modulated_light_dual_phase_stepping_noise_free(num_blocks, phase_steps, expected_fid, atol): +@pytest.mark.parametrize( + "num_blocks, phase_steps, expected_fid, atol", [(10, 8, 1, 1e-6)] +) +def test_measure_modulated_light_dual_phase_stepping_noise_free( + num_blocks, phase_steps, expected_fid, atol +): """Test fidelity estimation due to amount of modulated light. Noise-free.""" # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (20, 20)) @@ -224,12 +266,18 @@ def test_measure_modulated_light_dual_phase_stepping_noise_free(num_blocks, phas # Measure the amount of modulated light (no non-modulated light present) fidelity_modulated = measure_modulated_light_dual_phase_stepping( - slm=sim.slm, feedback=sim, phase_steps=phase_steps, num_blocks=num_blocks) + slm=sim.slm, feedback=sim, phase_steps=phase_steps, num_blocks=num_blocks + ) assert np.isclose(fidelity_modulated, expected_fid, atol=atol) -@pytest.mark.parametrize("num_blocks, phase_steps, gaussian_noise_std, atol", [(10, 6, 0.0, 1e-6), (6, 8, 2.0, 1e-3)]) -def test_measure_modulated_light_dual_phase_stepping_with_noise(num_blocks, phase_steps, gaussian_noise_std, atol): +@pytest.mark.parametrize( + "num_blocks, phase_steps, gaussian_noise_std, atol", + [(10, 6, 0.0, 1e-6), (6, 8, 2.0, 1e-3)], +) +def test_measure_modulated_light_dual_phase_stepping_with_noise( + num_blocks, phase_steps, gaussian_noise_std, atol +): """Test fidelity estimation due to amount of modulated light. Can test with noise.""" # === Define mock hardware, perfect SLM === # Aberration and image source @@ -239,38 +287,55 @@ def test_measure_modulated_light_dual_phase_stepping_with_noise(num_blocks, phas # SLM, simulation, camera, ROI detector slm = SLM(shape=(100, 100)) - sim = Microscope(source=src, incident_field=slm.field, magnification=1, numerical_aperture=1.0, - wavelength=800 * u.nm) + sim = Microscope( + source=src, + incident_field=slm.field, + magnification=1, + numerical_aperture=1.0, + wavelength=800 * u.nm, + ) cam = sim.get_camera(analog_max=1e4, gaussian_noise_std=gaussian_noise_std) roi_detector = SingleRoi(cam, radius=0) # Only measure that specific point # Measure the amount of modulated light (no non-modulated light present) fidelity_modulated = measure_modulated_light_dual_phase_stepping( - slm=slm, feedback=roi_detector, phase_steps=phase_steps, num_blocks=num_blocks) + slm=slm, feedback=roi_detector, phase_steps=phase_steps, num_blocks=num_blocks + ) assert np.isclose(fidelity_modulated, 1, atol=atol) @pytest.mark.parametrize( - "phase_steps, modulated_field_amplitude, non_modulated_field", [(6, 1.0, 0.0), (8, 0.5, 0.5), (8, 1.0, 0.25)]) -def test_measure_modulated_light_noise_free(phase_steps, modulated_field_amplitude, non_modulated_field): + "phase_steps, modulated_field_amplitude, non_modulated_field", + [(6, 1.0, 0.0), (8, 0.5, 0.5), (8, 1.0, 0.25)], +) +def test_measure_modulated_light_noise_free( + phase_steps, modulated_field_amplitude, non_modulated_field +): """Test fidelity estimation due to amount of modulated light. Noise-free.""" # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (20, 20)) - slm = SLM(aberrations.shape, field_amplitude=modulated_field_amplitude, - non_modulated_field_fraction=non_modulated_field) + slm = SLM( + aberrations.shape, + field_amplitude=modulated_field_amplitude, + non_modulated_field_fraction=non_modulated_field, + ) sim = SimulatedWFS(aberrations=aberrations, slm=slm) # Measure the amount of modulated light (no non-modulated light present) - fidelity_modulated = measure_modulated_light(slm=sim.slm, feedback=sim, phase_steps=phase_steps) - expected_fid = 1.0 / (1.0 + non_modulated_field ** 2) + fidelity_modulated = measure_modulated_light( + slm=sim.slm, feedback=sim, phase_steps=phase_steps + ) + expected_fid = 1.0 / (1.0 + non_modulated_field**2) assert np.isclose(fidelity_modulated, expected_fid, rtol=0.1) @pytest.mark.parametrize( "phase_steps, gaussian_noise_std, modulated_field_amplitude, non_modulated_field", - [(8, 0.0, 0.5, 0.4), (6, 0.0, 1.0, 0.0), (12, 2.0, 1.0, 0.25)]) + [(8, 0.0, 0.5, 0.4), (6, 0.0, 1.0, 0.0), (12, 2.0, 1.0, 0.25)], +) def test_measure_modulated_light_dual_phase_stepping_with_noise( - phase_steps, gaussian_noise_std, modulated_field_amplitude, non_modulated_field): + phase_steps, gaussian_noise_std, modulated_field_amplitude, non_modulated_field +): """Test fidelity estimation due to amount of modulated light. Can test with noise.""" # === Define mock hardware, perfect SLM === # Aberration and image source @@ -279,14 +344,18 @@ def test_measure_modulated_light_dual_phase_stepping_with_noise( src = StaticSource(img, 200 * u.nm) # SLM, simulation, camera, ROI detector - slm = SLM(shape=(100, 100), - field_amplitude=modulated_field_amplitude, - non_modulated_field_fraction=non_modulated_field) + slm = SLM( + shape=(100, 100), + field_amplitude=modulated_field_amplitude, + non_modulated_field_fraction=non_modulated_field, + ) sim = Microscope(source=src, incident_field=slm.field, wavelength=800 * u.nm) cam = sim.get_camera(analog_max=1e3, gaussian_noise_std=gaussian_noise_std) roi_detector = SingleRoi(cam, radius=0) # Only measure that specific point # Measure the amount of modulated light (no non-modulated light present) - expected_fid = 1.0 / (1.0 + non_modulated_field ** 2) - fidelity_modulated = measure_modulated_light(slm=slm, feedback=roi_detector, phase_steps=phase_steps) + expected_fid = 1.0 / (1.0 + non_modulated_field**2) + fidelity_modulated = measure_modulated_light( + slm=slm, feedback=roi_detector, phase_steps=phase_steps + ) assert np.isclose(fidelity_modulated, expected_fid, rtol=0.1) diff --git a/tests/test_camera.py b/tests/test_camera.py index d953491..f635c90 100644 --- a/tests/test_camera.py +++ b/tests/test_camera.py @@ -1,7 +1,9 @@ import pytest -pytest.importorskip('harvesters', - reason='harvesters is required for the Camera module, install with pip install harvesters') +pytest.importorskip( + "harvesters", + reason="harvesters is required for the Camera module, install with pip install harvesters", +) from ..openwfs.devices import Camera @@ -36,8 +38,16 @@ def test_roi(camera, binning, top, left): # take care that the size will be a multiple of the increment, # and that setting the binning will round this number down camera.binning = binning - expected_width = (original_shape[1] // binning) // camera._nodes.Width.inc * camera._nodes.Width.inc - expected_height = (original_shape[0] // binning) // camera._nodes.Height.inc * camera._nodes.Height.inc + expected_width = ( + (original_shape[1] // binning) + // camera._nodes.Width.inc + * camera._nodes.Width.inc + ) + expected_height = ( + (original_shape[0] // binning) + // camera._nodes.Height.inc + * camera._nodes.Height.inc + ) assert camera.data_shape == (expected_height, expected_width) # check if setting the ROI works diff --git a/tests/test_core.py b/tests/test_core.py index 847d032..a089f27 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,11 +1,13 @@ import logging import time + +import astropy.units as u +import numpy as np import pytest + +from ..openwfs.processors import CropProcessor from ..openwfs.simulation import StaticSource, NoiseSource, SLM from ..openwfs.utilities import set_pixel_size, get_pixel_size -from ..openwfs.processors import CropProcessor -import numpy as np -import astropy.units as u def test_set_pixel_size(): @@ -75,15 +77,17 @@ def test_timing_detector(caplog, duration): assert np.allclose(f0.result(), image0) t5 = time.time_ns() - assert np.allclose(t1 - t0, 0.0, atol=0.1E9) - assert np.allclose(t2 - t1, duration.to_value(u.ns), atol=0.1E9) - assert np.allclose(t3 - t2, 0.0, atol=0.1E9) - assert np.allclose(t4 - t3, duration.to_value(u.ns), atol=0.1E9) - assert np.allclose(t5 - t4, 0.0, atol=0.1E9) + assert np.allclose(t1 - t0, 0.0, atol=0.1e9) + assert np.allclose(t2 - t1, duration.to_value(u.ns), atol=0.1e9) + assert np.allclose(t3 - t2, 0.0, atol=0.1e9) + assert np.allclose(t4 - t3, duration.to_value(u.ns), atol=0.1e9) + assert np.allclose(t5 - t4, 0.0, atol=0.1e9) def test_noise_detector(): - source = NoiseSource('uniform', data_shape=(10, 11, 20), low=-1.0, high=1.0, pixel_size=4 * u.um) + source = NoiseSource( + "uniform", data_shape=(10, 11, 20), low=-1.0, high=1.0, pixel_size=4 * u.um + ) data = source.read() assert data.shape == (10, 11, 20) assert np.min(data) >= -1.0 @@ -98,18 +102,31 @@ def test_noise_detector(): def test_mock_slm(): slm = SLM((4, 4)) slm.set_phases(0.5) - assert np.allclose(slm.pixels.read(), round(0.5 * 256 / (2 * np.pi)), atol=0.5 / 256) + assert np.allclose( + slm.pixels.read(), round(0.5 * 256 / (2 * np.pi)), atol=0.5 / 256 + ) discretized_phase = slm.phases.read() assert np.allclose(discretized_phase, 0.5, atol=1.1 * np.pi / 256) - assert np.allclose(slm.field.read(), np.exp(1j * discretized_phase[0, 0]), rtol=2 / 256) + assert np.allclose( + slm.field.read(), np.exp(1j * discretized_phase[0, 0]), rtol=2 / 256 + ) slm.set_phases(np.array(((0.1, 0.2), (0.3, 0.4))), update=False) - assert np.allclose(slm.phases.read(), 0.5, atol=1.1 * np.pi / 256) # slm.update() not yet called, so should be 0.5 + assert np.allclose( + slm.phases.read(), 0.5, atol=1.1 * np.pi / 256 + ) # slm.update() not yet called, so should be 0.5 slm.update() - assert np.allclose(slm.phases.read(), np.array(( - (0.1, 0.1, 0.2, 0.2), - (0.1, 0.1, 0.2, 0.2), - (0.3, 0.3, 0.4, 0.4), - (0.3, 0.3, 0.4, 0.4))), atol=1.1 * np.pi / 256) + assert np.allclose( + slm.phases.read(), + np.array( + ( + (0.1, 0.1, 0.2, 0.2), + (0.1, 0.1, 0.2, 0.2), + (0.3, 0.3, 0.4, 0.4), + (0.3, 0.3, 0.4, 0.4), + ) + ), + atol=1.1 * np.pi / 256, + ) def test_crop(): @@ -160,6 +177,7 @@ def test_crop_1d(): assert c3.shape == cropped.data_shape assert np.all(c3 == data[4:6]) + # TODO: translate the tests below. # They should test the SingleROI processor, checking if the returned averaged value is correct. # diff --git a/tests/test_processors.py b/tests/test_processors.py index de74867..591a2c5 100644 --- a/tests/test_processors.py +++ b/tests/test_processors.py @@ -1,24 +1,25 @@ -import pytest +import astropy.units as u import numpy as np -from ..openwfs.simulation.mockdevices import StaticSource -from ..openwfs.processors import SingleRoi, select_roi, Roi, MultipleRoi +import pytest import skimage as sk -import astropy.units as u + +from ..openwfs.processors import SingleRoi, select_roi, Roi, MultipleRoi +from ..openwfs.simulation.mockdevices import StaticSource -@pytest.mark.skip(reason="This is an interactive test: skip by default. TODO: actually test if the roi was " - "selected correctly.") +@pytest.mark.skip( + reason="This is an interactive test: skip by default. TODO: actually test if the roi was " + "selected correctly." +) def test_croppers(): img = sk.data.camera() src = StaticSource(img, 50 * u.nm) - roi = select_roi(src, 'disk') - assert roi.mask_type == 'disk' + roi = select_roi(src, "disk") + assert roi.mask_type == "disk" def test_single_roi_simple_case(): - data = np.array([[1, 2, 3], - [4, 5, 6], - [7, 8, 9]]) + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) pixel_size = 1 * np.ones(2) mock_source = StaticSource(data, pixel_size=pixel_size) roi_processor = SingleRoi(mock_source, radius=np.sqrt(2)) @@ -29,8 +30,9 @@ def test_single_roi_simple_case(): print("Mask:", roi_processor._rois[()]._mask) expected_value = np.mean(data[0:3, 0:3]) # Assuming this is how the ROI is defined - assert np.isclose(result, - expected_value), f"ROI average value is incorrect. Expected: {expected_value}, Got: {result}" + assert np.isclose( + result, expected_value + ), f"ROI average value is incorrect. Expected: {expected_value}, Got: {result}" def create_mock_source_with_data(): @@ -38,36 +40,41 @@ def create_mock_source_with_data(): return StaticSource(data, pixel_size=1 * u.um) -@pytest.mark.parametrize("x, y, radius, expected_avg", [ - (2, 2, 1, 12), # Center ROI in 5x5 matrix - (0, 0, 0, 0) # Top-left corner ROI in 5x5 matrix -]) +@pytest.mark.parametrize( + "x, y, radius, expected_avg", + [ + (2, 2, 1, 12), # Center ROI in 5x5 matrix + (0, 0, 0, 0), # Top-left corner ROI in 5x5 matrix + ], +) def test_single_roi(x, y, radius, expected_avg): mock_source = create_mock_source_with_data() roi_processor = SingleRoi(mock_source, (y, x), radius) roi_processor.trigger() result = roi_processor.read() - assert np.isclose(result, expected_avg), f"ROI average value is incorrect. Expected: {expected_avg}, Got: {result}" + assert np.isclose( + result, expected_avg + ), f"ROI average value is incorrect. Expected: {expected_avg}, Got: {result}" def test_multiple_roi_simple_case(): - data = np.array([[1, 2, 3], - [4, 5, 6], - [7, 8, 9]]) + data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) pixel_size = 1 * np.ones(2) mock_source = StaticSource(data, pixel_size=pixel_size) - rois = [Roi((1, 1), radius=0), - Roi((2, 2), radius=0), - Roi((1, 1), radius=1), - Roi((0, 1), radius=0) - ] + rois = [ + Roi((1, 1), radius=0), + Roi((2, 2), radius=0), + Roi((1, 1), radius=1), + Roi((0, 1), radius=0), + ] roi_processor = MultipleRoi(mock_source, rois=rois) roi_processor.trigger() result = roi_processor.read() expected_values = [5, 9, 5, 2] - assert all(np.isclose(r, e) for r, e in zip(result, expected_values)), \ - f"ROI average values are incorrect. Expected: {expected_values}, Got: {result}" + assert all( + np.isclose(r, e) for r, e in zip(result, expected_values) + ), f"ROI average values are incorrect. Expected: {expected_values}, Got: {result}" diff --git a/tests/test_scanning_microscope.py b/tests/test_scanning_microscope.py index f6a8c45..fe0335a 100644 --- a/tests/test_scanning_microscope.py +++ b/tests/test_scanning_microscope.py @@ -2,8 +2,10 @@ import numpy as np import pytest -pytest.importorskip('nidaqmx', - reason='nidaqmx is required for the ScanningMicroscope module, install with pip install nidaqmx') +pytest.importorskip( + "nidaqmx", + reason="nidaqmx is required for the ScanningMicroscope module, install with pip install nidaqmx", +) from ..openwfs.devices import ScanningMicroscope, Axis from ..openwfs.devices.galvo_scanner import InputChannel @@ -13,12 +15,18 @@ @pytest.mark.parametrize("start, stop", [(0.0, 1.0), (1.0, 0.0)]) def test_scan_axis(start, stop): """Tests if the Axis class generates the correct voltage sequences for stepping and scanning.""" - maximum_acceleration = 1 * u.V / u.ms ** 2 + maximum_acceleration = 1 * u.V / u.ms**2 scale = 440 * u.um / u.V v_min = -1.0 * u.V v_max = 2.0 * u.V - a = Axis(channel='Dev4/ao0', v_min=v_min, v_max=v_max, maximum_acceleration=maximum_acceleration, scale=scale) - assert a.channel == 'Dev4/ao0' + a = Axis( + channel="Dev4/ao0", + v_min=v_min, + v_max=v_max, + maximum_acceleration=maximum_acceleration, + scale=scale, + ) + assert a.channel == "Dev4/ao0" assert a.v_min == v_min assert a.v_max == v_max assert a.maximum_acceleration == maximum_acceleration @@ -36,19 +44,22 @@ def test_scan_axis(start, stop): assert np.isclose(step[-1], 2.0 * u.V if start == 0.0 else -1.0 * u.V) assert np.all(step >= v_min) assert np.all(step <= v_max) - acceleration = np.diff(np.diff(step)) * sample_rate ** 2 + acceleration = np.diff(np.diff(step)) * sample_rate**2 assert np.all(np.abs(acceleration) <= maximum_acceleration * 1.01) center = 0.5 * (start + stop) amplitude = 0.5 * (stop - start) # test clipping - assert np.allclose(step, a.step(center - 1.1 * amplitude, center + 1.1 * amplitude, sample_rate)) + assert np.allclose( + step, a.step(center - 1.1 * amplitude, center + 1.1 * amplitude, sample_rate) + ) # test scan. Note that we cannot use the full scan range because we need # some time to accelerate / decelerate sample_count = 10000 - scan, launch, land, linear_region = a.scan(center - 0.8 * amplitude, center + 0.8 * amplitude, sample_count, - sample_rate) + scan, launch, land, linear_region = a.scan( + center - 0.8 * amplitude, center + 0.8 * amplitude, sample_count, sample_rate + ) half_pixel = 0.8 * amplitude / sample_count # plt.plot(scan) # plt.show() @@ -56,48 +67,76 @@ def test_scan_axis(start, stop): assert linear_region.start == len(scan) - linear_region.stop assert np.isclose(scan[0], a.to_volt(launch)) assert np.isclose(scan[-1], a.to_volt(land)) - assert np.isclose(scan[linear_region.start], a.to_volt(center - 0.8 * amplitude + half_pixel)) - assert np.isclose(scan[linear_region.stop - 1], a.to_volt(center + 0.8 * amplitude - half_pixel)) + assert np.isclose( + scan[linear_region.start], a.to_volt(center - 0.8 * amplitude + half_pixel) + ) + assert np.isclose( + scan[linear_region.stop - 1], a.to_volt(center + 0.8 * amplitude - half_pixel) + ) speed = np.diff(scan[linear_region]) assert np.allclose(speed, speed[0]) # speed should be constant - acceleration = np.diff(np.diff(scan)) * sample_rate ** 2 + acceleration = np.diff(np.diff(scan)) * sample_rate**2 assert np.all(np.abs(acceleration) <= maximum_acceleration * 1.01) def make_scanner(bidirectional, direction, reference_zoom): scale = 440 * u.um / u.V sample_rate = 0.5 * u.MHz - input_channel = InputChannel(channel='Dev4/ai0', v_min=-1.0 * u.V, v_max=1.0 * u.V) - y_axis = Axis(channel='Dev4/ao0', v_min=-2.0 * u.V, v_max=2.0 * u.V, maximum_acceleration=10 * u.V / u.ms ** 2, - scale=scale) - x_axis = Axis(channel='Dev4/ao1', v_min=-2.0 * u.V, v_max=2.0 * u.V, maximum_acceleration=10 * u.V / u.ms ** 2, - scale=scale) - return ScanningMicroscope(bidirectional=bidirectional, sample_rate=sample_rate, resolution=1024, - input=input_channel, y_axis=y_axis, x_axis=x_axis, - test_pattern=direction, reference_zoom=reference_zoom) - - -@pytest.mark.parametrize("direction", ['horizontal', 'vertical']) + input_channel = InputChannel(channel="Dev4/ai0", v_min=-1.0 * u.V, v_max=1.0 * u.V) + y_axis = Axis( + channel="Dev4/ao0", + v_min=-2.0 * u.V, + v_max=2.0 * u.V, + maximum_acceleration=10 * u.V / u.ms**2, + scale=scale, + ) + x_axis = Axis( + channel="Dev4/ao1", + v_min=-2.0 * u.V, + v_max=2.0 * u.V, + maximum_acceleration=10 * u.V / u.ms**2, + scale=scale, + ) + return ScanningMicroscope( + bidirectional=bidirectional, + sample_rate=sample_rate, + resolution=1024, + input=input_channel, + y_axis=y_axis, + x_axis=x_axis, + test_pattern=direction, + reference_zoom=reference_zoom, + ) + + +@pytest.mark.parametrize("direction", ["horizontal", "vertical"]) @pytest.mark.parametrize("bidirectional", [False, True]) def test_scan_pattern(direction, bidirectional): """A unit test for scanning patterns.""" reference_zoom = 1.2 scanner = make_scanner(bidirectional, direction, reference_zoom) - assert np.allclose(scanner.extent, scanner._x_axis.scale * 4.0 * u.V / reference_zoom) + assert np.allclose( + scanner.extent, scanner._x_axis.scale * 4.0 * u.V / reference_zoom + ) # plt.imshow(scanner.read()) # plt.show() # check if returned pattern is correct - (y, x) = coordinate_range((scanner._resolution, scanner._resolution), - 10000 / reference_zoom, offset=(5000, 5000)) - full = scanner.read().astype('float32') - 0x8000 + (y, x) = coordinate_range( + (scanner._resolution, scanner._resolution), + 10000 / reference_zoom, + offset=(5000, 5000), + ) + full = scanner.read().astype("float32") - 0x8000 pixel_size = full[1, 1] - full[0, 0] - if direction == 'horizontal': + if direction == "horizontal": assert np.allclose(full, full[0, :]) # all rows should be the same - assert np.allclose(x, full, atol=0.2 * pixel_size) # some rounding due to quantization + assert np.allclose( + x, full, atol=0.2 * pixel_size + ) # some rounding due to quantization else: # all columns should be the same (note we need to keep the last dimension for correct broadcasting) assert np.allclose(full, full[:, 0:1]) @@ -119,15 +158,17 @@ def test_scan_pattern(direction, bidirectional): assert scanner.height == height assert scanner.data_shape == (height, width) - roi = scanner.read().astype('float32') - 0x8000 - assert np.allclose(full[top:(top + height), left:(left + width)], roi, atol=0.2 * pixel_size) + roi = scanner.read().astype("float32") - 0x8000 + assert np.allclose( + full[top : (top + height), left : (left + width)], roi, atol=0.2 * pixel_size + ) @pytest.mark.parametrize("bidirectional", [False, True]) def test_park_beam(bidirectional): """A unit test for parking the beam of a DAQ scanner.""" reference_zoom = 1.2 - scanner = make_scanner(bidirectional, 'horizontal', reference_zoom) + scanner = make_scanner(bidirectional, "horizontal", reference_zoom) # Park beam horizontally scanner.top = 3 @@ -138,7 +179,9 @@ def test_park_beam(bidirectional): img = scanner.read() assert img.shape == (2, 1) voltages = scanner._scan_pattern - assert np.allclose(voltages[1, :], voltages[1, 0]) # all voltages should be the same + assert np.allclose( + voltages[1, :], voltages[1, 0] + ) # all voltages should be the same # Park beam vertically scanner.width = 2 @@ -154,8 +197,13 @@ def test_park_beam(bidirectional): img = scanner.read() assert img.shape == (1, 1) voltages = scanner._scan_pattern - assert np.allclose(voltages[1, :], voltages[1, 0]) # all voltages should be the same - assert np.allclose(voltages[0, :], voltages[0, 0]) # all voltages should be the same + assert np.allclose( + voltages[1, :], voltages[1, 0] + ) # all voltages should be the same + assert np.allclose( + voltages[0, :], voltages[0, 0] + ) # all voltages should be the same + # test zooming # ps = scanner.pixel_size diff --git a/tests/test_simulation.py b/tests/test_simulation.py index 236ceba..5ae8443 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -21,8 +21,12 @@ def test_mock_camera_and_single_roi(): img = np.zeros((1000, 1000), dtype=np.int16) img[200, 300] = 39.39 # some random float src = Camera(StaticSource(img, 450 * u.nm)) - roi_detector = SingleRoi(src, pos=(200, 300), radius=0) # Only measure that specific point - assert roi_detector.read() == int(2 ** 16 - 1) # it should cast the array into some int + roi_detector = SingleRoi( + src, pos=(200, 300), radius=0 + ) # Only measure that specific point + assert roi_detector.read() == int( + 2**16 - 1 + ) # it should cast the array into some int @pytest.mark.parametrize("shape", [(1000, 1000), (999, 999)]) @@ -37,11 +41,13 @@ def test_microscope_without_magnification(shape): src = Camera(StaticSource(img, 400 * u.nm)) # construct microscope - sim = Microscope(source=src, magnification=1, numerical_aperture=1, wavelength=800 * u.nm) + sim = Microscope( + source=src, magnification=1, numerical_aperture=1, wavelength=800 * u.nm + ) cam = sim.get_camera() img = cam.read() - assert img[256, 256] == 2 ** 16 - 1 + assert img[256, 256] == 2**16 - 1 def test_microscope_and_aberration(): @@ -56,7 +62,13 @@ def test_microscope_and_aberration(): aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) - sim = Microscope(source=src, magnification=1, incident_field=slm.field, numerical_aperture=1, wavelength=800 * u.nm) + sim = Microscope( + source=src, + magnification=1, + incident_field=slm.field, + numerical_aperture=1, + wavelength=800 * u.nm, + ) without_aberration = sim.read()[256, 256] slm.set_phases(aberrations) @@ -78,10 +90,17 @@ def test_slm_and_aberration(): aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) * 0 slm.set_phases(-aberrations) - aberration = StaticSource(aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled) - - sim1 = Microscope(source=src, incident_field=slm.field, numerical_aperture=1.0, aberrations=aberration, - wavelength=800 * u.nm) + aberration = StaticSource( + aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled + ) + + sim1 = Microscope( + source=src, + incident_field=slm.field, + numerical_aperture=1.0, + aberrations=aberration, + wavelength=800 * u.nm, + ) sim2 = Microscope(source=src, numerical_aperture=1.0, wavelength=800 * u.nm) # We correlate the two. @@ -111,14 +130,19 @@ def test_slm_tilt(): slm = SLM(shape=(1000, 1000)) na = 1.0 - sim = Microscope(source=src, incident_field=slm.field, magnification=1, numerical_aperture=na, - wavelength=wavelength) + sim = Microscope( + source=src, + incident_field=slm.field, + magnification=1, + numerical_aperture=na, + wavelength=wavelength, + ) # introduce a tilted pupil plane # the input parameter to `tilt` corresponds to a shift 2.0/π the Abbe diffraction limit. shift = np.array((-24, 40)) step = wavelength / (np.pi * na) - slm.set_phases(tilt(1000, - shift * pixel_size / step)) + slm.set_phases(tilt(1000, -shift * pixel_size / step)) new_location = signal_location + shift @@ -135,7 +159,9 @@ def test_microscope_wavefront_shaping(caplog): # caplog.set_level(logging.DEBUG) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) + np.pi - aberration = StaticSource(aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled) # note: incorrect scaling! + aberration = StaticSource( + aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled + ) # note: incorrect scaling! img = np.zeros((1000, 1000), dtype=np.int16) img[256, 256] = 100 @@ -148,13 +174,22 @@ def test_microscope_wavefront_shaping(caplog): slm = SLM(shape=(1000, 1000)) - sim = Microscope(source=src, incident_field=slm.field, numerical_aperture=1, aberrations=aberration, - wavelength=800 * u.nm) + sim = Microscope( + source=src, + incident_field=slm.field, + numerical_aperture=1, + aberrations=aberration, + wavelength=800 * u.nm, + ) cam = sim.get_camera(analog_max=100) - roi_detector = SingleRoi(cam, pos=signal_location, radius=0) # Only measure that specific point + roi_detector = SingleRoi( + cam, pos=signal_location, radius=0 + ) # Only measure that specific point - alg = StepwiseSequential(feedback=roi_detector, slm=slm, phase_steps=3, n_x=3, n_y=3) + alg = StepwiseSequential( + feedback=roi_detector, slm=slm, phase_steps=3, n_x=3, n_y=3 + ) t = alg.execute().t # test if the modes differ. The error causes them not to differ @@ -186,10 +221,36 @@ def test_mock_slm_lut_and_phase_response(): """ # === Test default lookup table and phase response === # Includes edge cases like rounding/wrapping: -0.501 -> 255, -0.499 -> 0 - input_phases_a = np.asarray( - (-1, -0.501, -0.499, 0, 1, 64, 128, 192, 255, 255.499, 255.501, 256, 257, 511, 512)) * 2 * np.pi / 256 - expected_output_phases_a = np.asarray( - (255, 255, 0, 0, 1, 64, 128, 192, 255, 255, 0, 0, 1, 255, 0)) * 2 * np.pi / 256 + input_phases_a = ( + np.asarray( + ( + -1, + -0.501, + -0.499, + 0, + 1, + 64, + 128, + 192, + 255, + 255.499, + 255.501, + 256, + 257, + 511, + 512, + ) + ) + * 2 + * np.pi + / 256 + ) + expected_output_phases_a = ( + np.asarray((255, 255, 0, 0, 1, 64, 128, 192, 255, 255, 0, 0, 1, 255, 0)) + * 2 + * np.pi + / 256 + ) slm1 = SLM(shape=(3, input_phases_a.shape[0])) slm1.set_phases(input_phases_a) assert np.all(np.abs(slm1.phases.read() - expected_output_phases_a) < 1e6) @@ -213,13 +274,22 @@ def test_mock_slm_lut_and_phase_response(): slm3 = SLM(shape=(3, 256)) slm3.lookup_table = lookup_table slm3.set_phases(linear_phase) - assert np.all(np.abs(slm3.phases.read() - inverse_phase_response_test_function(linear_phase, b, c, gamma)) < ( - 1.1 * np.pi / 256)) + assert np.all( + np.abs( + slm3.phases.read() + - inverse_phase_response_test_function(linear_phase, b, c, gamma) + ) + < (1.1 * np.pi / 256) + ) # === Test custom lookup table that counters custom synthetic phase response === - linear_phase_highres = np.arange(0, 2 * np.pi * 255.49 / 256, 0.25 * 2 * np.pi / 256) + linear_phase_highres = np.arange( + 0, 2 * np.pi * 255.49 / 256, 0.25 * 2 * np.pi / 256 + ) slm4 = SLM(shape=(3, linear_phase_highres.shape[0])) slm4.phase_response = phase_response slm4.lookup_table = lookup_table slm4.set_phases(linear_phase_highres) - assert np.all(np.abs(slm4.phases.read()[0] - linear_phase_highres) < (3 * np.pi / 256)) + assert np.all( + np.abs(slm4.phases.read()[0] - linear_phase_highres) < (3 * np.pi / 256) + ) diff --git a/tests/test_slm.py b/tests/test_slm.py index 5095a90..b4575d0 100644 --- a/tests/test_slm.py +++ b/tests/test_slm.py @@ -19,7 +19,7 @@ @pytest.fixture def slm() -> SLM: - slm = SLM(monitor_id=0, shape=(100, 200), pos=(20, 10), coordinate_system='full') + slm = SLM(monitor_id=0, shape=(100, 200), pos=(20, 10), coordinate_system="full") return slm @@ -30,7 +30,7 @@ def test_create_windowed(slm): assert slm.shape == (100, 200) assert slm.position == (20, 10) assert slm.transform == Transform.identity() - assert slm.coordinate_system == 'full' + assert slm.coordinate_system == "full" # check if frame buffer has correct size fb_texture = slm._frame_buffer._textures[Patch._PHASES_TEXTURE] @@ -110,7 +110,7 @@ def test_transform(slm): # now change the transform to 'short' to fit the pattern to a centered square, with the height of the # SLM. # Then check if the pattern is displayed correctly - slm.coordinate_system = 'short' # does not trigger an update + slm.coordinate_system = "short" # does not trigger an update assert np.all(slm.pixels.read() / 64 == pixels) slm.update() @@ -126,7 +126,7 @@ def test_transform(slm): # now change the transform to 'long' to fit the pattern to a centered square, with the width of the # SLM, causing part of the texture to be mapped outside the window. - slm.coordinate_system = 'long' # does not trigger an update + slm.coordinate_system = "long" # does not trigger an update assert np.all(slm.pixels.read() / 64 == pixels) slm.update() @@ -137,7 +137,7 @@ def test_transform(slm): assert np.allclose(pixels[:, 100:], 3) # test zooming the pattern - slm.coordinate_system = 'short' + slm.coordinate_system = "short" slm.transform = Transform.zoom(0.8) slm.update() @@ -153,8 +153,10 @@ def test_transform(slm): assert np.allclose(sub[20:, 40:], 3) -@pytest.mark.skip(reason="This test is skipped by default because it causes the screen to flicker, which may " - "affect people with epilepsy.") +@pytest.mark.skip( + reason="This test is skipped by default because it causes the screen to flicker, which may " + "affect people with epilepsy." +) def test_refresh_rate(): slm = SLM(1, latency=0, duration=0) refresh_rate = slm.refresh_rate @@ -171,14 +173,16 @@ def test_refresh_rate(): stop = time.time_ns() * u.ns del slm actual_refresh_rate = frame_count / (stop - start) - assert np.allclose(refresh_rate.to_value(u.Hz), actual_refresh_rate.to_value(u.Hz), rtol=1e-2) + assert np.allclose( + refresh_rate.to_value(u.Hz), actual_refresh_rate.to_value(u.Hz), rtol=1e-2 + ) def test_get_pixels(): width = 73 height = 99 slm = SLM(SLM.WINDOWED, shape=(height, width)) - slm.coordinate_system = 'full' # fill full screen exactly (anisotropic coordinates + slm.coordinate_system = "full" # fill full screen exactly (anisotropic coordinates pattern = np.random.uniform(size=(height, width)) * 2 * np.pi slm.set_phases(pattern) read_back = slm.pixels.read() @@ -257,8 +261,22 @@ def test_circular_geometry(slm): # read back the pixels and verify conversion to gray values pixels = np.rint(slm.pixels.read() / 256 * 70) - polar_pixels = cv2.warpPolar(pixels, (100, 40), (99.5, 99.5), 100, cv2.WARP_POLAR_LINEAR) - - assert np.allclose(polar_pixels[:, 3:24], np.repeat(np.flip(np.arange(0, 10)), 4).reshape((-1, 1)), atol=1) - assert np.allclose(polar_pixels[:, 27:47], np.repeat(np.flip(np.arange(10, 30)), 2).reshape((-1, 1)), atol=1) - assert np.allclose(polar_pixels[:, 53:97], np.repeat(np.flip(np.arange(30, 70)), 1).reshape((-1, 1)), atol=1) + polar_pixels = cv2.warpPolar( + pixels, (100, 40), (99.5, 99.5), 100, cv2.WARP_POLAR_LINEAR + ) + + assert np.allclose( + polar_pixels[:, 3:24], + np.repeat(np.flip(np.arange(0, 10)), 4).reshape((-1, 1)), + atol=1, + ) + assert np.allclose( + polar_pixels[:, 27:47], + np.repeat(np.flip(np.arange(10, 30)), 2).reshape((-1, 1)), + atol=1, + ) + assert np.allclose( + polar_pixels[:, 53:97], + np.repeat(np.flip(np.arange(30, 70)), 1).reshape((-1, 1)), + atol=1, + ) diff --git a/tests/test_utilities.py b/tests/test_utilities.py index becf964..f6766cc 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,13 +1,22 @@ -import numpy as np -from ..openwfs.utilities import set_pixel_size, get_pixel_size, place, Transform, project import astropy.units as u +import numpy as np + +from ..openwfs.utilities import ( + set_pixel_size, + get_pixel_size, + place, + Transform, + project, +) def test_to_matrix(): # Create a transform object - transform = Transform(transform=((1, 2), (3, 4)), - source_origin=(0.0, 0.0) * u.m, - destination_origin=(0.001, 0.002) * u.mm) + transform = Transform( + transform=((1, 2), (3, 4)), + source_origin=(0.0, 0.0) * u.m, + destination_origin=(0.001, 0.002) * u.mm, + ) # Define the expected output matrix for same input and output pixel sizes expected_matrix = ((1, 2, 1), (3, 4, 1)) @@ -26,23 +35,29 @@ def test_to_matrix(): src_center = np.array((0.5 * (src[1] - 1), 0.5 * (src[0] - 1), 1.0)) dst_center = np.array((0.5 * (dst[1] - 1), 0.5 * (dst[0] - 1))) transform = Transform() - result_matrix = transform.cv2_matrix(source_shape=src, - source_pixel_size=(1, 1), destination_shape=dst, - destination_pixel_size=(1, 1)) + result_matrix = transform.cv2_matrix( + source_shape=src, + source_pixel_size=(1, 1), + destination_shape=dst, + destination_pixel_size=(1, 1), + ) assert np.allclose(result_matrix @ src_center, dst_center) # Test center correction. The center of the source image should be mapped to the center of the destination image transform = Transform() # transform=((1, 2), (3, 4))) - result_matrix = transform.cv2_matrix(source_shape=src, - source_pixel_size=(0.5, 4) * u.um, destination_shape=dst, - destination_pixel_size=(1, 2) * u.um) + result_matrix = transform.cv2_matrix( + source_shape=src, + source_pixel_size=(0.5, 4) * u.um, + destination_shape=dst, + destination_pixel_size=(1, 2) * u.um, + ) assert np.allclose(result_matrix @ src_center, dst_center) # Also check openGL matrix (has y-axis flipped and extra row and column) expected_matrix = ((1, 2, 1), (3, 4, 2)) - transform = Transform(transform=((1, 2), (3, 4)), - source_origin=(0, 0), - destination_origin=(1, 2)) + transform = Transform( + transform=((1, 2), (3, 4)), source_origin=(0, 0), destination_origin=(1, 2) + ) result_matrix = transform.to_matrix((1, 1), (1, 1)) assert np.allclose(result_matrix, expected_matrix) @@ -96,16 +111,28 @@ def test_transform(): assert np.allclose(matrix, ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0))) # shift both origins by same distance - t0 = Transform(source_origin=-ps1 * (1.7, 2.2), destination_origin=-ps1 * (1.7, 2.2)) - dst0 = project(src, source_extent=ps1 * np.array(src.shape), transform=t0, out_extent=ps1 * np.array(src.shape), - out_shape=src.shape) + t0 = Transform( + source_origin=-ps1 * (1.7, 2.2), destination_origin=-ps1 * (1.7, 2.2) + ) + dst0 = project( + src, + source_extent=ps1 * np.array(src.shape), + transform=t0, + out_extent=ps1 * np.array(src.shape), + out_shape=src.shape, + ) assert np.allclose(dst0, src) # shift source by (1,2) pixel t1 = Transform(source_origin=-ps1 * (1, 2)) dst1a = place(src.shape, ps1, src, offset=ps1 * (1, 2)) - dst1b = project(src, source_extent=ps1 * np.array(src.shape), transform=t1, out_extent=ps1 * np.array(src.shape), - out_shape=src.shape) + dst1b = project( + src, + source_extent=ps1 * np.array(src.shape), + transform=t1, + out_extent=ps1 * np.array(src.shape), + out_shape=src.shape, + ) assert np.allclose(dst1a, dst1b) @@ -121,7 +148,8 @@ def test_inverse(): transform = Transform( transform=((0.1, 0.2), (-0.25, 0.33)), source_origin=(0.12, 0.15), - destination_origin=(0.23, 0.33)) + destination_origin=(0.23, 0.33), + ) vector = (0.3, 0.4) result = transform.apply(vector) diff --git a/tests/test_wfs.py b/tests/test_wfs.py index 5e422d7..f134287 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -5,8 +5,7 @@ from scipy.linalg import hadamard from scipy.ndimage import zoom -from ..openwfs.algorithms import StepwiseSequential, FourierDualReference, \ - DualReference +from ..openwfs.algorithms import StepwiseSequential, FourierDualReference, DualReference from ..openwfs.algorithms.troubleshoot import field_correlation from ..openwfs.algorithms.utilities import WFSController from ..openwfs.processors import SingleRoi @@ -16,7 +15,7 @@ @pytest.mark.parametrize("shape", [(4, 7), (10, 7), (20, 31)]) @pytest.mark.parametrize("noise", [0.0, 0.1]) -@pytest.mark.parametrize("algorithm", ['ssa', 'fourier']) +@pytest.mark.parametrize("algorithm", ["ssa", "fourier"]) def test_multi_target_algorithms(shape, noise: float, algorithm: str): """ Test the multi-target capable algorithms (SSA and Fourier dual ref). @@ -36,15 +35,28 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): I_0 = np.mean(sim.read()) feedback = GaussianNoise(sim, std=I_0 * noise) - if algorithm == 'ssa': - alg = StepwiseSequential(feedback=feedback, slm=sim.slm, n_x=shape[1], n_y=shape[0], phase_steps=phase_steps) + if algorithm == "ssa": + alg = StepwiseSequential( + feedback=feedback, + slm=sim.slm, + n_x=shape[1], + n_y=shape[0], + phase_steps=phase_steps, + ) N = np.prod(shape) # number of input modes alg_fidelity = (N - 1) / N # SSA is inaccurate if N is low - signal = (N - 1) / N ** 2 # for estimating SNR + signal = (N - 1) / N**2 # for estimating SNR else: # 'fourier': - alg = FourierDualReference(feedback=feedback, slm=sim.slm, slm_shape=shape, k_radius=(np.min(shape) - 1) // 2, - phase_steps=phase_steps) - N = alg.phase_patterns[0].shape[2] + alg.phase_patterns[1].shape[2] # number of input modes + alg = FourierDualReference( + feedback=feedback, + slm=sim.slm, + slm_shape=shape, + k_radius=(np.min(shape) - 1) // 2, + phase_steps=phase_steps, + ) + N = ( + alg.phase_patterns[0].shape[2] + alg.phase_patterns[1].shape[2] + ) # number of input modes alg_fidelity = 1.0 # Fourier is accurate for any N signal = 1 / 2 # for estimating SNR. @@ -65,7 +77,10 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): sim.slm.set_phases(-np.angle(result.t[:, :, b])) I_opt[b] = feedback.read()[b] t_correlation += abs(np.vdot(result.t[:, :, b], sim.t[:, :, b])) ** 2 - t_norm += abs(np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b])) + t_norm += abs( + np.vdot(result.t[:, :, b], result.t[:, :, b]) + * np.vdot(sim.t[:, :, b], sim.t[:, :, b]) + ) t_correlation /= t_norm # a correlation of 1 means optimal reconstruction of the N modulated modes, which may be less than the total number of inputs in the transmission matrix @@ -73,37 +88,55 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): # Check the enhancement, noise fidelity and # the fidelity of the transmission matrix reconstruction - theoretical_noise_fidelity = signal / (signal + noise ** 2 / phase_steps) + theoretical_noise_fidelity = signal / (signal + noise**2 / phase_steps) enhancement = I_opt.mean() / I_0 - theoretical_enhancement = np.pi / 4 * theoretical_noise_fidelity * alg_fidelity * (N - 1) + 1 + theoretical_enhancement = ( + np.pi / 4 * theoretical_noise_fidelity * alg_fidelity * (N - 1) + 1 + ) estimated_enhancement = result.estimated_enhancement.mean() * alg_fidelity theoretical_t_correlation = theoretical_noise_fidelity * alg_fidelity - estimated_t_correlation = result.fidelity_noise * result.fidelity_calibration * alg_fidelity + estimated_t_correlation = ( + result.fidelity_noise * result.fidelity_calibration * alg_fidelity + ) tolerance = 2.0 / np.sqrt(M) print( - f"\nenhancement: \ttheoretical= {theoretical_enhancement},\testimated={estimated_enhancement},\tactual: {enhancement}") + f"\nenhancement: \ttheoretical= {theoretical_enhancement},\testimated={estimated_enhancement},\tactual: {enhancement}" + ) print( - f"t-matrix fidelity:\ttheoretical = {theoretical_t_correlation},\testimated = {estimated_t_correlation},\tactual = {t_correlation}") - print(f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}") + f"t-matrix fidelity:\ttheoretical = {theoretical_t_correlation},\testimated = {estimated_t_correlation},\tactual = {t_correlation}" + ) + print( + f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}" + ) print(f"comparing at relative tolerance: {tolerance}") - assert np.allclose(enhancement, theoretical_enhancement, rtol=tolerance), f""" + assert np.allclose( + enhancement, theoretical_enhancement, rtol=tolerance + ), f""" The SSA algorithm did not enhance the focus as much as expected. Theoretical {theoretical_enhancement}, got {enhancement}""" - assert np.allclose(estimated_enhancement, enhancement, rtol=tolerance), f""" + assert np.allclose( + estimated_enhancement, enhancement, rtol=tolerance + ), f""" The SSA algorithm did not estimate the enhancement correctly. Estimated {estimated_enhancement}, got {enhancement}""" - assert np.allclose(t_correlation, theoretical_t_correlation, rtol=tolerance), f""" + assert np.allclose( + t_correlation, theoretical_t_correlation, rtol=tolerance + ), f""" The SSA algorithm did not measure the transmission matrix correctly. Expected {theoretical_t_correlation}, got {t_correlation}""" - assert np.allclose(estimated_t_correlation, theoretical_t_correlation, rtol=tolerance), f""" + assert np.allclose( + estimated_t_correlation, theoretical_t_correlation, rtol=tolerance + ), f""" The SSA algorithm did not estimate the fidelity of the transmission matrix correctly. Expected {theoretical_t_correlation}, got {estimated_t_correlation}""" - assert np.allclose(result.fidelity_noise, theoretical_noise_fidelity, rtol=tolerance), f""" + assert np.allclose( + result.fidelity_noise, theoretical_noise_fidelity, rtol=tolerance + ), f""" The SSA algorithm did not estimate the noise correctly. Expected {theoretical_noise_fidelity}, got {result.fidelity_noise}""" @@ -121,7 +154,9 @@ def test_fourier2(): slm_shape = (1000, 1000) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3) + alg = FourierDualReference( + feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3 + ) controller = WFSController(alg) controller.wavefront = WFSController.State.SHAPED_WAVEFRONT scaled_aberration = zoom(aberrations, np.array(slm_shape) / aberrations.shape) @@ -131,7 +166,9 @@ def test_fourier2(): @pytest.mark.skip("Not implemented") def test_fourier_microscope(): aberration_phase = skimage.data.camera() * ((2 * np.pi) / 255.0) + np.pi - aberration = StaticSource(aberration_phase, pixel_size=2.0 / np.array(aberration_phase.shape)) + aberration = StaticSource( + aberration_phase, pixel_size=2.0 / np.array(aberration_phase.shape) + ) img = np.zeros((1000, 1000), dtype=np.int16) signal_location = (250, 250) img[signal_location] = 100 @@ -139,13 +176,19 @@ def test_fourier_microscope(): src = StaticSource(img, 400 * u.nm) slm = SLM(shape=(1000, 1000)) - sim = Microscope(source=src, incident_field=slm.field, magnification=1, numerical_aperture=1, - aberrations=aberration, - wavelength=800 * u.nm) + sim = Microscope( + source=src, + incident_field=slm.field, + magnification=1, + numerical_aperture=1, + aberrations=aberration, + wavelength=800 * u.nm, + ) cam = sim.get_camera(analog_max=100) roi_detector = SingleRoi(cam, pos=(250, 250)) # Only measure that specific point - alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, - phase_steps=3) + alg = FourierDualReference( + feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, phase_steps=3 + ) controller = WFSController(alg) controller.wavefront = WFSController.State.FLAT_WAVEFRONT before = roi_detector.read() @@ -153,8 +196,12 @@ def test_fourier_microscope(): after = roi_detector.read() # imshow(controller._optimized_wavefront) print(after / before) - scaled_aberration = zoom(aberration_phase, np.array(slm_shape) / aberration_phase.shape) - assert_enhancement(slm, roi_detector, controller._result, np.exp(1j * scaled_aberration)) + scaled_aberration = zoom( + aberration_phase, np.array(slm_shape) / aberration_phase.shape + ) + assert_enhancement( + slm, roi_detector, controller._result, np.exp(1j * scaled_aberration) + ) def test_fourier_correction_field(): @@ -163,12 +210,19 @@ def test_fourier_correction_field(): """ aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=3.0, - phase_steps=3) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=np.shape(aberrations), + k_radius=3.0, + phase_steps=3, + ) t = alg.execute().t t_correct = np.exp(1j * aberrations) - correlation = np.vdot(t, t_correct) / np.sqrt(np.vdot(t, t) * np.vdot(t_correct, t_correct)) + correlation = np.vdot(t, t_correct) / np.sqrt( + np.vdot(t, t) * np.vdot(t_correct, t_correct) + ) # TODO: integrate with other test cases, duplication assert abs(correlation) > 0.75 @@ -182,8 +236,13 @@ def test_phase_shift_correction(): """ aberrations = skimage.data.camera() * (2.0 * np.pi / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=1.5, - phase_steps=3) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=np.shape(aberrations), + k_radius=1.5, + phase_steps=3, + ) t = alg.execute().t # compute the phase pattern to optimize the intensity in target 0 @@ -200,7 +259,9 @@ def test_phase_shift_correction(): signal = sim.read() signals.append(signal) - assert np.std(signals) / np.mean(signals) < 0.001, f"""The response of SimulatedWFS is sensitive to a flat + assert ( + np.std(signals) / np.mean(signals) < 0.001 + ), f"""The response of SimulatedWFS is sensitive to a flat phase shift. This is incorrect behaviour""" @@ -219,15 +280,23 @@ def test_flat_wf_response_fourier(optimized_reference, step): aberrations[:, 2:] = 2.0 sim = SimulatedWFS(aberrations=aberrations.reshape((*aberrations.shape, 1))) - alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=np.shape(aberrations), k_radius=1.5, phase_steps=3, - optimized_reference=optimized_reference) + alg = FourierDualReference( + feedback=sim, + slm=sim.slm, + slm_shape=np.shape(aberrations), + k_radius=1.5, + phase_steps=3, + optimized_reference=optimized_reference, + ) t = alg.execute().t # test the optimized wavefront by checking if it has irregularities. measured_aberrations = np.squeeze(np.angle(t)) measured_aberrations += aberrations[0, 0] - measured_aberrations[0, 0] - assert np.allclose(measured_aberrations, aberrations, atol=0.02) # The measured wavefront is not flat. + assert np.allclose( + measured_aberrations, aberrations, atol=0.02 + ) # The measured wavefront is not flat. def test_flat_wf_response_ssa(): @@ -245,7 +314,9 @@ def test_flat_wf_response_ssa(): # Assert that the standard deviation of the optimized wavefront is below the threshold, # indicating that it is effectively flat - assert np.std(optimised_wf) < 0.001, f"Response flat wavefront not flat, std: {np.std(optimised_wf)}" + assert ( + np.std(optimised_wf) < 0.001 + ), f"Response flat wavefront not flat, std: {np.std(optimised_wf)}" def test_multidimensional_feedback_ssa(): @@ -267,7 +338,9 @@ def test_multidimensional_feedback_ssa(): after = sim.read() enhancement = after / before - assert enhancement[target] >= 3.0, f"""The SSA algorithm did not enhance focus as much as expected. + assert ( + enhancement[target] >= 3.0 + ), f"""The SSA algorithm did not enhance focus as much as expected. Expected at least 3.0, got {enhancement}""" @@ -290,11 +363,13 @@ def test_multidimensional_feedback_fourier(): after = sim.read() enhancement = after / before - assert enhancement[2, 1] >= 3.0, f"""The algorithm did not enhance the focus as much as expected. + assert ( + enhancement[2, 1] >= 3.0 + ), f"""The algorithm did not enhance the focus as much as expected. Expected at least 3.0, got {enhancement}""" -@pytest.mark.parametrize("type", ('plane_wave', 'hadamard')) +@pytest.mark.parametrize("type", ("plane_wave", "hadamard")) @pytest.mark.parametrize("shape", ((8, 8), (16, 4))) def test_custom_blind_dual_reference_ortho_split(type: str, shape): """Test custom blind dual reference with an orthonormal phase-only basis. @@ -302,24 +377,28 @@ def test_custom_blind_dual_reference_ortho_split(type: str, shape): do_debug = False N = shape[0] * (shape[1] // 2) modes_shape = (shape[0], shape[1] // 2, N) - if type == 'plane_wave': + if type == "plane_wave": # Create a full plane wave basis for one half of the SLM. modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) else: # type == 'hadamard': modes = hadamard(N).reshape(modes_shape) - mask = np.concatenate((np.zeros(modes_shape[0:2], dtype=bool), np.ones(modes_shape[0:2], dtype=bool)), axis=1) + mask = np.concatenate( + (np.zeros(modes_shape[0:2], dtype=bool), np.ones(modes_shape[0:2], dtype=bool)), + axis=1, + ) mode_set = np.concatenate((modes, np.zeros(shape=modes_shape)), axis=1) phases_set = np.angle(mode_set) if do_debug: # Plot the modes import matplotlib.pyplot as plt + plt.figure(figsize=(12, 7)) for m in range(N): plt.subplot(*modes_shape[0:1], m + 1) plt.imshow(np.angle(mode_set[:, :, m]), vmin=-np.pi, vmax=np.pi) - plt.title(f'm={m}') + plt.title(f"m={m}") plt.xticks([]) plt.yticks([]) plt.pause(0.1) @@ -327,24 +406,30 @@ def test_custom_blind_dual_reference_ortho_split(type: str, shape): # Create aberrations sim = SimulatedWFS(t=random_transmission_matrix(shape)) - alg = DualReference(feedback=sim, slm=sim.slm, - phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, - iterations=4) + alg = DualReference( + feedback=sim, + slm=sim.slm, + phase_patterns=(phases_set, np.flip(phases_set, axis=1)), + group_mask=mask, + iterations=4, + ) result = alg.execute() if do_debug: plt.figure() - plt.imshow(np.angle(sim.t), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('Aberrations') + plt.imshow(np.angle(sim.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") + plt.title("Aberrations") plt.figure() - plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('t') + plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") + plt.title("t") plt.colorbar() plt.show() - assert np.abs(field_correlation(sim.t, result.t)) > 0.99 # todo: find out why this is not higher + assert ( + np.abs(field_correlation(sim.t, result.t)) > 0.99 + ) # todo: find out why this is not higher def test_custom_blind_dual_reference_non_ortho(): @@ -357,7 +442,9 @@ def test_custom_blind_dual_reference_non_ortho(): N1 = 6 N2 = 3 M = N1 * N2 - mode_set_half = (1 / M) * (1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M))) + mode_set_half = (1 / M) * ( + 1j * np.eye(M).reshape((N1, N2, M)) * -np.ones(shape=(N1, N2, M)) + ) mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1) phases_set = np.angle(mode_set) mask = np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1) @@ -365,41 +452,52 @@ def test_custom_blind_dual_reference_non_ortho(): if do_debug: # Plot the modes import matplotlib.pyplot as plt + plt.figure(figsize=(12, 7)) for m in range(M): plt.subplot(N2, N1, m + 1) plt.imshow(phases_set[:, :, m], vmin=-np.pi, vmax=np.pi) - plt.title(f'm={m}') + plt.title(f"m={m}") plt.xticks([]) plt.yticks([]) plt.pause(0.01) - plt.suptitle('Phase of basis functions for one half') + plt.suptitle("Phase of basis functions for one half") # Create aberrations x = np.linspace(-1, 1, 1 * N1).reshape((1, -1)) y = np.linspace(-1, 1, 1 * N1).reshape((-1, 1)) - aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi) + aberrations = ( + np.sin(0.8 * np.pi * x) + * np.cos(1.3 * np.pi * y) + * (0.8 * np.pi + 0.4 * x + 0.4 * y) + ) % (2 * np.pi) aberrations[0:1, :] = 0 aberrations[:, 0:2] = 0 sim = SimulatedWFS(aberrations=aberrations) - alg = DualReference(feedback=sim, slm=sim.slm, - phase_patterns=(phases_set, np.flip(phases_set, axis=1)), group_mask=mask, - phase_steps=4, - iterations=4) + alg = DualReference( + feedback=sim, + slm=sim.slm, + phase_patterns=(phases_set, np.flip(phases_set, axis=1)), + group_mask=mask, + phase_steps=4, + iterations=4, + ) result = alg.execute() if do_debug: plt.figure() - plt.imshow(np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('Aberrations') + plt.imshow( + np.angle(np.exp(1j * aberrations)), vmin=-np.pi, vmax=np.pi, cmap="hsv" + ) + plt.title("Aberrations") plt.colorbar() plt.figure() - plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap='hsv') - plt.title('t') + plt.imshow(np.angle(result.t), vmin=-np.pi, vmax=np.pi, cmap="hsv") + plt.title("t") plt.colorbar() plt.show() From efd9b7e00c65e5b8930a7864ecb84aa068b03889 Mon Sep 17 00:00:00 2001 From: Ivo Vellekoop Date: Tue, 1 Oct 2024 15:39:21 +0200 Subject: [PATCH 15/15] added black dependency and instructions --- STYLEGUIDE.md | 4 +++- pyproject.toml | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/STYLEGUIDE.md b/STYLEGUIDE.md index 0143572..ed33cef 100644 --- a/STYLEGUIDE.md +++ b/STYLEGUIDE.md @@ -6,7 +6,9 @@ # General -- PyCharm autoformatting should be enabled to ensure correct formatting. +- all .py files MUST be formatted with the 'black' autoformatter. This can be done by installing the 'black' package, + and running `black .` in the root directory. black is automatically installed when the development dependencies are + included. # Tests diff --git a/pyproject.toml b/pyproject.toml index 4e78432..094e472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,7 @@ optional = true scikit-image = ">=0.21.0" pytest = "~7.0.0" nidaq = "nidaqmx >=0.8.0" # we can test without the hardware, but still need the package +black = ">=24.0.0" # code formatter [tool.poetry.group.docs] optional = true