From 6187e5d82532f18688ec88300a9cc54d84109778 Mon Sep 17 00:00:00 2001 From: Jeroen Doornbos Date: Thu, 3 Oct 2024 15:09:36 +0200 Subject: [PATCH] ran black with line-length 120 --- docs/source/conf.py | 20 +-- examples/sample_microscope.py | 8 +- examples/slm_demo.py | 4 +- examples/troubleshooter_demo.py | 4 +- examples/wfs_demonstration_experimental.py | 8 +- openwfs/algorithms/basic_fourier.py | 4 +- .../algorithms/custom_iter_dual_reference.py | 76 ++++++---- openwfs/algorithms/dual_reference.py | 58 +++----- openwfs/algorithms/genetic.py | 4 +- openwfs/algorithms/ssa.py | 4 +- openwfs/algorithms/troubleshoot.py | 96 ++++--------- openwfs/algorithms/utilities.py | 20 +-- openwfs/core.py | 62 ++------- openwfs/devices/camera.py | 4 +- openwfs/devices/galvo_scanner.py | 131 +++++------------- openwfs/devices/nidaq_gain.py | 4 +- openwfs/devices/slm/geometry.py | 12 +- openwfs/devices/slm/patch.py | 32 ++--- openwfs/devices/slm/slm.py | 92 +++--------- openwfs/devices/slm/texture.py | 12 +- openwfs/plot_utilities.py | 64 ++++++--- openwfs/processors/processors.py | 38 ++--- openwfs/simulation/microscope.py | 26 +--- openwfs/simulation/mockdevices.py | 14 +- openwfs/simulation/slm.py | 20 +-- openwfs/simulation/transmission.py | 7 +- openwfs/utilities/patterns.py | 4 +- openwfs/utilities/utilities.py | 60 ++------ tests/test_algorithms_troubleshoot.py | 76 +++------- tests/test_camera.py | 12 +- tests/test_core.py | 16 +-- tests/test_processors.py | 7 +- tests/test_scanning_microscope.py | 36 ++--- tests/test_simulation.py | 58 +++----- tests/test_slm.py | 8 +- tests/test_utilities.py | 8 +- tests/test_wfs.py | 85 ++++-------- 37 files changed, 359 insertions(+), 835 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 060e523..9c927d7 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -34,13 +34,9 @@ # basic project information project = "OpenWFS" copyright = "2023-, Ivo Vellekoop, Daniël W. S. Cox, and Jeroen H. Doornbos, University of Twente" -author = ( - "Jeroen H. Doornbos, Daniël W. S. Cox, Tom Knop, Harish Sasikumar, Ivo M. Vellekoop" -) +author = "Jeroen H. Doornbos, Daniël W. S. Cox, Tom Knop, Harish Sasikumar, Ivo M. Vellekoop" release = "0.1.0rc2" -html_title = ( - "OpenWFS - a library for conducting and simulating wavefront shaping experiments" -) +html_title = "OpenWFS - a library for conducting and simulating wavefront shaping experiments" # \renewenvironment{sphinxtheindex}{\setbox0\vbox\bgroup\begin{theindex}}{\end{theindex}} # latex configuration @@ -167,23 +163,17 @@ def setup(app): def source_read(app, docname, source): if docname == "readme" or docname == "conclusion": if (app.builder.name == "latex") == (docname == "conclusion"): - source[0] = source[0].replace( - "%endmatter%", ".. include:: acknowledgements.rst" - ) + source[0] = source[0].replace("%endmatter%", ".. include:: acknowledgements.rst") else: source[0] = source[0].replace("%endmatter%", "") def builder_inited(app): if app.builder.name == "html": - exclude_patterns.extend( - ["conclusion.rst", "index_latex.rst", "index_markdown.rst"] - ) + exclude_patterns.extend(["conclusion.rst", "index_latex.rst", "index_markdown.rst"]) app.config.master_doc = "index" elif app.builder.name == "latex": - exclude_patterns.extend( - ["auto_examples/*", "index_markdown.rst", "index.rst", "api*"] - ) + exclude_patterns.extend(["auto_examples/*", "index_markdown.rst", "index.rst", "api*"]) app.config.master_doc = "index_latex" elif app.builder.name == "markdown": include_patterns.clear() diff --git a/examples/sample_microscope.py b/examples/sample_microscope.py index 9259594..a2daba7 100644 --- a/examples/sample_microscope.py +++ b/examples/sample_microscope.py @@ -41,9 +41,7 @@ # Code img = set_pixel_size( - np.maximum( - np.random.randint(-10000, 100, (img_size_y, img_size_x), dtype=np.int16), 0 - ), + np.maximum(np.random.randint(-10000, 100, (img_size_y, img_size_x), dtype=np.int16), 0), 60 * u.nm, ) src = StaticSource(img) @@ -76,7 +74,5 @@ mic.xy_stage.x = p * 1 * u.um mic.numerical_aperture = 1.0 * (p + 1) / p_limit # NA increases to 1.0 ax = grab_and_show(cam, ax) - plt.title( - f"NA: {mic.numerical_aperture}, δ: {mic.abbe_limit.to_value(u.um):2.2} μm" - ) + plt.title(f"NA: {mic.numerical_aperture}, δ: {mic.abbe_limit.to_value(u.um):2.2} μm") plt.pause(0.2) diff --git a/examples/slm_demo.py b/examples/slm_demo.py index aba50bf..bbb64c9 100644 --- a/examples/slm_demo.py +++ b/examples/slm_demo.py @@ -34,9 +34,7 @@ p4.phases = 1 p4.additive_blend = False -pf.phases = patterns.lens( - 100, f=1 * u.m, wavelength=0.8 * u.um, extent=(10 * u.mm, 10 * u.mm) -) +pf.phases = patterns.lens(100, f=1 * u.m, wavelength=0.8 * u.um, extent=(10 * u.mm, 10 * u.mm)) rng = np.random.default_rng() for n in range(200): random_data = rng.random([10, 10], np.float32) * 2.0 * np.pi diff --git a/examples/troubleshooter_demo.py b/examples/troubleshooter_demo.py index abf08d6..f836ccf 100644 --- a/examples/troubleshooter_demo.py +++ b/examples/troubleshooter_demo.py @@ -56,7 +56,5 @@ roi_background = SingleRoi(cam, radius=10) # Run WFS troubleshooter and output a report to the console -trouble = troubleshoot( - algorithm=alg, background_feedback=roi_background, frame_source=cam, shutter=shutter -) +trouble = troubleshoot(algorithm=alg, background_feedback=roi_background, frame_source=cam, shutter=shutter) trouble.report() diff --git a/examples/wfs_demonstration_experimental.py b/examples/wfs_demonstration_experimental.py index b2a16f2..44eb784 100644 --- a/examples/wfs_demonstration_experimental.py +++ b/examples/wfs_demonstration_experimental.py @@ -19,16 +19,12 @@ # constructs the actual slm for wavefront shaping, and a monitor window to display the current phase pattern slm = SLM(monitor_id=2, duration=2) -monitor = slm.clone( - monitor_id=0, pos=(0, 0), shape=(slm.shape[0] // 4, slm.shape[1] // 4) -) +monitor = slm.clone(monitor_id=0, pos=(0, 0), shape=(slm.shape[0] // 4, slm.shape[1] // 4)) # we are using a setup with an SLM that produces 2pi phase shift # at a gray value of 142 slm.lookup_table = range(142) -alg = FourierDualReference( - feedback=roi_detector, slm=slm, slm_shape=[800, 800], k_radius=7 -) +alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=[800, 800], k_radius=7) result = alg.execute() print(result) diff --git a/openwfs/algorithms/basic_fourier.py b/openwfs/algorithms/basic_fourier.py index 62bd394..ec1d8da 100644 --- a/openwfs/algorithms/basic_fourier.py +++ b/openwfs/algorithms/basic_fourier.py @@ -52,14 +52,14 @@ def __init__( self.k_step = k_step self._slm_shape = slm_shape group_mask = np.zeros(slm_shape, dtype=bool) - group_mask[:, slm_shape[1] // 2:] = True + group_mask[:, slm_shape[1] // 2 :] = True super().__init__( feedback=feedback, slm=slm, phase_patterns=None, group_mask=group_mask, phase_steps=phase_steps, - amplitude='uniform', + amplitude="uniform", iterations=iterations, optimized_reference=optimized_reference, analyzer=analyzer, diff --git a/openwfs/algorithms/custom_iter_dual_reference.py b/openwfs/algorithms/custom_iter_dual_reference.py index 73b083b..4d6ce4c 100644 --- a/openwfs/algorithms/custom_iter_dual_reference.py +++ b/openwfs/algorithms/custom_iter_dual_reference.py @@ -45,8 +45,16 @@ class IterativeDualReference: https://opg.optica.org/oe/ abstract.cfm?uri=oe-27-8-1167 """ - def __init__(self, feedback: Detector, slm: PhaseSLM, phase_patterns: tuple[nd, nd], group_mask: nd, - phase_steps: int = 4, iterations: int = 4, analyzer: Optional[callable] = analyze_phase_stepping): + def __init__( + self, + feedback: Detector, + slm: PhaseSLM, + phase_patterns: tuple[nd, nd], + group_mask: nd, + phase_steps: int = 4, + iterations: int = 4, + analyzer: Optional[callable] = analyze_phase_stepping, + ): """ Args: feedback: The feedback source, usually a detector that provides measurement data. @@ -79,8 +87,9 @@ def __init__(self, feedback: Detector, slm: PhaseSLM, phase_patterns: tuple[nd, self.masks = (~mask, mask) # masks[0] is True for group A, mask[1] is True for group B # Pre-compute the conjugate modes for reconstruction - self.modes = [np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in - range(2)] + self.modes = [ + np.exp(-1j * self.phase_patterns[side]) * np.expand_dims(self.masks[side], axis=2) for side in range(2) + ] def execute(self, capture_intermediate_results: bool = False, progress_bar=None) -> WFSResult: """ @@ -109,8 +118,10 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) # Prepare progress bar if progress_bar: - num_measurements = np.ceil(self.iterations / 2) * self.modes[0].shape[2] \ - + np.floor(self.iterations / 2) * self.modes[1].shape[2] + num_measurements = ( + np.ceil(self.iterations / 2) * self.modes[0].shape[2] + + np.floor(self.iterations / 2) * self.modes[1].shape[2] + ) progress_bar.total = num_measurements # Switch the phase sets back and forth multiple times @@ -119,8 +130,12 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) ref_phases = -np.angle(t_full) # use the best estimate so far to construct an optimized reference side_mask = self.masks[side] # Perform WFS experiment on one side, keeping the other side sized at the ref_phases - result = self._single_side_experiment(mod_phases=self.phase_patterns[side], ref_phases=ref_phases, - mod_mask=side_mask, progress_bar=progress_bar) + result = self._single_side_experiment( + mod_phases=self.phase_patterns[side], + ref_phases=ref_phases, + mod_mask=side_mask, + progress_bar=progress_bar, + ) # Compute transmission matrix for the current side and update # estimated transmission matrix @@ -139,23 +154,31 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) intermediate_results[it] = self.feedback.read() # Compute average fidelity factors - fidelity_noise = weighted_average(results_latest[0].fidelity_noise, - results_latest[1].fidelity_noise, results_latest[0].n, - results_latest[1].n) - fidelity_amplitude = weighted_average(results_latest[0].fidelity_amplitude, - results_latest[1].fidelity_amplitude, results_latest[0].n, - results_latest[1].n) - fidelity_calibration = weighted_average(results_latest[0].fidelity_calibration, - results_latest[1].fidelity_calibration, results_latest[0].n, - results_latest[1].n) - - result = WFSResult(t=t_full, - t_f=None, - n=self.modes[0].shape[2] + self.modes[1].shape[2], - axis=2, - fidelity_noise=fidelity_noise, - fidelity_amplitude=fidelity_amplitude, - fidelity_calibration=fidelity_calibration) + fidelity_noise = weighted_average( + results_latest[0].fidelity_noise, results_latest[1].fidelity_noise, results_latest[0].n, results_latest[1].n + ) + fidelity_amplitude = weighted_average( + results_latest[0].fidelity_amplitude, + results_latest[1].fidelity_amplitude, + results_latest[0].n, + results_latest[1].n, + ) + fidelity_calibration = weighted_average( + results_latest[0].fidelity_calibration, + results_latest[1].fidelity_calibration, + results_latest[0].n, + results_latest[1].n, + ) + + result = WFSResult( + t=t_full, + t_f=None, + n=self.modes[0].shape[2] + self.modes[1].shape[2], + axis=2, + fidelity_noise=fidelity_noise, + fidelity_amplitude=fidelity_amplitude, + fidelity_calibration=fidelity_calibration, + ) # TODO: document the t_set_all and results_all attributes result.t_set_all = t_set_all @@ -163,8 +186,7 @@ def execute(self, capture_intermediate_results: bool = False, progress_bar=None) result.intermediate_results = intermediate_results return result - def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, - progress_bar=None) -> WFSResult: + def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, progress_bar=None) -> WFSResult: """ Conducts experiments on one part of the SLM. diff --git a/openwfs/algorithms/dual_reference.py b/openwfs/algorithms/dual_reference.py index 3fd0283..03bda5b 100644 --- a/openwfs/algorithms/dual_reference.py +++ b/openwfs/algorithms/dual_reference.py @@ -89,9 +89,7 @@ def __init__( if iterations < 2: raise ValueError("The number of iterations must be at least 2.") if not optimized_reference and iterations != 2: - raise ValueError( - "When not using an optimized reference, the number of iterations must be 2." - ) + raise ValueError("When not using an optimized reference, the number of iterations must be 2.") self.slm = slm self.feedback = feedback @@ -108,7 +106,7 @@ def __init__( ~mask, mask, ) # self.masks[0] is True for group A, self.masks[1] is True for group B - self.amplitude = amplitude # Note: when 'uniform' is passed, the shape of self.masks[0] is used. + self.amplitude = amplitude # Note: when 'uniform' is passed, the shape of self.masks[0] is used. self.phase_patterns = phase_patterns @property @@ -121,19 +119,17 @@ def amplitude(self, value): self._amplitude = None return - if value == 'uniform': + if value == "uniform": self._amplitude = tuple( - (np.ones(shape=self._shape) / np.sqrt(self.masks[side].sum())).astype(np.float32) for side in range(2)) + (np.ones(shape=self._shape) / np.sqrt(self.masks[side].sum())).astype(np.float32) for side in range(2) + ) return if value[0].shape != self._shape or value[1].shape != self._shape: - raise ValueError( - "The amplitude and group mask must all have the same shape." - ) + raise ValueError("The amplitude and group mask must all have the same shape.") self._amplitude = value - @property def phase_patterns(self) -> tuple[nd, nd]: return self._phase_patterns @@ -148,26 +144,14 @@ def phase_patterns(self, value): if not self.optimized_reference: # find the modes in A and B that correspond to flat wavefronts with phase 0 try: - a0_index = next( - i - for i in range(value[0].shape[2]) - if np.allclose(value[0][:, :, i], 0) - ) - b0_index = next( - i - for i in range(value[1].shape[2]) - if np.allclose(value[1][:, :, i], 0) - ) + a0_index = next(i for i in range(value[0].shape[2]) if np.allclose(value[0][:, :, i], 0)) + b0_index = next(i for i in range(value[1].shape[2]) if np.allclose(value[1][:, :, i], 0)) self.zero_indices = (a0_index, b0_index) except StopIteration: - raise ( - "For multi-target optimization, the both sets must contain a flat wavefront with phase 0." - ) + raise ("For multi-target optimization, the both sets must contain a flat wavefront with phase 0.") if (value[0].shape[0:2] != self._shape) or (value[1].shape[0:2] != self._shape): - raise ValueError( - "The phase patterns and group mask must all have the same shape." - ) + raise ValueError("The phase patterns and group mask must all have the same shape.") self._phase_patterns = ( value[0].astype(np.float32), @@ -200,7 +184,7 @@ def _compute_cobasis(self): denotes the matrix inverse, and ⁺ denotes the Moore-Penrose pseudo-inverse. """ if self.phase_patterns is None: - raise('The phase_patterns must be set before computing the cobasis.') + raise ("The phase_patterns must be set before computing the cobasis.") cobasis = [None, None] for side in range(2): @@ -215,9 +199,7 @@ def _compute_cobasis(self): self._cobasis = cobasis - def execute( - self, *, capture_intermediate_results: bool = False, progress_bar=None - ) -> WFSResult: + def execute(self, *, capture_intermediate_results: bool = False, progress_bar=None) -> WFSResult: """ Executes the blind focusing dual reference algorithm and compute the SLM transmission matrix. capture_intermediate_results: When True, measures the feedback from the optimized wavefront after each iteration. @@ -237,9 +219,7 @@ def execute( # Initialize storage lists results_all = [None] * self.iterations # List to store all results - intermediate_results = np.zeros( - self.iterations - ) # List to store feedback from full patterns + intermediate_results = np.zeros(self.iterations) # List to store feedback from full patterns # Prepare progress bar if progress_bar: @@ -284,9 +264,7 @@ def execute( relative = results_all[0].t[self.zero_indices[0], ...] + np.conjugate( results_all[1].t[self.zero_indices[1], ...] ) - factor = (relative / np.abs(relative)).reshape( - (1, *self.feedback.data_shape) - ) + factor = (relative / np.abs(relative)).reshape((1, *self.feedback.data_shape)) t_full = self.compute_t_set(results_all[0].t, self.cobasis[0]) + self.compute_t_set( factor * results_all[1].t, self.cobasis[1] @@ -304,9 +282,7 @@ def execute( result.intermediate_results = intermediate_results return result - def _single_side_experiment( - self, mod_phases: nd, ref_phases: nd, mod_mask: nd, progress_bar=None - ) -> WFSResult: + def _single_side_experiment(self, mod_phases: nd, ref_phases: nd, mod_mask: nd, progress_bar=None) -> WFSResult: """ Conducts experiments on one part of the SLM. @@ -322,9 +298,7 @@ def _single_side_experiment( WFSResult: An object containing the computed SLM transmission matrix and related data. """ num_modes = mod_phases.shape[2] - measurements = np.zeros( - (num_modes, self.phase_steps, *self.feedback.data_shape) - ) + measurements = np.zeros((num_modes, self.phase_steps, *self.feedback.data_shape)) for m in range(num_modes): phases = ref_phases.copy() diff --git a/openwfs/algorithms/genetic.py b/openwfs/algorithms/genetic.py index 519aa54..7b0df1b 100644 --- a/openwfs/algorithms/genetic.py +++ b/openwfs/algorithms/genetic.py @@ -61,9 +61,7 @@ def __init__( self.elite_size = elite_size self.generations = generations self.generator = generator or np.random.default_rng() - self.mutation_count = round( - (population_size - elite_size) * np.prod(shape) * mutation_probability - ) + self.mutation_count = round((population_size - elite_size) * np.prod(shape) * mutation_probability) def _generate_random_phases(self, shape): return self.generator.random(size=shape, dtype=np.float32) * (2 * np.pi) diff --git a/openwfs/algorithms/ssa.py b/openwfs/algorithms/ssa.py index 4e3eeee..48368af 100644 --- a/openwfs/algorithms/ssa.py +++ b/openwfs/algorithms/ssa.py @@ -48,9 +48,7 @@ def execute(self) -> WFSResult: WFSResult: An object containing the computed transmission matrix and statistics. """ phase_pattern = np.zeros((self.n_y, self.n_x), "float32") - measurements = np.zeros( - (self.n_y, self.n_x, self.phase_steps, *self.feedback.data_shape) - ) + measurements = np.zeros((self.n_y, self.n_x, self.phase_steps, *self.feedback.data_shape)) for y in range(self.n_y): for x in range(self.n_x): diff --git a/openwfs/algorithms/troubleshoot.py b/openwfs/algorithms/troubleshoot.py index 4ae3cd4..4a82a15 100644 --- a/openwfs/algorithms/troubleshoot.py +++ b/openwfs/algorithms/troubleshoot.py @@ -47,9 +47,7 @@ def cnr(signal_with_noise: np.ndarray, noise: np.ndarray) -> np.float64: return signal_std(signal_with_noise, noise) / noise.std() -def contrast_enhancement( - signal_with_noise: np.ndarray, reference_with_noise: np.ndarray, noise: np.ndarray -) -> float: +def contrast_enhancement(signal_with_noise: np.ndarray, reference_with_noise: np.ndarray, noise: np.ndarray) -> float: """ Compute noise corrected contrast enhancement. The noise is assumed to be uncorrelated with the signal, such that var(measured) = var(signal) + var(noise). @@ -125,9 +123,7 @@ def frame_correlation(a: np.ndarray, b: np.ndarray) -> float: return np.mean(a * b) / (np.mean(a) * np.mean(b)) - 1 -def pearson_correlation( - a: np.ndarray, b: np.ndarray, noise_var: np.ndarray = 0.0 -) -> float: +def pearson_correlation(a: np.ndarray, b: np.ndarray, noise_var: np.ndarray = 0.0) -> float: """ Compute Pearson correlation. @@ -199,9 +195,7 @@ def plot(self): """ # Comparisons with first frame plt.figure() - plt.plot( - self.timestamps, self.pixel_shifts_first, ".-", label="image-shift (pix)" - ) + plt.plot(self.timestamps, self.pixel_shifts_first, ".-", label="image-shift (pix)") plt.title("Stability - Image shift w.r.t. first frame") plt.ylabel("Image shift (pix)") plt.xlabel("time (s)") @@ -219,17 +213,13 @@ def plot(self): plt.legend() plt.figure() - plt.plot( - self.timestamps, self.contrast_ratios_first, ".-", label="contrast ratio" - ) + plt.plot(self.timestamps, self.contrast_ratios_first, ".-", label="contrast ratio") plt.title("Stability - Contrast ratio with first frame") plt.xlabel("time (s)") # Comparisons with previous frame plt.figure() - plt.plot( - self.timestamps, self.pixel_shifts_prev, ".-", label="image-shift (pix)" - ) + plt.plot(self.timestamps, self.pixel_shifts_prev, ".-", label="image-shift (pix)") plt.title("Stability - Image shift w.r.t. previous frame") plt.ylabel("Image shift (pix)") plt.xlabel("time (s)") @@ -247,9 +237,7 @@ def plot(self): plt.legend() plt.figure() - plt.plot( - self.timestamps, self.contrast_ratios_prev, ".-", label="contrast ratio" - ) + plt.plot(self.timestamps, self.contrast_ratios_prev, ".-", label="contrast ratio") plt.title("Stability - Contrast ratio with previous frame") plt.xlabel("time (s)") @@ -292,22 +280,14 @@ def measure_setup_stability( # Compare with first frame pixel_shifts_first[n, :] = find_pixel_shift(first_frame, new_frame) correlations_first[n] = pearson_correlation(first_frame, new_frame) - correlations_disattenuated_first[n] = pearson_correlation( - first_frame, new_frame, noise_var=dark_var - ) - contrast_ratios_first[n] = contrast_enhancement( - new_frame, first_frame, dark_frame - ) + correlations_disattenuated_first[n] = pearson_correlation(first_frame, new_frame, noise_var=dark_var) + contrast_ratios_first[n] = contrast_enhancement(new_frame, first_frame, dark_frame) # Compare with previous frame pixel_shifts_prev[n, :] = find_pixel_shift(prev_frame, new_frame) correlations_prev[n] = pearson_correlation(prev_frame, new_frame) - correlations_disattenuated_prev[n] = pearson_correlation( - prev_frame, new_frame, noise_var=dark_var - ) - contrast_ratios_prev[n] = contrast_enhancement( - new_frame, prev_frame, dark_frame - ) + correlations_disattenuated_prev[n] = pearson_correlation(prev_frame, new_frame, noise_var=dark_var) + contrast_ratios_prev[n] = contrast_enhancement(new_frame, prev_frame, dark_frame) abs_timestamps[n] = time.perf_counter() # Save frame if requested @@ -330,9 +310,7 @@ def measure_setup_stability( ) -def measure_modulated_light_dual_phase_stepping( - slm: PhaseSLM, feedback: Detector, phase_steps: int, num_blocks: int -): +def measure_modulated_light_dual_phase_stepping(slm: PhaseSLM, feedback: Detector, phase_steps: int, num_blocks: int): """ Measure the ratio of modulated light with the dual phase stepping method. @@ -371,12 +349,8 @@ def measure_modulated_light_dual_phase_stepping( # Compute fidelity factor due to modulated light eps = 1e-6 # Epsilon term to prevent division by zero - m1_m2_ratio = (np.abs(f[0, 1]) ** 2 + eps) / ( - np.abs(f[1, 0]) ** 2 + eps - ) # Ratio of modulated intensities - fidelity_modulated = (1 + m1_m2_ratio) / ( - 1 + m1_m2_ratio + np.abs(f[0, 1]) ** 2 / np.abs(f[1, -1]) ** 2 - ) + m1_m2_ratio = (np.abs(f[0, 1]) ** 2 + eps) / (np.abs(f[1, 0]) ** 2 + eps) # Ratio of modulated intensities + fidelity_modulated = (1 + m1_m2_ratio) / (1 + m1_m2_ratio + np.abs(f[0, 1]) ** 2 / np.abs(f[1, -1]) ** 2) return fidelity_modulated @@ -410,9 +384,7 @@ def measure_modulated_light(slm: PhaseSLM, feedback: Detector, phase_steps: int) f = np.fft.fft(measurements) # Compute ratio of modulated light over total - fidelity_modulated = 0.5 * ( - 1.0 + np.sqrt(np.clip(1.0 - 4.0 * np.abs(f[1] / f[0]) ** 2, 0, None)) - ) + fidelity_modulated = 0.5 * (1.0 + np.sqrt(np.clip(1.0 - 4.0 * np.abs(f[1] / f[0]) ** 2, 0, None))) return fidelity_modulated @@ -495,9 +467,7 @@ def report(self, do_plots=True): print(f"fidelity_amplitude: {self.wfs_result.fidelity_amplitude.squeeze():.3f}") print(f"fidelity_noise: {self.wfs_result.fidelity_noise.squeeze():.3f}") print(f"fidelity_non_modulated: {self.fidelity_non_modulated:.3f}") - print( - f"fidelity_phase_calibration: {self.wfs_result.fidelity_calibration.squeeze():.3f}" - ) + print(f"fidelity_phase_calibration: {self.wfs_result.fidelity_calibration.squeeze():.3f}") print(f"fidelity_decorrelation: {self.fidelity_decorrelation:.3f}") print(f"expected enhancement: {self.expected_enhancement:.3f}") print(f"measured enhancement: {self.measured_enhancement:.3f}") @@ -505,9 +475,7 @@ def report(self, do_plots=True): print(f"=== Frame metrics ===") print(f"signal std, before: {self.frame_signal_std_before:.2f}") print(f"signal std, after: {self.frame_signal_std_after:.2f}") - print( - f"signal std, with shaped wavefront: {self.frame_signal_std_shaped_wf:.2f}" - ) + print(f"signal std, with shaped wavefront: {self.frame_signal_std_shaped_wf:.2f}") if self.dark_frame is not None: print(f"average offset (dark frame): {self.dark_frame.mean():.2f}") print(f"median offset (dark frame): {np.median(self.dark_frame):.2f}") @@ -515,9 +483,7 @@ def report(self, do_plots=True): print(f"frame repeatability: {self.frame_repeatability:.3f}") print(f"contrast to noise ratio before: {self.frame_cnr_before:.3f}") print(f"contrast to noise ratio after: {self.frame_cnr_after:.3f}") - print( - f"contrast to noise ratio with shaped wavefront: {self.frame_cnr_shaped_wf:.3f}" - ) + print(f"contrast to noise ratio with shaped wavefront: {self.frame_cnr_shaped_wf:.3f}") print(f"contrast enhancement: {self.frame_contrast_enhancement:.3f}") print(f"photobleaching ratio: {self.frame_photobleaching_ratio:.3f}") @@ -609,13 +575,9 @@ def troubleshoot( before_frame_2 = frame_source.read() # Frame metrics - trouble.frame_signal_std_before = signal_std( - trouble.before_frame, trouble.dark_frame - ) + trouble.frame_signal_std_before = signal_std(trouble.before_frame, trouble.dark_frame) trouble.frame_cnr_before = cnr(trouble.before_frame, trouble.dark_frame) - trouble.frame_repeatability = pearson_correlation( - trouble.before_frame, before_frame_2 - ) + trouble.frame_repeatability = pearson_correlation(trouble.before_frame, before_frame_2) if do_long_stability_test and do_frame_capture: logging.info("Run long stability test...") @@ -648,27 +610,17 @@ def troubleshoot( algorithm.slm.set_phases(-np.angle(trouble.wfs_result.t)) trouble.feedback_shaped_wf = algorithm.feedback.read() - trouble.measured_enhancement = ( - trouble.feedback_shaped_wf / trouble.average_background - ) + trouble.measured_enhancement = trouble.feedback_shaped_wf / trouble.average_background if do_frame_capture: trouble.shaped_wf_frame = frame_source.read() # Shaped wavefront frame # Frame metrics logging.info("Compute frame metrics...") - trouble.frame_signal_std_after = signal_std( - trouble.after_frame, trouble.dark_frame - ) - trouble.frame_signal_std_shaped_wf = signal_std( - trouble.shaped_wf_frame, trouble.dark_frame - ) - trouble.frame_cnr_after = cnr( - trouble.after_frame, trouble.dark_frame - ) # Frame CNR after - trouble.frame_cnr_shaped_wf = cnr( - trouble.shaped_wf_frame, trouble.dark_frame - ) # Frame CNR shaped wf + trouble.frame_signal_std_after = signal_std(trouble.after_frame, trouble.dark_frame) + trouble.frame_signal_std_shaped_wf = signal_std(trouble.shaped_wf_frame, trouble.dark_frame) + trouble.frame_cnr_after = cnr(trouble.after_frame, trouble.dark_frame) # Frame CNR after + trouble.frame_cnr_shaped_wf = cnr(trouble.shaped_wf_frame, trouble.dark_frame) # Frame CNR shaped wf trouble.frame_contrast_enhancement = contrast_enhancement( trouble.shaped_wf_frame, trouble.after_frame, trouble.dark_frame ) diff --git a/openwfs/algorithms/utilities.py b/openwfs/algorithms/utilities.py index 545cd1d..963c119 100644 --- a/openwfs/algorithms/utilities.py +++ b/openwfs/algorithms/utilities.py @@ -67,11 +67,7 @@ def __init__( self.fidelity_amplitude = np.atleast_1d(fidelity_amplitude) self.fidelity_calibration = np.atleast_1d(fidelity_calibration) self.estimated_enhancement = np.atleast_1d( - 1.0 - + (self.n - 1) - * self.fidelity_amplitude - * self.fidelity_noise - * self.fidelity_calibration + 1.0 + (self.n - 1) * self.fidelity_amplitude * self.fidelity_noise * self.fidelity_calibration ) self.intensity_offset = ( intensity_offset * np.ones(self.fidelity_calibration.shape) @@ -79,24 +75,16 @@ def __init__( else intensity_offset ) after = ( - np.sum(np.abs(t), tuple(range(self.axis))) ** 2 - * self.fidelity_noise - * self.fidelity_calibration + np.sum(np.abs(t), tuple(range(self.axis))) ** 2 * self.fidelity_noise * self.fidelity_calibration + intensity_offset ) self.estimated_optimized_intensity = np.atleast_1d(after) def __str__(self) -> str: noise_warning = "OK" if self.fidelity_noise > 0.5 else "WARNING low signal quality." - amplitude_warning = ( - "OK" - if self.fidelity_amplitude > 0.5 - else "WARNING uneven contribution of optical modes." - ) + amplitude_warning = "OK" if self.fidelity_amplitude > 0.5 else "WARNING uneven contribution of optical modes." calibration_fidelity_warning = ( - "OK" - if self.fidelity_calibration > 0.5 - else ("WARNING non-linear phase response, check " "lookup table.") + "OK" if self.fidelity_calibration > 0.5 else ("WARNING non-linear phase response, check " "lookup table.") ) return f""" Wavefront shaping results: diff --git a/openwfs/core.py b/openwfs/core.py index 3151216..570622b 100644 --- a/openwfs/core.py +++ b/openwfs/core.py @@ -71,25 +71,15 @@ def _start(self): else: logging.debug("switch to MOVING requested by %s.", self) - same_type = [ - device - for device in Device._devices - if device._is_actuator == self._is_actuator - ] - other_type = [ - device - for device in Device._devices - if device._is_actuator != self._is_actuator - ] + same_type = [device for device in Device._devices if device._is_actuator == self._is_actuator] + other_type = [device for device in Device._devices if device._is_actuator != self._is_actuator] # compute the minimum latency of same_type # for instance, when switching to 'measuring', this number tells us how long it takes before any of the # detectors actually starts a measurement. # If this is a positive number, we can make the switch to 'measuring' slightly _before_ # all actuators have stabilized. - latency = min( - [device.latency for device in same_type], default=0.0 * u.ns - ) # noqa - incorrect warning + latency = min([device.latency for device in same_type], default=0.0 * u.ns) # noqa - incorrect warning # wait until all devices of the other type have (almost) finished for device in other_type: @@ -104,11 +94,7 @@ def _start(self): # also store the time we expect the operation to finish # note: it may finish slightly earlier since (latency + duration) is a maximum value - self._end_time_ns = ( - time.time_ns() - + self.latency.to_value(u.ns) - + self.duration.to_value(u.ns) - ) + self._end_time_ns = time.time_ns() + self.latency.to_value(u.ns) + self.duration.to_value(u.ns) @property def latency(self) -> Quantity[u.ms]: @@ -196,9 +182,7 @@ def wait(self, up_to: Optional[Quantity[u.ms]] = None) -> None: while self.busy(): time.sleep(0.01) if time.time_ns() - start > timeout: - raise TimeoutError( - "Timeout in %s (tid %i)", self, threading.get_ident() - ) + raise TimeoutError("Timeout in %s (tid %i)", self, threading.get_ident()) else: time_to_wait = self._end_time_ns - time.time_ns() if up_to is not None: @@ -400,19 +384,10 @@ def __do_fetch(self, out_, *args_, **kwargs_): """Helper function that awaits all futures in the keyword argument list, and then calls _fetch""" try: if len(args_) > 0 or len(kwargs_) > 0: - logging.debug( - "awaiting inputs for %s (tid: %i).", self, threading.get_ident() - ) - awaited_args = [ - (arg.result() if isinstance(arg, Future) else arg) for arg in args_ - ] - awaited_kwargs = { - key: (arg.result() if isinstance(arg, Future) else arg) - for (key, arg) in kwargs_.items() - } - logging.debug( - "fetching data of %s ((tid: %i)).", self, threading.get_ident() - ) + logging.debug("awaiting inputs for %s (tid: %i).", self, threading.get_ident()) + awaited_args = [(arg.result() if isinstance(arg, Future) else arg) for arg in args_] + awaited_kwargs = {key: (arg.result() if isinstance(arg, Future) else arg) for (key, arg) in kwargs_.items()} + logging.debug("fetching data of %s ((tid: %i)).", self, threading.get_ident()) data = self._fetch(*awaited_args, **awaited_kwargs) data = set_pixel_size(data, self.pixel_size) assert data.shape == self.data_shape @@ -523,16 +498,10 @@ def coordinates(self, dimension: int) -> Quantity: Args: dimension: Dimension for which to return the coordinates. """ - unit = ( - u.dimensionless_unscaled - if self.pixel_size is None - else self.pixel_size[dimension] - ) + unit = u.dimensionless_unscaled if self.pixel_size is None else self.pixel_size[dimension] shape = np.ones_like(self.data_shape) shape[dimension] = self.data_shape[dimension] - return ( - np.arange(0.5, 0.5 + self.data_shape[dimension], 1.0).reshape(shape) * unit - ) + return np.arange(0.5, 0.5 + self.data_shape[dimension], 1.0).reshape(shape) * unit @final @property @@ -583,8 +552,7 @@ def __init__(self, *args, multi_threaded: bool): def trigger(self, *args, immediate=False, **kwargs): """Triggers all sources at the same time (regardless of latency), and schedules a call to `_fetch()`""" future_data = [ - (source.trigger(immediate=immediate) if source is not None else None) - for source in self._sources + (source.trigger(immediate=immediate) if source is not None else None) for source in self._sources ] return super().trigger(*future_data, *args, **kwargs) @@ -606,11 +574,7 @@ def duration(self) -> Quantity[u.ms]: Note that `latency` is allowed to vary over time for devices that can only be triggered periodically, so this `duration` may also vary over time. """ - times = [ - (source.duration, source.latency) - for source in self._sources - if source is not None - ] + times = [(source.duration, source.latency) for source in self._sources if source is not None] if len(times) == 0: return 0.0 * u.ms return max([duration + latency for (duration, latency) in times]) - min( diff --git a/openwfs/devices/camera.py b/openwfs/devices/camera.py index 4a23bb5..0030e90 100644 --- a/openwfs/devices/camera.py +++ b/openwfs/devices/camera.py @@ -71,9 +71,7 @@ def __init__( self._harvester.update() # open the camera, use the serial_number to select the camera if it is specified. - search_key = ( - {"serial_number": serial_number} if serial_number is not None else None - ) + search_key = {"serial_number": serial_number} if serial_number is not None else None self._camera = self._harvester.create(search_key=search_key) nodes = self._camera.remote_device.node_map diff --git a/openwfs/devices/galvo_scanner.py b/openwfs/devices/galvo_scanner.py index 667d681..2c1f84a 100644 --- a/openwfs/devices/galvo_scanner.py +++ b/openwfs/devices/galvo_scanner.py @@ -106,22 +106,16 @@ def maximum_scan_speed(self, linear_range: float): Quantity[u.V / u.s]: maximum scan speed """ # x = 0.5 · a · t² = 0.5 (v_max - v_min) · (1 - linear_range) - t_accel = np.sqrt( - (self.v_max - self.v_min) * (1 - linear_range) / self.maximum_acceleration - ) + t_accel = np.sqrt((self.v_max - self.v_min) * (1 - linear_range) / self.maximum_acceleration) hardware_limit = t_accel * self.maximum_acceleration # t_linear = linear_range · (v_max - v_min) / maximum_speed # t_accel = maximum_speed / maximum_acceleration # 0.5·t_linear == t_accel => 0.5·linear_range · (v_max-v_min) · maximum_acceleration = maximum_speed² - practical_limit = np.sqrt( - 0.5 * linear_range * (self.v_max - self.v_min) * self.maximum_acceleration - ) + practical_limit = np.sqrt(0.5 * linear_range * (self.v_max - self.v_min) * self.maximum_acceleration) return np.minimum(hardware_limit, practical_limit) - def step( - self, start: float, stop: float, sample_rate: Quantity[u.Hz] - ) -> Quantity[u.V]: + def step(self, start: float, stop: float, sample_rate: Quantity[u.Hz]) -> Quantity[u.V]: """ Generate a voltage sequence to move from `start` to `stop` in the fastest way possible. @@ -149,21 +143,13 @@ def step( # `a` is measured in volt/sample² a = self.maximum_acceleration / sample_rate**2 * np.sign(v_end - v_start) t_total = unitless(2.0 * np.sqrt((v_end - v_start) / a)) - t = np.arange( - np.ceil(t_total + 1e-6) - ) # add a small number to deal with case t=0 (start=end) + t = np.arange(np.ceil(t_total + 1e-6)) # add a small number to deal with case t=0 (start=end) v_accel = v_start + 0.5 * a * t[: len(t) // 2] ** 2 # acceleration part - v_decel = ( - v_end - 0.5 * a * (t_total - t[len(t) // 2 :]) ** 2 - ) # deceleration part + v_decel = v_end - 0.5 * a * (t_total - t[len(t) // 2 :]) ** 2 # deceleration part v_decel[-1] = v_end # fix last point because t may be > t_total due to rounding - return np.clip( - np.concatenate((v_accel, v_decel)), self.v_min, self.v_max - ) # noqa ignore incorrect type warning + return np.clip(np.concatenate((v_accel, v_decel)), self.v_min, self.v_max) # noqa ignore incorrect type warning - def scan( - self, start: float, stop: float, sample_count: int, sample_rate: Quantity[u.Hz] - ): + def scan(self, start: float, stop: float, sample_count: int, sample_rate: Quantity[u.Hz]): """ Generate a voltage sequence to scan with a constant velocity from start to stop, including acceleration and deceleration. @@ -203,13 +189,9 @@ def scan( # we start by constructing a sequence with a maximum acceleration. # This sequence may be up to 1 sample longer than needed to reach the scan speed. # This last sample is replaced by movement at a linear scan speed - a = ( - self.maximum_acceleration / sample_rate**2 * np.sign(scan_speed) - ) # V per sample² + a = self.maximum_acceleration / sample_rate**2 * np.sign(scan_speed) # V per sample² t_launch = np.arange(np.ceil(unitless(scan_speed / a))) # in samples - v_accel = ( - 0.5 * a * t_launch**2 - ) # last sample may have faster scan speed than needed + v_accel = 0.5 * a * t_launch**2 # last sample may have faster scan speed than needed if len(v_accel) > 1 and np.abs(v_accel[-1] - v_accel[-2]) > np.abs(scan_speed): v_accel[-1] = v_accel[-2] + scan_speed v_launch = v_start - v_accel[-1] - 0.5 * scan_speed # launch point @@ -255,9 +237,7 @@ def compute_scale( """ f_objective = reference_tube_lens / objective_magnification angle_to_displacement = f_objective / u.rad - return ( - (optical_deflection / galvo_to_pupil_magnification) * angle_to_displacement - ).to(u.um / u.V) + return ((optical_deflection / galvo_to_pupil_magnification) * angle_to_displacement).to(u.um / u.V) @staticmethod def compute_acceleration( @@ -284,9 +264,7 @@ def compute_acceleration( maximum_current (Quantity[u.A]): The maximum current that can be applied to the galvo mirror. """ - angular_acceleration = (torque_constant * maximum_current / rotor_inertia).to( - u.s**-2 - ) * u.rad + angular_acceleration = (torque_constant * maximum_current / rotor_inertia).to(u.s**-2) * u.rad return (angular_acceleration / optical_deflection).to(u.V / u.s**2) @@ -396,12 +374,8 @@ def __init__( self._resolution = int(resolution) self._roi_top = 0 # in pixels self._roi_left = 0 # in pixels - self._center_x = ( - 0.5 # in relative coordinates (relative to the full field of view) - ) - self._center_y = ( - 0.5 # in relative coordinates (relative to the full field of view) - ) + self._center_x = 0.5 # in relative coordinates (relative to the full field of view) + self._center_y = 0.5 # in relative coordinates (relative to the full field of view) self._delay = delay.to(u.us) self._reference_zoom = float(reference_zoom) self._zoom = 1.0 @@ -462,9 +436,7 @@ def _update(self): # Compute the retrace pattern for the slow axis # The scan starts at half a pixel after roi_bottom and ends half a pixel before roi_top - v_yr = self._y_axis.step( - roi_bottom - 0.5 * roi_scale, roi_top + 0.5 * roi_scale, self._sample_rate - ) + v_yr = self._y_axis.step(roi_bottom - 0.5 * roi_scale, roi_top + 0.5 * roi_scale, self._sample_rate) # Compute the scan pattern for the fast axis # The naive speed is the scan speed assuming one pixel per sample @@ -472,12 +444,8 @@ def _update(self): # (at least, without spending more time on accelerating and decelerating than the scan itself) # The user can set the scan speed relative to the maximum speed. # If this set speed is lower than naive scan speed, multiple samples are taken per pixel. - naive_speed = ( - (self._x_axis.v_max - self._x_axis.v_min) * roi_scale * self._sample_rate - ) - max_speed = ( - self._x_axis.maximum_scan_speed(1.0 / actual_zoom) * self._scan_speed_factor - ) + naive_speed = (self._x_axis.v_max - self._x_axis.v_min) * roi_scale * self._sample_rate + max_speed = self._x_axis.maximum_scan_speed(1.0 / actual_zoom) * self._scan_speed_factor if max_speed == 0.0: # this may happen if the ROI reaches to or beyond [0,1]. In this case, the mirror has no time to accelerate # TODO: implement an auto-adjust option instead of raising an error @@ -492,13 +460,9 @@ def _update(self): roi_left, roi_right, oversampled_width, self._sample_rate ) if self._bidirectional: - v_x_odd, _, _, _ = self._x_axis.scan( - roi_right, roi_left, oversampled_width, self._sample_rate - ) + v_x_odd, _, _, _ = self._x_axis.scan(roi_right, roi_left, oversampled_width, self._sample_rate) else: - v_xr = self._x_axis.step( - x_land, x_launch, self._sample_rate - ) # horizontal retrace + v_xr = self._x_axis.step(x_land, x_launch, self._sample_rate) # horizontal retrace v_x_even = np.concatenate((v_x_even, v_xr)) v_x_odd = v_x_even @@ -563,9 +527,7 @@ def _update(self): min_val=self._y_axis.v_min.to_value(u.V), max_val=self._y_axis.v_max.to_value(u.V), ) - self._write_task.timing.cfg_samp_clk_timing( - sample_rate, samps_per_chan=sample_count - ) + self._write_task.timing.cfg_samp_clk_timing(sample_rate, samps_per_chan=sample_count) # Configure the analog input task (one channel) self._read_task.ai_channels.add_ai_voltage_chan( @@ -574,18 +536,12 @@ def _update(self): max_val=self._input_channel.v_max.to_value(u.V), terminal_config=self._input_channel.terminal_configuration, ) - self._read_task.timing.cfg_samp_clk_timing( - sample_rate, samps_per_chan=sample_count - ) - self._read_task.triggers.start_trigger.cfg_dig_edge_start_trig( - self._write_task.triggers.start_trigger.term - ) + self._read_task.timing.cfg_samp_clk_timing(sample_rate, samps_per_chan=sample_count) + self._read_task.triggers.start_trigger.cfg_dig_edge_start_trig(self._write_task.triggers.start_trigger.term) delay = self._delay.to_value(u.s) if delay > 0.0: self._read_task.triggers.start_trigger.delay = delay - self._read_task.triggers.start_trigger.delay_units = ( - DigitalWidthUnits.SECONDS - ) + self._read_task.triggers.start_trigger.delay_units = DigitalWidthUnits.SECONDS self._writer = AnalogMultiChannelWriter(self._write_task.out_stream) self._valid = True @@ -624,22 +580,16 @@ def _raw_to_cropped(self, raw: np.ndarray) -> np.ndarray: # down sample along fast axis if needed if self._oversampling > 1: # remove samples if not divisible by oversampling factor - cropped = cropped[ - :, : (cropped.shape[1] // self._oversampling) * self._oversampling - ] + cropped = cropped[:, : (cropped.shape[1] // self._oversampling) * self._oversampling] cropped = cropped.reshape(cropped.shape[0], -1, self._oversampling) - cropped = np.round(np.mean(cropped, 2)).astype( - cropped.dtype - ) # todo: faster alternative? + cropped = np.round(np.mean(cropped, 2)).astype(cropped.dtype) # todo: faster alternative? # Change the data type into uint16 if necessary if cropped.dtype == np.int16: # add 32768 to go from -32768-32767 to 0-65535 cropped = cropped.view("uint16") + 0x8000 elif cropped.dtype != np.uint16: - raise ValueError( - f"Only int16 and uint16 data types are supported at the moment, got type {cropped.dtype}." - ) + raise ValueError(f"Only int16 and uint16 data types are supported at the moment, got type {cropped.dtype}.") if self._bidirectional: # note: requires the mask to be symmetrical cropped[1::2, :] = cropped[1::2, ::-1] @@ -653,24 +603,18 @@ def _fetch(self) -> np.ndarray: # noqa self._read_task.stop() self._write_task.stop() elif self._test_pattern == TestPatternType.HORIZONTAL: - raw = np.round( - self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * 10000 - ).astype("int16") + raw = np.round(self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * 10000).astype("int16") elif self._test_pattern == TestPatternType.VERTICAL: - raw = np.round( - self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * 10000 - ).astype("int16") + raw = np.round(self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * 10000).astype("int16") elif self._test_pattern == TestPatternType.IMAGE: if self._test_image is None: raise ValueError("No test image was provided for the image simulation.") # todo: cache the test image row = np.floor( - self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) - * (self._test_image.shape[0] - 1) + self._y_axis.to_pos(self._scan_pattern[0, :] * u.V) * (self._test_image.shape[0] - 1) ).astype("int32") column = np.floor( - self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) - * (self._test_image.shape[1] - 1) + self._x_axis.to_pos(self._scan_pattern[1, :] * u.V) * (self._test_image.shape[1] - 1) ).astype("int32") raw = self._test_image[row, column] else: @@ -683,13 +627,9 @@ def _fetch(self) -> np.ndarray: # noqa if self._preprocessor is None: preprocessed_raw = raw elif callable(self._preprocessor): - preprocessed_raw = self._preprocessor( - data=raw, sample_rate=self._sample_rate - ) + preprocessed_raw = self._preprocessor(data=raw, sample_rate=self._sample_rate) else: - raise TypeError( - f"Invalid type for {self._preprocessor}. Should be callable or None." - ) + raise TypeError(f"Invalid type for {self._preprocessor}. Should be callable or None.") return self._raw_to_cropped(preprocessed_raw) def close(self): @@ -718,9 +658,7 @@ def preprocessor(self): @preprocessor.setter def preprocessor(self, value: Optional[callable]): if not callable(value) and value is not None: - raise TypeError( - f"Invalid type for {self._preprocessor}. Should be callable or None." - ) + raise TypeError(f"Invalid type for {self._preprocessor}. Should be callable or None.") self._preprocessor = value @property @@ -729,10 +667,7 @@ def pixel_size(self) -> Quantity: # TODO: make extent a read-only attribute of Axis extent_y = (self._y_axis.v_max - self._y_axis.v_min) * self._y_axis.scale extent_x = (self._x_axis.v_max - self._x_axis.v_min) * self._x_axis.scale - return ( - Quantity(extent_y, extent_x) - / (self._reference_zoom * self._zoom * self._resolution) - ).to(u.um) + return (Quantity(extent_y, extent_x) / (self._reference_zoom * self._zoom * self._resolution)).to(u.um) @property def duration(self) -> Quantity[u.ms]: diff --git a/openwfs/devices/nidaq_gain.py b/openwfs/devices/nidaq_gain.py index b0cb23c..64c7c51 100644 --- a/openwfs/devices/nidaq_gain.py +++ b/openwfs/devices/nidaq_gain.py @@ -55,9 +55,7 @@ def check_overload(self): def on_reset(self, value): if value: with ni.Task() as task: - task.do_channels.add_do_chan( - self.port_do, line_grouping=LineGrouping.CHAN_FOR_ALL_LINES - ) + task.do_channels.add_do_chan(self.port_do, line_grouping=LineGrouping.CHAN_FOR_ALL_LINES) task.write([True]) time.sleep(1) task.write([False]) diff --git a/openwfs/devices/slm/geometry.py b/openwfs/devices/slm/geometry.py index b6f73a1..560257d 100644 --- a/openwfs/devices/slm/geometry.py +++ b/openwfs/devices/slm/geometry.py @@ -184,9 +184,7 @@ def circular( segments_inside = 0 total_segments = np.sum(segments_per_ring) for r in range(ring_count): - x_outside = ( - x * radii[r + 1] - ) # coordinates of the vertices at the outside of the ring + x_outside = x * radii[r + 1] # coordinates of the vertices at the outside of the ring y_outside = y * radii[r + 1] segments = segments_inside + segments_per_ring[r] vertices[r, 0, :, 0] = x_inside + center[1] @@ -194,8 +192,7 @@ def circular( vertices[r, 1, :, 0] = x_outside + center[1] vertices[r, 1, :, 1] = y_outside + center[0] vertices[r, :, :, 2] = ( - np.linspace(segments_inside, segments, edge_count + 1).reshape((1, -1)) - / total_segments + np.linspace(segments_inside, segments, edge_count + 1).reshape((1, -1)) / total_segments ) # tx x_inside = x_outside y_inside = y_outside @@ -206,9 +203,6 @@ def circular( # construct indices for a single ring, and repeat for all rings with the appropriate offset indices = Geometry.compute_indices_for_grid((1, edge_count)).reshape((1, -1)) - indices = ( - indices - + np.arange(ring_count).reshape((-1, 1)) * vertices.shape[1] * vertices.shape[2] - ) + indices = indices + np.arange(ring_count).reshape((-1, 1)) * vertices.shape[1] * vertices.shape[2] indices[:, -1] = 0xFFFF return Geometry(vertices.reshape((-1, 4)), indices.reshape(-1)) diff --git a/openwfs/devices/slm/patch.py b/openwfs/devices/slm/patch.py index 7c55840..5039ef6 100644 --- a/openwfs/devices/slm/patch.py +++ b/openwfs/devices/slm/patch.py @@ -112,9 +112,7 @@ def _draw(self): # perform the actual drawing glBindBuffer(GL.GL_ELEMENT_ARRAY_BUFFER, self._indices) glBindVertexBuffer(0, self._vertices, 0, 16) - glDrawElements( - GL.GL_TRIANGLE_STRIP, self._index_count, GL.GL_UNSIGNED_SHORT, None - ) + glDrawElements(GL.GL_TRIANGLE_STRIP, self._index_count, GL.GL_UNSIGNED_SHORT, None) def set_phases(self, values: ArrayLike, update=True): """ @@ -190,9 +188,7 @@ def __init__(self, slm, lookup_table: Sequence[int]): # is then processed as a whole (applying the software lookup table) and displayed on the screen. self._frame_buffer = glGenFramebuffers(1) - self.set_phases( - np.zeros(self.context.slm.shape, dtype=np.float32), update=False - ) + self.set_phases(np.zeros(self.context.slm.shape, dtype=np.float32), update=False) glBindFramebuffer(GL.GL_FRAMEBUFFER, self._frame_buffer) glFramebufferTexture2D( GL.GL_FRAMEBUFFER, @@ -205,9 +201,7 @@ def __init__(self, slm, lookup_table: Sequence[int]): raise Exception("Could not construct frame buffer") glBindFramebuffer(GL.GL_FRAMEBUFFER, 0) - self._textures.append( - Texture(self.context, GL.GL_TEXTURE_1D) - ) # create texture for lookup table + self._textures.append(Texture(self.context, GL.GL_TEXTURE_1D)) # create texture for lookup table self._lookup_table = None self.lookup_table = lookup_table self.additive_blend = False @@ -250,25 +244,15 @@ class VertexArray: # Since we have a fixed vertex format, we only need to bind the VertexArray once, and not bother with # updating, binding, or even deleting it def __init__(self): - self._vertex_array = glGenVertexArrays( - 1 - ) # no need to destroy explicitly, destroyed when window is destroyed + self._vertex_array = glGenVertexArrays(1) # no need to destroy explicitly, destroyed when window is destroyed glBindVertexArray(self._vertex_array) glEnableVertexAttribArray(0) glEnableVertexAttribArray(1) - glVertexAttribFormat( - 0, 2, GL.GL_FLOAT, GL.GL_FALSE, 0 - ) # first two float32 are screen coordinates - glVertexAttribFormat( - 1, 2, GL.GL_FLOAT, GL.GL_FALSE, 8 - ) # second two are texture coordinates + glVertexAttribFormat(0, 2, GL.GL_FLOAT, GL.GL_FALSE, 0) # first two float32 are screen coordinates + glVertexAttribFormat(1, 2, GL.GL_FLOAT, GL.GL_FALSE, 8) # second two are texture coordinates glVertexAttribBinding(0, 0) # use binding index 0 for both attributes - glVertexAttribBinding( - 1, 0 - ) # the attribute format can now be used with glBindVertexBuffer + glVertexAttribBinding(1, 0) # the attribute format can now be used with glBindVertexBuffer # enable primitive restart, so that we can draw multiple triangle strips with a single draw call glEnable(GL.GL_PRIMITIVE_RESTART) - glPrimitiveRestartIndex( - 0xFFFF - ) # this is the index we use to separate individual triangle strips + glPrimitiveRestartIndex(0xFFFF) # this is the index we use to separate individual triangle strips diff --git a/openwfs/devices/slm/slm.py b/openwfs/devices/slm/slm.py index d1a4750..ebdf900 100644 --- a/openwfs/devices/slm/slm.py +++ b/openwfs/devices/slm/slm.py @@ -122,9 +122,7 @@ def __init__( self._position = pos (default_shape, default_rate, _) = SLM._current_mode(monitor_id) self._shape = default_shape if shape is None else shape - self._refresh_rate = ( - default_rate if refresh_rate is None else refresh_rate.to_value(u.Hz) - ) + self._refresh_rate = default_rate if refresh_rate is None else refresh_rate.to_value(u.Hz) self._frame_buffer = None self._window = None self._globals = -1 @@ -159,9 +157,7 @@ def _assert_window_available(self, monitor_id) -> None: Exception: If a full screen SLM is already present on the target monitor. """ if monitor_id == SLM.WINDOWED: - if any( - [slm.monitor_id == 1 for slm in SLM._active_slms if slm is not self] - ): + if any([slm.monitor_id == 1 for slm in SLM._active_slms if slm is not self]): raise RuntimeError( f"Cannot create an SLM window because a full-screen SLM is already active on monitor 1" ) @@ -170,8 +166,7 @@ def _assert_window_available(self, monitor_id) -> None: # a full screen window on monitor 1 if there are already windowed SLMs. if any( [ - slm.monitor_id == monitor_id - or (monitor_id == 1 and slm.monitor_id == SLM.WINDOWED) + slm.monitor_id == monitor_id or (monitor_id == 1 and slm.monitor_id == SLM.WINDOWED) for slm in SLM._active_slms if slm is not self ] @@ -182,8 +177,7 @@ def _assert_window_available(self, monitor_id) -> None: ) if monitor_id > len(glfw.get_monitors()): raise IndexError( - f"Monitor {monitor_id} not found, only {len(glfw.get_monitors())} monitor(s) " - f"are connected." + f"Monitor {monitor_id} not found, only {len(glfw.get_monitors())} monitor(s) " f"are connected." ) @staticmethod @@ -223,11 +217,7 @@ def _on_resize(self): """ # create a new frame buffer, re-use the old one if one was present, otherwise use a default of range(256) # re-use the lookup table if possible, otherwise create a default one ranging from 0 to 255. - old_lut = ( - self._frame_buffer.lookup_table - if self._frame_buffer is not None - else range(256) - ) + old_lut = self._frame_buffer.lookup_table if self._frame_buffer is not None else range(256) self._frame_buffer = FrameBufferPatch(self, old_lut) glViewport(0, 0, self._shape[1], self._shape[0]) # tell openGL to wait for the vertical retrace when swapping buffers (it appears need to do this @@ -238,28 +228,22 @@ def _on_resize(self): (fb_width, fb_height) = glfw.get_framebuffer_size(self._window) fb_shape = (fb_height, fb_width) if self._shape != fb_shape: - warnings.warn( - f"Actual resolution {fb_shape} does not match requested resolution {self._shape}." - ) + warnings.warn(f"Actual resolution {fb_shape} does not match requested resolution {self._shape}.") self._shape = fb_shape - (current_size, current_rate, current_bit_depth) = SLM._current_mode( - self._monitor_id - ) + (current_size, current_rate, current_bit_depth) = SLM._current_mode(self._monitor_id) # verify that the bit depth is at least 8 bit if current_bit_depth < 8: warnings.warn( - f"Bit depth is less than 8 bits " - f"You may not be able to use the full phase resolution of your SLM." + f"Bit depth is less than 8 bits " f"You may not be able to use the full phase resolution of your SLM." ) # verify the refresh rate is correct # Then update the refresh rate to the actual value if int(self._refresh_rate) != current_rate: warnings.warn( - f"Actual refresh rate of {current_rate} Hz does not match set rate " - f"of {self._refresh_rate} Hz" + f"Actual refresh rate of {current_rate} Hz does not match set rate " f"of {self._refresh_rate} Hz" ) self._refresh_rate = current_rate @@ -273,29 +257,19 @@ def _init_glfw(): trouble if the user of our library also uses glfw for something else. """ glfw.init() - glfw.window_hint( - glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE - ) # Required on Mac. Doesn't hurt on Windows - glfw.window_hint( - glfw.OPENGL_FORWARD_COMPAT, glfw.TRUE - ) # Required on Mac. Useless on Windows + glfw.window_hint(glfw.OPENGL_PROFILE, glfw.OPENGL_CORE_PROFILE) # Required on Mac. Doesn't hurt on Windows + glfw.window_hint(glfw.OPENGL_FORWARD_COMPAT, glfw.TRUE) # Required on Mac. Useless on Windows glfw.window_hint(glfw.CONTEXT_VERSION_MAJOR, 4) # request at least opengl 4.2 glfw.window_hint(glfw.CONTEXT_VERSION_MINOR, 2) glfw.window_hint(glfw.FLOATING, glfw.TRUE) # Keep window on top glfw.window_hint(glfw.DECORATED, glfw.FALSE) # Disable window border - glfw.window_hint( - glfw.AUTO_ICONIFY, glfw.FALSE - ) # Prevent window minimization during task switch + glfw.window_hint(glfw.AUTO_ICONIFY, glfw.FALSE) # Prevent window minimization during task switch glfw.window_hint(glfw.FOCUSED, glfw.FALSE) glfw.window_hint(glfw.DOUBLEBUFFER, glfw.TRUE) - glfw.window_hint( - glfw.RED_BITS, 8 - ) # require at least 8 bits per color channel (256 gray values) + glfw.window_hint(glfw.RED_BITS, 8) # require at least 8 bits per color channel (256 gray values) glfw.window_hint(glfw.GREEN_BITS, 8) glfw.window_hint(glfw.BLUE_BITS, 8) - glfw.window_hint( - glfw.COCOA_RETINA_FRAMEBUFFER, glfw.FALSE - ) # disable retina multisampling on Mac (untested) + glfw.window_hint(glfw.COCOA_RETINA_FRAMEBUFFER, glfw.FALSE) # disable retina multisampling on Mac (untested) glfw.window_hint(glfw.SAMPLES, 0) # disable multisampling def _create_window(self): @@ -307,18 +281,10 @@ def _create_window(self): shared = other._window if other is not None else None # noqa: ok to use _window SLM._active_slms.add(self) - monitor = ( - glfw.get_monitors()[self._monitor_id - 1] - if self._monitor_id != SLM.WINDOWED - else None - ) + monitor = glfw.get_monitors()[self._monitor_id - 1] if self._monitor_id != SLM.WINDOWED else None glfw.window_hint(glfw.REFRESH_RATE, int(self._refresh_rate)) - self._window = glfw.create_window( - self._shape[1], self._shape[0], "OpenWFS SLM", monitor, shared - ) - glfw.set_input_mode( - self._window, glfw.CURSOR, glfw.CURSOR_HIDDEN - ) # disable cursor + self._window = glfw.create_window(self._shape[1], self._shape[0], "OpenWFS SLM", monitor, shared) + glfw.set_input_mode(self._window, glfw.CURSOR, glfw.CURSOR_HIDDEN) # disable cursor if monitor: # full screen mode glfw.set_gamma(monitor, 1.0) else: # windowed mode @@ -449,9 +415,7 @@ def update(self): """ with self._context: # first draw all patches into the frame buffer - glBindFramebuffer( - GL.GL_FRAMEBUFFER, self._frame_buffer._frame_buffer - ) # noqa - ok to access 'friend class' + glBindFramebuffer(GL.GL_FRAMEBUFFER, self._frame_buffer._frame_buffer) # noqa - ok to access 'friend class' glClear(GL.GL_COLOR_BUFFER_BIT) for patch in self.patches: patch._draw() # noqa - ok to access 'friend class' @@ -463,9 +427,7 @@ def update(self): glfw.poll_events() # process window messages if len(self._clones) > 0: - self._context.__exit__( - None, None, None - ) # release context before updating clones + self._context.__exit__(None, None, None) # release context before updating clones for clone in self._clones: with clone.slm._context: # noqa self._frame_buffer._draw() # noqa - ok to access 'friend class' @@ -563,24 +525,16 @@ def transform(self, value: Transform): else: scale_width = (width > height) == (self._coordinate_system == "short") if scale_width: - root_transform = Transform( - np.array(((1.0, 0.0), (0.0, height / width))) - ) + root_transform = Transform(np.array(((1.0, 0.0), (0.0, height / width)))) else: - root_transform = Transform( - np.array(((width / height, 0.0), (0.0, 1.0))) - ) + root_transform = Transform(np.array(((width / height, 0.0), (0.0, 1.0)))) transform = self._transform @ root_transform padded = transform.opencl_matrix() with self._context: glBindBuffer(GL.GL_UNIFORM_BUFFER, self._globals) - glBufferData( - GL.GL_UNIFORM_BUFFER, padded.size * 4, padded, GL.GL_STATIC_DRAW - ) - glBindBufferBase( - GL.GL_UNIFORM_BUFFER, 1, self._globals - ) # connect buffer to binding point 1 + glBufferData(GL.GL_UNIFORM_BUFFER, padded.size * 4, padded, GL.GL_STATIC_DRAW) + glBindBufferBase(GL.GL_UNIFORM_BUFFER, 1, self._globals) # connect buffer to binding point 1 @property def lookup_table(self) -> Sequence[int]: diff --git a/openwfs/devices/slm/texture.py b/openwfs/devices/slm/texture.py index 97c83f1..a05b674 100644 --- a/openwfs/devices/slm/texture.py +++ b/openwfs/devices/slm/texture.py @@ -28,9 +28,7 @@ def __init__(self, slm, texture_type=GL.GL_TEXTURE_2D): self.context = Context(slm) self.handle = glGenTextures(1) self.type = texture_type - self.synchronized = ( - False # self.data is not yet synchronized with texture in GPU memory - ) + self.synchronized = False # self.data is not yet synchronized with texture in GPU memory self._data_shape = None # current size of the texture, to see if we need to make a new texture or # overwrite the exiting one @@ -64,9 +62,7 @@ def set_data(self, value): with self.context: glBindTexture(self.type, self.handle) - glPixelStorei( - GL.GL_UNPACK_ALIGNMENT, 4 - ) # alignment is at least four bytes since we use float32 + glPixelStorei(GL.GL_UNPACK_ALIGNMENT, 4) # alignment is at least four bytes since we use float32 (internal_format, data_format, data_type) = ( GL.GL_R32F, GL.GL_RED, @@ -140,7 +136,5 @@ def set_data(self, value): def get_data(self): with self.context: data = np.empty(self._data_shape, dtype="float32") - glGetTextureImage( - self.handle, 0, GL.GL_RED, GL.GL_FLOAT, data.size * 4, data - ) + glGetTextureImage(self.handle, 0, GL.GL_RED, GL.GL_FLOAT, data.size * 4, data) return data diff --git a/openwfs/plot_utilities.py b/openwfs/plot_utilities.py index e11d1e8..7391aa9 100644 --- a/openwfs/plot_utilities.py +++ b/openwfs/plot_utilities.py @@ -67,7 +67,7 @@ def slope_step(a: nd, width: nd | float) -> nd: Returns: An array the size of a, with the result of the sloped step function. """ - return (a >= width) + a/width * (0 < a) * (a < width) + return (a >= width) + a / width * (0 < a) * (a < width) def linear_blend(a: nd, b: nd, blend: nd | float) -> nd: @@ -82,7 +82,7 @@ def linear_blend(a: nd, b: nd, blend: nd | float) -> nd: Returns: A linear combination of a and b, corresponding to the blend factor. a*blend + b*(1-blend) """ - return a*blend + b*(1-blend) + return a * blend + b * (1 - blend) def complex_to_rgb(array: nd, scale: float | nd | None = None, axis: int = 2) -> nd: @@ -133,7 +133,7 @@ def plot_scatter_field(x, y, array, scale, scatter_kwargs=None): Plot complex scattered data as RGB values. """ if scatter_kwargs is None: - scatter_kwargs = {'s': 80} + scatter_kwargs = {"s": 80} rgb = complex_to_rgb(array, scale, axis=1) plt.scatter(x, y, c=rgb, **scatter_kwargs) @@ -147,20 +147,26 @@ def complex_colorbar(scale, width_inverse: int = 15): z = amp * np.exp(1j * phase) rgb = complex_to_rgb(z, 1) ax = plt.subplot(1, width_inverse, width_inverse) - plt.imshow(rgb, aspect='auto', extent=(0, scale, -np.pi, np.pi)) + plt.imshow(rgb, aspect="auto", extent=(0, scale, -np.pi, np.pi)) # Ticks and labels - ax.set_yticks((-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi), ('$-\\pi$', '$-\\pi/2$', '0', '$\\pi/2$', '$\\pi$')) - ax.set_xlabel('amp.') - ax.set_ylabel('phase (rad)') + ax.set_yticks((-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi), ("$-\\pi$", "$-\\pi/2$", "0", "$\\pi/2$", "$\\pi$")) + ax.set_xlabel("amp.") + ax.set_ylabel("phase (rad)") ax.yaxis.tick_right() ax.yaxis.set_label_position("right") return ax -def complex_colorwheel(ax: Axes = None, shape: Tuple[int, int] = (100, 100), imshow_kwargs: dict = {}, - arrow_props: dict = {}, text_kwargs: dict = {}, amplitude_str: str = 'A', - phase_str: str = '$\\phi$'): +def complex_colorwheel( + ax: Axes = None, + shape: Tuple[int, int] = (100, 100), + imshow_kwargs: dict = {}, + arrow_props: dict = {}, + text_kwargs: dict = {}, + amplitude_str: str = "A", + phase_str: str = "$\\phi$", +): """ Create an rgb image for a colorwheel representing the complex unit circle. @@ -181,7 +187,7 @@ def complex_colorwheel(ax: Axes = None, shape: Tuple[int, int] = (100, 100), ims x = np.linspace(-1, 1, shape[1]).reshape(1, -1) y = np.linspace(-1, 1, shape[0]).reshape(-1, 1) - z = x + 1j*y + z = x + 1j * y rgb = complex_to_rgb(z, scale=1) step_width = 1.5 / shape[1] blend = np.expand_dims(slope_step(1 - np.abs(z) - step_width, width=step_width), axis=2) @@ -189,18 +195,32 @@ def complex_colorwheel(ax: Axes = None, shape: Tuple[int, int] = (100, 100), ims ax.imshow(rgba_wheel, extent=(-1, 1, -1, 1), **imshow_kwargs) # Add arrows with annotations - ax.annotate('', xy=(-0.98/np.sqrt(2),)*2, xytext=(0, 0), arrowprops={'color': 'white', 'width': 1.8, - 'headwidth': 5.0, 'headlength': 6.0, **arrow_props}) - ax.text(**{'x': -0.4, 'y': -0.8, 's': amplitude_str, 'color': 'white', 'fontsize': 15, **text_kwargs}) - ax.annotate('', xy=(0, 0.9), xytext=(0.9, 0), - arrowprops={'connectionstyle': 'arc3,rad=0.4', 'color': 'white', 'width': 1.8, 'headwidth': 5.0, - 'headlength': 6.0, **arrow_props}) - ax.text(**{'x': 0.1, 'y': 0.5, 's': phase_str, 'color': 'white', 'fontsize': 15, **text_kwargs}) + ax.annotate( + "", + xy=(-0.98 / np.sqrt(2),) * 2, + xytext=(0, 0), + arrowprops={"color": "white", "width": 1.8, "headwidth": 5.0, "headlength": 6.0, **arrow_props}, + ) + ax.text(**{"x": -0.4, "y": -0.8, "s": amplitude_str, "color": "white", "fontsize": 15, **text_kwargs}) + ax.annotate( + "", + xy=(0, 0.9), + xytext=(0.9, 0), + arrowprops={ + "connectionstyle": "arc3,rad=0.4", + "color": "white", + "width": 1.8, + "headwidth": 5.0, + "headlength": 6.0, + **arrow_props, + }, + ) + ax.text(**{"x": 0.1, "y": 0.5, "s": phase_str, "color": "white", "fontsize": 15, **text_kwargs}) # Hide axes spines and ticks ax.set_xticks([]) ax.set_yticks([]) - ax.spines['left'].set_visible(False) - ax.spines['right'].set_visible(False) - ax.spines['top'].set_visible(False) - ax.spines['bottom'].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["right"].set_visible(False) + ax.spines["top"].set_visible(False) + ax.spines["bottom"].set_visible(False) diff --git a/openwfs/processors/processors.py b/openwfs/processors/processors.py index d270a77..217cf34 100644 --- a/openwfs/processors/processors.py +++ b/openwfs/processors/processors.py @@ -17,9 +17,7 @@ class Roi: radius, mask type, and parameters specific to the mask type. """ - def __init__( - self, pos, radius=0.1, mask_type: str = "disk", waist=None, source_shape=None - ): + def __init__(self, pos, radius=0.1, mask_type: str = "disk", waist=None, source_shape=None): """ Initialize the Roi object. @@ -41,10 +39,7 @@ def __init__( round(pos[0] - radius) < 0 or round(pos[1] - radius) < 0 or source_shape is not None - and ( - round(pos[0] + radius) >= source_shape[0] - or round(pos[1] + radius) >= source_shape[1] - ) + and (round(pos[0] + radius) >= source_shape[0] or round(pos[1] + radius) >= source_shape[1]) ): raise ValueError("ROI does not fit inside source image") @@ -266,11 +261,7 @@ def __init__( """ super().__init__(source, multi_threaded=multi_threaded) self._data_shape = tuple(shape) if shape is not None else source.data_shape - self._pos = ( - np.array(pos) - if pos is not None - else np.zeros((len(self.data_shape),), dtype=int) - ) + self._pos = np.array(pos) if pos is not None else np.zeros((len(self.data_shape),), dtype=int) self._padding_value = padding_value @property @@ -302,15 +293,11 @@ def _fetch(self, image: np.ndarray) -> np.ndarray: # noqa src_end = np.minimum(self._pos + self._data_shape, image.shape).astype("int32") dst_start = np.maximum(-self._pos, 0).astype("int32") dst_end = dst_start + src_end - src_start - src_select = tuple( - slice(start, end) for (start, end) in zip(src_start, src_end) - ) + src_select = tuple(slice(start, end) for (start, end) in zip(src_start, src_end)) src = image.__getitem__(src_select) if any(dst_start != 0) or any(dst_end != self._data_shape): dst = np.zeros(self._data_shape) + self._padding_value - dst_select = tuple( - slice(start, end) for (start, end) in zip(dst_start, dst_end) - ) + dst_select = tuple(slice(start, end) for (start, end) in zip(dst_start, dst_end)) dst.__setitem__(dst_select, src) else: dst = src @@ -349,9 +336,7 @@ def mouse_callback(event, x, y, flags, _param): roi_size = np.minimum(x - roi_start[0], y - roi_start[1]) rect_image = image.copy() if mask_type == "square": - cv2.rectangle( - rect_image, roi_start, roi_start + roi_size, (0.0, 0.0, 255.0), 2 - ) + cv2.rectangle(rect_image, roi_start, roi_start + roi_size, (0.0, 0.0, 255.0), 2) else: cv2.circle( rect_image, @@ -403,9 +388,7 @@ def __init__( data_shape: Shape of the output. If omitted, the shape of the input data is used. multi_threaded: Whether to perform processing in a worker thread. """ - if (data_shape is not None and len(data_shape) != 2) or len( - source.data_shape - ) != 2: + if (data_shape is not None and len(data_shape) != 2) or len(source.data_shape) != 2: raise ValueError("TransformProcessor only supports 2-D data") if transform is None: transform = Transform() @@ -413,13 +396,10 @@ def __init__( # check if input and output pixel sizes are compatible dst_unit = transform.destination_unit(source.pixel_size.unit) if pixel_size is not None and not pixel_size.unit.is_equivalent(dst_unit): - raise ValueError( - "Pixel size unit does not match the unit of the transformed data" - ) + raise ValueError("Pixel size unit does not match the unit of the transformed data") if pixel_size is None and not source.pixel_size.unit.is_equivalent(dst_unit): raise ValueError( - "The transform changes the unit of the coordinates." - " An output pixel_size must be provided." + "The transform changes the unit of the coordinates." " An output pixel_size must be provided." ) self.transform = transform diff --git a/openwfs/simulation/microscope.py b/openwfs/simulation/microscope.py index 93cae59..15f7274 100644 --- a/openwfs/simulation/microscope.py +++ b/openwfs/simulation/microscope.py @@ -107,9 +107,7 @@ def __init__( if get_pixel_size(aberrations) is None: aberrations = StaticSource(aberrations) - super().__init__( - source, aberrations, incident_field, multi_threaded=multi_threaded - ) + super().__init__(source, aberrations, incident_field, multi_threaded=multi_threaded) self._magnification = magnification self._data_shape = data_shape if data_shape is not None else source.data_shape self.numerical_aperture = numerical_aperture @@ -153,9 +151,7 @@ def _fetch( source_pixel_size = get_pixel_size(source) target_pixel_size = self.pixel_size / self.magnification if np.any(source_pixel_size > target_pixel_size): - warnings.warn( - "The resolution of the specimen image is worse than that of the output." - ) + warnings.warn("The resolution of the specimen image is worse than that of the output.") # Note: there seems to be a bug (feature?) in `fftconvolve` that shifts the image by one pixel # when the 'same' option is used. To compensate for this feature, @@ -188,21 +184,13 @@ def _fetch( # Compute the field in the pupil plane # The aberrations and the SLM phase pattern are both mapped to the pupil plane coordinates - pupil_field = patterns.disk( - pupil_shape, radius=self.numerical_aperture, extent=pupil_extent - ) - pupil_area = np.sum( - pupil_field - ) # TODO (efficiency): compute area directly from radius + pupil_field = patterns.disk(pupil_shape, radius=self.numerical_aperture, extent=pupil_extent) + pupil_area = np.sum(pupil_field) # TODO (efficiency): compute area directly from radius # Project aberrations if aberrations is not None: # use default of 2.0 * NA for the extent of the aberration map if no pixel size is provided - aberration_extent = ( - (2.0 * self.numerical_aperture,) * 2 - if get_pixel_size(aberrations) is None - else None - ) + aberration_extent = (2.0 * self.numerical_aperture,) * 2 if get_pixel_size(aberrations) is None else None pupil_field = pupil_field * np.exp( 1.0j * project( @@ -290,8 +278,6 @@ def get_camera( if transform is None and data_shape is None and pixel_size is None: src = self else: - src = TransformProcessor( - self, data_shape=data_shape, pixel_size=pixel_size, transform=transform - ) + src = TransformProcessor(self, data_shape=data_shape, pixel_size=pixel_size, transform=transform) return Camera(src, **kwargs) diff --git a/openwfs/simulation/mockdevices.py b/openwfs/simulation/mockdevices.py index 79a3b83..5b6566b 100644 --- a/openwfs/simulation/mockdevices.py +++ b/openwfs/simulation/mockdevices.py @@ -43,11 +43,7 @@ def __init__( else: pixel_size = get_pixel_size(data) - if ( - pixel_size is not None - and (np.isscalar(pixel_size) or pixel_size.size == 1) - and data.ndim > 1 - ): + if pixel_size is not None and (np.isscalar(pixel_size) or pixel_size.size == 1) and data.ndim > 1: pixel_size = pixel_size.repeat(data.ndim) if multi_threaded is None: @@ -176,9 +172,7 @@ def _fetch(self, data) -> np.ndarray: # noqa if self.analog_max == 0.0: # auto scaling max_value = np.max(data) if max_value > 0.0: - data = data * ( - self.digital_max / max_value - ) # auto-scale to maximum value + data = data * (self.digital_max / max_value) # auto-scale to maximum value else: data = data * (self.digital_max / self.analog_max) @@ -186,9 +180,7 @@ def _fetch(self, data) -> np.ndarray: # noqa data = self._rng.poisson(data) if self._gaussian_noise_std > 0.0: - data = data + self._rng.normal( - scale=self._gaussian_noise_std, size=data.shape - ) + data = data + self._rng.normal(scale=self._gaussian_noise_std, size=data.shape) return np.clip(np.rint(data), 0, self.digital_max).astype("uint16") diff --git a/openwfs/simulation/slm.py b/openwfs/simulation/slm.py index eaa04b2..9cf682e 100644 --- a/openwfs/simulation/slm.py +++ b/openwfs/simulation/slm.py @@ -38,9 +38,7 @@ def _fetch(self, slm_phases: np.ndarray) -> np.ndarray: # noqa Updates the complex field output of the SLM. The output field is the sum of the modulated field and the non-modulated field. """ - return self.modulated_field_amplitude * ( - np.exp(1j * slm_phases) + self.non_modulated_field - ) + return self.modulated_field_amplitude * (np.exp(1j * slm_phases) + self.non_modulated_field) class _SLMTiming(Detector): @@ -202,9 +200,7 @@ def __init__( self._hardware_fields = PhaseToField( self._hardware_phases, field_amplitude, non_modulated_field_fraction ) # Simulates reading the field from the SLM - self._lookup_table = ( - None # index = input phase (scaled to -> [0, 255]), value = grey value - ) + self._lookup_table = None # index = input phase (scaled to -> [0, 255]), value = grey value self._first_update_ns = time.time_ns() self._back_buffer = np.zeros(shape, dtype=np.float32) @@ -214,12 +210,8 @@ def update(self): self._start() # wait for detectors to finish if self.refresh_rate > 0: # wait for the vertical retrace - time_in_frames = unitless( - (time.time_ns() - self._first_update_ns) * u.ns * self.refresh_rate - ) - time_to_next_frame = ( - np.ceil(time_in_frames) - time_in_frames - ) / self.refresh_rate + time_in_frames = unitless((time.time_ns() - self._first_update_ns) * u.ns * self.refresh_rate) + time_to_next_frame = (np.ceil(time_in_frames) - time_in_frames) / self.refresh_rate time.sleep(time_to_next_frame.tovalue(u.s)) # update the start time (this is also done in the actual SLM) self._start() @@ -235,9 +227,7 @@ def update(self): if self._lookup_table is None: grey_values = (256 * tx).astype(np.uint8) else: - lookup_index = (self._lookup_table.shape[0] * tx).astype( - np.uint8 - ) # index into lookup table + lookup_index = (self._lookup_table.shape[0] * tx).astype(np.uint8) # index into lookup table grey_values = self._lookup_table[lookup_index] self._hardware_timing.send(grey_values) diff --git a/openwfs/simulation/transmission.py b/openwfs/simulation/transmission.py index d3b3ec4..f0c7a12 100644 --- a/openwfs/simulation/transmission.py +++ b/openwfs/simulation/transmission.py @@ -50,12 +50,7 @@ def __init__( """ # transmission matrix (normalized so that the maximum transmission is 1) - self.t = ( - t - if t is not None - else np.exp(1.0j * aberrations) - / (aberrations.shape[0] * aberrations.shape[1]) - ) + self.t = t if t is not None else np.exp(1.0j * aberrations) / (aberrations.shape[0] * aberrations.shape[1]) self.slm = slm if slm is not None else SLM(self.t.shape[0:2]) super().__init__(self.slm.field, multi_threaded=multi_threaded) diff --git a/openwfs/utilities/patterns.py b/openwfs/utilities/patterns.py index 818c352..ca650fa 100644 --- a/openwfs/utilities/patterns.py +++ b/openwfs/utilities/patterns.py @@ -153,9 +153,7 @@ def propagation( # convert pupil coordinates to absolute k_x, k_y coordinates k_0 = 2.0 * np.pi / wavelength extent_k = Quantity(extent) * numerical_aperture * k_0 - k_z = np.sqrt( - np.maximum((refractive_index * k_0) ** 2 - r2_range(shape, extent_k), 0.0) - ) + k_z = np.sqrt(np.maximum((refractive_index * k_0) ** 2 - r2_range(shape, extent_k), 0.0)) return unitless(k_z * distance) diff --git a/openwfs/utilities/utilities.py b/openwfs/utilities/utilities.py index 3572e46..5102792 100644 --- a/openwfs/utilities/utilities.py +++ b/openwfs/utilities/utilities.py @@ -97,17 +97,11 @@ def __init__( ): self.transform = Quantity(transform if transform is not None else np.eye(2)) - self.source_origin = ( - Quantity(source_origin) if source_origin is not None else None - ) - self.destination_origin = ( - Quantity(destination_origin) if destination_origin is not None else None - ) + self.source_origin = Quantity(source_origin) if source_origin is not None else None + self.destination_origin = Quantity(destination_origin) if destination_origin is not None else None if source_origin is not None: - self.destination_unit( - self.source_origin.unit - ) # check if the units are consistent + self.destination_unit(self.source_origin.unit) # check if the units are consistent def destination_unit(self, src_unit: u.Unit) -> u.Unit: """Computes the unit of the output of the transformation, given the unit of the input. @@ -116,17 +110,11 @@ def destination_unit(self, src_unit: u.Unit) -> u.Unit: ValueError: If src_unit does not match the unit of the source_origin (if specified) or if dst_unit does not match the unit of the destination_origin (if specified). """ - if ( - self.source_origin is not None - and not self.source_origin.unit.is_equivalent(src_unit) - ): + if self.source_origin is not None and not self.source_origin.unit.is_equivalent(src_unit): raise ValueError("src_unit must match the units of source_origin.") dst_unit = (self.transform[0, 0] * src_unit).unit - if ( - self.destination_origin is not None - and not self.destination_origin.unit.is_equivalent(dst_unit) - ): + if self.destination_origin is not None and not self.destination_origin.unit.is_equivalent(dst_unit): raise ValueError("dst_unit must match the units of destination_origin.") return dst_unit @@ -150,9 +138,7 @@ def cv2_matrix( if self.source_origin is not None: source_origin += self.source_origin - destination_origin = ( - 0.5 * (np.array(destination_shape) - 1.0) * destination_pixel_size - ) + destination_origin = 0.5 * (np.array(destination_shape) - 1.0) * destination_pixel_size if self.destination_origin is not None: destination_origin += self.destination_origin @@ -173,19 +159,13 @@ def cv2_matrix( transform_matrix = transform_matrix[:, [1, 0, 2]] return transform_matrix - def to_matrix( - self, source_pixel_size: CoordinateType, destination_pixel_size: CoordinateType - ) -> np.ndarray: + def to_matrix(self, source_pixel_size: CoordinateType, destination_pixel_size: CoordinateType) -> np.ndarray: matrix = np.zeros((2, 3)) - matrix[0:2, 0:2] = unitless( - self.transform * source_pixel_size / destination_pixel_size - ) + matrix[0:2, 0:2] = unitless(self.transform * source_pixel_size / destination_pixel_size) if self.destination_origin is not None: matrix[0:2, 2] = unitless(self.destination_origin / destination_pixel_size) if self.source_origin is not None: - matrix[0:2, 2] -= unitless( - (self.transform @ self.source_origin) / destination_pixel_size - ) + matrix[0:2, 2] -= unitless((self.transform @ self.source_origin) / destination_pixel_size) return matrix def opencl_matrix(self) -> np.ndarray: @@ -261,11 +241,7 @@ def compose(self, other: "Transform"): """ transform = self.transform @ other.transform source_origin = other.source_origin - destination_origin = ( - self.apply(other.destination_origin) - if other.destination_origin is not None - else None - ) + destination_origin = self.apply(other.destination_origin) if other.destination_origin is not None else None return Transform(transform, source_origin, destination_origin) def _standard_input(self) -> Quantity: @@ -301,9 +277,7 @@ def place( """ out_extent = out_pixel_size * np.array(out_shape) transform = Transform(destination_origin=offset) - return project( - source, out_extent=out_extent, out_shape=out_shape, transform=transform, out=out - ) + return project(source, out_extent=out_extent, out_shape=out_shape, transform=transform, out=out) def project( @@ -335,9 +309,7 @@ def project( transform = transform if transform is not None else Transform() if out is not None: if out_shape is not None and out_shape != out.shape: - raise ValueError( - "out_shape and out.shape must match. Note that out_shape may be omitted" - ) + raise ValueError("out_shape and out.shape must match. Note that out_shape may be omitted") if out.dtype != source.dtype: raise ValueError("out and source must have the same dtype") out_shape = out.shape @@ -345,9 +317,7 @@ def project( if out_shape is None: raise ValueError("Either out_shape or out must be specified") if out_extent is None: - raise ValueError( - "Either out_extent or the pixel_size metadata of out must be specified" - ) + raise ValueError("Either out_extent or the pixel_size metadata of out must be specified") source_extent = source_extent if source_extent is not None else get_extent(source) source_ps = source_extent / np.array(source.shape) out_ps = out_extent / np.array(out_shape) @@ -388,9 +358,7 @@ def project( borderValue=(0.0,), ) if out is not None and out is not dst: - raise ValueError( - "OpenCV did not use the specified output array. This should not happen." - ) + raise ValueError("OpenCV did not use the specified output array. This should not happen.") out = dst return set_pixel_size(out, out_ps) diff --git a/tests/test_algorithms_troubleshoot.py b/tests/test_algorithms_troubleshoot.py index a5b157c..daae1a1 100644 --- a/tests/test_algorithms_troubleshoot.py +++ b/tests/test_algorithms_troubleshoot.py @@ -25,12 +25,8 @@ def test_signal_std(): a = np.random.rand(400, 400) b = np.random.rand(400, 400) assert signal_std(a, a) < 1e-6 # Test noise only - assert ( - np.abs(signal_std(a + b, b) - a.std()) < 0.005 - ) # Test signal+uncorrelated noise - assert ( - np.abs(signal_std(a + a, a) - np.sqrt(3) * a.std()) < 0.005 - ) # Test signal+correlated noise + assert np.abs(signal_std(a + b, b) - a.std()) < 0.005 # Test signal+uncorrelated noise + assert np.abs(signal_std(a + a, a) - np.sqrt(3) * a.std()) < 0.005 # Test signal+correlated noise def test_cnr(): @@ -41,12 +37,8 @@ def test_cnr(): b = np.random.randn(800, 800) cnr_gt = 3.0 # Ground Truth assert cnr(a, a) < 1e-6 # Test noise only - assert ( - np.abs(cnr(cnr_gt * a + b, b) - cnr_gt) < 0.01 - ) # Test signal+uncorrelated noise - assert ( - np.abs(cnr(cnr_gt * a + a, a) - np.sqrt((cnr_gt + 1) ** 2 - 1)) < 0.01 - ) # Test signal+correlated noise + assert np.abs(cnr(cnr_gt * a + b, b) - cnr_gt) < 0.01 # Test signal+uncorrelated noise + assert np.abs(cnr(cnr_gt * a + a, a) - np.sqrt((cnr_gt + 1) ** 2 - 1)) < 0.01 # Test signal+correlated noise def test_find_pixel_shift(): @@ -95,12 +87,8 @@ def test_field_correlation(): assert field_correlation(a, a) == 1.0 # Self-correlation assert field_correlation(2 * a, a) == 1.0 # Invariant under scalar-multiplication assert field_correlation(a, b) == 0.0 # Orthogonal arrays - assert ( - np.abs(field_correlation(a + b, b) - np.sqrt(0.5)) < 1e-10 - ) # Self+orthogonal array - assert ( - np.abs(field_correlation(b, c) - np.conj(field_correlation(c, b))) < 1e-10 - ) # Arguments swapped + assert np.abs(field_correlation(a + b, b) - np.sqrt(0.5)) < 1e-10 # Self+orthogonal array + assert np.abs(field_correlation(b, c) - np.conj(field_correlation(c, b))) < 1e-10 # Arguments swapped def test_frame_correlation(): @@ -171,9 +159,7 @@ def test_pearson_correlation_noise_compensated(): assert np.isclose(noise1.var(), noise2.var(), atol=2e-3) assert np.isclose(corr_AA, 1, atol=2e-3) assert np.isclose(corr_AB, 0, atol=2e-3) - A_spearman = 1 / np.sqrt( - (1 + noise1.var() / A1.var()) * (1 + noise2.var() / A2.var()) - ) + A_spearman = 1 / np.sqrt((1 + noise1.var() / A1.var()) * (1 + noise2.var() / A2.var())) assert np.isclose(corr_AA_with_noise, A_spearman, atol=2e-3) @@ -188,9 +174,7 @@ def test_fidelity_phase_calibration_ssa_noise_free(n_y, n_x, phase_steps, b, c, # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (n_y, n_x)) sim = SimulatedWFS(aberrations=aberrations) - alg = StepwiseSequential( - feedback=sim, slm=sim.slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps - ) + alg = StepwiseSequential(feedback=sim, slm=sim.slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps) result = alg.execute() assert result.fidelity_calibration > 0.99 @@ -206,12 +190,8 @@ def test_fidelity_phase_calibration_ssa_noise_free(n_y, n_x, phase_steps, b, c, assert result.fidelity_calibration > 0.99 -@pytest.mark.parametrize( - "n_y, n_x, phase_steps, gaussian_noise_std", [(4, 4, 10, 0.2), (6, 6, 12, 1.0)] -) -def test_fidelity_phase_calibration_ssa_with_noise( - n_y, n_x, phase_steps, gaussian_noise_std -): +@pytest.mark.parametrize("n_y, n_x, phase_steps, gaussian_noise_std", [(4, 4, 10, 0.2), (6, 6, 12, 1.0)]) +def test_fidelity_phase_calibration_ssa_with_noise(n_y, n_x, phase_steps, gaussian_noise_std): """ Test estimation of phase calibration fidelity factor, with the SSA algorithm. With noise. """ @@ -222,7 +202,7 @@ def test_fidelity_phase_calibration_ssa_with_noise( aberration = StaticSource(aberration_phase, extent=2 * numerical_aperture) img = np.zeros((64, 64), dtype=np.int16) img[32, 32] = 250 - src = StaticSource(img, pixel_size = 500 * u.nm) + src = StaticSource(img, pixel_size=500 * u.nm) # SLM, simulation, camera, ROI detector slm = SLM(shape=(80, 80)) @@ -238,27 +218,19 @@ def test_fidelity_phase_calibration_ssa_with_noise( roi_detector = SingleRoi(cam, radius=0) # Only measure that specific point # Define and run WFS algorithm - alg = StepwiseSequential( - feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps - ) + alg = StepwiseSequential(feedback=roi_detector, slm=slm, n_x=n_x, n_y=n_y, phase_steps=phase_steps) result_good = alg.execute() assert result_good.fidelity_calibration > 0.9 # SLM with incorrect phase response linear_phase = np.arange(0, 2 * np.pi, 2 * np.pi / 256) - slm.phase_response = phase_response_test_function( - linear_phase, b=0.05, c=0.6, gamma=1.5 - ) + slm.phase_response = phase_response_test_function(linear_phase, b=0.05, c=0.6, gamma=1.5) result_good = alg.execute() assert result_good.fidelity_calibration < 0.9 -@pytest.mark.parametrize( - "num_blocks, phase_steps, expected_fid, atol", [(10, 8, 1, 1e-6)] -) -def test_measure_modulated_light_dual_phase_stepping_noise_free( - num_blocks, phase_steps, expected_fid, atol -): +@pytest.mark.parametrize("num_blocks, phase_steps, expected_fid, atol", [(10, 8, 1, 1e-6)]) +def test_measure_modulated_light_dual_phase_stepping_noise_free(num_blocks, phase_steps, expected_fid, atol): """Test fidelity estimation due to amount of modulated light. Noise-free.""" # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (20, 20)) @@ -275,9 +247,7 @@ def test_measure_modulated_light_dual_phase_stepping_noise_free( "num_blocks, phase_steps, gaussian_noise_std, atol", [(10, 6, 0.0, 1e-6), (6, 8, 2.0, 1e-3)], ) -def test_measure_modulated_light_dual_phase_stepping_with_noise( - num_blocks, phase_steps, gaussian_noise_std, atol -): +def test_measure_modulated_light_dual_phase_stepping_with_noise(num_blocks, phase_steps, gaussian_noise_std, atol): """Test fidelity estimation due to amount of modulated light. Can test with noise.""" # === Define mock hardware, perfect SLM === # Aberration and image source @@ -308,9 +278,7 @@ def test_measure_modulated_light_dual_phase_stepping_with_noise( "phase_steps, modulated_field_amplitude, non_modulated_field", [(6, 1.0, 0.0), (8, 0.5, 0.5), (8, 1.0, 0.25)], ) -def test_measure_modulated_light_noise_free( - phase_steps, modulated_field_amplitude, non_modulated_field -): +def test_measure_modulated_light_noise_free(phase_steps, modulated_field_amplitude, non_modulated_field): """Test fidelity estimation due to amount of modulated light. Noise-free.""" # Perfect SLM, noise-free aberrations = np.random.uniform(0.0, 2 * np.pi, (20, 20)) @@ -322,9 +290,7 @@ def test_measure_modulated_light_noise_free( sim = SimulatedWFS(aberrations=aberrations, slm=slm) # Measure the amount of modulated light (no non-modulated light present) - fidelity_modulated = measure_modulated_light( - slm=sim.slm, feedback=sim, phase_steps=phase_steps - ) + fidelity_modulated = measure_modulated_light(slm=sim.slm, feedback=sim, phase_steps=phase_steps) expected_fid = 1.0 / (1.0 + non_modulated_field**2) assert np.isclose(fidelity_modulated, expected_fid, rtol=0.1) @@ -341,7 +307,7 @@ def test_measure_modulated_light_dual_phase_stepping_with_noise( # Aberration and image source img = np.zeros((64, 64), dtype=np.int16) img[32, 32] = 100 - src = StaticSource(img, pixel_size= 200 * u.nm) + src = StaticSource(img, pixel_size=200 * u.nm) # SLM, simulation, camera, ROI detector slm = SLM( @@ -355,7 +321,5 @@ def test_measure_modulated_light_dual_phase_stepping_with_noise( # Measure the amount of modulated light (no non-modulated light present) expected_fid = 1.0 / (1.0 + non_modulated_field**2) - fidelity_modulated = measure_modulated_light( - slm=slm, feedback=roi_detector, phase_steps=phase_steps - ) + fidelity_modulated = measure_modulated_light(slm=slm, feedback=roi_detector, phase_steps=phase_steps) assert np.isclose(fidelity_modulated, expected_fid, rtol=0.1) diff --git a/tests/test_camera.py b/tests/test_camera.py index f635c90..2889a4a 100644 --- a/tests/test_camera.py +++ b/tests/test_camera.py @@ -38,16 +38,8 @@ def test_roi(camera, binning, top, left): # take care that the size will be a multiple of the increment, # and that setting the binning will round this number down camera.binning = binning - expected_width = ( - (original_shape[1] // binning) - // camera._nodes.Width.inc - * camera._nodes.Width.inc - ) - expected_height = ( - (original_shape[0] // binning) - // camera._nodes.Height.inc - * camera._nodes.Height.inc - ) + expected_width = (original_shape[1] // binning) // camera._nodes.Width.inc * camera._nodes.Width.inc + expected_height = (original_shape[0] // binning) // camera._nodes.Height.inc * camera._nodes.Height.inc assert camera.data_shape == (expected_height, expected_width) # check if setting the ROI works diff --git a/tests/test_core.py b/tests/test_core.py index a089f27..03e5278 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -85,9 +85,7 @@ def test_timing_detector(caplog, duration): def test_noise_detector(): - source = NoiseSource( - "uniform", data_shape=(10, 11, 20), low=-1.0, high=1.0, pixel_size=4 * u.um - ) + source = NoiseSource("uniform", data_shape=(10, 11, 20), low=-1.0, high=1.0, pixel_size=4 * u.um) data = source.read() assert data.shape == (10, 11, 20) assert np.min(data) >= -1.0 @@ -102,18 +100,12 @@ def test_noise_detector(): def test_mock_slm(): slm = SLM((4, 4)) slm.set_phases(0.5) - assert np.allclose( - slm.pixels.read(), round(0.5 * 256 / (2 * np.pi)), atol=0.5 / 256 - ) + assert np.allclose(slm.pixels.read(), round(0.5 * 256 / (2 * np.pi)), atol=0.5 / 256) discretized_phase = slm.phases.read() assert np.allclose(discretized_phase, 0.5, atol=1.1 * np.pi / 256) - assert np.allclose( - slm.field.read(), np.exp(1j * discretized_phase[0, 0]), rtol=2 / 256 - ) + assert np.allclose(slm.field.read(), np.exp(1j * discretized_phase[0, 0]), rtol=2 / 256) slm.set_phases(np.array(((0.1, 0.2), (0.3, 0.4))), update=False) - assert np.allclose( - slm.phases.read(), 0.5, atol=1.1 * np.pi / 256 - ) # slm.update() not yet called, so should be 0.5 + assert np.allclose(slm.phases.read(), 0.5, atol=1.1 * np.pi / 256) # slm.update() not yet called, so should be 0.5 slm.update() assert np.allclose( slm.phases.read(), diff --git a/tests/test_processors.py b/tests/test_processors.py index 591a2c5..6406fa2 100644 --- a/tests/test_processors.py +++ b/tests/test_processors.py @@ -8,8 +8,7 @@ @pytest.mark.skip( - reason="This is an interactive test: skip by default. TODO: actually test if the roi was " - "selected correctly." + reason="This is an interactive test: skip by default. TODO: actually test if the roi was " "selected correctly." ) def test_croppers(): img = sk.data.camera() @@ -53,9 +52,7 @@ def test_single_roi(x, y, radius, expected_avg): roi_processor.trigger() result = roi_processor.read() - assert np.isclose( - result, expected_avg - ), f"ROI average value is incorrect. Expected: {expected_avg}, Got: {result}" + assert np.isclose(result, expected_avg), f"ROI average value is incorrect. Expected: {expected_avg}, Got: {result}" def test_multiple_roi_simple_case(): diff --git a/tests/test_scanning_microscope.py b/tests/test_scanning_microscope.py index fe0335a..fadee2d 100644 --- a/tests/test_scanning_microscope.py +++ b/tests/test_scanning_microscope.py @@ -50,9 +50,7 @@ def test_scan_axis(start, stop): amplitude = 0.5 * (stop - start) # test clipping - assert np.allclose( - step, a.step(center - 1.1 * amplitude, center + 1.1 * amplitude, sample_rate) - ) + assert np.allclose(step, a.step(center - 1.1 * amplitude, center + 1.1 * amplitude, sample_rate)) # test scan. Note that we cannot use the full scan range because we need # some time to accelerate / decelerate @@ -67,12 +65,8 @@ def test_scan_axis(start, stop): assert linear_region.start == len(scan) - linear_region.stop assert np.isclose(scan[0], a.to_volt(launch)) assert np.isclose(scan[-1], a.to_volt(land)) - assert np.isclose( - scan[linear_region.start], a.to_volt(center - 0.8 * amplitude + half_pixel) - ) - assert np.isclose( - scan[linear_region.stop - 1], a.to_volt(center + 0.8 * amplitude - half_pixel) - ) + assert np.isclose(scan[linear_region.start], a.to_volt(center - 0.8 * amplitude + half_pixel)) + assert np.isclose(scan[linear_region.stop - 1], a.to_volt(center + 0.8 * amplitude - half_pixel)) speed = np.diff(scan[linear_region]) assert np.allclose(speed, speed[0]) # speed should be constant @@ -116,9 +110,7 @@ def test_scan_pattern(direction, bidirectional): """A unit test for scanning patterns.""" reference_zoom = 1.2 scanner = make_scanner(bidirectional, direction, reference_zoom) - assert np.allclose( - scanner.extent, scanner._x_axis.scale * 4.0 * u.V / reference_zoom - ) + assert np.allclose(scanner.extent, scanner._x_axis.scale * 4.0 * u.V / reference_zoom) # plt.imshow(scanner.read()) # plt.show() @@ -134,9 +126,7 @@ def test_scan_pattern(direction, bidirectional): if direction == "horizontal": assert np.allclose(full, full[0, :]) # all rows should be the same - assert np.allclose( - x, full, atol=0.2 * pixel_size - ) # some rounding due to quantization + assert np.allclose(x, full, atol=0.2 * pixel_size) # some rounding due to quantization else: # all columns should be the same (note we need to keep the last dimension for correct broadcasting) assert np.allclose(full, full[:, 0:1]) @@ -159,9 +149,7 @@ def test_scan_pattern(direction, bidirectional): assert scanner.data_shape == (height, width) roi = scanner.read().astype("float32") - 0x8000 - assert np.allclose( - full[top : (top + height), left : (left + width)], roi, atol=0.2 * pixel_size - ) + assert np.allclose(full[top : (top + height), left : (left + width)], roi, atol=0.2 * pixel_size) @pytest.mark.parametrize("bidirectional", [False, True]) @@ -179,9 +167,7 @@ def test_park_beam(bidirectional): img = scanner.read() assert img.shape == (2, 1) voltages = scanner._scan_pattern - assert np.allclose( - voltages[1, :], voltages[1, 0] - ) # all voltages should be the same + assert np.allclose(voltages[1, :], voltages[1, 0]) # all voltages should be the same # Park beam vertically scanner.width = 2 @@ -197,12 +183,8 @@ def test_park_beam(bidirectional): img = scanner.read() assert img.shape == (1, 1) voltages = scanner._scan_pattern - assert np.allclose( - voltages[1, :], voltages[1, 0] - ) # all voltages should be the same - assert np.allclose( - voltages[0, :], voltages[0, 0] - ) # all voltages should be the same + assert np.allclose(voltages[1, :], voltages[1, 0]) # all voltages should be the same + assert np.allclose(voltages[0, :], voltages[0, 0]) # all voltages should be the same # test zooming diff --git a/tests/test_simulation.py b/tests/test_simulation.py index ff04b15..87daf5a 100644 --- a/tests/test_simulation.py +++ b/tests/test_simulation.py @@ -20,13 +20,9 @@ def test_mock_camera_and_single_roi(): """ img = np.zeros((1000, 1000), dtype=np.int16) img[200, 300] = 39.39 # some random float - src = Camera(StaticSource(img, pixel_size= 450 * u.nm)) - roi_detector = SingleRoi( - src, pos=(200, 300), radius=0 - ) # Only measure that specific point - assert roi_detector.read() == int( - 2**16 - 1 - ) # it should cast the array into some int + src = Camera(StaticSource(img, pixel_size=450 * u.nm)) + roi_detector = SingleRoi(src, pos=(200, 300), radius=0) # Only measure that specific point + assert roi_detector.read() == int(2**16 - 1) # it should cast the array into some int @pytest.mark.parametrize("shape", [(1000, 1000), (999, 999)]) @@ -38,12 +34,10 @@ def test_microscope_without_magnification(shape): # construct input image img = np.zeros(shape, dtype=np.int16) img[256, 256] = 100 - src = Camera(StaticSource(img, pixel_size= 400 * u.nm)) + src = Camera(StaticSource(img, pixel_size=400 * u.nm)) # construct microscope - sim = Microscope( - source=src, magnification=1, numerical_aperture=1, wavelength=800 * u.nm - ) + sim = Microscope(source=src, magnification=1, numerical_aperture=1, wavelength=800 * u.nm) cam = sim.get_camera() img = cam.read() @@ -56,7 +50,7 @@ def test_microscope_and_aberration(): """ img = np.zeros((1000, 1000), dtype=np.int16) img[256, 256] = 100 - src = Camera(StaticSource(img, pixel_size= 400 * u.nm)) + src = Camera(StaticSource(img, pixel_size=400 * u.nm)) slm = SLM(shape=(512, 512)) @@ -84,15 +78,13 @@ def test_slm_and_aberration(): """ img = np.zeros((1000, 1000), dtype=np.int16) img[256, 256] = 100 - src = Camera(StaticSource(img, pixel_size= 400 * u.nm)) + src = Camera(StaticSource(img, pixel_size=400 * u.nm)) slm = SLM(shape=(512, 512)) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) * 0 slm.set_phases(-aberrations) - aberration = StaticSource( - aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled - ) + aberration = StaticSource(aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled) sim1 = Microscope( source=src, @@ -125,7 +117,7 @@ def test_slm_tilt(): img[signal_location] = 100 pixel_size = 400 * u.nm wavelength = 750 * u.nm - src = Camera(StaticSource(img, pixel_size= pixel_size)) + src = Camera(StaticSource(img, pixel_size=pixel_size)) slm = SLM(shape=(1000, 1000)) @@ -159,9 +151,7 @@ def test_microscope_wavefront_shaping(caplog): # caplog.set_level(logging.DEBUG) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) + np.pi - aberration = StaticSource( - aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled - ) # note: incorrect scaling! + aberration = StaticSource(aberrations, pixel_size=1.0 / 512 * u.dimensionless_unscaled) # note: incorrect scaling! img = np.zeros((1000, 1000), dtype=np.int16) img[256, 256] = 100 @@ -170,7 +160,7 @@ def test_microscope_wavefront_shaping(caplog): signal_location = (250, 200) img[signal_location] = 100 - src = StaticSource(img, pixel_size= 400 * u.nm) + src = StaticSource(img, pixel_size=400 * u.nm) slm = SLM(shape=(1000, 1000)) @@ -183,13 +173,9 @@ def test_microscope_wavefront_shaping(caplog): ) cam = sim.get_camera(analog_max=100) - roi_detector = SingleRoi( - cam, pos=signal_location, radius=0 - ) # Only measure that specific point + roi_detector = SingleRoi(cam, pos=signal_location, radius=0) # Only measure that specific point - alg = StepwiseSequential( - feedback=roi_detector, slm=slm, phase_steps=3, n_x=3, n_y=3 - ) + alg = StepwiseSequential(feedback=roi_detector, slm=slm, phase_steps=3, n_x=3, n_y=3) t = alg.execute().t # test if the modes differ. The error causes them not to differ @@ -246,10 +232,7 @@ def test_mock_slm_lut_and_phase_response(): / 256 ) expected_output_phases_a = ( - np.asarray((255, 255, 0, 0, 1, 64, 128, 192, 255, 255, 0, 0, 1, 255, 0)) - * 2 - * np.pi - / 256 + np.asarray((255, 255, 0, 0, 1, 64, 128, 192, 255, 255, 0, 0, 1, 255, 0)) * 2 * np.pi / 256 ) slm1 = SLM(shape=(3, input_phases_a.shape[0])) slm1.set_phases(input_phases_a) @@ -275,21 +258,14 @@ def test_mock_slm_lut_and_phase_response(): slm3.lookup_table = lookup_table slm3.set_phases(linear_phase) assert np.all( - np.abs( - slm3.phases.read() - - inverse_phase_response_test_function(linear_phase, b, c, gamma) - ) + np.abs(slm3.phases.read() - inverse_phase_response_test_function(linear_phase, b, c, gamma)) < (1.1 * np.pi / 256) ) # === Test custom lookup table that counters custom synthetic phase response === - linear_phase_highres = np.arange( - 0, 2 * np.pi * 255.49 / 256, 0.25 * 2 * np.pi / 256 - ) + linear_phase_highres = np.arange(0, 2 * np.pi * 255.49 / 256, 0.25 * 2 * np.pi / 256) slm4 = SLM(shape=(3, linear_phase_highres.shape[0])) slm4.phase_response = phase_response slm4.lookup_table = lookup_table slm4.set_phases(linear_phase_highres) - assert np.all( - np.abs(slm4.phases.read()[0] - linear_phase_highres) < (3 * np.pi / 256) - ) + assert np.all(np.abs(slm4.phases.read()[0] - linear_phase_highres) < (3 * np.pi / 256)) diff --git a/tests/test_slm.py b/tests/test_slm.py index b4575d0..7c6344b 100644 --- a/tests/test_slm.py +++ b/tests/test_slm.py @@ -173,9 +173,7 @@ def test_refresh_rate(): stop = time.time_ns() * u.ns del slm actual_refresh_rate = frame_count / (stop - start) - assert np.allclose( - refresh_rate.to_value(u.Hz), actual_refresh_rate.to_value(u.Hz), rtol=1e-2 - ) + assert np.allclose(refresh_rate.to_value(u.Hz), actual_refresh_rate.to_value(u.Hz), rtol=1e-2) def test_get_pixels(): @@ -261,9 +259,7 @@ def test_circular_geometry(slm): # read back the pixels and verify conversion to gray values pixels = np.rint(slm.pixels.read() / 256 * 70) - polar_pixels = cv2.warpPolar( - pixels, (100, 40), (99.5, 99.5), 100, cv2.WARP_POLAR_LINEAR - ) + polar_pixels = cv2.warpPolar(pixels, (100, 40), (99.5, 99.5), 100, cv2.WARP_POLAR_LINEAR) assert np.allclose( polar_pixels[:, 3:24], diff --git a/tests/test_utilities.py b/tests/test_utilities.py index f6766cc..948b55d 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -55,9 +55,7 @@ def test_to_matrix(): # Also check openGL matrix (has y-axis flipped and extra row and column) expected_matrix = ((1, 2, 1), (3, 4, 2)) - transform = Transform( - transform=((1, 2), (3, 4)), source_origin=(0, 0), destination_origin=(1, 2) - ) + transform = Transform(transform=((1, 2), (3, 4)), source_origin=(0, 0), destination_origin=(1, 2)) result_matrix = transform.to_matrix((1, 1), (1, 1)) assert np.allclose(result_matrix, expected_matrix) @@ -111,9 +109,7 @@ def test_transform(): assert np.allclose(matrix, ((1.0, 0.0, 0.0), (0.0, 1.0, 0.0))) # shift both origins by same distance - t0 = Transform( - source_origin=-ps1 * (1.7, 2.2), destination_origin=-ps1 * (1.7, 2.2) - ) + t0 = Transform(source_origin=-ps1 * (1.7, 2.2), destination_origin=-ps1 * (1.7, 2.2)) dst0 = project( src, source_extent=ps1 * np.array(src.shape), diff --git a/tests/test_wfs.py b/tests/test_wfs.py index d1f0734..2708da0 100644 --- a/tests/test_wfs.py +++ b/tests/test_wfs.py @@ -60,9 +60,7 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): k_radius=(np.min(shape) - 1) // 2, phase_steps=phase_steps, ) - N = ( - alg.phase_patterns[0].shape[2] + alg.phase_patterns[1].shape[2] - ) # number of input modes + N = alg.phase_patterns[0].shape[2] + alg.phase_patterns[1].shape[2] # number of input modes alg_fidelity = 1.0 # Fourier is accurate for any N signal = 1 / 2 # for estimating SNR. @@ -83,10 +81,7 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): sim.slm.set_phases(-np.angle(result.t[:, :, b])) I_opt[b] = feedback.read()[b] t_correlation += abs(np.vdot(result.t[:, :, b], sim.t[:, :, b])) ** 2 - t_norm += abs( - np.vdot(result.t[:, :, b], result.t[:, :, b]) - * np.vdot(sim.t[:, :, b], sim.t[:, :, b]) - ) + t_norm += abs(np.vdot(result.t[:, :, b], result.t[:, :, b]) * np.vdot(sim.t[:, :, b], sim.t[:, :, b])) t_correlation /= t_norm # a correlation of 1 means optimal reconstruction of the N modulated modes, which may be less than the total number of inputs in the transmission matrix @@ -96,14 +91,10 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): # the fidelity of the transmission matrix reconstruction theoretical_noise_fidelity = signal / (signal + noise**2 / phase_steps) enhancement = I_opt.mean() / I_0 - theoretical_enhancement = ( - np.pi / 4 * theoretical_noise_fidelity * alg_fidelity * (N - 1) + 1 - ) + theoretical_enhancement = np.pi / 4 * theoretical_noise_fidelity * alg_fidelity * (N - 1) + 1 estimated_enhancement = result.estimated_enhancement.mean() * alg_fidelity theoretical_t_correlation = theoretical_noise_fidelity * alg_fidelity - estimated_t_correlation = ( - result.fidelity_noise * result.fidelity_calibration * alg_fidelity - ) + estimated_t_correlation = result.fidelity_noise * result.fidelity_calibration * alg_fidelity tolerance = 2.0 / np.sqrt(M) print( f"\nenhancement: \ttheoretical= {theoretical_enhancement},\testimated={estimated_enhancement},\tactual: {enhancement}" @@ -111,9 +102,7 @@ def test_multi_target_algorithms(shape, noise: float, algorithm: str): print( f"t-matrix fidelity:\ttheoretical = {theoretical_t_correlation},\testimated = {estimated_t_correlation},\tactual = {t_correlation}" ) - print( - f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}" - ) + print(f"noise fidelity: \ttheoretical = {theoretical_noise_fidelity},\testimated = {result.fidelity_noise}") print(f"comparing at relative tolerance: {tolerance}") assert np.allclose( @@ -160,9 +149,7 @@ def test_fourier2(): slm_shape = (1000, 1000) aberrations = skimage.data.camera() * ((2 * np.pi) / 255.0) sim = SimulatedWFS(aberrations=aberrations) - alg = FourierDualReference( - feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3 - ) + alg = FourierDualReference(feedback=sim, slm=sim.slm, slm_shape=slm_shape, k_radius=7.5, phase_steps=3) controller = WFSController(alg) controller.wavefront = WFSController.State.SHAPED_WAVEFRONT scaled_aberration = zoom(aberrations, np.array(slm_shape) / aberrations.shape) @@ -172,9 +159,7 @@ def test_fourier2(): @pytest.mark.skip("Not implemented") def test_fourier_microscope(): aberration_phase = skimage.data.camera() * ((2 * np.pi) / 255.0) + np.pi - aberration = StaticSource( - aberration_phase, pixel_size=2.0 / np.array(aberration_phase.shape) - ) + aberration = StaticSource(aberration_phase, pixel_size=2.0 / np.array(aberration_phase.shape)) img = np.zeros((1000, 1000), dtype=np.int16) signal_location = (250, 250) img[signal_location] = 100 @@ -192,9 +177,7 @@ def test_fourier_microscope(): ) cam = sim.get_camera(analog_max=100) roi_detector = SingleRoi(cam, pos=(250, 250)) # Only measure that specific point - alg = FourierDualReference( - feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, phase_steps=3 - ) + alg = FourierDualReference(feedback=roi_detector, slm=slm, slm_shape=slm_shape, k_radius=1.5, phase_steps=3) controller = WFSController(alg) controller.wavefront = WFSController.State.FLAT_WAVEFRONT before = roi_detector.read() @@ -202,12 +185,8 @@ def test_fourier_microscope(): after = roi_detector.read() # imshow(controller._optimized_wavefront) print(after / before) - scaled_aberration = zoom( - aberration_phase, np.array(slm_shape) / aberration_phase.shape - ) - assert_enhancement( - slm, roi_detector, controller._result, np.exp(1j * scaled_aberration) - ) + scaled_aberration = zoom(aberration_phase, np.array(slm_shape) / aberration_phase.shape) + assert_enhancement(slm, roi_detector, controller._result, np.exp(1j * scaled_aberration)) def test_fourier_correction_field(): @@ -226,9 +205,7 @@ def test_fourier_correction_field(): t = alg.execute().t t_correct = np.exp(1j * aberrations) - correlation = np.vdot(t, t_correct) / np.sqrt( - np.vdot(t, t) * np.vdot(t_correct, t_correct) - ) + correlation = np.vdot(t, t_correct) / np.sqrt(np.vdot(t, t) * np.vdot(t_correct, t_correct)) # TODO: integrate with other test cases, duplication assert abs(correlation) > 0.75 @@ -300,9 +277,7 @@ def test_flat_wf_response_fourier(optimized_reference, step): # test the optimized wavefront by checking if it has irregularities. measured_aberrations = np.squeeze(np.angle(t)) measured_aberrations += aberrations[0, 0] - measured_aberrations[0, 0] - assert np.allclose( - measured_aberrations, aberrations, atol=0.02 - ) # The measured wavefront is not flat. + assert np.allclose(measured_aberrations, aberrations, atol=0.02) # The measured wavefront is not flat. def test_flat_wf_response_ssa(): @@ -320,9 +295,7 @@ def test_flat_wf_response_ssa(): # Assert that the standard deviation of the optimized wavefront is below the threshold, # indicating that it is effectively flat - assert ( - np.std(optimised_wf) < 0.001 - ), f"Response flat wavefront not flat, std: {np.std(optimised_wf)}" + assert np.std(optimised_wf) < 0.001, f"Response flat wavefront not flat, std: {np.std(optimised_wf)}" def test_multidimensional_feedback_ssa(): @@ -412,7 +385,7 @@ def test_dual_reference_ortho_split(basis_str: str, shape): if basis_str == "plane_wave": # Create a full plane wave basis for one half of the SLM. modes = np.fft.fft2(np.eye(N).reshape(modes_shape), axes=(0, 1)) / np.sqrt(N) - elif basis_str == 'hadamard': + elif basis_str == "hadamard": modes = hadamard(N).reshape(modes_shape) / np.sqrt(N) else: raise f'Unknown type of basis "{basis_str}".' @@ -435,7 +408,7 @@ def test_dual_reference_ortho_split(basis_str: str, shape): plt.title(f"m={m}") plt.xticks([]) plt.yticks([]) - plt.suptitle('Basis') + plt.suptitle("Basis") plt.pause(0.01) # Create aberrations @@ -445,7 +418,7 @@ def test_dual_reference_ortho_split(basis_str: str, shape): feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), - amplitude='uniform', + amplitude="uniform", group_mask=mask, iterations=4, ) @@ -457,8 +430,8 @@ def test_dual_reference_ortho_split(basis_str: str, shape): for m in range(N): plt.subplot(*modes_shape[0:2], m + 1) plot_field(alg.cobasis[0][:, :, m]) - plt.title(f'{m}') - plt.suptitle('Cobasis') + plt.title(f"{m}") + plt.suptitle("Cobasis") plt.pause(0.01) plt.figure() @@ -495,7 +468,7 @@ def test_dual_reference_non_ortho_split(): N1 = 6 N2 = 3 M = N1 * N2 - mode_set_half = np.exp(2j*np.pi/3 * np.eye(M).reshape((N1, N2, M))) / np.sqrt(M) + mode_set_half = np.exp(2j * np.pi / 3 * np.eye(M).reshape((N1, N2, M))) / np.sqrt(M) mode_set = np.concatenate((mode_set_half, np.zeros(shape=(N1, N2, M))), axis=1) phases_set = np.angle(mode_set) mask = np.concatenate((np.zeros((N1, N2)), np.ones((N1, N2))), axis=1) @@ -508,7 +481,7 @@ def test_dual_reference_non_ortho_split(): for m in range(M): plt.subplot(N2, N1, m + 1) plot_field(mode_set[:, :, m]) - plt.title(f'm={m}') + plt.title(f"m={m}") plt.xticks([]) plt.yticks([]) plt.pause(0.01) @@ -517,11 +490,7 @@ def test_dual_reference_non_ortho_split(): # Create aberrations x = np.linspace(-1, 1, 1 * N1).reshape((1, -1)) y = np.linspace(-1, 1, 1 * N1).reshape((-1, 1)) - aberrations = ( - np.sin(0.8 * np.pi * x) - * np.cos(1.3 * np.pi * y) - * (0.8 * np.pi + 0.4 * x + 0.4 * y) - ) % (2 * np.pi) + aberrations = (np.sin(0.8 * np.pi * x) * np.cos(1.3 * np.pi * y) * (0.8 * np.pi + 0.4 * x + 0.4 * y)) % (2 * np.pi) aberrations[0:1, :] = 0 aberrations[:, 0:2] = 0 @@ -531,7 +500,7 @@ def test_dual_reference_non_ortho_split(): feedback=sim, slm=sim.slm, phase_patterns=(phases_set, np.flip(phases_set, axis=1)), - amplitude='uniform', + amplitude="uniform", group_mask=mask, phase_steps=4, iterations=4, @@ -547,23 +516,23 @@ def test_dual_reference_non_ortho_split(): for m in range(M): plt.subplot(N2, N1, m + 1) plot_field(alg.cobasis[0][:, :, m], scale=2) - plt.title(f'{m}') - plt.suptitle('Cobasis') + plt.title(f"{m}") + plt.suptitle("Cobasis") plt.pause(0.01) plt.figure() plt.imshow(abs(alg.gram), vmin=0, vmax=1) - plt.title('Gram matrix abs values') + plt.title("Gram matrix abs values") plt.colorbar() plt.figure() plt.subplot(1, 2, 1) plot_field(aberration_field) - plt.title('Aberrations') + plt.title("Aberrations") plt.subplot(1, 2, 2) plot_field(t_field) - plt.title('t') + plt.title("t") plt.show() assert np.abs(field_correlation(aberration_field, t_field)) > 0.999